diff --git a/.githooks/commit-msg/commit-msg b/.githooks/commit-msg/commit-msg
deleted file mode 100755
index 0e6d69b4..00000000
--- a/.githooks/commit-msg/commit-msg
+++ /dev/null
@@ -1,52 +0,0 @@
-#!/bin/sh
-#
-# Add a specific emoji to the end of the first line in every commit message
-# based on the conventional commits keyword.
-
-if [ ! -f "$1" ] || grep -q "fixup!" "$1"; then
- # Exit if we didn't get a target file for some reason
- # or we have a fixup commit
- exit 0
-fi
-
-KEYWORD=$(head -n 1 "$1" | awk '{print $1}' | sed -e 's/://')
-
-case $KEYWORD in
- "feat"|"feat("*)
- EMOJI=":sparkles:"
- ;;
- "fix"|"fix("*)
- EMOJI=":bug:"
- ;;
- "docs"|"docs("*)
- EMOJI=":books:"
- ;;
- "style"|"style("*)
- EMOJI=":gem:"
- ;;
- "refactor"|"refactor("*)
- EMOJI=":hammer:"
- ;;
- "perf"|"perf("*)
- EMOJI=":rocket:"
- ;;
- "test"|"test("*)
- EMOJI=":rotating_light:"
- ;;
- "build"|"build("*)
- EMOJI=":package:"
- ;;
- "ci"|"ci("*)
- EMOJI=":construction_worker:"
- ;;
- "chore"|"chore("*)
- EMOJI=":wrench:"
- ;;
- *)
- EMOJI=""
- ;;
-esac
-
-MESSAGE=$(sed -E "1s/(.*)/\\1 $EMOJI/" <"$1")
-
-echo "$MESSAGE" >"$1"
diff --git a/.githooks/pre-commit/pre-commit b/.githooks/pre-commit/pre-commit
deleted file mode 100755
index 6a8c8ac1..00000000
--- a/.githooks/pre-commit/pre-commit
+++ /dev/null
@@ -1,56 +0,0 @@
-#!/bin/bash
-
-REPO_ROOT=$(git rev-parse --show-toplevel)
-source "$REPO_ROOT"/commons/shell/colors.sh 2>/dev/null || true
-
-branch=$(git rev-parse --abbrev-ref HEAD)
-
-if [[ $branch == "main" || $branch == "develop" || $branch == release/* ]]; then
- echo "${bold:-}You can't commit directly to protected branches${normal:-}"
- exit 1
-fi
-
-# Check license headers in source files
-if [ -x "$REPO_ROOT/scripts/check-license-header.sh" ]; then
- "$REPO_ROOT/scripts/check-license-header.sh" || exit 1
-fi
-
-commit_msg_type_regex='feature|fix|refactor|style|test|docs|build'
-commit_msg_scope_regex='.{1,20}'
-commit_msg_description_regex='.{1,100}'
-commit_msg_regex="^(${commit_msg_type_regex})(\(${commit_msg_scope_regex}\))?: (${commit_msg_description_regex})\$"
-merge_msg_regex="^Merge branch '.+'\$"
-
-zero_commit="0000000000000000000000000000000000000000"
-
-# Do not traverse over commits that are already in the repository
-excludeExisting="--not --all"
-
-error=""
-while read oldrev newrev refname; do
- # branch or tag get deleted
- if [ "$newrev" = "$zero_commit" ]; then
- continue
- fi
-
- # Check for new branch or tag
- if [ "$oldrev" = "$zero_commit" ]; then
- rev_span=$(git rev-list $newrev $excludeExisting)
- else
- rev_span=$(git rev-list $oldrev..$newrev $excludeExisting)
- fi
-
- for commit in $rev_span; do
- commit_msg_header=$(git show -s --format=%s $commit)
- if ! [[ "$commit_msg_header" =~ (${commit_msg_regex})|(${merge_msg_regex}) ]]; then
- echo "$commit" >&2
- echo "ERROR: Invalid commit message format" >&2
- echo "$commit_msg_header" >&2
- error="true"
- fi
- done
-done
-
-if [ -n "$error" ]; then
- exit 1
-fi
\ No newline at end of file
diff --git a/.githooks/pre-push/pre-push b/.githooks/pre-push/pre-push
deleted file mode 100755
index d33494ac..00000000
--- a/.githooks/pre-push/pre-push
+++ /dev/null
@@ -1,17 +0,0 @@
-#!/bin/bash
-
-source "$PWD"/pkg/shell/colors.sh
-source "$PWD"/pkg/shell/ascii.sh
-
-while read local_ref local_sha remote_ref remote_sha; do
- if [[ "$local_ref" =~ ^refs/heads/ ]]; then
- branch_name=$(echo "$local_ref" | sed 's|^refs/heads/||')
-
- if [[ ! "$branch_name" =~ ^(feature|fix|hotfix|docs|refactor|build|test)/.*$ ]]; then
- echo "${bold}Branch names must start with 'feature/', 'fix/', 'refactor/', 'docs/', 'test/' or 'hotfix/' followed by either a task id or feature name."
- exit 1
- fi
- fi
-done
-
-exit 0
diff --git a/.githooks/pre-receive/pre-receive b/.githooks/pre-receive/pre-receive
deleted file mode 100644
index 6e1aa30d..00000000
--- a/.githooks/pre-receive/pre-receive
+++ /dev/null
@@ -1,21 +0,0 @@
-#!/usr/bin/env bash
-
-zero_commit="0000000000000000000000000000000000000000"
-
-while read oldrev newrev refname; do
-
- if [[ $oldrev == $zero_commit ]]; then
- continue
- fi
-
- if [[ $refname == "refs/heads/main" && $newrev != $zero_commit ]]; then
- branch_name=$(basename $refname)
-
- if [[ $branch_name == release/* ]]; then
- continue
- else
- echo "Error: You can only merge branches that start with 'release/' into the main branch."
- exit 1
- fi
- fi
-done
\ No newline at end of file
diff --git a/.github/workflows/go-combined-analysis.yml b/.github/workflows/go-combined-analysis.yml
index 6f374b0c..f0990f2a 100644
--- a/.github/workflows/go-combined-analysis.yml
+++ b/.github/workflows/go-combined-analysis.yml
@@ -38,7 +38,7 @@ jobs:
lerian_ci_cd_user_email: ${{ secrets.LERIAN_CI_CD_USER_EMAIL }}
go_version: '1.25'
github_token: ${{ secrets.GITHUB_TOKEN }}
- golangci_lint_version: 'v2.4.0'
+ golangci_lint_version: 'v2.11.2'
GoSec:
name: Run GoSec to SDK
diff --git a/.gitignore b/.gitignore
index 83dcc5d4..bcfa8e5d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,4 @@
.idea/*
-CLAUDE.md
.DS_Store
.claude/
.mcp.json
@@ -12,4 +11,8 @@ coverage.html
*_coverage.html
# Security scan reports
-gosec-report.sarif
\ No newline at end of file
+gosec-report.sarif
+
+docs/codereview/
+.codegraph/
+vendor/
diff --git a/.golangci.yml b/.golangci.yml
index 66b670c1..fe97dce5 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -3,6 +3,7 @@ run:
tests: false
linters:
enable:
+ # --- Existing linters ---
- bodyclose
- depguard
- dogsled
@@ -20,18 +21,56 @@ linters:
- reassign
- revive
- staticcheck
- - thelper
- - tparallel
- unconvert
- unparam
- usestdlibvars
- wastedassign
- wsl_v5
+
+ # --- Tier 1: Safety & Correctness ---
+ - errorlint # type assertions on errors, missing %w
+ - exhaustive # non-exhaustive enum switches
+ - fatcontext # context growing in loop
+ - forcetypeassert # unchecked type assertions (sync.Map etc.)
+ - gosec # security issues (math/rand for jitter etc.)
+ - nilnil # return nil, nil ambiguity
+ - noctx # net.Listen/exec.Command without context
+
+ # --- Tier 2: Code Quality & Modernization ---
+ - goconst # repeated string literals
+ - gocritic # if-else→switch, deprecated comments
+ - inamedparam # unnamed interface params
+ - intrange # integer range loops
+ - mirror # allocation savings (bytes.Equal etc.)
+ - modernize # Go modernization suggestions
+ - perfsprint # fmt.Errorf→errors.New where applicable
+
+ # --- Tier 3: Zero-Issue Guards ---
+ - asasalint # variadic any argument passing
+ - copyloopvar # loop variable capture prevention
+ - durationcheck # time.Duration math bugs
+ - exptostd # x/exp to stdlib migration
+ - gocheckcompilerdirectives # malformed //go: comments
+ - makezero # make with non-zero length passed to append
+ - musttag # struct tag validation for marshaling
+ - nilnesserr # subtle nil error patterns
+ - recvcheck # receiver consistency
+ - rowserrcheck # SQL rows.Err() checks
+ - spancheck # OTEL span lifecycle
+ - sqlclosecheck # SQL resource close
+ - testifylint # testify assertion patterns
+
settings:
- wsl_v5:
- allow-first-in-block: true
- allow-whole-block: false
- branch-max-lines: 2
+ # --- New settings ---
+ exhaustive:
+ default-signifies-exhaustive: true
+
+ goconst:
+ min-len: 3
+ min-occurrences: 3
+ ignore-tests: true
+
+ # --- Existing settings (unchanged) ---
depguard:
rules:
main:
@@ -60,6 +99,11 @@ linters:
- name: use-any
severity: warning
disabled: false
+ wsl_v5:
+ allow-first-in-block: true
+ allow-whole-block: false
+ branch-max-lines: 2
+
exclusions:
generated: lax
rules:
diff --git a/.goreleaser.yml b/.goreleaser.yml
index f01929fe..4f9ca9b1 100644
--- a/.goreleaser.yml
+++ b/.goreleaser.yml
@@ -1,42 +1,10 @@
version: 2
-builds:
- - id: "auth"
- env:
- - CGO_ENABLED=0
- main: ./cmd
- binary: auth
-
- goos:
- - linux
- - windows
- - darwin
- - freebsd
-
- goarch:
- - "386"
- - amd64
- - arm
- - ppc64
+# lib-commons/v4 is a Go library (no binary to build).
+# GoReleaser is used only for changelog generation and GitHub release creation.
- goarm:
- - "7"
-
-archives:
- - format: zip
-
-nfpms:
- - id: packages
- license: Apache-2.0 license
- maintainer: Lerian Studio Technologies
- package_name: auth
- homepage: https://github.com/LerianStudio/auth
- bindir: /usr/local/bin
- formats:
- - apk
- - deb
- - rpm
- - archlinux
+builds:
+ - skip: true
changelog:
sort: asc
diff --git a/AGENTS.md b/AGENTS.md
new file mode 100644
index 00000000..c4a35fa2
--- /dev/null
+++ b/AGENTS.md
@@ -0,0 +1,246 @@
+# AGENTS
+
+This file provides repository-specific guidance for coding agents working on `lib-commons`.
+
+## Project snapshot
+
+- Module: `github.com/LerianStudio/lib-commons/v4`
+- Language: Go
+- Go version: `1.25.7` (see `go.mod`)
+- Current API generation: v4
+
+## Primary objective for changes
+
+- Preserve v4 public API contracts unless a task explicitly asks for breaking changes.
+- Prefer explicit error returns over panic paths in production code.
+- Keep behavior nil-safe and concurrency-safe by default.
+
+## Repository shape
+
+Root:
+- `commons/`: root shared helpers (`app`, `context`, `errors`, utilities, time, string, os)
+
+Observability and logging:
+- `commons/opentelemetry`: telemetry bootstrap, propagation, redaction, span helpers
+- `commons/opentelemetry/metrics`: metric factory + fluent builders (Counter, Gauge, Histogram)
+- `commons/log`: logging abstraction (`Logger` interface), typed `Field` constructors, log-injection prevention, sanitizer
+- `commons/zap`: zap adapter for `commons/log` with OTEL bridge support
+
+Data and messaging:
+- `commons/postgres`: Postgres connector with `dbresolver`, migrations, OTEL spans, backoff-based lazy-connect
+- `commons/mongo`: MongoDB connector with functional options, URI builder, index helpers, OTEL spans
+- `commons/redis`: Redis connector with topology-based config (standalone/sentinel/cluster), GCP IAM auth, distributed locking (Redsync), backoff-based reconnect
+- `commons/rabbitmq`: AMQP connection/channel/health helpers with context-aware methods
+
+HTTP and server:
+- `commons/net/http`: Fiber HTTP helpers (response, error rendering, cursor/offset/sort pagination, validation, SSRF-protected reverse proxy, CORS, basic auth, telemetry middleware, health checks, access logging)
+- `commons/net/http/ratelimit`: Redis-backed rate limit storage for Fiber
+- `commons/server`: `ServerManager`-based graceful shutdown and lifecycle helpers
+
+Resilience and safety:
+- `commons/circuitbreaker`: circuit breaker manager with preset configs and health checker
+- `commons/backoff`: exponential backoff with jitter and context-aware sleep
+- `commons/runtime`: panic recovery, panic metrics, safe goroutine wrappers, error reporter, production mode
+- `commons/assert`: production-safe assertions with telemetry integration and domain predicates
+- `commons/safe`: panic-free math/regex/slice operations with error returns
+- `commons/security`: sensitive field detection and handling
+- `commons/errgroup`: goroutine coordination with panic recovery
+
+Domain and support:
+- `commons/transaction`: intent-based transaction planning, balance eligibility validation, posting flow
+- `commons/crypto`: hashing and symmetric encryption with credential-safe `fmt` output
+- `commons/jwt`: HMAC-based JWT signing, verification, and time-claim validation
+- `commons/license`: license validation and enforcement with functional options
+- `commons/pointers`: pointer conversion helpers
+- `commons/cron`: cron expression parsing and scheduling
+- `commons/constants`: shared constants (headers, errors, pagination, transactions, metadata, datasource status, OTEL attributes, obfuscation)
+
+Build and shell:
+- `commons/shell/`: Makefile include helpers (`makefile_colors.mk`, `makefile_utils.mk`), shell scripts, ASCII art
+
+## API invariants to respect
+
+### Telemetry (`commons/opentelemetry`)
+
+- Initialization is explicit with `opentelemetry.NewTelemetry(cfg TelemetryConfig) (*Telemetry, error)`.
+- Global OTEL providers are opt-in via `(*Telemetry).ApplyGlobals()`.
+- `(*Telemetry).Tracer(name) (trace.Tracer, error)` and `(*Telemetry).Meter(name) (metric.Meter, error)` for named providers.
+- Shutdown via `ShutdownTelemetry()` or `ShutdownTelemetryWithContext(ctx) error`.
+- `TelemetryConfig` includes `InsecureExporter`, `Propagator`, and `Redactor` fields.
+- Redaction uses `Redactor` with `RedactionRule` patterns; `NewDefaultRedactor()` and `NewRedactor(rules, mask)`. Old `FieldObfuscator` interface is removed.
+- `RedactingAttrBagSpanProcessor` redacts sensitive span attributes using a `Redactor`.
+
+### Metrics (`commons/opentelemetry/metrics`)
+
+- Metric factory/builder operations return errors and should not be silently ignored.
+- Supports Counter, Histogram, and Gauge instrument types.
+- `NewMetricsFactory(meter, logger) (*MetricsFactory, error)`.
+- `NewNopFactory() *MetricsFactory` for tests / disabled metrics.
+- Builder pattern: `.WithLabels(map)` or `.WithAttributes(attrs...)` then `.Add()` / `.Set()` / `.Record()`.
+- Convenience recorders: `RecordAccountCreated`, `RecordTransactionProcessed`, etc. (no more org/ledger positional args).
+
+### Logging (`commons/log`)
+
+- `Logger` interface: 5 methods -- `Log(ctx, level, msg, fields...)`, `With(fields...)`, `WithGroup(name)`, `Enabled(level)`, `Sync(ctx)`.
+- Level constants: `LevelError` (0), `LevelWarn` (1), `LevelInfo` (2), `LevelDebug` (3).
+- Field constructors: `String()`, `Int()`, `Bool()`, `Err()`, `Any()`.
+- `NewNop() Logger` for test/disabled logging.
+- `GoLogger` provides a stdlib-based implementation with CWE-117 log-injection prevention.
+- Sanitizer: `SafeError()` and `SanitizeExternalResponse()`.
+
+### Zap adapter (`commons/zap`)
+
+- `New(cfg Config) (*Logger, error)` for construction.
+- `Logger` implements `log.Logger` and also exposes `Raw() *zap.Logger`, `Level() zap.AtomicLevel`.
+- Direct zap convenience: `Debug()`, `Info()`, `Warn()`, `Error()`, `WithZapFields()`.
+- `Config` has `Environment` (typed string), `Level`, `OTelLibraryName` fields.
+- Field constructors: `Any()`, `String()`, `Int()`, `Bool()`, `Duration()`, `ErrorField()`.
+
+### HTTP helpers (`commons/net/http`)
+
+- Response: `Respond`, `RespondStatus`, `RespondError`, `RenderError`, `FiberErrorHandler`. Individual status helpers (BadRequestError, etc.) are removed.
+- Health: `Ping` (returns `"pong"`), `HealthWithDependencies(deps...)` with AND semantics (both circuit breaker and health check must pass).
+- Reverse proxy: `ServeReverseProxy(target, policy, res, req) error` with `ReverseProxyPolicy` for SSRF protection.
+- Pagination: offset-based (`ParsePagination`), opaque cursor (`ParseOpaqueCursorPagination`), timestamp cursor, and sort cursor APIs. All encode functions return errors.
+- Validation: `ParseBodyAndValidate`, `ValidateStruct`, `GetValidator`, `ValidateSortDirection`, `ValidateLimit`, `ValidateQueryParamLength`.
+- Context/ownership: `ParseAndVerifyTenantScopedID`, `ParseAndVerifyResourceScopedID` with `TenantOwnershipVerifier` and `ResourceOwnershipVerifier` func types.
+- Middleware: `WithHTTPLogging`, `WithGrpcLogging`, `WithCORS`, `AllowFullOptionsWithCORS`, `WithBasicAuth`, `NewTelemetryMiddleware`.
+- `ErrorResponse` has `Code int` (not string), `Title`, `Message`; implements `error`.
+
+### Server lifecycle (`commons/server`)
+
+- `ServerManager` exclusively; `GracefulShutdown` is removed.
+- `NewServerManager(licenseClient, telemetry, logger) *ServerManager`.
+- Chainable config: `WithHTTPServer`, `WithGRPCServer`, `WithShutdownChannel`, `WithShutdownTimeout`, `WithShutdownHook`.
+- `StartWithGracefulShutdown()` (exits on error) or `StartWithGracefulShutdownWithError() error` (returns error).
+- `ServersStarted() <-chan struct{}` for test coordination.
+
+### Circuit breaker (`commons/circuitbreaker`)
+
+- `Manager` interface with `NewManager(logger, opts...) (Manager, error)` constructor.
+- `GetOrCreate` returns `(CircuitBreaker, error)` and validates config.
+- Preset configs: `DefaultConfig()`, `AggressiveConfig()`, `ConservativeConfig()`, `HTTPServiceConfig()`, `DatabaseConfig()`.
+- Metrics via `WithMetricsFactory` option.
+- `NewHealthCheckerWithValidation(manager, interval, timeout, logger) (HealthChecker, error)`.
+
+### Assertions (`commons/assert`)
+
+- `New(ctx, logger, component, operation) *Asserter` and return errors instead of panicking.
+- Methods: `That()`, `NotNil()`, `NotEmpty()`, `NoError()`, `Never()`, `Halt()`.
+- Metrics: `InitAssertionMetrics(factory)`, `GetAssertionMetrics()`, `ResetAssertionMetrics()`.
+- Predicates library (`predicates.go`): `Positive`, `NonNegative`, `InRange`, `ValidUUID`, `ValidAmount`, `PositiveDecimal`, `NonNegativeDecimal`, `ValidPort`, `ValidSSLMode`, `DebitsEqualCredits`, `TransactionCanTransitionTo`, `BalanceSufficientForRelease`, and more.
+
+### Runtime (`commons/runtime`)
+
+- Recovery: `RecoverAndLog`, `RecoverAndCrash`, `RecoverWithPolicy` (and `*WithContext` variants).
+- Safe goroutines: `SafeGo`, `SafeGoWithContext`, `SafeGoWithContextAndComponent` with `PanicPolicy` (KeepRunning/CrashProcess).
+- Panic metrics: `InitPanicMetrics(factory[, logger])`, `GetPanicMetrics()`, `ResetPanicMetrics()`.
+- Span recording: `RecordPanicToSpan`, `RecordPanicToSpanWithComponent`.
+- Error reporter: `SetErrorReporter(reporter)`, `GetErrorReporter()`.
+- Production mode: `SetProductionMode(bool)`, `IsProductionMode() bool`.
+
+### Safe operations (`commons/safe`)
+
+- Math: `Divide`, `DivideRound`, `DivideOrZero`, `DivideOrDefault`, `Percentage`, `PercentageOrZero`, `DivideFloat64`, `DivideFloat64OrZero`.
+- Regex: `Compile`, `CompilePOSIX`, `MatchString`, `FindString`, `ClearCache` (all with caching).
+- Slices: `First[T]`, `Last[T]`, `At[T]` with error returns and `*OrDefault` variants.
+
+### JWT (`commons/jwt`)
+
+- `Parse(token, secret, allowedAlgs) (*Token, error)` -- signature verification only.
+- `ParseAndValidate(token, secret, allowedAlgs) (*Token, error)` -- signature + time claims.
+- `Sign(claims, secret, alg) (string, error)`.
+- `ValidateTimeClaims(claims)` and `ValidateTimeClaimsAt(claims, now)`.
+- `Token.SignatureValid` (bool) -- replaces v1 `Token.Valid`; clarifies signature-only scope.
+- Algorithms: `AlgHS256`, `AlgHS384`, `AlgHS512`.
+
+### Data connectors
+
+- **Postgres:** `New(cfg Config) (*Client, error)` with explicit `Config`; `Resolver(ctx)` replaces `GetDB()`. `Primary() (*sql.DB, error)` for raw access. Migrations via `NewMigrator(cfg)`.
+- **Mongo:** `NewClient(ctx, cfg, opts...) (*Client, error)`; methods `Client(ctx)`, `ResolveClient(ctx)`, `Database(ctx)`, `Ping(ctx)`, `Close(ctx)`, `EnsureIndexes(ctx, collection, indexes...)`.
+- **Redis:** `New(ctx, cfg) (*Client, error)` with topology-based `Config` (standalone/sentinel/cluster). `GetClient(ctx)`, `Close()`, `Status()`, `IsConnected()`, `LastRefreshError()`. `SetPackageLogger(logger)` for nil-receiver diagnostics.
+- **Redis locking:** `NewRedisLockManager(conn) (*RedisLockManager, error)` and `LockManager` interface. `LockHandle` for acquired locks. `DefaultLockOptions()`, `RateLimiterLockOptions()`.
+- **RabbitMQ:** `*Context()` variants of all lifecycle methods; `HealthCheck() (bool, error)`.
+
+### Other packages
+
+- **Backoff:** `ExponentialWithJitter()` and `WaitContext()`. Used by redis and postgres for retry rate-limiting.
+- **Errgroup:** `WithContext(ctx) (*Group, context.Context)`; `Go(fn)` with panic recovery; `SetLogger(logger)`.
+- **Crypto:** `Crypto` struct with `GenerateHash`, `InitializeCipher`, `Encrypt`, `Decrypt`. `String()` / `GoString()` redact credentials.
+- **License:** `New(opts...) *ManagerShutdown` with `WithLogger()` option. `SetHandler()`, `Terminate()`, `TerminateWithError()`, `TerminateSafe()`.
+- **Pointers:** `String()`, `Bool()`, `Time()`, `Int()`, `Int64()`, `Float64()`.
+- **Cron:** `Parse(expr) (Schedule, error)`; `Schedule.Next(t) (time.Time, error)`.
+- **Security:** `IsSensitiveField(name)`, `DefaultSensitiveFields()`, `DefaultSensitiveFieldsMap()`.
+- **Transaction:** `BuildIntentPlan()` + `ValidateBalanceEligibility()` + `ApplyPosting()` with typed `IntentPlan`, `Posting`, `LedgerTarget`. `ResolveOperation(pending, isSource, status) (Operation, error)`.
+- **Constants:** `SanitizeMetricLabel(value) string` for OTEL label safety.
+
+## Coding rules
+
+- Do not add `panic(...)` in production paths.
+- Do not swallow errors; return or handle with context.
+- Keep exported docs aligned with behavior.
+- Reuse existing package patterns before introducing new abstractions.
+- Avoid introducing high-cardinality telemetry labels by default.
+- Use the structured log interface (`Log(ctx, level, msg, fields...)`) -- do not add printf-style methods.
+
+## Testing and validation
+
+### Core commands
+
+- `make test` -- run unit tests (uses gotestsum if available)
+- `make test-unit` -- run unit tests excluding integration
+- `make test-integration` -- run integration tests with testcontainers (requires Docker)
+- `make test-all` -- run all tests (unit + integration)
+- `make ci` -- run the local fix + verify pipeline (`lint-fix`, `format`, `tidy`, `check-tests`, `sec`, `vet`, `test-unit`, `test-integration`)
+- `make lint` -- run lint checks (read-only)
+- `make lint-fix` -- auto-fix lint issues
+- `make build` -- build all packages
+- `make format` -- format code with gofmt
+- `make tidy` -- clean dependencies
+- `make vet` -- run `go vet` on all packages
+- `make sec` -- run security checks using gosec (`SARIF=1` for SARIF output)
+- `make clean` -- clean build artifacts
+
+### Coverage
+
+- `make coverage-unit` -- unit tests with coverage report (respects `.ignorecoverunit`)
+- `make coverage-integration` -- integration tests with coverage
+- `make coverage` -- run all coverage targets
+
+### Test flags
+
+- `LOW_RESOURCE=1` -- sets `-p=1 -parallel=1`, disables `-race` for constrained machines
+- `RETRY_ON_FAIL=1` -- retries failed tests once
+- `RUN=` -- filter integration tests by name pattern
+- `PKG=` -- filter to specific package(s)
+- `DISABLE_OSX_LINKER_WORKAROUND=1` -- disable macOS ld_classic workaround
+
+### Integration test conventions
+
+- Test files: `*_integration_test.go`
+- Test functions: `TestIntegration_`
+- Build tag: `integration`
+
+### Other
+
+- `make tools` -- install gotestsum
+- `make check-tests` -- verify test coverage for packages
+- `make setup-git-hooks` -- install git hooks
+- `make check-hooks` -- verify git hooks installation
+- `make check-envs` -- check hooks + environment file security
+- `make goreleaser` -- create release snapshot
+
+## Migration awareness
+
+- If a task touches renamed/removed v1 symbols, update `MIGRATION_MAP.md`.
+- If a task changes package-level behavior or API expectations, update `README.md`.
+
+## Project rules
+
+- Full coding standards, architecture patterns, and development guidelines are in [`docs/PROJECT_RULES.md`](docs/PROJECT_RULES.md).
+
+## Documentation policy
+
+- Keep docs factual and code-backed.
+- Avoid speculative roadmap text.
+- Prefer concise package-level examples that compile with current API names.
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3e4230b6..4d71fbd7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,1134 +1,3 @@
-## [2.5.0](https://github.com/LerianStudio/lib-commons/compare/v2.4.0...v2.5.0) (2025-11-07)
+# Changelog
-
-### Bug Fixes
-
-* improve SafeIntToUint32 function by using uint64 for overflow checks :bug: ([4340367](https://github.com/LerianStudio/lib-commons/commit/43403675c46dc513cbfa12102929de0387f026cd))
-
-## [2.4.0](https://github.com/LerianStudio/lib-commons/compare/v2.3.0...v2.4.0) (2025-10-30)
-
-
-### Features
-
-* **redis:** add RateLimiterLockOptions helper function ([6535d18](https://github.com/LerianStudio/lib-commons/commit/6535d18146a36eaf23584893b7ff4fdef0d6fe61))
-* **ratelimit:** add Redis-based rate limiting with global middleware support ([9a976c3](https://github.com/LerianStudio/lib-commons/commit/9a976c3267adc45f77482f68a3e1ebc65c6baa42))
-* **commons:** add SafeIntToUint32 utility with overflow protection and logging ([5a13d45](https://github.com/LerianStudio/lib-commons/commit/5a13d45f0a3cd2fafdb3debf99017bac473083f7))
-* add service unavailable error code and standardize rate limit responses ([f65af5a](https://github.com/LerianStudio/lib-commons/commit/f65af5a258b3d7659e3b5afc0854036d8ace14b5))
-* **circuitbreaker:** add state change notifications and immediate health checks ([2532b8b](https://github.com/LerianStudio/lib-commons/commit/2532b8b9605619b8b3a6f0f6e1ec0b3574de5516))
-* Adding datasource constants. ([5a04f8a](https://github.com/LerianStudio/lib-commons/commit/5a04f8a5eb139318b7b71c1fef9d966bfd296f50))
-* **circuitbreaker:** extend HealthChecker interface to include state change notifications ([9087254](https://github.com/LerianStudio/lib-commons/commit/90872540cf2aad78d642596652789747075e71c7))
-* **circuitbreaker:** implement circuit breaker package with health checks and state management ([d93b161](https://github.com/LerianStudio/lib-commons/commit/d93b1610c0cae3be263be4e684afc157c88e93b4))
-* **redis:** implement distributed locking with RedLock algorithm ([5ee1bdb](https://github.com/LerianStudio/lib-commons/commit/5ee1bdb96af56371309231323f4be7e09c98e6b5))
-* improve distributed locking and rate limiting reliability ([79dbad3](https://github.com/LerianStudio/lib-commons/commit/79dbad34e600d27a512c2f99104b91a77e6f0f3e))
-* update OperateBalances to include balance versioning :sparkles: ([3a75235](https://github.com/LerianStudio/lib-commons/commit/3a75235256893ea35ea94edfe84789a84b620b2f))
-
-
-### Bug Fixes
-
-* add nil check for circuit breaker state change listener registration ([55da00b](https://github.com/LerianStudio/lib-commons/commit/55da00b081dcc0251433dcb702b14e98486348cd))
-* add nil logger check and change warn to debug level in SafeIntToUint32 ([a72880c](https://github.com/LerianStudio/lib-commons/commit/a72880ca0525c05cf61802c0f976e7b872f85b51))
-* add panic recovery to circuit breaker state change listeners ([96fe07e](https://github.com/LerianStudio/lib-commons/commit/96fe07eff47627fde636fbf814b687cdab3ecac7))
-* **redis:** correct benchmark loop and test naming in rate limiter tests ([4622c78](https://github.com/LerianStudio/lib-commons/commit/4622c783412d81408697413d1e70d1ced6c6c3be))
-* **redis:** correct goroutine test assertions in distributed lock tests ([b9e6d70](https://github.com/LerianStudio/lib-commons/commit/b9e6d703de7893cec558bb673632559175e4604f))
-* update OperateBalances to handle unknown operations without changing balance version :bug: ([2f4369d](https://github.com/LerianStudio/lib-commons/commit/2f4369d1b73eaaf66bd2b9a430584c2f9a840ac4))
-
-## [2.4.0-beta.9](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.8...v2.4.0-beta.9) (2025-10-30)
-
-
-### Features
-
-* improve distributed locking and rate limiting reliability ([79dbad3](https://github.com/LerianStudio/lib-commons/commit/79dbad34e600d27a512c2f99104b91a77e6f0f3e))
-
-## [2.4.0-beta.8](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.7...v2.4.0-beta.8) (2025-10-29)
-
-
-### Bug Fixes
-
-* add panic recovery to circuit breaker state change listeners ([96fe07e](https://github.com/LerianStudio/lib-commons/commit/96fe07eff47627fde636fbf814b687cdab3ecac7))
-
-## [2.4.0-beta.7](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.6...v2.4.0-beta.7) (2025-10-27)
-
-
-### Features
-
-* **commons:** add SafeIntToUint32 utility with overflow protection and logging ([5a13d45](https://github.com/LerianStudio/lib-commons/commit/5a13d45f0a3cd2fafdb3debf99017bac473083f7))
-* **circuitbreaker:** add state change notifications and immediate health checks ([2532b8b](https://github.com/LerianStudio/lib-commons/commit/2532b8b9605619b8b3a6f0f6e1ec0b3574de5516))
-* **circuitbreaker:** extend HealthChecker interface to include state change notifications ([9087254](https://github.com/LerianStudio/lib-commons/commit/90872540cf2aad78d642596652789747075e71c7))
-
-
-### Bug Fixes
-
-* add nil check for circuit breaker state change listener registration ([55da00b](https://github.com/LerianStudio/lib-commons/commit/55da00b081dcc0251433dcb702b14e98486348cd))
-* add nil logger check and change warn to debug level in SafeIntToUint32 ([a72880c](https://github.com/LerianStudio/lib-commons/commit/a72880ca0525c05cf61802c0f976e7b872f85b51))
-
-## [2.4.0-beta.6](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.5...v2.4.0-beta.6) (2025-10-24)
-
-
-### Features
-
-* **circuitbreaker:** implement circuit breaker package with health checks and state management ([d93b161](https://github.com/LerianStudio/lib-commons/commit/d93b1610c0cae3be263be4e684afc157c88e93b4))
-
-## [2.4.0-beta.5](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.4...v2.4.0-beta.5) (2025-10-21)
-
-
-### Features
-
-* **redis:** add RateLimiterLockOptions helper function ([6535d18](https://github.com/LerianStudio/lib-commons/commit/6535d18146a36eaf23584893b7ff4fdef0d6fe61))
-* **redis:** implement distributed locking with RedLock algorithm ([5ee1bdb](https://github.com/LerianStudio/lib-commons/commit/5ee1bdb96af56371309231323f4be7e09c98e6b5))
-
-
-### Bug Fixes
-
-* **redis:** correct benchmark loop and test naming in rate limiter tests ([4622c78](https://github.com/LerianStudio/lib-commons/commit/4622c783412d81408697413d1e70d1ced6c6c3be))
-* **redis:** correct goroutine test assertions in distributed lock tests ([b9e6d70](https://github.com/LerianStudio/lib-commons/commit/b9e6d703de7893cec558bb673632559175e4604f))
-
-## [2.4.0-beta.4](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.3...v2.4.0-beta.4) (2025-10-17)
-
-
-### Features
-
-* add service unavailable error code and standardize rate limit responses ([f65af5a](https://github.com/LerianStudio/lib-commons/commit/f65af5a258b3d7659e3b5afc0854036d8ace14b5))
-
-## [2.4.0-beta.3](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.2...v2.4.0-beta.3) (2025-10-16)
-
-
-### Features
-
-* **ratelimit:** add Redis-based rate limiting with global middleware support ([9a976c3](https://github.com/LerianStudio/lib-commons/commit/9a976c3267adc45f77482f68a3e1ebc65c6baa42))
-
-## [2.4.0-beta.2](https://github.com/LerianStudio/lib-commons/compare/v2.4.0-beta.1...v2.4.0-beta.2) (2025-10-15)
-
-
-### Features
-
-* update OperateBalances to include balance versioning :sparkles: ([3a75235](https://github.com/LerianStudio/lib-commons/commit/3a75235256893ea35ea94edfe84789a84b620b2f))
-
-
-### Bug Fixes
-
-* update OperateBalances to handle unknown operations without changing balance version :bug: ([2f4369d](https://github.com/LerianStudio/lib-commons/commit/2f4369d1b73eaaf66bd2b9a430584c2f9a840ac4))
-
-## [2.4.0-beta.1](https://github.com/LerianStudio/lib-commons/compare/v2.3.0...v2.4.0-beta.1) (2025-10-14)
-
-
-### Features
-
-* Adding datasource constants. ([5a04f8a](https://github.com/LerianStudio/lib-commons/commit/5a04f8a5eb139318b7b71c1fef9d966bfd296f50))
-
-## [2.3.0](https://github.com/LerianStudio/lib-commons/compare/v2.2.0...v2.3.0) (2025-09-18)
-
-
-### Features
-
-* **rabbitmq:** add EnsureChannel method to manage RabbitMQ connection and channel lifecycle :sparkles: ([9e6ebf8](https://github.com/LerianStudio/lib-commons/commit/9e6ebf89c727e52290e83754ed89303557f6f69d))
-* add telemetry and logging to transaction validation and gRPC middleware ([0aabecc](https://github.com/LerianStudio/lib-commons/commit/0aabeccb0a7bb2f50dfc3cf9544cfe6b2dcddf91))
-* Adding the crypto package of encryption and decryption. ([f309c23](https://github.com/LerianStudio/lib-commons/commit/f309c233404a56ca1bd3f27e7a9a28bd839fac37))
-* Adding the crypto package of encryption and decryption. ([577b746](https://github.com/LerianStudio/lib-commons/commit/577b746c0dfad3dc863027bbe6f5508b194f7578))
-* **transaction:** implement balanceKey support in operations :sparkles: ([38ac489](https://github.com/LerianStudio/lib-commons/commit/38ac489a64c11810bf406d7a2141b4aed3ca6746))
-* **rabbitmq:** improve error logging in EnsureChannel method for connection and channel failures :sparkles: ([266febc](https://github.com/LerianStudio/lib-commons/commit/266febc427996da526abc7e50c53675b8abe2f18))
-* some adjusts; ([60b206a](https://github.com/LerianStudio/lib-commons/commit/60b206a8bf1c8a299648a5df09aea76191dbea0c))
-
-
-### Bug Fixes
-
-* add error handling for short ciphertext in Decrypt method :bug: ([bc73d51](https://github.com/LerianStudio/lib-commons/commit/bc73d510bb21e5cc18a450d616746d21fbf85a3d))
-* add nil check for uninitialized cipher in Decrypt method :bug: ([e1934a2](https://github.com/LerianStudio/lib-commons/commit/e1934a26e5e2b6012f3bfdcf4378f70f21ec659a))
-* add nil check for uninitialized cipher in Encrypt method :bug: ([207cae6](https://github.com/LerianStudio/lib-commons/commit/207cae617e34bcf9ece83b61fbfbac308b935b44))
-* Adjusting instance when telemetry is off. ([68504a7](https://github.com/LerianStudio/lib-commons/commit/68504a7080ce4f437a9f551ae4c259ed7c0daaa6))
-* ensure nil check for values in AttributesFromContext function :bug: ([38f8c77](https://github.com/LerianStudio/lib-commons/commit/38f8c7725f9e91eff04c79b69983497f9ea5c86c))
-* go.mod and go.sum; ([cda49e7](https://github.com/LerianStudio/lib-commons/commit/cda49e7e7d7a9b5da91155c43bdb9966826a7f4c))
-* initialize no-op providers in InitializeTelemetry when telemetry is disabled to prevent nil-pointer panics :bug: ([c40310d](https://github.com/LerianStudio/lib-commons/commit/c40310d90f06952877f815238e33cc382a4eafbd))
-* make lint ([ec9fc3a](https://github.com/LerianStudio/lib-commons/commit/ec9fc3ac4c39996b2e5ce308032f269380df32ee))
-* **otel:** reorder shutdown sequence to ensure proper telemetry export and add span attributes from request params id ([44fc4c9](https://github.com/LerianStudio/lib-commons/commit/44fc4c996e2f322244965bb31c79e069719a1e1f))
-* **cursor:** resolve first page prev_cursor bug and infinite loop issues; ([b0f8861](https://github.com/LerianStudio/lib-commons/commit/b0f8861c22521b6ec742a365560a439e28b866c4))
-* **cursor:** resolve pagination logic errors and add comprehensive UUID v7 tests ([2d48453](https://github.com/LerianStudio/lib-commons/commit/2d4845332e94b8225e781b267eec9f405519a7f6))
-* return TelemetryConfig in InitializeTelemetry when telemetry is disabled :bug: ([62bd90b](https://github.com/LerianStudio/lib-commons/commit/62bd90b525978ea2540746b367775143d39ca922))
-* **http:** use HasPrefix instead of Contains for route exclusion matching ([9891eac](https://github.com/LerianStudio/lib-commons/commit/9891eacbd75dfce11ba57ebf2a6f38144dc04505))
-
-## [2.3.0-beta.10](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.9...v2.3.0-beta.10) (2025-09-18)
-
-
-### Bug Fixes
-
-* add error handling for short ciphertext in Decrypt method :bug: ([bc73d51](https://github.com/LerianStudio/lib-commons/commit/bc73d510bb21e5cc18a450d616746d21fbf85a3d))
-* add nil check for uninitialized cipher in Decrypt method :bug: ([e1934a2](https://github.com/LerianStudio/lib-commons/commit/e1934a26e5e2b6012f3bfdcf4378f70f21ec659a))
-* add nil check for uninitialized cipher in Encrypt method :bug: ([207cae6](https://github.com/LerianStudio/lib-commons/commit/207cae617e34bcf9ece83b61fbfbac308b935b44))
-* ensure nil check for values in AttributesFromContext function :bug: ([38f8c77](https://github.com/LerianStudio/lib-commons/commit/38f8c7725f9e91eff04c79b69983497f9ea5c86c))
-* initialize no-op providers in InitializeTelemetry when telemetry is disabled to prevent nil-pointer panics :bug: ([c40310d](https://github.com/LerianStudio/lib-commons/commit/c40310d90f06952877f815238e33cc382a4eafbd))
-* return TelemetryConfig in InitializeTelemetry when telemetry is disabled :bug: ([62bd90b](https://github.com/LerianStudio/lib-commons/commit/62bd90b525978ea2540746b367775143d39ca922))
-
-## [2.3.0-beta.9](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.8...v2.3.0-beta.9) (2025-09-18)
-
-## [2.3.0-beta.8](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.7...v2.3.0-beta.8) (2025-09-15)
-
-
-### Features
-
-* **rabbitmq:** add EnsureChannel method to manage RabbitMQ connection and channel lifecycle :sparkles: ([9e6ebf8](https://github.com/LerianStudio/lib-commons/commit/9e6ebf89c727e52290e83754ed89303557f6f69d))
-* **rabbitmq:** improve error logging in EnsureChannel method for connection and channel failures :sparkles: ([266febc](https://github.com/LerianStudio/lib-commons/commit/266febc427996da526abc7e50c53675b8abe2f18))
-
-## [2.3.0-beta.7](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.6...v2.3.0-beta.7) (2025-09-10)
-
-
-### Features
-
-* **transaction:** implement balanceKey support in operations :sparkles: ([38ac489](https://github.com/LerianStudio/lib-commons/commit/38ac489a64c11810bf406d7a2141b4aed3ca6746))
-
-## [2.3.0-beta.6](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.5...v2.3.0-beta.6) (2025-08-21)
-
-
-### Features
-
-* some adjusts; ([60b206a](https://github.com/LerianStudio/lib-commons/commit/60b206a8bf1c8a299648a5df09aea76191dbea0c))
-
-
-### Bug Fixes
-
-* go.mod and go.sum; ([cda49e7](https://github.com/LerianStudio/lib-commons/commit/cda49e7e7d7a9b5da91155c43bdb9966826a7f4c))
-* make lint ([ec9fc3a](https://github.com/LerianStudio/lib-commons/commit/ec9fc3ac4c39996b2e5ce308032f269380df32ee))
-* **cursor:** resolve first page prev_cursor bug and infinite loop issues; ([b0f8861](https://github.com/LerianStudio/lib-commons/commit/b0f8861c22521b6ec742a365560a439e28b866c4))
-* **cursor:** resolve pagination logic errors and add comprehensive UUID v7 tests ([2d48453](https://github.com/LerianStudio/lib-commons/commit/2d4845332e94b8225e781b267eec9f405519a7f6))
-
-## [2.3.0-beta.5](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.4...v2.3.0-beta.5) (2025-08-20)
-
-## [2.3.0-beta.4](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.3...v2.3.0-beta.4) (2025-08-20)
-
-
-### Features
-
-* add telemetry and logging to transaction validation and gRPC middleware ([0aabecc](https://github.com/LerianStudio/lib-commons/commit/0aabeccb0a7bb2f50dfc3cf9544cfe6b2dcddf91))
-
-## [2.3.0-beta.3](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.2...v2.3.0-beta.3) (2025-08-19)
-
-
-### Bug Fixes
-
-* Adjusting instance when telemetry is off. ([68504a7](https://github.com/LerianStudio/lib-commons/commit/68504a7080ce4f437a9f551ae4c259ed7c0daaa6))
-
-## [2.3.0-beta.2](https://github.com/LerianStudio/lib-commons/compare/v2.3.0-beta.1...v2.3.0-beta.2) (2025-08-18)
-
-
-### Features
-
-* Adding the crypto package of encryption and decryption. ([f309c23](https://github.com/LerianStudio/lib-commons/commit/f309c233404a56ca1bd3f27e7a9a28bd839fac37))
-* Adding the crypto package of encryption and decryption. ([577b746](https://github.com/LerianStudio/lib-commons/commit/577b746c0dfad3dc863027bbe6f5508b194f7578))
-
-## [2.3.0-beta.1](https://github.com/LerianStudio/lib-commons/compare/v2.2.0...v2.3.0-beta.1) (2025-08-18)
-
-
-### Bug Fixes
-
-* **otel:** reorder shutdown sequence to ensure proper telemetry export and add span attributes from request params id ([44fc4c9](https://github.com/LerianStudio/lib-commons/commit/44fc4c996e2f322244965bb31c79e069719a1e1f))
-* **http:** use HasPrefix instead of Contains for route exclusion matching ([9891eac](https://github.com/LerianStudio/lib-commons/commit/9891eacbd75dfce11ba57ebf2a6f38144dc04505))
-
-## [2.2.0](https://github.com/LerianStudio/lib-commons/compare/v2.1.0...v2.2.0) (2025-08-08)
-
-
-### Features
-
-* add new field transaction date to be used to make past transactions; ([fcb4704](https://github.com/LerianStudio/lib-commons/commit/fcb47044c5b11d0da0eb53a75fc31f26ae6f7fb6))
-* add span events, UUID conversion and configurable log obfuscation ([d92bb13](https://github.com/LerianStudio/lib-commons/commit/d92bb13aabeb0b49b30a4ed9161182d73aab300f))
-* merge pull request [#182](https://github.com/LerianStudio/lib-commons/issues/182) from LerianStudio/feat/COMMONS-1155 ([931fdcb](https://github.com/LerianStudio/lib-commons/commit/931fdcb9c5cdeabf1602108db813855162b8e655))
-
-
-### Bug Fixes
-
-* go get -u ./... && make tidy; ([a18914f](https://github.com/LerianStudio/lib-commons/commit/a18914fd032c639bf06732ccbd0c66eabd89753d))
-* **otel:** add nil checks and remove unnecessary error handling in span methods ([3f9d468](https://github.com/LerianStudio/lib-commons/commit/3f9d46884dad366520eb1b95a5ee032a2992b959))
-
-## [2.2.0-beta.4](https://github.com/LerianStudio/lib-commons/compare/v2.2.0-beta.3...v2.2.0-beta.4) (2025-08-08)
-
-## [2.2.0-beta.3](https://github.com/LerianStudio/lib-commons/compare/v2.2.0-beta.2...v2.2.0-beta.3) (2025-08-08)
-
-## [2.2.0-beta.2](https://github.com/LerianStudio/lib-commons/compare/v2.2.0-beta.1...v2.2.0-beta.2) (2025-08-08)
-
-
-### Features
-
-* add span events, UUID conversion and configurable log obfuscation ([d92bb13](https://github.com/LerianStudio/lib-commons/commit/d92bb13aabeb0b49b30a4ed9161182d73aab300f))
-
-
-### Bug Fixes
-
-* **otel:** add nil checks and remove unnecessary error handling in span methods ([3f9d468](https://github.com/LerianStudio/lib-commons/commit/3f9d46884dad366520eb1b95a5ee032a2992b959))
-
-## [2.2.0-beta.1](https://github.com/LerianStudio/lib-commons/compare/v2.1.0...v2.2.0-beta.1) (2025-08-06)
-
-
-### Features
-
-* add new field transaction date to be used to make past transactions; ([fcb4704](https://github.com/LerianStudio/lib-commons/commit/fcb47044c5b11d0da0eb53a75fc31f26ae6f7fb6))
-* merge pull request [#182](https://github.com/LerianStudio/lib-commons/issues/182) from LerianStudio/feat/COMMONS-1155 ([931fdcb](https://github.com/LerianStudio/lib-commons/commit/931fdcb9c5cdeabf1602108db813855162b8e655))
-
-
-### Bug Fixes
-
-* go get -u ./... && make tidy; ([a18914f](https://github.com/LerianStudio/lib-commons/commit/a18914fd032c639bf06732ccbd0c66eabd89753d))
-
-## [2.1.0](https://github.com/LerianStudio/lib-commons/compare/v2.0.0...v2.1.0) (2025-08-01)
-
-
-### Bug Fixes
-
-* add UTF-8 sanitization for span attributes and error handling improvements ([e69dae8](https://github.com/LerianStudio/lib-commons/commit/e69dae8728c7c2ae669c96e102a811febc45de14))
-
-## [2.1.0-beta.2](https://github.com/LerianStudio/lib-commons/compare/v2.1.0-beta.1...v2.1.0-beta.2) (2025-08-01)
-
-## [2.1.0-beta.1](https://github.com/LerianStudio/lib-commons/compare/v2.0.0...v2.1.0-beta.1) (2025-08-01)
-
-
-### Bug Fixes
-
-* add UTF-8 sanitization for span attributes and error handling improvements ([e69dae8](https://github.com/LerianStudio/lib-commons/commit/e69dae8728c7c2ae669c96e102a811febc45de14))
-
-## [2.0.0](https://github.com/LerianStudio/lib-commons/compare/v1.18.0...v2.0.0) (2025-07-30)
-
-
-### ⚠ BREAKING CHANGES
-
-* change version and paths to v2
-
-### Features
-
-* **security:** add accesstoken and refreshtoken to sensitive fields list ([9e884c7](https://github.com/LerianStudio/lib-commons/commit/9e884c784e686c15354196fa09526371570f01e1))
-* **security:** add accesstoken and refreshtoken to sensitive fields ([ede9b9b](https://github.com/LerianStudio/lib-commons/commit/ede9b9ba17b7f98ffe53a927d42cfb7b0f867f29))
-* **telemetry:** add metrics factory with fluent API for counter, gauge and histogram metrics ([517352b](https://github.com/LerianStudio/lib-commons/commit/517352b95111de59613d9b2f15429c751302b779))
-* **telemetry:** add request ID to HTTP span attributes ([3c60b29](https://github.com/LerianStudio/lib-commons/commit/3c60b29f9432c012219f0c08b1403594ea54069b))
-* **telemetry:** add telemetry queue propagation ([610c702](https://github.com/LerianStudio/lib-commons/commit/610c702c3f927d08bcd3f5279caf99b75127dfd8))
-* adjust internal keys on redis to use generic one; ([c0e4556](https://github.com/LerianStudio/lib-commons/commit/c0e45566040c9da35043601b8128b3792c43cb61))
-* create a new balance internal key to lock balance on redis; ([715e2e7](https://github.com/LerianStudio/lib-commons/commit/715e2e72b47c681064fd83dcef89c053c1d33d1c))
-* extract logger separator constant and enhance telemetry span attributes ([2f611bb](https://github.com/LerianStudio/lib-commons/commit/2f611bb808f4fb68860b9745490a3ffdf8ba37a9))
-* **security:** implement sensitive field obfuscation for telemetry and logging ([b98bd60](https://github.com/LerianStudio/lib-commons/commit/b98bd604259823c733711ef552d23fb347a86956))
-* Merge pull request [#166](https://github.com/LerianStudio/lib-commons/issues/166) from LerianStudio/feat/add-new-redis-key ([3199765](https://github.com/LerianStudio/lib-commons/commit/3199765d6832d8a068f8e925773ea44acce5291e))
-* Merge pull request [#168](https://github.com/LerianStudio/lib-commons/issues/168) from LerianStudio/feat/COMMONS-redis-balance-key ([2b66484](https://github.com/LerianStudio/lib-commons/commit/2b66484703bb7551fbe5264cc8f20618fe61bd5b))
-* merge pull request [#176](https://github.com/LerianStudio/lib-commons/issues/176) from LerianStudio/develop ([69fd3fa](https://github.com/LerianStudio/lib-commons/commit/69fd3face5ada8718fe290ac951e89720c253980))
-
-
-### Bug Fixes
-
-* Add NormalizeDateTime helper for date offset and time bounds formatting ([838c5f1](https://github.com/LerianStudio/lib-commons/commit/838c5f1940fd06c109ba9480f30781553e80ff45))
-* Merge pull request [#164](https://github.com/LerianStudio/lib-commons/issues/164) from LerianStudio/fix/COMMONS-1111 ([295ca40](https://github.com/LerianStudio/lib-commons/commit/295ca4093e919513bfcf7a0de50108c9e5609eb2))
-* remove commets; ([333fe49](https://github.com/LerianStudio/lib-commons/commit/333fe499e1a8a43654cd6c0f0546e3a1c5279bc9))
-
-
-### Code Refactoring
-
-* update module to v2 ([1c20f97](https://github.com/LerianStudio/lib-commons/commit/1c20f97279dd7ab0c59e447b4e1ffc1595077deb))
-
-## [2.0.0-beta.1](https://github.com/LerianStudio/lib-commons/compare/v1.19.0-beta.11...v2.0.0-beta.1) (2025-07-30)
-
-
-### ⚠ BREAKING CHANGES
-
-* change version and paths to v2
-
-### Features
-
-* **security:** add accesstoken and refreshtoken to sensitive fields list ([9e884c7](https://github.com/LerianStudio/lib-commons/commit/9e884c784e686c15354196fa09526371570f01e1))
-
-
-### Code Refactoring
-
-* update module to v2 ([1c20f97](https://github.com/LerianStudio/lib-commons/commit/1c20f97279dd7ab0c59e447b4e1ffc1595077deb))
-
-## [1.19.0-beta.11](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.10...v1.19.0-beta.11) (2025-07-30)
-
-
-### Features
-
-* **telemetry:** add request ID to HTTP span attributes ([3c60b29](https://github.com/LerianStudio/lib-commons/v2/commit/3c60b29f9432c012219f0c08b1403594ea54069b))
-
-## [1.19.0-beta.10](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.9...v1.19.0-beta.10) (2025-07-30)
-
-
-### Features
-
-* **security:** add accesstoken and refreshtoken to sensitive fields ([ede9b9b](https://github.com/LerianStudio/lib-commons/v2/commit/ede9b9ba17b7f98ffe53a927d42cfb7b0f867f29))
-
-## [1.19.0-beta.9](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.8...v1.19.0-beta.9) (2025-07-30)
-
-## [1.19.0-beta.8](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.7...v1.19.0-beta.8) (2025-07-29)
-
-## [1.19.0-beta.7](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.6...v1.19.0-beta.7) (2025-07-29)
-
-
-### Features
-
-* extract logger separator constant and enhance telemetry span attributes ([2f611bb](https://github.com/LerianStudio/lib-commons/v2/commit/2f611bb808f4fb68860b9745490a3ffdf8ba37a9))
-
-## [1.19.0-beta.6](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.5...v1.19.0-beta.6) (2025-07-28)
-
-
-### Features
-
-* **telemetry:** add metrics factory with fluent API for counter, gauge and histogram metrics ([517352b](https://github.com/LerianStudio/lib-commons/v2/commit/517352b95111de59613d9b2f15429c751302b779))
-
-## [1.19.0-beta.5](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.4...v1.19.0-beta.5) (2025-07-28)
-
-
-### Features
-
-* adjust internal keys on redis to use generic one; ([c0e4556](https://github.com/LerianStudio/lib-commons/v2/commit/c0e45566040c9da35043601b8128b3792c43cb61))
-* Merge pull request [#168](https://github.com/LerianStudio/lib-commons/v2/issues/168) from LerianStudio/feat/COMMONS-redis-balance-key ([2b66484](https://github.com/LerianStudio/lib-commons/v2/commit/2b66484703bb7551fbe5264cc8f20618fe61bd5b))
-
-## [1.19.0-beta.4](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.3...v1.19.0-beta.4) (2025-07-28)
-
-
-### Features
-
-* create a new balance internal key to lock balance on redis; ([715e2e7](https://github.com/LerianStudio/lib-commons/v2/commit/715e2e72b47c681064fd83dcef89c053c1d33d1c))
-* Merge pull request [#166](https://github.com/LerianStudio/lib-commons/v2/issues/166) from LerianStudio/feat/add-new-redis-key ([3199765](https://github.com/LerianStudio/lib-commons/v2/commit/3199765d6832d8a068f8e925773ea44acce5291e))
-
-## [1.19.0-beta.3](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.2...v1.19.0-beta.3) (2025-07-25)
-
-
-### Features
-
-* **telemetry:** add telemetry queue propagation ([610c702](https://github.com/LerianStudio/lib-commons/v2/commit/610c702c3f927d08bcd3f5279caf99b75127dfd8))
-
-## [1.19.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.19.0-beta.1...v1.19.0-beta.2) (2025-07-25)
-
-
-### Bug Fixes
-
-* Add NormalizeDateTime helper for date offset and time bounds formatting ([838c5f1](https://github.com/LerianStudio/lib-commons/v2/commit/838c5f1940fd06c109ba9480f30781553e80ff45))
-* Merge pull request [#164](https://github.com/LerianStudio/lib-commons/v2/issues/164) from LerianStudio/fix/COMMONS-1111 ([295ca40](https://github.com/LerianStudio/lib-commons/v2/commit/295ca4093e919513bfcf7a0de50108c9e5609eb2))
-* remove commets; ([333fe49](https://github.com/LerianStudio/lib-commons/v2/commit/333fe499e1a8a43654cd6c0f0546e3a1c5279bc9))
-
-## [1.19.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.18.0...v1.19.0-beta.1) (2025-07-23)
-
-
-### Features
-
-* **security:** implement sensitive field obfuscation for telemetry and logging ([b98bd60](https://github.com/LerianStudio/lib-commons/v2/commit/b98bd604259823c733711ef552d23fb347a86956))
-
-## [1.18.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0...v1.18.0) (2025-07-22)
-
-
-### Features
-
-* Improve Redis client configuration with UniversalOptions and connection pool tuning ([1587047](https://github.com/LerianStudio/lib-commons/v2/commit/158704738d1c823af6fbf3bc37f97d9e9734ed8e))
-* Merge pull request [#159](https://github.com/LerianStudio/lib-commons/v2/issues/159) from LerianStudio/feat/COMMONS-REDIS-RETRY ([e279ae9](https://github.com/LerianStudio/lib-commons/v2/commit/e279ae92be1464100e7f11c236afa9df408834cb))
-* Merge pull request [#162](https://github.com/LerianStudio/lib-commons/v2/issues/162) from LerianStudio/develop ([f0778f0](https://github.com/LerianStudio/lib-commons/v2/commit/f0778f040d2e0ec776a5e7ca796578b1a01bd869))
-
-
-### Bug Fixes
-
-* add on const magic numbers; ([ff4d39b](https://github.com/LerianStudio/lib-commons/v2/commit/ff4d39b9ae209ce83827d5ba8b73f1e54692caad))
-* add redis values default; ([7fe8252](https://github.com/LerianStudio/lib-commons/v2/commit/7fe8252291623f0c148155c60e33e48c7e2722ec))
-* add variables default config; ([3c0b0a8](https://github.com/LerianStudio/lib-commons/v2/commit/3c0b0a8d5a07979ed668885d9799fb5c1c60aa3b))
-* change default values to regular size; ([42ff053](https://github.com/LerianStudio/lib-commons/v2/commit/42ff053d9545be847d7f6033c6e3afd8f4fd4bf0))
-* remove alias concat on operation route assignment :bug: ([ddf7530](https://github.com/LerianStudio/lib-commons/v2/commit/ddf7530692f9e1121b986b1c4d7cc27022b22f24))
-
-## [1.18.0-beta.4](https://github.com/LerianStudio/lib-commons/v2/compare/v1.18.0-beta.3...v1.18.0-beta.4) (2025-07-22)
-
-
-### Bug Fixes
-
-* add redis values default; ([7fe8252](https://github.com/LerianStudio/lib-commons/v2/commit/7fe8252291623f0c148155c60e33e48c7e2722ec))
-
-## [1.18.0-beta.3](https://github.com/LerianStudio/lib-commons/v2/compare/v1.18.0-beta.2...v1.18.0-beta.3) (2025-07-22)
-
-
-### Bug Fixes
-
-* add variables default config; ([3c0b0a8](https://github.com/LerianStudio/lib-commons/v2/commit/3c0b0a8d5a07979ed668885d9799fb5c1c60aa3b))
-* change default values to regular size; ([42ff053](https://github.com/LerianStudio/lib-commons/v2/commit/42ff053d9545be847d7f6033c6e3afd8f4fd4bf0))
-
-## [1.18.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.18.0-beta.1...v1.18.0-beta.2) (2025-07-22)
-
-
-### Features
-
-* Improve Redis client configuration with UniversalOptions and connection pool tuning ([1587047](https://github.com/LerianStudio/lib-commons/v2/commit/158704738d1c823af6fbf3bc37f97d9e9734ed8e))
-* Merge pull request [#159](https://github.com/LerianStudio/lib-commons/v2/issues/159) from LerianStudio/feat/COMMONS-REDIS-RETRY ([e279ae9](https://github.com/LerianStudio/lib-commons/v2/commit/e279ae92be1464100e7f11c236afa9df408834cb))
-
-
-### Bug Fixes
-
-* add on const magic numbers; ([ff4d39b](https://github.com/LerianStudio/lib-commons/v2/commit/ff4d39b9ae209ce83827d5ba8b73f1e54692caad))
-
-## [1.18.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0...v1.18.0-beta.1) (2025-07-21)
-
-
-### Bug Fixes
-
-* remove alias concat on operation route assignment :bug: ([ddf7530](https://github.com/LerianStudio/lib-commons/v2/commit/ddf7530692f9e1121b986b1c4d7cc27022b22f24))
-
-## [1.17.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.16.0...v1.17.0) (2025-07-17)
-
-
-### Features
-
-* **transaction:** add accounting routes to Responses struct :sparkles: ([5f36263](https://github.com/LerianStudio/lib-commons/v2/commit/5f36263e6036d5e993d17af7d846c10c9290e610))
-* **utils:** add ExtractTokenFromHeader function to parse Authorization headers ([c91ea16](https://github.com/LerianStudio/lib-commons/v2/commit/c91ea16580bba21118a726c3ad0751752fe59e5b))
-* **http:** add Fiber error handler with OpenTelemetry span management ([5c7deed](https://github.com/LerianStudio/lib-commons/v2/commit/5c7deed8216321edd0527b10bad220dde1492d2e))
-* add gcp credentials to use passing by app like base64 string; ([326ff60](https://github.com/LerianStudio/lib-commons/v2/commit/326ff601e7eccbfd9aa7a31a54488cd68d8d2bbb))
-* add new internal key generation functions for settings and accounting routes :sparkles: ([d328f29](https://github.com/LerianStudio/lib-commons/v2/commit/d328f29ef095c8ca2e3741744918da4761a1696f))
-* add some refactors ([8cd3f91](https://github.com/LerianStudio/lib-commons/v2/commit/8cd3f915f3b136afe9d2365b36a3cc96934e1c52))
-* add TTL support to Redis/Valkey and support cluster + sentinel modes alongside standalone ([1d825df](https://github.com/LerianStudio/lib-commons/v2/commit/1d825dfefbf574bfe3db0bc718b9d0876aec5e03))
-* add variable tableAlias variadic to ApplyCursorPagination; ([1579a9e](https://github.com/LerianStudio/lib-commons/v2/commit/1579a9e25eae1da3247422ccd64e48730c59ba31))
-* adjust to use only one host; ([22696b0](https://github.com/LerianStudio/lib-commons/v2/commit/22696b0f989eff5db22aeeff06d82df3b16230e4))
-* change cacert to string to receive base64; ([a24f5f4](https://github.com/LerianStudio/lib-commons/v2/commit/a24f5f472686e39b44031e00fcc2b7989f1cf6b7))
-* create a new const called x-idempotency-replayed; ([df9946c](https://github.com/LerianStudio/lib-commons/v2/commit/df9946c830586ed80577495cc653109b636b4575))
-* **otel:** enhance trace context propagation with tracestate support for grpc ([f6f65ee](https://github.com/LerianStudio/lib-commons/v2/commit/f6f65eec7999c9bb4d6c14b2314c5c7e5d7f76ea))
-* implements IAM refresh token; ([3d21e04](https://github.com/LerianStudio/lib-commons/v2/commit/3d21e04194a10710a1b9de46a3f3aba89804c8b8))
-* Merge pull request [#118](https://github.com/LerianStudio/lib-commons/v2/issues/118) from LerianStudio/feat/COMMONS-52 ([e8f8917](https://github.com/LerianStudio/lib-commons/v2/commit/e8f8917b5c828c487f6bf2236b391dd4f8da5623))
-* merge pull request [#120](https://github.com/LerianStudio/lib-commons/v2/issues/120) from LerianStudio/feat/COMMONS-52-2 ([4293e11](https://github.com/LerianStudio/lib-commons/v2/commit/4293e11ae36942afd7a376ab3ee3db3981922ebf))
-* merge pull request [#124](https://github.com/LerianStudio/lib-commons/v2/issues/124) from LerianStudio/feat/COMMONS-52-6 ([8aaaf65](https://github.com/LerianStudio/lib-commons/v2/commit/8aaaf652e399746c67c0b8699c57f4a249271ef0))
-* merge pull request [#127](https://github.com/LerianStudio/lib-commons/v2/issues/127) from LerianStudio/feat/COMMONS-52-9 ([12ee2a9](https://github.com/LerianStudio/lib-commons/v2/commit/12ee2a947d2fc38e8957b9b9f6e129b65e4b87a2))
-* Merge pull request [#128](https://github.com/LerianStudio/lib-commons/v2/issues/128) from LerianStudio/feat/COMMONS-52-10 ([775f24a](https://github.com/LerianStudio/lib-commons/v2/commit/775f24ac85da8eb5e08a6e374ee61f327e798094))
-* Merge pull request [#132](https://github.com/LerianStudio/lib-commons/v2/issues/132) from LerianStudio/feat/COMMOS-1023 ([e2cce46](https://github.com/LerianStudio/lib-commons/v2/commit/e2cce46b11ca9172f45769dae444de48e74e051f))
-* Merge pull request [#152](https://github.com/LerianStudio/lib-commons/v2/issues/152) from LerianStudio/develop ([9e38ece](https://github.com/LerianStudio/lib-commons/v2/commit/9e38ece58cac8458cf3aed44bd2e210510424a61))
-* merge pull request [#153](https://github.com/LerianStudio/lib-commons/v2/issues/153) from LerianStudio/feat/COMMONS-1055 ([1cc6cb5](https://github.com/LerianStudio/lib-commons/v2/commit/1cc6cb53c71515bd0c574ece0bb6335682aab953))
-* Preallocate structures and isolate channels per goroutine for CalculateTotal ([8e92258](https://github.com/LerianStudio/lib-commons/v2/commit/8e922587f4b88f93434dfac5e16f0e570bef4a98))
-* revert code that was on the main; ([c2f1772](https://github.com/LerianStudio/lib-commons/v2/commit/c2f17729bde8d2f5bbc36381173ad9226640d763))
-
-
-### Bug Fixes
-
-* .golangci.yml ([038bedd](https://github.com/LerianStudio/lib-commons/v2/commit/038beddbe9ed4a867f6ed93dd4e84480ed65bb1b))
-* add fallback logging when logger is nil in shutdown handler ([800d644](https://github.com/LerianStudio/lib-commons/v2/commit/800d644d920bd54abf787d3be457cc0a1117c7a1))
-* add new check channel is closed; ([e3956c4](https://github.com/LerianStudio/lib-commons/v2/commit/e3956c46eb8a87e637e035d7676d5c592001b509))
-* adjust camel case time name; ([5ba77b9](https://github.com/LerianStudio/lib-commons/v2/commit/5ba77b958a0386a2ab9f8197503bbd4bd57235f0))
-* adjust decimal values from remains and percentage; ([e1dc4b1](https://github.com/LerianStudio/lib-commons/v2/commit/e1dc4b183d0ca2d1247f727b81f8f27d4ddcc3c7))
-* adjust redis key to use {} to calculate slot on cluster; ([318f269](https://github.com/LerianStudio/lib-commons/v2/commit/318f26947ee847aebfc600ed6e21cb903ee6a795))
-* adjust some code and test; ([c6aca75](https://github.com/LerianStudio/lib-commons/v2/commit/c6aca756499e8b9875e1474e4f7949bb9cc9f60c))
-* adjust to create tls on redis using variable; ([e78ae20](https://github.com/LerianStudio/lib-commons/v2/commit/e78ae2035b5583ce59654e3c7f145d93d86051e7))
-* gitactions; ([7f9ebeb](https://github.com/LerianStudio/lib-commons/v2/commit/7f9ebeb1a9328a902e82c8c60428b2a8246793cf))
-* go lint ([2499476](https://github.com/LerianStudio/lib-commons/v2/commit/249947604ed5d5382cd46e28e03c7396b9096d63))
-* improve error handling and prevent deadlocks in server and license management ([24282ee](https://github.com/LerianStudio/lib-commons/v2/commit/24282ee9a411e0d5bf1977447a97e1e3fb260835))
-* Merge pull request [#119](https://github.com/LerianStudio/lib-commons/v2/issues/119) from LerianStudio/feat/COMMONS-52 ([3ba9ca0](https://github.com/LerianStudio/lib-commons/v2/commit/3ba9ca0e284cf36797772967904d21947f8856a5))
-* Merge pull request [#121](https://github.com/LerianStudio/lib-commons/v2/issues/121) from LerianStudio/feat/COMMONS-52-3 ([69c9e00](https://github.com/LerianStudio/lib-commons/v2/commit/69c9e002ab0a4fcd24622c79c5da7857eb22c922))
-* Merge pull request [#122](https://github.com/LerianStudio/lib-commons/v2/issues/122) from LerianStudio/feat/COMMONS-52-4 ([46f5140](https://github.com/LerianStudio/lib-commons/v2/commit/46f51404f5f472172776abb1fbfd3bab908fc540))
-* Merge pull request [#123](https://github.com/LerianStudio/lib-commons/v2/issues/123) from LerianStudio/fix/COMMONS-52-5 ([788915b](https://github.com/LerianStudio/lib-commons/v2/commit/788915b8c333156046e1d79860f80dc84f9aa08b))
-* Merge pull request [#126](https://github.com/LerianStudio/lib-commons/v2/issues/126) from LerianStudio/fix-COMMONS-52-8 ([cfe9bbd](https://github.com/LerianStudio/lib-commons/v2/commit/cfe9bbde1bcf97847faf3fdc7e72e20ff723d586))
-* rabbit hearthbeat and log type of client conn on redis/valkey; ([9607bf5](https://github.com/LerianStudio/lib-commons/v2/commit/9607bf5c0abf21603372d32ea8d66b5d34c77ec0))
-* revert to original rabbit source; ([351c6ea](https://github.com/LerianStudio/lib-commons/v2/commit/351c6eac3e27301e4a65fce293032567bfd88807))
-* **otel:** simplify resource creation to solve schema merging conflict ([318a38c](https://github.com/LerianStudio/lib-commons/v2/commit/318a38c07ca8c3bd6e2345c78302ad0c515d39a3))
-
-## [1.17.0-beta.31](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.30...v1.17.0-beta.31) (2025-07-17)
-
-## [1.17.0-beta.30](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.29...v1.17.0-beta.30) (2025-07-17)
-
-
-### Bug Fixes
-
-* improve error handling and prevent deadlocks in server and license management ([24282ee](https://github.com/LerianStudio/lib-commons/v2/commit/24282ee9a411e0d5bf1977447a97e1e3fb260835))
-
-## [1.17.0-beta.29](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.28...v1.17.0-beta.29) (2025-07-16)
-
-
-### Features
-
-* merge pull request [#153](https://github.com/LerianStudio/lib-commons/v2/issues/153) from LerianStudio/feat/COMMONS-1055 ([1cc6cb5](https://github.com/LerianStudio/lib-commons/v2/commit/1cc6cb53c71515bd0c574ece0bb6335682aab953))
-* Preallocate structures and isolate channels per goroutine for CalculateTotal ([8e92258](https://github.com/LerianStudio/lib-commons/v2/commit/8e922587f4b88f93434dfac5e16f0e570bef4a98))
-
-## [1.17.0-beta.28](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.27...v1.17.0-beta.28) (2025-07-15)
-
-
-### Features
-
-* **http:** add Fiber error handler with OpenTelemetry span management ([5c7deed](https://github.com/LerianStudio/lib-commons/v2/commit/5c7deed8216321edd0527b10bad220dde1492d2e))
-
-## [1.17.0-beta.27](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.26...v1.17.0-beta.27) (2025-07-15)
-
-## [1.17.0-beta.26](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.25...v1.17.0-beta.26) (2025-07-15)
-
-## [1.17.0-beta.25](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.24...v1.17.0-beta.25) (2025-07-11)
-
-
-### Features
-
-* **transaction:** add accounting routes to Responses struct :sparkles: ([5f36263](https://github.com/LerianStudio/lib-commons/v2/commit/5f36263e6036d5e993d17af7d846c10c9290e610))
-
-## [1.17.0-beta.24](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.23...v1.17.0-beta.24) (2025-07-07)
-
-
-### Bug Fixes
-
-* **otel:** simplify resource creation to solve schema merging conflict ([318a38c](https://github.com/LerianStudio/lib-commons/v2/commit/318a38c07ca8c3bd6e2345c78302ad0c515d39a3))
-
-## [1.17.0-beta.23](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.22...v1.17.0-beta.23) (2025-07-07)
-
-## [1.17.0-beta.22](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.21...v1.17.0-beta.22) (2025-07-07)
-
-
-### Features
-
-* **otel:** enhance trace context propagation with tracestate support for grpc ([f6f65ee](https://github.com/LerianStudio/lib-commons/v2/commit/f6f65eec7999c9bb4d6c14b2314c5c7e5d7f76ea))
-
-## [1.17.0-beta.21](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.20...v1.17.0-beta.21) (2025-07-02)
-
-
-### Features
-
-* **utils:** add ExtractTokenFromHeader function to parse Authorization headers ([c91ea16](https://github.com/LerianStudio/lib-commons/v2/commit/c91ea16580bba21118a726c3ad0751752fe59e5b))
-
-## [1.17.0-beta.20](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.19...v1.17.0-beta.20) (2025-07-01)
-
-
-### Features
-
-* add new internal key generation functions for settings and accounting routes :sparkles: ([d328f29](https://github.com/LerianStudio/lib-commons/v2/commit/d328f29ef095c8ca2e3741744918da4761a1696f))
-
-## [1.17.0-beta.19](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.18...v1.17.0-beta.19) (2025-06-30)
-
-
-### Features
-
-* create a new const called x-idempotency-replayed; ([df9946c](https://github.com/LerianStudio/lib-commons/v2/commit/df9946c830586ed80577495cc653109b636b4575))
-* Merge pull request [#132](https://github.com/LerianStudio/lib-commons/v2/issues/132) from LerianStudio/feat/COMMOS-1023 ([e2cce46](https://github.com/LerianStudio/lib-commons/v2/commit/e2cce46b11ca9172f45769dae444de48e74e051f))
-
-## [1.17.0-beta.18](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.17...v1.17.0-beta.18) (2025-06-27)
-
-## [1.17.0-beta.17](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.16...v1.17.0-beta.17) (2025-06-27)
-
-## [1.17.0-beta.16](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.15...v1.17.0-beta.16) (2025-06-26)
-
-
-### Features
-
-* add gcp credentials to use passing by app like base64 string; ([326ff60](https://github.com/LerianStudio/lib-commons/v2/commit/326ff601e7eccbfd9aa7a31a54488cd68d8d2bbb))
-
-## [1.17.0-beta.15](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.14...v1.17.0-beta.15) (2025-06-25)
-
-
-### Features
-
-* add some refactors ([8cd3f91](https://github.com/LerianStudio/lib-commons/v2/commit/8cd3f915f3b136afe9d2365b36a3cc96934e1c52))
-* Merge pull request [#128](https://github.com/LerianStudio/lib-commons/v2/issues/128) from LerianStudio/feat/COMMONS-52-10 ([775f24a](https://github.com/LerianStudio/lib-commons/v2/commit/775f24ac85da8eb5e08a6e374ee61f327e798094))
-
-## [1.17.0-beta.14](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.13...v1.17.0-beta.14) (2025-06-25)
-
-
-### Features
-
-* change cacert to string to receive base64; ([a24f5f4](https://github.com/LerianStudio/lib-commons/v2/commit/a24f5f472686e39b44031e00fcc2b7989f1cf6b7))
-* merge pull request [#127](https://github.com/LerianStudio/lib-commons/v2/issues/127) from LerianStudio/feat/COMMONS-52-9 ([12ee2a9](https://github.com/LerianStudio/lib-commons/v2/commit/12ee2a947d2fc38e8957b9b9f6e129b65e4b87a2))
-
-## [1.17.0-beta.13](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.12...v1.17.0-beta.13) (2025-06-25)
-
-
-### Bug Fixes
-
-* Merge pull request [#126](https://github.com/LerianStudio/lib-commons/v2/issues/126) from LerianStudio/fix-COMMONS-52-8 ([cfe9bbd](https://github.com/LerianStudio/lib-commons/v2/commit/cfe9bbde1bcf97847faf3fdc7e72e20ff723d586))
-* revert to original rabbit source; ([351c6ea](https://github.com/LerianStudio/lib-commons/v2/commit/351c6eac3e27301e4a65fce293032567bfd88807))
-
-## [1.17.0-beta.12](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.11...v1.17.0-beta.12) (2025-06-25)
-
-
-### Bug Fixes
-
-* add new check channel is closed; ([e3956c4](https://github.com/LerianStudio/lib-commons/v2/commit/e3956c46eb8a87e637e035d7676d5c592001b509))
-
-## [1.17.0-beta.11](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.10...v1.17.0-beta.11) (2025-06-25)
-
-
-### Features
-
-* merge pull request [#124](https://github.com/LerianStudio/lib-commons/v2/issues/124) from LerianStudio/feat/COMMONS-52-6 ([8aaaf65](https://github.com/LerianStudio/lib-commons/v2/commit/8aaaf652e399746c67c0b8699c57f4a249271ef0))
-
-
-### Bug Fixes
-
-* rabbit hearthbeat and log type of client conn on redis/valkey; ([9607bf5](https://github.com/LerianStudio/lib-commons/v2/commit/9607bf5c0abf21603372d32ea8d66b5d34c77ec0))
-
-## [1.17.0-beta.10](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.9...v1.17.0-beta.10) (2025-06-24)
-
-
-### Bug Fixes
-
-* adjust camel case time name; ([5ba77b9](https://github.com/LerianStudio/lib-commons/v2/commit/5ba77b958a0386a2ab9f8197503bbd4bd57235f0))
-* Merge pull request [#123](https://github.com/LerianStudio/lib-commons/v2/issues/123) from LerianStudio/fix/COMMONS-52-5 ([788915b](https://github.com/LerianStudio/lib-commons/v2/commit/788915b8c333156046e1d79860f80dc84f9aa08b))
-
-## [1.17.0-beta.9](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.8...v1.17.0-beta.9) (2025-06-24)
-
-
-### Bug Fixes
-
-* adjust redis key to use {} to calculate slot on cluster; ([318f269](https://github.com/LerianStudio/lib-commons/v2/commit/318f26947ee847aebfc600ed6e21cb903ee6a795))
-* Merge pull request [#122](https://github.com/LerianStudio/lib-commons/v2/issues/122) from LerianStudio/feat/COMMONS-52-4 ([46f5140](https://github.com/LerianStudio/lib-commons/v2/commit/46f51404f5f472172776abb1fbfd3bab908fc540))
-
-## [1.17.0-beta.8](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.7...v1.17.0-beta.8) (2025-06-24)
-
-
-### Features
-
-* implements IAM refresh token; ([3d21e04](https://github.com/LerianStudio/lib-commons/v2/commit/3d21e04194a10710a1b9de46a3f3aba89804c8b8))
-
-
-### Bug Fixes
-
-* Merge pull request [#121](https://github.com/LerianStudio/lib-commons/v2/issues/121) from LerianStudio/feat/COMMONS-52-3 ([69c9e00](https://github.com/LerianStudio/lib-commons/v2/commit/69c9e002ab0a4fcd24622c79c5da7857eb22c922))
-
-## [1.17.0-beta.7](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.6...v1.17.0-beta.7) (2025-06-24)
-
-
-### Features
-
-* merge pull request [#120](https://github.com/LerianStudio/lib-commons/v2/issues/120) from LerianStudio/feat/COMMONS-52-2 ([4293e11](https://github.com/LerianStudio/lib-commons/v2/commit/4293e11ae36942afd7a376ab3ee3db3981922ebf))
-
-
-### Bug Fixes
-
-* adjust to create tls on redis using variable; ([e78ae20](https://github.com/LerianStudio/lib-commons/v2/commit/e78ae2035b5583ce59654e3c7f145d93d86051e7))
-* go lint ([2499476](https://github.com/LerianStudio/lib-commons/v2/commit/249947604ed5d5382cd46e28e03c7396b9096d63))
-
-## [1.17.0-beta.6](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.5...v1.17.0-beta.6) (2025-06-23)
-
-
-### Features
-
-* adjust to use only one host; ([22696b0](https://github.com/LerianStudio/lib-commons/v2/commit/22696b0f989eff5db22aeeff06d82df3b16230e4))
-
-
-### Bug Fixes
-
-* Merge pull request [#119](https://github.com/LerianStudio/lib-commons/v2/issues/119) from LerianStudio/feat/COMMONS-52 ([3ba9ca0](https://github.com/LerianStudio/lib-commons/v2/commit/3ba9ca0e284cf36797772967904d21947f8856a5))
-
-## [1.17.0-beta.5](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.4...v1.17.0-beta.5) (2025-06-23)
-
-
-### Features
-
-* add TTL support to Redis/Valkey and support cluster + sentinel modes alongside standalone ([1d825df](https://github.com/LerianStudio/lib-commons/v2/commit/1d825dfefbf574bfe3db0bc718b9d0876aec5e03))
-* Merge pull request [#118](https://github.com/LerianStudio/lib-commons/v2/issues/118) from LerianStudio/feat/COMMONS-52 ([e8f8917](https://github.com/LerianStudio/lib-commons/v2/commit/e8f8917b5c828c487f6bf2236b391dd4f8da5623))
-
-
-### Bug Fixes
-
-* .golangci.yml ([038bedd](https://github.com/LerianStudio/lib-commons/v2/commit/038beddbe9ed4a867f6ed93dd4e84480ed65bb1b))
-* gitactions; ([7f9ebeb](https://github.com/LerianStudio/lib-commons/v2/commit/7f9ebeb1a9328a902e82c8c60428b2a8246793cf))
-
-## [1.17.0-beta.4](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.3...v1.17.0-beta.4) (2025-06-20)
-
-
-### Bug Fixes
-
-* adjust decimal values from remains and percentage; ([e1dc4b1](https://github.com/LerianStudio/lib-commons/v2/commit/e1dc4b183d0ca2d1247f727b81f8f27d4ddcc3c7))
-* adjust some code and test; ([c6aca75](https://github.com/LerianStudio/lib-commons/v2/commit/c6aca756499e8b9875e1474e4f7949bb9cc9f60c))
-
-## [1.17.0-beta.3](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.2...v1.17.0-beta.3) (2025-06-20)
-
-
-### Bug Fixes
-
-* add fallback logging when logger is nil in shutdown handler ([800d644](https://github.com/LerianStudio/lib-commons/v2/commit/800d644d920bd54abf787d3be457cc0a1117c7a1))
-
-## [1.17.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.17.0-beta.1...v1.17.0-beta.2) (2025-06-20)
-
-
-### Features
-
-* add variable tableAlias variadic to ApplyCursorPagination; ([1579a9e](https://github.com/LerianStudio/lib-commons/v2/commit/1579a9e25eae1da3247422ccd64e48730c59ba31))
-
-## [1.17.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.16.0...v1.17.0-beta.1) (2025-06-16)
-
-
-### Features
-
-* revert code that was on the main; ([c2f1772](https://github.com/LerianStudio/lib-commons/v2/commit/c2f17729bde8d2f5bbc36381173ad9226640d763))
-
-## [1.12.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0...v1.12.0) (2025-06-13)
-
-
-### Features
-
-* add log test; ([7ad741f](https://github.com/LerianStudio/lib-commons/v2/commit/7ad741f558e7a725e95dab257500d5d24b2536e5))
-* add shutdown test ([9d5fb77](https://github.com/LerianStudio/lib-commons/v2/commit/9d5fb77893e10a708136767eda3f9bac99363ba4))
-
-
-### Bug Fixes
-
-* Add integer overflow protection to transaction operations; :bug: ([32904de](https://github.com/LerianStudio/lib-commons/v2/commit/32904def9bee6388f12a6e2cc997c20a594db696))
-* add url for health check to read. from envs; update testes; update go mod and go sum; ([e9b8333](https://github.com/LerianStudio/lib-commons/v2/commit/e9b83330834c7c2949dfb05a4dc46f4786cd509d))
-* create redis test; ([3178547](https://github.com/LerianStudio/lib-commons/v2/commit/317854731e550d222713503eecbdf26e2c26fa90))
-
-## [1.12.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0...v1.12.0-beta.1) (2025-06-13)
-
-
-### Features
-
-* add log test; ([7ad741f](https://github.com/LerianStudio/lib-commons/v2/commit/7ad741f558e7a725e95dab257500d5d24b2536e5))
-* add shutdown test ([9d5fb77](https://github.com/LerianStudio/lib-commons/v2/commit/9d5fb77893e10a708136767eda3f9bac99363ba4))
-
-
-### Bug Fixes
-
-* Add integer overflow protection to transaction operations; :bug: ([32904de](https://github.com/LerianStudio/lib-commons/v2/commit/32904def9bee6388f12a6e2cc997c20a594db696))
-* add url for health check to read. from envs; update testes; update go mod and go sum; ([e9b8333](https://github.com/LerianStudio/lib-commons/v2/commit/e9b83330834c7c2949dfb05a4dc46f4786cd509d))
-* create redis test; ([3178547](https://github.com/LerianStudio/lib-commons/v2/commit/317854731e550d222713503eecbdf26e2c26fa90))
-
-## [1.12.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0...v1.12.0) (2025-06-13)
-
-
-### Features
-
-* add log test; ([7ad741f](https://github.com/LerianStudio/lib-commons/v2/commit/7ad741f558e7a725e95dab257500d5d24b2536e5))
-
-
-### Bug Fixes
-
-* Add integer overflow protection to transaction operations; :bug: ([32904de](https://github.com/LerianStudio/lib-commons/v2/commit/32904def9bee6388f12a6e2cc997c20a594db696))
-* add url for health check to read. from envs; update testes; update go mod and go sum; ([e9b8333](https://github.com/LerianStudio/lib-commons/v2/commit/e9b83330834c7c2949dfb05a4dc46f4786cd509d))
-* create redis test; ([3178547](https://github.com/LerianStudio/lib-commons/v2/commit/317854731e550d222713503eecbdf26e2c26fa90))
-
-## [1.12.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0...v1.12.0-beta.1) (2025-06-13)
-
-
-### Features
-
-* add log test; ([7ad741f](https://github.com/LerianStudio/lib-commons/v2/commit/7ad741f558e7a725e95dab257500d5d24b2536e5))
-
-
-### Bug Fixes
-
-* Add integer overflow protection to transaction operations; :bug: ([32904de](https://github.com/LerianStudio/lib-commons/v2/commit/32904def9bee6388f12a6e2cc997c20a594db696))
-* add url for health check to read. from envs; update testes; update go mod and go sum; ([e9b8333](https://github.com/LerianStudio/lib-commons/v2/commit/e9b83330834c7c2949dfb05a4dc46f4786cd509d))
-* create redis test; ([3178547](https://github.com/LerianStudio/lib-commons/v2/commit/317854731e550d222713503eecbdf26e2c26fa90))
-
-## [1.12.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0...v1.12.0-beta.1) (2025-06-13)
-
-
-### Bug Fixes
-
-* Add integer overflow protection to transaction operations; :bug: ([32904de](https://github.com/LerianStudio/lib-commons/v2/commit/32904def9bee6388f12a6e2cc997c20a594db696))
-* add url for health check to read. from envs; update testes; update go mod and go sum; ([e9b8333](https://github.com/LerianStudio/lib-commons/v2/commit/e9b83330834c7c2949dfb05a4dc46f4786cd509d))
-* create redis test; ([3178547](https://github.com/LerianStudio/lib-commons/v2/commit/317854731e550d222713503eecbdf26e2c26fa90))
-
-## [1.12.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0...v1.12.0) (2025-06-13)
-
-
-### Bug Fixes
-
-* Add integer overflow protection to transaction operations; :bug: ([32904de](https://github.com/LerianStudio/lib-commons/v2/commit/32904def9bee6388f12a6e2cc997c20a594db696))
-* add url for health check to read. from envs; update testes; update go mod and go sum; ([e9b8333](https://github.com/LerianStudio/lib-commons/v2/commit/e9b83330834c7c2949dfb05a4dc46f4786cd509d))
-
-## [1.12.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0...v1.12.0-beta.1) (2025-06-13)
-
-
-### Bug Fixes
-
-* Add integer overflow protection to transaction operations; :bug: ([32904de](https://github.com/LerianStudio/lib-commons/v2/commit/32904def9bee6388f12a6e2cc997c20a594db696))
-* add url for health check to read. from envs; update testes; update go mod and go sum; ([e9b8333](https://github.com/LerianStudio/lib-commons/v2/commit/e9b83330834c7c2949dfb05a4dc46f4786cd509d))
-
-## [1.11.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.10.0...v1.11.0) (2025-05-19)
-
-
-### Features
-
-* add info and debug log levels to zap logger initializer by env name ([c132299](https://github.com/LerianStudio/lib-commons/v2/commit/c13229910647081facf9f555e4b4efa74aff60ec))
-* add start app with graceful shutdown module ([21d9697](https://github.com/LerianStudio/lib-commons/v2/commit/21d9697c35686e82adbf3f41744ce25c369119ce))
-* bump lib-license-go version to v1.0.8 ([4d93834](https://github.com/LerianStudio/lib-commons/v2/commit/4d93834af0dd4d4d48564b98f9d2dc766369c1be))
-* move license shutdown to the end of execution and add recover from panic in graceful shutdown ([6cf1171](https://github.com/LerianStudio/lib-commons/v2/commit/6cf117159cc10b3fa97200c53fbb6a058566c7d6))
-
-
-### Bug Fixes
-
-* fix lint - remove cuddled if blocks ([cd6424b](https://github.com/LerianStudio/lib-commons/v2/commit/cd6424b741811ec119a2bf35189760070883b993))
-* import corret lib license go uri ([f55338f](https://github.com/LerianStudio/lib-commons/v2/commit/f55338fa2c9ed1d974ab61f28b1c70101b35eb61))
-
-## [1.11.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.11.0-beta.1...v1.11.0-beta.2) (2025-05-19)
-
-
-### Features
-
-* add start app with graceful shutdown module ([21d9697](https://github.com/LerianStudio/lib-commons/v2/commit/21d9697c35686e82adbf3f41744ce25c369119ce))
-* bump lib-license-go version to v1.0.8 ([4d93834](https://github.com/LerianStudio/lib-commons/v2/commit/4d93834af0dd4d4d48564b98f9d2dc766369c1be))
-* move license shutdown to the end of execution and add recover from panic in graceful shutdown ([6cf1171](https://github.com/LerianStudio/lib-commons/v2/commit/6cf117159cc10b3fa97200c53fbb6a058566c7d6))
-
-
-### Bug Fixes
-
-* fix lint - remove cuddled if blocks ([cd6424b](https://github.com/LerianStudio/lib-commons/v2/commit/cd6424b741811ec119a2bf35189760070883b993))
-* import corret lib license go uri ([f55338f](https://github.com/LerianStudio/lib-commons/v2/commit/f55338fa2c9ed1d974ab61f28b1c70101b35eb61))
-
-## [1.11.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.10.0...v1.11.0-beta.1) (2025-05-19)
-
-
-### Features
-
-* add info and debug log levels to zap logger initializer by env name ([c132299](https://github.com/LerianStudio/lib-commons/v2/commit/c13229910647081facf9f555e4b4efa74aff60ec))
-
-## [1.10.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0...v1.10.0) (2025-05-14)
-
-
-### Features
-
-* **postgres:** sets migrations path from environment variable :sparkles: ([7f9d40e](https://github.com/LerianStudio/lib-commons/v2/commit/7f9d40e88a9e9b94a8d6076121e73324421bd6e8))
-
-## [1.10.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0...v1.10.0-beta.1) (2025-05-14)
-
-
-### Features
-
-* **postgres:** sets migrations path from environment variable :sparkles: ([7f9d40e](https://github.com/LerianStudio/lib-commons/v2/commit/7f9d40e88a9e9b94a8d6076121e73324421bd6e8))
-
-## [1.9.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.8.0...v1.9.0) (2025-05-14)
-
-
-### Bug Fixes
-
-* add check if account is empty using accountAlias; :bug: ([d2054d8](https://github.com/LerianStudio/lib-commons/v2/commit/d2054d8e0924accd15cfcac95ef1be6e58abae93))
-* **transaction:** add index variable to loop iteration ([e2974f0](https://github.com/LerianStudio/lib-commons/v2/commit/e2974f0c2cc87f39417bf42943e143188c3f9fc8))
-* final adjust to use multiple identical accounts; :bug: ([b2165de](https://github.com/LerianStudio/lib-commons/v2/commit/b2165de3642c9c9949cda25d370cad9358e5f5be))
-* **transaction:** improve validation in send source and distribute calculations ([625f2f9](https://github.com/LerianStudio/lib-commons/v2/commit/625f2f9598a61dbb4227722f605e1d4798a9a881))
-* **transaction:** improve validation in send source and distribute calculations ([2b05323](https://github.com/LerianStudio/lib-commons/v2/commit/2b05323b81eea70278dbb2326423dedaf5078373))
-* **transaction:** improve validation in send source and distribute calculations ([4a8f3f5](https://github.com/LerianStudio/lib-commons/v2/commit/4a8f3f59da5563842e0785732ad5b05989f62fb7))
-* **transaction:** improve validation in send source and distribute calculations ([1cf5b04](https://github.com/LerianStudio/lib-commons/v2/commit/1cf5b04fb510594c5d13989c137cc8401ea2e23d))
-* **transaction:** optimize balance operations in UpdateBalances function ([524fe97](https://github.com/LerianStudio/lib-commons/v2/commit/524fe975d125742d10920236e055db879809b01e))
-* **transaction:** optimize balance operations in UpdateBalances function ([63201dd](https://github.com/LerianStudio/lib-commons/v2/commit/63201ddeb00835d8b8b9269f8a32850e4f28374e))
-* **transaction:** optimize balance operations in UpdateBalances function ([8b6397d](https://github.com/LerianStudio/lib-commons/v2/commit/8b6397df3261cc0f5af190c69b16a55e215952ed))
-* some more adjusts; :bug: ([af69b44](https://github.com/LerianStudio/lib-commons/v2/commit/af69b447658b0f4dfcd2e2f252dd2d0d68753094))
-
-## [1.9.0-beta.8](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0-beta.7...v1.9.0-beta.8) (2025-05-14)
-
-
-### Bug Fixes
-
-* final adjust to use multiple identical accounts; :bug: ([b2165de](https://github.com/LerianStudio/lib-commons/v2/commit/b2165de3642c9c9949cda25d370cad9358e5f5be))
-
-## [1.9.0-beta.7](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0-beta.6...v1.9.0-beta.7) (2025-05-13)
-
-
-### Bug Fixes
-
-* add check if account is empty using accountAlias; :bug: ([d2054d8](https://github.com/LerianStudio/lib-commons/v2/commit/d2054d8e0924accd15cfcac95ef1be6e58abae93))
-* some more adjusts; :bug: ([af69b44](https://github.com/LerianStudio/lib-commons/v2/commit/af69b447658b0f4dfcd2e2f252dd2d0d68753094))
-
-## [1.9.0-beta.6](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0-beta.5...v1.9.0-beta.6) (2025-05-12)
-
-
-### Bug Fixes
-
-* **transaction:** optimize balance operations in UpdateBalances function ([524fe97](https://github.com/LerianStudio/lib-commons/v2/commit/524fe975d125742d10920236e055db879809b01e))
-* **transaction:** optimize balance operations in UpdateBalances function ([63201dd](https://github.com/LerianStudio/lib-commons/v2/commit/63201ddeb00835d8b8b9269f8a32850e4f28374e))
-
-## [1.9.0-beta.5](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0-beta.4...v1.9.0-beta.5) (2025-05-12)
-
-
-### Bug Fixes
-
-* **transaction:** optimize balance operations in UpdateBalances function ([8b6397d](https://github.com/LerianStudio/lib-commons/v2/commit/8b6397df3261cc0f5af190c69b16a55e215952ed))
-
-## [1.9.0-beta.4](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0-beta.3...v1.9.0-beta.4) (2025-05-09)
-
-
-### Bug Fixes
-
-* **transaction:** add index variable to loop iteration ([e2974f0](https://github.com/LerianStudio/lib-commons/v2/commit/e2974f0c2cc87f39417bf42943e143188c3f9fc8))
-
-## [1.9.0-beta.3](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0-beta.2...v1.9.0-beta.3) (2025-05-09)
-
-
-### Bug Fixes
-
-* **transaction:** improve validation in send source and distribute calculations ([625f2f9](https://github.com/LerianStudio/lib-commons/v2/commit/625f2f9598a61dbb4227722f605e1d4798a9a881))
-
-## [1.9.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.9.0-beta.1...v1.9.0-beta.2) (2025-05-09)
-
-
-### Bug Fixes
-
-* **transaction:** improve validation in send source and distribute calculations ([2b05323](https://github.com/LerianStudio/lib-commons/v2/commit/2b05323b81eea70278dbb2326423dedaf5078373))
-
-## [1.9.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.8.0...v1.9.0-beta.1) (2025-05-09)
-
-
-### Bug Fixes
-
-* **transaction:** improve validation in send source and distribute calculations ([4a8f3f5](https://github.com/LerianStudio/lib-commons/v2/commit/4a8f3f59da5563842e0785732ad5b05989f62fb7))
-* **transaction:** improve validation in send source and distribute calculations ([1cf5b04](https://github.com/LerianStudio/lib-commons/v2/commit/1cf5b04fb510594c5d13989c137cc8401ea2e23d))
-
-## [1.8.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.7.0...v1.8.0) (2025-04-24)
-
-
-### Features
-
-* update go mod and go sum and change method health visibility; :sparkles: ([355991f](https://github.com/LerianStudio/lib-commons/v2/commit/355991f4416722ee51356139ed3c4fe08e1fe47e))
-
-## [1.8.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.7.0...v1.8.0-beta.1) (2025-04-24)
-
-
-### Features
-
-* update go mod and go sum and change method health visibility; :sparkles: ([355991f](https://github.com/LerianStudio/lib-commons/v2/commit/355991f4416722ee51356139ed3c4fe08e1fe47e))
-
-## [1.7.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.6.0...v1.7.0) (2025-04-16)
-
-
-### Bug Fixes
-
-* fix lint cuddled code ([dcbf7c6](https://github.com/LerianStudio/lib-commons/v2/commit/dcbf7c6f26f379cec9790e14b76ee2e6868fb142))
-* lint complexity over 31 in getBodyObfuscatedString ([0f9eb4a](https://github.com/LerianStudio/lib-commons/v2/commit/0f9eb4a82a544204119500db09d38fd6ec003c7e))
-* obfuscate password field in the body before logging ([e35bfa3](https://github.com/LerianStudio/lib-commons/v2/commit/e35bfa36424caae3f90b351ed979d2c6e6e143f5))
-
-## [1.7.0-beta.3](https://github.com/LerianStudio/lib-commons/v2/compare/v1.7.0-beta.2...v1.7.0-beta.3) (2025-04-16)
-
-
-### Bug Fixes
-
-* lint complexity over 31 in getBodyObfuscatedString ([0f9eb4a](https://github.com/LerianStudio/lib-commons/v2/commit/0f9eb4a82a544204119500db09d38fd6ec003c7e))
-
-## [1.7.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.7.0-beta.1...v1.7.0-beta.2) (2025-04-16)
-
-## [1.7.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.6.0...v1.7.0-beta.1) (2025-04-16)
-
-
-### Bug Fixes
-
-* fix lint cuddled code ([dcbf7c6](https://github.com/LerianStudio/lib-commons/v2/commit/dcbf7c6f26f379cec9790e14b76ee2e6868fb142))
-* obfuscate password field in the body before logging ([e35bfa3](https://github.com/LerianStudio/lib-commons/v2/commit/e35bfa36424caae3f90b351ed979d2c6e6e143f5))
-
-## [1.6.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.5.0...v1.6.0) (2025-04-11)
-
-
-### Bug Fixes
-
-* **transaction:** correct percentage calculation in CalculateTotal ([02b939c](https://github.com/LerianStudio/lib-commons/v2/commit/02b939c3abf1834de2078c2d0ae40b4fd9095bca))
-
-## [1.6.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.5.0...v1.6.0-beta.1) (2025-04-11)
-
-
-### Bug Fixes
-
-* **transaction:** correct percentage calculation in CalculateTotal ([02b939c](https://github.com/LerianStudio/lib-commons/v2/commit/02b939c3abf1834de2078c2d0ae40b4fd9095bca))
-
-## [1.5.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.4.0...v1.5.0) (2025-04-10)
-
-
-### Features
-
-* adding accountAlias field to keep backward compatibility ([81bf528](https://github.com/LerianStudio/lib-commons/v2/commit/81bf528dfa8ceb5055714589745c1d3987cfa6da))
-
-## [1.5.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.4.0...v1.5.0-beta.1) (2025-04-09)
-
-
-### Features
-
-* adding accountAlias field to keep backward compatibility ([81bf528](https://github.com/LerianStudio/lib-commons/v2/commit/81bf528dfa8ceb5055714589745c1d3987cfa6da))
-
-## [1.4.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.3.0...v1.4.0) (2025-04-08)
-
-## [1.4.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.3.1-beta.1...v1.4.0-beta.1) (2025-04-08)
-
-## [1.3.1-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.3.0...v1.3.1-beta.1) (2025-04-08)
-
-## [1.3.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.2.0...v1.3.0) (2025-04-08)
-
-## [1.3.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.2.0...v1.3.0-beta.1) (2025-04-08)
-
-## [1.2.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.1.0...v1.2.0) (2025-04-03)
-
-
-### Bug Fixes
-
-* update safe uint convertion to convert int instead of int64 ([a85628b](https://github.com/LerianStudio/lib-commons/v2/commit/a85628bb031d64d542b378180c2254c198e9ae59))
-* update safe uint convertion to convert max int to uint first to validate ([c7dee02](https://github.com/LerianStudio/lib-commons/v2/commit/c7dee026532f42712eabdb3fde0c8d2b8ec7cdd8))
-
-## [1.2.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.1.0...v1.2.0-beta.1) (2025-04-03)
-
-
-### Bug Fixes
-
-* update safe uint convertion to convert int instead of int64 ([a85628b](https://github.com/LerianStudio/lib-commons/v2/commit/a85628bb031d64d542b378180c2254c198e9ae59))
-* update safe uint convertion to convert max int to uint first to validate ([c7dee02](https://github.com/LerianStudio/lib-commons/v2/commit/c7dee026532f42712eabdb3fde0c8d2b8ec7cdd8))
-
-## [1.1.0](https://github.com/LerianStudio/lib-commons/v2/compare/v1.0.0...v1.1.0) (2025-04-03)
-
-
-### Features
-
-* add safe uint convertion ([0d9e405](https://github.com/LerianStudio/lib-commons/v2/commit/0d9e4052ebbd70b18508d68906296c35b881d85e))
-* organize golangci-lint module ([8d71f3b](https://github.com/LerianStudio/lib-commons/v2/commit/8d71f3bb2079457617a5ff8a8290492fd885b30d))
-
-
-### Bug Fixes
-
-* golang lint fixed version to v1.64.8; go mod and sum update packages; :bug: ([6b825c1](https://github.com/LerianStudio/lib-commons/v2/commit/6b825c1a0162326df2abb93b128419f2ea9a4175))
-
-## [1.1.0-beta.3](https://github.com/LerianStudio/lib-commons/v2/compare/v1.1.0-beta.2...v1.1.0-beta.3) (2025-04-03)
-
-
-### Features
-
-* add safe uint convertion ([0d9e405](https://github.com/LerianStudio/lib-commons/v2/commit/0d9e4052ebbd70b18508d68906296c35b881d85e))
-
-## [1.1.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.1.0-beta.1...v1.1.0-beta.2) (2025-03-27)
-
-
-### Features
-
-* organize golangci-lint module ([8d71f3b](https://github.com/LerianStudio/lib-commons/v2/commit/8d71f3bb2079457617a5ff8a8290492fd885b30d))
-
-## [1.1.0-beta.1](https://github.com/LerianStudio/lib-commons/v2/compare/v1.0.0...v1.1.0-beta.1) (2025-03-25)
-
-
-### Bug Fixes
-
-* golang lint fixed version to v1.64.8; go mod and sum update packages; :bug: ([6b825c1](https://github.com/LerianStudio/lib-commons/v2/commit/6b825c1a0162326df2abb93b128419f2ea9a4175))
-
-## 1.0.0 (2025-03-19)
-
-
-### Features
-
-* add transaction validations to the lib-commons; :sparkles: ([098b730](https://github.com/LerianStudio/lib-commons/v2/commit/098b730fa1686b2f683faec69fabd6aa1607cf0b))
-* initial commit to lib commons; ([7d49924](https://github.com/LerianStudio/lib-commons/v2/commit/7d4992494a1328fd1c0afc4f5814fa5c63cb0f9c))
-* initiate new implements from lib-commons; ([18dff5c](https://github.com/LerianStudio/lib-commons/v2/commit/18dff5cbde19bd2659368ce5665a01f79119e7ef))
-
-
-### Bug Fixes
-
-* remove midaz reference; :bug: ([27cbdaa](https://github.com/LerianStudio/lib-commons/v2/commit/27cbdaa5ad103edf903fb24d2b652e7e9f15d909))
-* remove wrong tests; :bug: ([9f9d30f](https://github.com/LerianStudio/lib-commons/v2/commit/9f9d30f0d783ab3f9f4f6e7141981e3b266ba600))
-* update message withBasicAuth.go ([d1dcdbc](https://github.com/LerianStudio/lib-commons/v2/commit/d1dcdbc7dfd4ef829b94de19db71e273452be425))
-* update some places and adjust golint; :bug: ([db18dbb](https://github.com/LerianStudio/lib-commons/v2/commit/db18dbb7270675e87c150f3216ac9be1b2610c1c))
-* update to return err instead of nil; :bug: ([8aade18](https://github.com/LerianStudio/lib-commons/v2/commit/8aade18d65bf6fe0d4e925f3bf178c51672fd7f4))
-* update to use one response json objetc; :bug: ([2e42859](https://github.com/LerianStudio/lib-commons/v2/commit/2e428598b1f41f9c2de369a34510c5ed2ba21569))
-
-## [1.0.0-beta.2](https://github.com/LerianStudio/lib-commons/v2/compare/v1.0.0-beta.1...v1.0.0-beta.2) (2025-03-19)
-
-
-### Features
-
-* add transaction validations to the lib-commons; :sparkles: ([098b730](https://github.com/LerianStudio/lib-commons/v2/commit/098b730fa1686b2f683faec69fabd6aa1607cf0b))
-
-
-### Bug Fixes
-
-* update some places and adjust golint; :bug: ([db18dbb](https://github.com/LerianStudio/lib-commons/v2/commit/db18dbb7270675e87c150f3216ac9be1b2610c1c))
-* update to use one response json objetc; :bug: ([2e42859](https://github.com/LerianStudio/lib-commons/v2/commit/2e428598b1f41f9c2de369a34510c5ed2ba21569))
-
-## 1.0.0-beta.1 (2025-03-18)
-
-
-### Features
-
-* initial commit to lib commons; ([7d49924](https://github.com/LerianStudio/lib-commons/v2/commit/7d4992494a1328fd1c0afc4f5814fa5c63cb0f9c))
-* initiate new implements from lib-commons; ([18dff5c](https://github.com/LerianStudio/lib-commons/v2/commit/18dff5cbde19bd2659368ce5665a01f79119e7ef))
-
-
-### Bug Fixes
-
-* remove midaz reference; :bug: ([27cbdaa](https://github.com/LerianStudio/lib-commons/v2/commit/27cbdaa5ad103edf903fb24d2b652e7e9f15d909))
-* remove wrong tests; :bug: ([9f9d30f](https://github.com/LerianStudio/lib-commons/v2/commit/9f9d30f0d783ab3f9f4f6e7141981e3b266ba600))
-* update message withBasicAuth.go ([d1dcdbc](https://github.com/LerianStudio/lib-commons/v2/commit/d1dcdbc7dfd4ef829b94de19db71e273452be425))
-* update to return err instead of nil; :bug: ([8aade18](https://github.com/LerianStudio/lib-commons/v2/commit/8aade18d65bf6fe0d4e925f3bf178c51672fd7f4))
-
-## 1.0.0 (2025-03-06)
-
-
-### Features
-
-* configuration of CI/CD ([1bb1c4c](https://github.com/LerianStudio/lib-boilerplate/commit/1bb1c4ca0659e593ff22b3b5bf919163366301a7))
-* set configuration of boilerplate ([138a60c](https://github.com/LerianStudio/lib-boilerplate/commit/138a60c7947a9e82e4808fa16cc53975e27e7de5))
-
-## 1.0.0-beta.1 (2025-03-06)
-
-
-### Features
-
-* configuration of CI/CD ([1bb1c4c](https://github.com/LerianStudio/lib-boilerplate/commit/1bb1c4ca0659e593ff22b3b5bf919163366301a7))
-* set configuration of boilerplate ([138a60c](https://github.com/LerianStudio/lib-boilerplate/commit/138a60c7947a9e82e4808fa16cc53975e27e7de5))
+All notable changes to lib-commons will be documented in this file.
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 120000
index 00000000..47dc3e3d
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1 @@
+AGENTS.md
\ No newline at end of file
diff --git a/MIGRATION_MAP.md b/MIGRATION_MAP.md
new file mode 100644
index 00000000..1fc97af0
--- /dev/null
+++ b/MIGRATION_MAP.md
@@ -0,0 +1,964 @@
+# lib-commons Migration Map (v3 -> v4)
+
+This document maps notable `lib-commons/v3` APIs to the unified `lib-commons/v4` APIs. Use it as a lookup reference when migrating consumer code from the previous `lib-commons` line to the new unified major version.
+
+---
+
+## commons/opentelemetry
+
+### Initialization
+
+| v3 | v4 | Notes |
+|----|----|----|
+| `InitializeTelemetryWithError(*TelemetryConfig)` | `NewTelemetry(TelemetryConfig) (*Telemetry, error)` | Config passed by value, not pointer |
+| `InitializeTelemetry(*TelemetryConfig)` | removed | Use `NewTelemetry` (no silent-failure variant) |
+| implicit globals on init | explicit `(*Telemetry).ApplyGlobals()` | Globals are opt-in now |
+
+### Span helpers (pointer -> value receivers on span)
+
+| v3 | v4 |
+|----|----|
+| `HandleSpanError(*trace.Span, ...)` | `HandleSpanError(trace.Span, ...)` |
+| `HandleSpanEvent(*trace.Span, ...)` | `HandleSpanEvent(trace.Span, ...)` |
+| `HandleSpanBusinessErrorEvent(*trace.Span, ...)` | `HandleSpanBusinessErrorEvent(trace.Span, ...)` |
+
+### Span attributes
+
+| v3 | v4 |
+|----|----|
+| `SetSpanAttributesFromStruct(...)` | removed; use `SetSpanAttributesFromValue(...)` |
+| `SetSpanAttributesFromStructWithObfuscation(...)` | removed; use `SetSpanAttributesFromValue(...)` |
+| `SetSpanAttributesFromStructWithCustomObfuscation(...)` | removed; use `SetSpanAttributesFromValue(...)` |
+
+### Struct and field changes
+
+| v3 | v4 |
+|----|----|
+| `Telemetry.MetricProvider` field | renamed to `Telemetry.MeterProvider` |
+| `ErrNilTelemetryConfig` | removed; replaced by `ErrNilTelemetryLogger`, `ErrEmptyEndpoint`, `ErrNilTelemetry`, `ErrNilShutdown` |
+
+### New in v4
+
+- `TelemetryConfig` gains fields: `InsecureExporter bool`, `Propagator propagation.TextMapPropagator`, `Redactor *Redactor`
+- New method: `(*Telemetry).Tracer(name) (trace.Tracer, error)`
+- New method: `(*Telemetry).Meter(name) (metric.Meter, error)`
+- New method: `(*Telemetry).ShutdownTelemetryWithContext(ctx) error` -- context-aware shutdown (alternative to `ShutdownTelemetry()`)
+- New type: `RedactingAttrBagSpanProcessor` (span processor that redacts sensitive span attributes)
+
+### Obfuscation -> Redaction
+
+The former obfuscation subsystem has been replaced by the redaction subsystem in v4.
+
+| v3 | v4 |
+|----|----|
+| `FieldObfuscator` interface | removed entirely |
+| `DefaultObfuscator` struct | removed |
+| `CustomObfuscator` struct | removed |
+| `NewDefaultObfuscator()` | `NewDefaultRedactor()` |
+| `NewCustomObfuscator([]string)` | `NewRedactor([]RedactionRule, maskValue)` |
+| `ObfuscateStruct(any, FieldObfuscator)` | `ObfuscateStruct(any, *Redactor)` |
+
+New types:
+
+- `RedactionAction` (string type)
+- `RedactionRule` struct
+- `Redactor` struct
+- Constants: `RedactionMask`, `RedactionHash`, `RedactionDrop`
+
+### Propagation
+
+All propagation functions now follow the `context-first` convention.
+
+| v3 | v4 |
+|----|----|
+| `InjectHTTPContext(*http.Header, context.Context)` | `InjectHTTPContext(context.Context, http.Header)` |
+| `ExtractHTTPContext(*fiber.Ctx)` | `ExtractHTTPContext(context.Context, *fiber.Ctx)` |
+| `InjectGRPCContext(context.Context)` | `InjectGRPCContext(context.Context, metadata.MD) metadata.MD` |
+| `ExtractGRPCContext(context.Context)` | `ExtractGRPCContext(context.Context, metadata.MD) context.Context` |
+
+New low-level APIs:
+
+- `InjectTraceContext(context.Context, propagation.TextMapCarrier)`
+- `ExtractTraceContext(context.Context, propagation.TextMapCarrier) context.Context`
+
+---
+
+## commons/opentelemetry/metrics
+
+### Factory and builders now return errors
+
+| v3 | v4 |
+|----|----|
+| `NewMetricsFactory(meter, logger) *MetricsFactory` | `NewMetricsFactory(meter, logger) (*MetricsFactory, error)` |
+| `(*MetricsFactory).Counter(m) *CounterBuilder` | `(*MetricsFactory).Counter(m) (*CounterBuilder, error)` |
+| `(*MetricsFactory).Gauge(m) *GaugeBuilder` | `(*MetricsFactory).Gauge(m) (*GaugeBuilder, error)` |
+| `(*MetricsFactory).Histogram(m) *HistogramBuilder` | `(*MetricsFactory).Histogram(m) (*HistogramBuilder, error)` |
+
+### Builder operations now return errors
+
+| v3 | v4 |
+|----|----|
+| `(*CounterBuilder).Add(ctx, value)` | now returns `error` |
+| `(*CounterBuilder).AddOne(ctx)` | now returns `error` |
+| `(*GaugeBuilder).Set(ctx, value)` | now returns `error` |
+| `(*GaugeBuilder).Record(ctx, value)` | removed (was deprecated; use `Set`) |
+| `(*HistogramBuilder).Record(ctx, value)` | now returns `error` |
+
+### Removed label helpers
+
+| v3 | v4 |
+|----|----|
+| `WithOrganizationLabels(...)` | removed |
+| `WithLedgerLabels(...)` | removed |
+
+### Convenience recorders (organization/ledger args removed)
+
+| v3 | v4 |
+|----|----|
+| `RecordAccountCreated(ctx, organizationID, ledgerID, attrs...)` | `RecordAccountCreated(ctx, attrs...) error` |
+| `RecordTransactionProcessed(ctx, organizationID, ledgerID, attrs...)` | `RecordTransactionProcessed(ctx, attrs...) error` |
+| `RecordOperationRouteCreated(ctx, organizationID, ledgerID, attrs...)` | `RecordOperationRouteCreated(ctx, attrs...) error` |
+| `RecordTransactionRouteCreated(ctx, organizationID, ledgerID, attrs...)` | `RecordTransactionRouteCreated(ctx, attrs...) error` |
+
+**Migration note:** The `organizationID` and `ledgerID` positional parameters and the internal `WithLedgerLabels()` call were removed in v4. Callers must now pass these labels explicitly via OpenTelemetry attributes:
+
+```go
+// v3
+factory.RecordAccountCreated(ctx, orgID, ledgerID)
+
+// v4
+factory.RecordAccountCreated(ctx,
+ attribute.String("organization_id", orgID),
+ attribute.String("ledger_id", ledgerID),
+)
+```
+
+### New in v4
+
+- `NewNopFactory() *MetricsFactory` -- no-op fallback for tests / disabled metrics
+- New sentinel errors: `ErrNilMeter`, `ErrNilCounter`, `ErrNilGauge`, `ErrNilHistogram`
+
+---
+
+## commons/log
+
+### Interface rewrite (18 methods -> 5)
+
+The `Logger` interface has been completely redesigned.
+
+**v3 interface (18 methods):**
+
+```
+Info / Infof / Infoln
+Error / Errorf / Errorln
+Warn / Warnf / Warnln
+Debug / Debugf / Debugln
+Fatal / Fatalf / Fatalln
+WithFields(fields ...any) Logger
+WithDefaultMessageTemplate(message string) Logger
+Sync() error
+```
+
+**v4 interface (5 methods):**
+
+```
+Log(ctx context.Context, level Level, msg string, fields ...Field)
+With(fields ...Field) Logger
+WithGroup(name string) Logger
+Enabled(level Level) bool
+Sync(ctx context.Context) error
+```
+
+### Level type and constants
+
+| v3 | v4 |
+|----|----|
+| `LogLevel` type (int8) | `Level` type (uint8) |
+| `PanicLevel` | removed entirely |
+| `FatalLevel` | removed entirely |
+| `ErrorLevel` | `LevelError` |
+| `WarnLevel` | `LevelWarn` |
+| `InfoLevel` | `LevelInfo` |
+| `DebugLevel` | `LevelDebug` |
+| `ParseLevel(string) (LogLevel, error)` | `ParseLevel(string) (Level, error)` (no longer accepts "panic" or "fatal") |
+
+### Logger helpers
+
+| v3 | v4 |
+|----|----|
+| `NoneLogger` | `NopLogger` |
+| (no constructor) | `NewNop() Logger` |
+| `WithFields(fields ...any) Logger` | `With(fields ...Field) Logger` |
+| `WithDefaultMessageTemplate(message string) Logger` | removed |
+| `Sync() error` | `Sync(ctx context.Context) error` |
+
+### New `Field` type
+
+v4 introduces a structured `Field` type with constructors:
+
+- `Field` struct: `Key string`, `Value any`
+- `Any(key, value) Field`
+- `String(key, value) Field`
+- `Int(key, value) Field`
+- `Bool(key, value) Field`
+- `Err(err) Field`
+
+### Level constants
+
+- `LevelError` (0), `LevelWarn` (1), `LevelInfo` (2), `LevelDebug` (3), `LevelUnknown` (255)
+
+### GoLogger
+
+`GoLogger` moved from `log.go` to `go_logger.go`, fully reimplemented with the v4 interface. Includes CWE-117 log-injection prevention.
+
+### Sanitizer (package move)
+
+| v3 | v4 |
+|----|----|
+| `commons/logging` package | removed entirely |
+| `logging.SafeErrorf(...)` | `log.SafeError(logger, ctx, msg, err, production)` |
+| `logging.SanitizeExternalResponse(...)` | `log.SanitizeExternalResponse(statusCode) string` |
+
+---
+
+## commons/zap
+
+| v3 | v4 |
+|----|----|
+| `ZapWithTraceLogger` struct | `Logger` struct (renamed, restructured) |
+| `InitializeLoggerWithError() (log.Logger, error)` | removed (use `New(...)`) |
+| `InitializeLogger() log.Logger` | removed (use `New(...)`) |
+| `InitializeLoggerFromConfig(...)` | `New(cfg Config) (*Logger, error)` |
+| `hydrateArgs` / template-based logging | removed |
+
+### New in v4
+
+- New types: `Config`, `Environment` (string type with constants: `EnvironmentProduction`, `EnvironmentStaging`, `EnvironmentUAT`, `EnvironmentDevelopment`, `EnvironmentLocal`)
+- `Logger.Raw() *zap.Logger` -- access underlying zap logger
+- `Logger.Level() zap.AtomicLevel` -- access dynamic log level
+- Direct zap convenience methods: `Debug()`, `Info()`, `Warn()`, `Error()`, `WithZapFields()`
+- Field constructors: `Any(key, value)`, `String(key, value)`, `Int(key, value)`, `Bool(key, value)`, `Duration(key, value)`, `ErrorField(err)`
+
+---
+
+## commons/net/http
+
+### Response helpers consolidated
+
+All individual status helpers have been removed in favor of two generic functions.
+
+| v3 | v4 |
+|----|----|
+| `WriteError(c, status, title, message)` | `RespondError(c, status, title, message)` |
+| `HandleFiberError(c, err)` | `FiberErrorHandler(c, err)` |
+| `JSONResponse(c, status, s)` | `Respond(c, status, payload)` |
+| `JSONResponseError(c, err)` | removed (use `RespondError`) |
+| `NoContent(c)` | `RespondStatus(c, status)` |
+
+**Removed individual status helpers** (use `Respond` / `RespondError` / `RespondStatus` instead):
+
+`BadRequestError`, `UnauthorizedError`, `ForbiddenError`, `NotFoundError`, `ConflictError`, `RequestEntityTooLargeError`, `UnprocessableEntityError`, `SimpleInternalServerError`, `InternalServerErrorWithTitle`, `ServiceUnavailableError`, `ServiceUnavailableErrorWithTitle`, `GatewayTimeoutError`, `GatewayTimeoutErrorWithTitle`, `Unauthorized`, `Forbidden`, `BadRequest`, `Created`, `OK`, `Accepted`, `PartialContent`, `RangeNotSatisfiable`, `NotFound`, `Conflict`, `NotImplemented`, `UnprocessableEntity`, `InternalServerError`
+
+### Cursor pagination
+
+| v3 | v4 |
+|----|----|
+| `Cursor.PointsNext` (bool) | `Cursor.Direction` (string: `"next"` / `"prev"`) |
+| `CreateCursor(id, pointsNext)` | removed (construct `Cursor` directly) |
+| `ApplyCursorPagination(squirrel.SelectBuilder, ...)` | removed (use `CursorDirectionRules(sortDir, cursorDir)`) |
+| `PaginateRecords[T](..., pointsNext bool, ..., orderUsed string)` | `PaginateRecords[T](..., cursorDirection string, ...) ` (orderUsed removed) |
+| `CalculateCursor(..., pointsNext bool, ...)` | `CalculateCursor(..., cursorDirection string, ...)` |
+| `EncodeCursor(cursor) string` | `EncodeCursor(cursor) (string, error)` (now validates) |
+
+New constants: `CursorDirectionNext`, `CursorDirectionPrev`
+New error: `ErrInvalidCursorDirection`
+
+### Validation / context
+
+| v3 | v4 |
+|----|----|
+| `ParseAndVerifyContextParam(...)` | `ParseAndVerifyTenantScopedID(...)` |
+| `ParseAndVerifyContextQuery(...)` | `ParseAndVerifyResourceScopedID(...)` |
+| `ParseAndVerifyExceptionParam(...)` | removed |
+| `ParseAndVerifyDisputeParam(...)` | removed |
+| `ContextOwnershipVerifier` interface | `TenantOwnershipVerifier` func type |
+| `ExceptionOwnershipVerifier` interface | removed |
+| `DisputeOwnershipVerifier` interface | removed |
+
+New types: `ResourceOwnershipVerifier` func type, `IDLocation` type, `ErrInvalidIDLocation`, `ErrLookupFailed`
+
+### Error types
+
+| v3 | v4 |
+|----|----|
+| `ErrorResponse.Code` (string) | `ErrorResponse.Code` (int) |
+| `ErrorResponse.Error` field | removed |
+| `WithError(ctx, err)` | `RenderError(ctx, err)` |
+| `HealthSimple` var | removed (use `Ping` directly) |
+
+`ErrorResponse` now implements the `error` interface.
+
+**Wire format impact:** `ErrorResponse.Code` changed from `string` to `int`, which changes the JSON serialization from `"code": "400"` to `"code": 400`. Any downstream consumer that unmarshals error responses with `Code` as a string type will break. Callers must update their response parsing structs to use `int` (or a numeric JSON type) for the `code` field.
+
+### Proxy
+
+| v3 | v4 |
+|----|----|
+| `ServeReverseProxy(target, res, req)` | `ServeReverseProxy(target, policy, res, req) error` |
+
+New: `DefaultReverseProxyPolicy()`, `ReverseProxyPolicy` struct with SSRF protection.
+
+### Pagination (v4 refinement)
+
+| v4 (previous) | v4 (current) |
+|---|---|
+| `EncodeTimestampCursor(time, uuid) string` | `EncodeTimestampCursor(time, uuid) (string, error)` |
+| `EncodeSortCursor(col, val, id, next) string` | `EncodeSortCursor(col, val, id, next) (string, error)` |
+| `CalculateSortCursorPagination(...) (next, prev string)` | `CalculateSortCursorPagination(...) (next, prev string, err error)` |
+| `ErrOffsetMustBePositive` sentinel | removed (negative offset silently coerced to `DefaultOffset=0`; see note below) |
+| `type Order string` + `Asc Order = "asc"` / `Desc Order = "desc"` | removed; replaced by `SortDirASC = "ASC"` / `SortDirDESC = "DESC"` (untyped `string`, uppercase) |
+
+**Migration note (offset coercion):** The `ErrOffsetMustBePositive` sentinel error is removed. In v4, negative offsets are silently coerced to `DefaultOffset=0` instead of returning an error. This tradeoff avoids breaking callers that relied on the previous behavior and preserves backward compatibility. However, callers should validate offsets before calling pagination functions (e.g., reject negative offsets at the handler level) since the pagination codepaths that previously returned `ErrOffsetMustBePositive` will now silently accept any negative value.
+
+**Migration note (cursor/sort):** The cursor encode functions now return errors. The `Order` type is removed; use the `SortDirASC`/`SortDirDESC` constants directly. Note the **case change** from lowercase `"asc"`/`"desc"` to uppercase `"ASC"`/`"DESC"` — any consumer that stores or compares these values must be updated.
+
+New pagination defaults in `constants/pagination.go`: `DefaultLimit=20`, `DefaultOffset=0`, `MaxLimit=200`.
+
+### Handler
+
+| v4 (previous) | v4 (current) |
+|---|---|
+| `Ping` handler returns `"healthy"` | `Ping` handler returns `"pong"` |
+
+**Migration note:** Any health check monitor that string-matches the response body for `"healthy"` must be updated. Use `HealthWithDependencies` for production health endpoints.
+
+### Health check semantics
+
+| v4 (previous) | v4 (current) |
+|---|---|
+| `HealthWithDependencies`: HealthCheck overrides CircuitBreaker status | Both must report healthy (AND semantics) |
+
+**Migration note:** An open circuit breaker can no longer be overridden by a passing HealthCheck function. This is the correct reliability behavior but may surface previously-hidden unhealthy states.
+
+### Rate limit storage
+
+| v3 | v4 |
+|----|----|
+| `NewRedisStorage(conn *RedisConnection)` | `NewRedisStorage(conn *Client)` |
+| Nil storage operations silently return nil | Now return `ErrStorageUnavailable` |
+
+---
+
+## commons/server
+
+| v3 | v4 |
+|----|----|
+| `GracefulShutdown` struct | removed entirely |
+| `NewGracefulShutdown(...)` | removed |
+| `(*GracefulShutdown).HandleShutdown()` | removed |
+
+Use `ServerManager` (already existed in v3) with `StartWithGracefulShutdown()`.
+
+### New in v4
+
+- `(*ServerManager).WithShutdownTimeout(d) *ServerManager` -- configures max wait for gRPC GracefulStop before hard stop (default: 30s)
+- `(*ServerManager).WithShutdownHook(hook func(context.Context) error) *ServerManager` -- registers cleanup callbacks executed during graceful shutdown (nil hooks are silently ignored)
+- `(*ServerManager).WithShutdownChannel(ch <-chan struct{}) *ServerManager` -- custom shutdown trigger for tests (instead of relying on OS signals)
+- `(*ServerManager).StartWithGracefulShutdownWithError() error` -- returns error on config failure instead of calling `os.Exit(1)`
+- `(*ServerManager).ServersStarted() <-chan struct{}` -- closed when server goroutines have been launched (for test coordination)
+- `ErrNoServersConfigured` sentinel error
+
+---
+
+## commons/mongo
+
+| v3 | v4 |
+|----|----|
+| `MongoConnection` struct | `Client` struct |
+| `BuildConnectionString(scheme, user, password, host, port, parameters, logger) string` | `BuildURI(URIConfig) (string, error)` |
+| `MongoConnection{}` + `Connect(ctx)` | `NewClient(ctx, cfg Config, opts ...Option) (*Client, error)` |
+| `GetDB(ctx) (*mongo.Client, error)` | `Client(ctx) (*mongo.Client, error)` |
+| `EnsureIndexes(ctx, collection, index)` | `EnsureIndexes(ctx, collection, indexes...) error` (variadic) |
+
+### Error sentinels (v4 refinement)
+
+| v4 (previous) | v4 (current) | Notes |
+|---|---|---|
+| `ErrClientClosed` (nil receiver) | `ErrNilClient` | Nil receiver now returns `ErrNilClient`; `ErrClientClosed` reserved for closed/not-connected state |
+
+### New in v4
+
+- Methods: `Database(ctx)`, `DatabaseName()`, `Ping(ctx)`, `Close(ctx)`, `ResolveClient(ctx)` (alias for `Client(ctx)`)
+- Types: `Config`, `URIConfig`, `Option`, `TLSConfig`
+- Sentinel errors: `ErrNilClient`, `ErrNilDependency`, `ErrInvalidConfig`, `ErrEmptyURI`, `ErrEmptyDatabaseName`, `ErrEmptyCollectionName`, `ErrEmptyIndexes`, `ErrConnect`, `ErrPing`, `ErrDisconnect`, `ErrCreateIndex`, `ErrNilMongoClient`, `ErrNilContext`
+- URI builder errors: `ErrInvalidScheme`, `ErrEmptyHost`, `ErrInvalidPort`, `ErrPortNotAllowedForSRV`, `ErrPasswordWithoutUser`
+- `Config.TLS` field — optional `*TLSConfig` for TLS connections (mirrors redis `TLSConfig`)
+- Non-TLS connection warning — logs at `Warn` level when connecting without TLS
+- `Config.MaxPoolSize` silently clamped to 1000 (mirrors redis `maxPoolSize` pattern)
+- Credential clearing — `Config.URI` is cleared after successful `Connect()` to reduce credential exposure
+
+---
+
+## commons/redis
+
+| v3 | v4 |
+|----|----|
+| `RedisConnection` struct | `Client` struct |
+| `Mode` type | removed |
+| `RedisConnection{}` + `Connect(ctx)` | `New(ctx, cfg Config) (*Client, error)` |
+| `NewDistributedLock(conn *RedisConnection)` | `NewDistributedLock(conn *Client)` |
+| `WithLock(ctx, key, func() error)` | `WithLock(ctx, key, func(context.Context) error)` (context propagated to callback) |
+| `WithLockOptions(ctx, key, opts, func() error)` | `WithLockOptions(ctx, key, opts, func(context.Context) error)` |
+| `InitVariables()` | removed (handled by constructor) |
+| `BuildTLSConfig()` | removed (handled internally) |
+
+### Behavioral changes
+
+| Behavior | v4 |
+|----------|-----|
+| TLS minimum version | `normalizeTLSDefaults` enforces `tls.VersionTLS12` as the minimum TLS version. Explicit `tls.VersionTLS10` or `tls.VersionTLS11` values in `TLSConfig.MinVersion` are upgraded to TLS 1.2 and a warning is logged. If you still need legacy endpoints temporarily, set `TLSConfig.AllowLegacyMinVersion=true` as an explicit compatibility override and plan removal. |
+
+Recommended rollout:
+
+- First deploy with explicit `TLSConfig.MinVersion=tls.VersionTLS12` where endpoints are compatible.
+- Use `TLSConfig.AllowLegacyMinVersion=true` only for temporary exceptions and monitor warning logs.
+- Remove legacy override after endpoint upgrades to restore strict floor enforcement.
+
+### Interface and lock handle changes
+
+| v4 (previous) | v4 (current) |
+|----|----|
+| `TryLock(ctx, key) (*redsync.Mutex, bool, error)` | `TryLock(ctx, key) (LockHandle, bool, error)` |
+| `Unlock(ctx, *redsync.Mutex) error` | `LockHandle.Unlock(ctx) error` |
+| `DistributedLocker` interface (4 methods, imports `redsync`) | `LockManager` interface (3 methods, no `redsync` dependency) |
+| `DistributedLock` struct | `RedisLockManager` struct |
+| `NewDistributedLock(conn)` | `NewRedisLockManager(conn) (*RedisLockManager, error)` |
+
+**Migration note:** `TryLock` now returns an opaque `LockHandle` instead of `*redsync.Mutex`. Call `handle.Unlock(ctx)` directly instead of `lock.Unlock(ctx, mutex)`. The standalone `Unlock` method on `DistributedLock` is deprecated -- it now accepts `LockHandle` instead of `*redsync.Mutex`. Consumers no longer need to import `github.com/go-redsync/redsync/v4` to use the `DistributedLocker` interface.
+
+### New in v4
+
+- Config types: `Config`, `Topology`, `StandaloneTopology`, `SentinelTopology`, `ClusterTopology`, `TLSConfig`, `Auth`, `StaticPasswordAuth`, `GCPIAMAuth`, `ConnectionOptions`
+- Methods: `GetClient(ctx) (redis.UniversalClient, error)`, `Close() error`, `Status() (Status, error)`, `IsConnected() (bool, error)`, `LastRefreshError() error`
+- `SetPackageLogger(log.Logger)` -- configures package-level logger for nil-receiver assertion diagnostics
+- `LockHandle` interface -- opaque lock token with self-contained `Unlock(ctx) error`
+- `DefaultLockOptions() LockOptions` -- sensible defaults for general-purpose locking
+- `RateLimiterLockOptions() LockOptions` -- optimized for rate limiter use case
+- `StaticPasswordAuth.String()` / `GCPIAMAuth.String()` -- credential redaction in `fmt` output
+- Config validation: `RefreshEvery < TokenLifetime` enforced, `PoolSize` capped at 1000, `LockOptions.Tries` capped at 1000
+- Lazy pool adapter: `DistributedLock` survives IAM token refresh reconnections
+
+---
+
+## commons/postgres
+
+| v3 | v4 |
+|----|----|
+| `PostgresConnection` struct | `Client` struct |
+| `PostgresConnection{}` + field assignment | `New(cfg Config) (*Client, error)` |
+| `Connect() error` | `Connect(ctx context.Context) error` |
+| `GetDB() (dbresolver.DB, error)` | `Resolver(ctx context.Context) (dbresolver.DB, error)` |
+| `Pagination` struct | removed (moved to `commons/net/http`) |
+| `squirrel` dependency | removed |
+
+### Error wrapping (v4 refinement)
+
+`SanitizedError.Unwrap()` returns `nil` to prevent error chain traversal from leaking database credentials. `Error()` returns the sanitized text. Because `Unwrap()` is intentionally blocked, `errors.Is/errors.As` do not match the hidden original cause through `SanitizedError`.
+
+### New in v4
+
+- Methods: `Primary() (*sql.DB, error)`, `Close() error`, `IsConnected() (bool, error)`
+- Types: `Config`, `MigrationConfig`, `SanitizedError`
+- Migration: `NewMigrator(cfg MigrationConfig) (*Migrator, error)` and `(*Migrator).Up(ctx) error`
+
+---
+
+## commons/rabbitmq
+
+### Context-aware methods added alongside existing ones
+
+| Existing (kept) | New context-aware variant |
+|----|----|
+| `Connect()` | `ConnectContext(ctx) error` |
+| `EnsureChannel()` | `EnsureChannelContext(ctx) error` |
+| `GetNewConnect()` | `GetNewConnectContext(ctx) (*amqp.Channel, error)` |
+
+### Changed signatures
+
+| v3 | v4 |
+|----|----|
+| `HealthCheck() bool` | `HealthCheck() (bool, error)` (now returns error) |
+
+### New in v4
+
+- `HealthCheckContext(ctx) (bool, error)`
+- `Close() error`, `CloseContext(ctx) error`
+- New errors: `ErrInsecureTLS`, `ErrNilConnection`, `ErrInsecureHealthCheck`, `ErrHealthCheckHostNotAllowed`, `ErrHealthCheckAllowedHostsRequired`
+
+### Health check rollout/security knobs
+
+- Basic auth over plain HTTP is rejected by default; set `AllowInsecureHealthCheck=true` only as temporary compatibility override.
+- Basic-auth health checks now require `HealthCheckAllowedHosts` unless `AllowInsecureHealthCheck=true` is explicitly set.
+- Host allowlist controls: `HealthCheckAllowedHosts` (accepts `host` or `host:port`) and `RequireHealthCheckAllowedHosts`.
+- Recommended rollout: configure `HealthCheckAllowedHosts` first, then enable `RequireHealthCheckAllowedHosts=true`.
+
+---
+
+## commons/outbox
+
+The root `commons/outbox` package is newly available in the unified `lib-commons/v4` line.
+
+Key APIs now available to consumers:
+
+- `NewOutboxEvent(...)` / `NewOutboxEventWithID(...)` -- validated outbox event construction
+- `Dispatcher`, `DispatcherConfig`, `DefaultDispatcherConfig()` -- dispatcher orchestration and tuning
+- Dispatcher options such as `WithBatchSize`, `WithDispatchInterval`, `WithPublishMaxAttempts`, `WithRetryWindow`, `WithProcessingTimeout`, `WithPriorityEventTypes`, and `WithTenantMetricAttributes`
+- Tenant helpers: `ContextWithTenantID`, `TenantIDFromContext`, `TenantResolver`, `TenantDiscoverer`
+
+Use `commons/outbox/postgres` for PostgreSQL-backed repository and tenant resolution implementations.
+
+---
+
+## commons/outbox/postgres
+
+### Behavioral changes
+
+| Behavior | v4 |
+|----------|-----|
+| Schema resolver tenant enforcement | `SchemaResolver` now requires tenant context by default. Use `WithAllowEmptyTenant()` only for explicit public-schema/single-tenant flows. |
+| Schema resolver tenant ID validation | `SchemaResolver.ApplyTenant` and `NewSchemaResolver` now trim whitespace from tenant IDs **and** validate them as UUIDs. Previously, whitespace was silently accepted. In v4, whitespace is trimmed but non-UUID values are rejected with an error (`"invalid tenant id format"` from `ApplyTenant`, `ErrDefaultTenantIDInvalid` from `NewSchemaResolver`). Callers must ensure tenant IDs passed to outbox functions are valid UUIDs — any code using non-UUID tenant identifiers (e.g., plain strings or slugs) will break. |
+| Column migration primary key | `migrations/column/000001_outbox_events_column.up.sql` uses composite primary key `(tenant_id, id)` to avoid cross-tenant key coupling. |
+
+---
+
+## commons/transaction
+
+### Types restructured
+
+**Removed types:** `Responses`, `Metadata`, `Amount`, `Share`, `Send`, `Source`, `Rate`, `FromTo`, `Distribute`, `Transaction`
+
+**New types:** `Operation`, `TransactionStatus`, `AccountType`, `ErrorCode`, `DomainError`, `LedgerTarget`, `Allocation`, `TransactionIntentInput`, `Posting`, `IntentPlan`
+
+New constructor: `NewDomainError(code, field, message) error`
+
+`Balance` struct changes: removed fields `Alias`, `Key`, `AssetCode`; added field `Asset` (replaces `AssetCode`). `AccountType` changed from `string` to typed `AccountType` enum.
+
+New operation types: `OperationDebit`, `OperationCredit`, `OperationOnHold`, `OperationRelease`
+New status types: `StatusCreated`, `StatusApproved`, `StatusPending`, `StatusCanceled`
+New function: `ResolveOperation(pending, isSource bool, status TransactionStatus) (Operation, error)`
+
+### Validation flow
+
+| v3 | v4 |
+|----|----|
+| `ValidateBalancesRules(ctx, transaction, validate, balances) error` | `BuildIntentPlan(input, status) (IntentPlan, error)` + `ValidateBalanceEligibility(plan, balances) error` |
+| `ValidateFromToOperation(ft, validate, balance) (Amount, Balance, error)` | `ApplyPosting(balance, posting) (Balance, error)` |
+
+**Removed helpers:** `SplitAlias`, `ConcatAlias`, `AliasKey`, `SplitAliasWithKey`, `OperateBalances`
+
+---
+
+## commons/circuitbreaker
+
+| v3 | v4 |
+|----|----|
+| `NewManager(logger) Manager` | `NewManager(logger, opts...) (Manager, error)` (returns error on nil logger; accepts options) |
+| `(*Manager).GetOrCreate(serviceName, config) CircuitBreaker` | `(*Manager).GetOrCreate(serviceName, config) (CircuitBreaker, error)` (validates config) |
+
+New: `Config.Validate() error`
+New: `WithMetricsFactory(f *metrics.MetricsFactory) ManagerOption` -- emits `circuit_breaker_state_transitions_total` and `circuit_breaker_executions_total` counters
+
+---
+
+## commons/errors
+
+| v3 | v4 |
+|----|----|
+| `ValidateBusinessError(err, entityType, args...)` | Variadic `args` now appended to error message (previously ignored extra args) |
+
+---
+
+## commons/app
+
+| v3 | v4 |
+|----|----|
+| `(*Launcher).Add(appName, app) *Launcher` | `(*Launcher).Add(appName, app) error` (no more method chaining) |
+
+New sentinel errors: `ErrNilLauncher`, `ErrEmptyApp`, `ErrNilApp`
+
+---
+
+## commons/context (removals)
+
+| v3 | v4 |
+|----|----|
+| `NewTracerFromContext(ctx)` | removed (was deprecated; use `NewTrackingFromContext`) |
+| `NewMetricFactoryFromContext(ctx)` | removed (was deprecated; use `NewTrackingFromContext`) |
+| `NewHeaderIDFromContext(ctx)` | removed (was deprecated; use `NewTrackingFromContext`) |
+| `WithTimeout(parent, timeout)` | removed (was deprecated; use `WithTimeoutSafe`) |
+| All `NoneLogger{}` references | `NopLogger{}` |
+
+---
+
+## commons/os
+
+| v3 | v4 |
+|----|----|
+| `EnsureConfigFromEnvVars(s any) any` | removed (use `SetConfigFromEnvVars(s any) error`) |
+
+---
+
+## commons/utils
+
+### Signature changes
+
+| v3 | v4 |
+|----|----|
+| `GenerateUUIDv7() uuid.UUID` | `GenerateUUIDv7() (uuid.UUID, error)` |
+
+**Migration note:** In v3, `GenerateUUIDv7()` internally used `uuid.Must(uuid.NewV7())`, which panics if `crypto/rand` fails. In v4 the panic path is removed: the function returns `(uuid.UUID, error)` so callers can handle the (rare but possible) entropy-source failure gracefully. All call sites must now check the returned error.
+
+### Removed deprecated functions (moved to Midaz)
+
+- `ValidateCountryAddress`, `ValidateAccountType`, `ValidateType`, `ValidateCode`, `ValidateCurrency`
+- `GenericInternalKey`, `TransactionInternalKey`, `IdempotencyInternalKey`, `BalanceInternalKey`, `AccountingRoutesInternalKey`
+
+---
+
+## commons/crypto
+
+| v3 | v4 |
+|----|----|
+| `Crypto.Logger` field (`*zap.Logger`) | `Crypto.Logger` field (`log.Logger`) |
+
+Direct `go.uber.org/zap` dependency removed from this package.
+
+---
+
+## commons/jwt
+
+### Token validation semantics
+
+| v3 | v4 |
+|----|----|
+| `Token.Valid` (bool) -- full validation | `Token.SignatureValid` (bool) -- signature-only verification |
+| (no separate time validation) | `ValidateTimeClaims(claims) error` |
+| (no separate time validation) | `ValidateTimeClaimsAt(claims, now) error` |
+| (no combined parse+validate) | `ParseAndValidate(token, secret, allowedAlgs) (*Token, error)` |
+
+**Migration note:** In v3, the `Token.Valid` field was set to `true` after `Parse()` succeeded, which callers commonly interpreted as "the token is fully valid." In v4, `Token.SignatureValid` clarifies that only the cryptographic HMAC signature was verified -- it does **not** cover time-based claims (`exp`, `nbf`, `iat`). Callers relying on `Token.Valid` for authorization decisions must either:
+
+1. Switch to `ParseAndValidate()`, which performs both signature verification and time-claim validation in one call, or
+2. Call `ValidateTimeClaims(token.Claims)` (or `ValidateTimeClaimsAt(token.Claims, now)` for deterministic testing) after `Parse()`.
+
+New sentinel errors for time validation: `ErrTokenExpired`, `ErrTokenNotYetValid`, `ErrTokenIssuedInFuture`.
+
+---
+
+## commons/license
+
+| v3 | v4 |
+|----|----|
+| `DefaultHandler(reason)` panics | `DefaultHandler(reason)` records assertion failure (no panic) |
+| `ManagerShutdown.Terminate(reason)` panics on nil handler | Records assertion failure, returns without panic |
+| Direct struct construction `&ManagerShutdown{}` | `New(opts ...ManagerOption) *ManagerShutdown` constructor with functional options |
+
+### New in v4
+
+- `New(opts ...ManagerOption) *ManagerShutdown` -- constructor with default handler and functional options
+- `WithLogger(l log.Logger) ManagerOption` -- provides structured logger for assertion and validation logging
+- `DefaultHandlerWithError(reason string) error` -- returns `ErrLicenseValidationFailed` instead of panicking
+- `(*ManagerShutdown).TerminateWithError(reason) error` -- returns error instead of invoking handler (for validation checks)
+- `(*ManagerShutdown).TerminateSafe(reason) error` -- invokes handler but returns error if manager is uninitialized
+- Sentinel errors: `ErrLicenseValidationFailed`, `ErrManagerNotInitialized`
+
+---
+
+## commons/cron
+
+| v3 | v4 |
+|----|----|
+| `schedule.Next(from)` on nil receiver | returns `(time.Time{}, nil)` -> now returns `(time.Time{}, ErrNilSchedule)` |
+
+New error: `ErrNilSchedule`
+
+---
+
+## commons/security
+
+| v3 | v4 |
+|----|----|
+| `DefaultSensitiveFieldsMap()` | still available (reimplemented with lazy init + `sync.Once`) |
+
+Field list expanded with additional financial and PII identifiers.
+
+---
+
+## commons/constants
+
+The `commons/constants` package remains available in v4 and is materially expanded in the unified line.
+
+Notable additions used across the migrated packages:
+
+- OpenTelemetry attribute and metric constants for connectors and runtime packages
+- `SanitizeMetricLabel(value string) string` for bounded metric-label values
+- Shared datasource, header, metadata, pagination, transaction, and obfuscation constants consolidated under one package tree
+
+---
+
+## commons/pointers
+
+The `commons/pointers` package remains available at the same path in v4.
+
+Exported helpers:
+
+- `String()`, `Bool()`, `Time()`, `Int()`, `Int64()`, `Float64()`
+
+---
+
+## commons/secretsmanager
+
+The `commons/secretsmanager` package remains available in the unified v4 line.
+
+Core APIs:
+
+- `GetM2MCredentials(ctx, client, env, tenantOrgID, applicationName, targetService)`
+- `M2MCredentials`
+- `SecretsManagerClient`
+- Sentinel errors such as `ErrM2MCredentialsNotFound`, `ErrM2MVaultAccessDenied`, `ErrM2MRetrievalFailed`, `ErrM2MUnmarshalFailed`, `ErrM2MInvalidInput`, and `ErrM2MInvalidCredentials`
+
+No import-path change is required for consumers already using `commons/secretsmanager`.
+
+---
+
+## Added or newly available in v4
+
+### commons/circuitbreaker
+
+- `NewManager(logger, opts...) (Manager, error)` -- circuit breaker manager for service-level resilience
+- `WithMetricsFactory(f *metrics.MetricsFactory) ManagerOption` -- emits state transition and execution counters
+- `NewHealthCheckerWithValidation(manager, interval, timeout, logger) (HealthChecker, error)` -- periodic health checks with recovery and config validation
+- Preset configs: `DefaultConfig()`, `AggressiveConfig()`, `ConservativeConfig()`, `HTTPServiceConfig()`, `DatabaseConfig()`
+- `Config.Validate() error` -- validates circuit breaker configuration
+- Core types: `Config`, `State`, `Counts`, `CircuitBreaker` interface, `Manager` interface, `HealthChecker` interface
+- State constants: `StateClosed`, `StateOpen`, `StateHalfOpen`, `StateUnknown`
+- Sentinel errors: `ErrInvalidConfig`, `ErrNilLogger`, `ErrNilCircuitBreaker`, `ErrNilManager`, `ErrInvalidHealthCheckInterval`, `ErrInvalidHealthCheckTimeout`
+
+### commons/assert
+
+- `New(ctx, logger, component, operation) *Asserter` -- production-safe assertions
+- Methods: `That()`, `NotNil()`, `NotEmpty()`, `NoError()`, `Never()`, `Halt()`
+- Returns errors + emits telemetry instead of panicking
+- Metrics: `InitAssertionMetrics(factory)`, `GetAssertionMetrics()`, `ResetAssertionMetrics()`
+- Predicates library (`predicates.go`): `Positive`, `NonNegative`, `NotZero`, `InRange`, `PositiveInt`, `InRangeInt`, `ValidUUID`, `ValidAmount`, `ValidScale`, `PositiveDecimal`, `NonNegativeDecimal`, `ValidPort`, `ValidSSLMode`, `DebitsEqualCredits`, `NonZeroTotals`, `ValidTransactionStatus`, `TransactionCanTransitionTo`, `TransactionCanBeReverted`, `BalanceSufficientForRelease`, `DateNotInFuture`, `DateAfter`, `BalanceIsZero`, `TransactionHasOperations`, `TransactionOperationsMatch`
+- Sentinel error: `ErrAssertionFailed`
+
+### commons/runtime
+
+- Recovery: `RecoverAndLog`, `RecoverAndCrash`, `RecoverWithPolicy` (and `*WithContext` variants)
+- Safe goroutines: `SafeGo`, `SafeGoWithContext`, `SafeGoWithContextAndComponent` with `PanicPolicy` (KeepRunning/CrashProcess)
+- Panic metrics: `InitPanicMetrics(factory[, logger])`, `GetPanicMetrics()`, `ResetPanicMetrics()`
+- Span recording: `RecordPanicToSpan`, `RecordPanicToSpanWithComponent`
+- Error reporter: `SetErrorReporter(reporter)`, `GetErrorReporter()` with `ErrorReporter` interface
+- Production mode: `SetProductionMode(bool)`, `IsProductionMode() bool`
+- Sentinel error: `ErrPanic`
+
+### commons/safe
+
+- **Math:** `Divide()`, `DivideRound()`, `DivideOrZero()`, `DivideOrDefault()`, `Percentage()`, `PercentageOrZero()` on `decimal.Decimal` with zero-division safety; `DivideFloat64()`, `DivideFloat64OrZero()` for float64
+- **Regex:** `Compile()`, `CompilePOSIX()`, `MatchString()`, `FindString()`, `ClearCache()` with caching
+- **Slices:** `First[T]()`, `Last[T]()`, `At[T]()` with error returns and `*OrDefault` variants
+- Sentinel errors: `ErrDivisionByZero`, `ErrInvalidRegex`, `ErrEmptySlice`, `ErrIndexOutOfBounds`
+
+### commons/security
+
+- `IsSensitiveField(name) bool` -- case-insensitive sensitive field detection
+- `DefaultSensitiveFields() []string` -- default sensitive field patterns
+- `DefaultSensitiveFieldsMap() map[string]bool` -- map version for lookups
+
+### commons/jwt
+
+- `Parse(token, secret, allowedAlgs) (*Token, error)` -- HMAC JWT signature verification only
+- `ParseAndValidate(token, secret, allowedAlgs) (*Token, error)` -- signature + time claim validation
+- `Sign(claims, secret, alg) (string, error)` -- HMAC JWT creation
+- `ValidateTimeClaims(claims) error` -- exp/nbf/iat validation against current UTC time
+- `ValidateTimeClaimsAt(claims, now) error` -- exp/nbf/iat validation against a specific time (for deterministic testing)
+- `Token.SignatureValid` (bool) -- replaces v3 `Token.Valid`; clarifies signature-only scope
+- Algorithms: `AlgHS256`, `AlgHS384`, `AlgHS512`
+- Sentinel errors: `ErrTokenExpired`, `ErrTokenNotYetValid`, `ErrTokenIssuedInFuture`
+
+### commons/backoff
+
+- `Exponential(base, attempt) time.Duration` -- exponential delay calculation
+- `FullJitter(delay) time.Duration` -- crypto/rand-based jitter
+- `ExponentialWithJitter(base, attempt) time.Duration` -- combined helper
+- `WaitContext(ctx, delay) error` -- context-aware sleep (renamed from `SleepWithContext`)
+
+### commons/cron
+
+- `Parse(expr) (Schedule, error)` -- 5-field cron expression parser
+- `Schedule.Next(t) (time.Time, error)` -- next execution time
+
+### commons/errgroup
+
+- `WithContext(ctx) (*Group, context.Context)` -- goroutine group with cancellation
+- `(*Group).Go(fn)` -- launch goroutine with panic recovery
+- `(*Group).Wait() error` -- wait and return first error
+- `(*Group).SetLogger(logger)` -- configure logger for panic recovery diagnostics
+- Sentinel error: `ErrPanicRecovered`
+
+### commons/tenant-manager
+
+The `tenant-manager` package tree provides multi-tenant connection management, preserved and expanded in unified `lib-commons/v4`.
+
+#### New packages
+
+| Package | Purpose |
+|---------|---------|
+| `tenant-manager/core` | Shared types (`TenantConfig`), context helpers (`ContextWithTenantID`, `GetTenantIDFromContext`), error types |
+| `tenant-manager/cache` | Exported config cache contract and in-memory cache implementation for tenant settings |
+| `tenant-manager/client` | HTTP client for Tenant Manager API with circuit breaker, caching, and invalidation helpers |
+| `tenant-manager/consumer` | `MultiTenantConsumer` — goroutine-per-tenant lifecycle management |
+| `tenant-manager/middleware` | Fiber middleware for tenant extraction (`TenantMiddleware`) and multi-pool routing (`MultiPoolMiddleware`) |
+| `tenant-manager/postgres` | `Manager` — per-tenant PostgreSQL connection pool management with LRU eviction |
+| `tenant-manager/mongo` | `Manager` — per-tenant MongoDB connection management with LRU eviction |
+| `tenant-manager/rabbitmq` | `Manager` — per-tenant RabbitMQ connection management |
+| `tenant-manager/s3` | Tenant-scoped S3 object storage key prefixing |
+| `tenant-manager/valkey` | Tenant-scoped Redis/Valkey key prefixing |
+
+#### Breaking changes
+
+**1. Removed `NewMultiTenantConsumer`**
+
+| v3 | v4 |
+|---|---|
+| `consumer.NewMultiTenantConsumer(cfg, logger) *MultiTenantConsumer` | removed; use `consumer.NewMultiTenantConsumerWithError(cfg, logger) (*MultiTenantConsumer, error)` |
+
+The deprecated panicking constructor has been removed. `NewMultiTenantConsumerWithError` returns an error on invalid configuration instead of calling `panic()`.
+
+**2. Tenant client caching remains available through exported cache APIs**
+
+| v3 | v4 |
+|---|---|
+| cache package exposed at `tenant-manager/cache` | still available at `tenant-manager/cache` |
+| `client.WithCache(...)` / `client.WithCacheTTL(...)` | still supported |
+| per-call cache bypass | `client.WithSkipCache()` |
+| cache eviction | `(*Client).InvalidateConfig(ctx, tenantID, service) error` |
+
+**3. S3 function signature changes**
+
+Three S3 functions now return `(string, error)` instead of `string` to support delimiter validation:
+
+| v3 | v4 |
+|---|---|
+| `s3.GetObjectStorageKey(tenantID, key) string` | `s3.GetObjectStorageKey(tenantID, key) (string, error)` |
+| `s3.GetObjectStorageKeyForTenant(ctx, key) string` | `s3.GetObjectStorageKeyForTenant(ctx, key) (string, error)` |
+| `s3.StripObjectStoragePrefix(tenantID, prefixedKey) string` | `s3.StripObjectStoragePrefix(tenantID, prefixedKey) (string, error)` |
+
+**4. Valkey function signature changes**
+
+Five Valkey functions now return `(string, error)` instead of `string` to support delimiter validation:
+
+| v3 | v4 |
+|---|---|
+| `valkey.GetKey(tenantID, key) string` | `valkey.GetKey(tenantID, key) (string, error)` |
+| `valkey.GetKeyFromContext(ctx, key) string` | `valkey.GetKeyFromContext(ctx, key) (string, error)` |
+| `valkey.GetPattern(tenantID, pattern) string` | `valkey.GetPattern(tenantID, pattern) (string, error)` |
+| `valkey.GetPatternFromContext(ctx, pattern) string` | `valkey.GetPatternFromContext(ctx, pattern) (string, error)` |
+| `valkey.StripTenantPrefix(tenantID, prefixedKey) string` | `valkey.StripTenantPrefix(tenantID, prefixedKey) (string, error)` |
+
+**5. `hasUpstreamAuthAssertion` behavioral change**
+
+| Behavior | v4 |
+|----------|-----|
+| Auth assertion via HTTP header | The middleware no longer checks the `X-User-ID` HTTP header for auth assertion (headers are client-spoofable). Only `c.Locals("user_id")` set by upstream lib-auth middleware is checked. |
+
+**Migration note:** Applications relying on the `X-User-ID` header for auth assertion must ensure upstream auth middleware sets the Fiber local `user_id` value instead. The header path was removed because HTTP headers are client-spoofable and cannot be trusted for authorization decisions.
+
+**6. `isPublicPath` boundary-aware matching**
+
+| Behavior | v3 | v4 |
+|----------|---|---|
+| `isPublicPath` matching | `strings.HasPrefix(path, prefix)` | `path == prefix \|\| strings.HasPrefix(path, prefix+"/")` |
+
+**Before:** `/healthy` matched public path `/health` because `strings.HasPrefix("/healthy", "/health")` is true.
+
+**After:** `/healthy` does **not** match public path `/health`. Only exact matches (`/health`) or sub-paths (`/health/live`) match.
+
+**Migration note:** Services using `WithPublicPaths()` that relied on the previous prefix-only matching behavior may need to adjust their configured paths. For example, if a service had `WithPublicPaths("/health")` and expected `/healthz` to be treated as public, it must now explicitly add `/healthz` to the public paths list. This change prevents unintended route matching where a public path prefix accidentally exempted unrelated endpoints from tenant resolution.
+
+**7. PostgreSQL SSL default changed**
+
+| Behavior | v3 | v4 |
+|----------|---|---|
+| `buildConnectionString` SSL mode | `sslmode=disable` | `sslmode=prefer` |
+
+Connections will now attempt TLS when available with graceful fallback to plaintext. Set `SSLMode: "disable"` explicitly in `PostgreSQLConfig` to restore the previous behavior.
+
+**8. Tenant ID format validation**
+
+| Behavior | v4 |
+|----------|-----|
+| Tenant ID format | Middleware and consumer now validate tenant IDs against `^[a-zA-Z0-9][a-zA-Z0-9_-]*$` with a 256-character limit. |
+
+Tenant IDs containing dots, spaces, or special characters will be rejected. This applies to both `TenantMiddleware` and `MultiTenantConsumer` tenant lifecycle management.
+
+**9. `WorkersPerQueue` default changed**
+
+| Config field | v3 | v4 |
+|---|---|---|
+| `DefaultMultiTenantConfig().WorkersPerQueue` | `1` | `0` |
+
+The field is reserved for future use and currently a no-op.
+
+**10. Client error message format**
+
+| Behavior | v4 |
+|----------|-----|
+| Error messages from tenant manager HTTP client | No longer include raw response body content. Response bodies are now logged separately via `truncateBody` for security. |
+
+**Migration note:** Any error-message parsing that relied on response body content embedded in the error string will no longer match. Use structured logging output to inspect response bodies.
+
+#### Behavioral changes in outbox/tenant.go
+
+- `ContextWithTenantID` now writes to both the new `core.tenantIDKey` context key AND the legacy `TenantIDContextKey` for backward compatibility.
+- `TenantIDFromContext` reads the new `core.tenantIDKey` first, then falls back to the legacy key.
+- Tenant IDs with leading/trailing whitespace are now **rejected** (v3 behavior was to silently trim). Callers must pre-trim tenant IDs.
+
+---
+
+## Deleted files in v4
+
+The following files were removed during v4 consolidation:
+
+| File | Reason |
+|------|--------|
+| `mk/tests.mk` | test targets inlined into main Makefile |
+| `commons/logging/sanitizer.go` + `sanitizer_test.go` | package removed; moved to `commons/log/sanitizer.go` |
+| `commons/opentelemetry/metrics/labels.go` | organization/ledger label helpers removed |
+| `commons/opentelemetry/metrics/metrics_test.go` | replaced by v4 test suite |
+| `commons/opentelemetry/otel_test.go` | replaced by v4 test suite |
+| `commons/opentelemetry/extract_queue_test.go` | consolidated |
+| `commons/opentelemetry/inject_trace_test.go` | consolidated |
+| `commons/opentelemetry/queue_trace_test.go` | consolidated |
+| `commons/postgres/pagination.go` | `Pagination` moved to `commons/net/http` |
+| `commons/runtime/log_mode_link.go` | functionality inlined into runtime package |
+| `commons/server/grpc_test.go` | removed |
+| `commons/zap/sanitize.go` + `sanitize_test.go` | CWE-117 sanitization moved into zap core |
+
+---
+
+## Suggested verification command
+
+```bash
+# Check for removed v3 patterns
+rg -n "InitializeTelemetryWithError|InitializeTelemetry\(|SetSpanAttributesFromStruct|WithLedgerLabels|WithOrganizationLabels|NoneLogger|BuildConnectionString\(|WriteError\(|HandleFiberError\(|ValidateBalancesRules\(|DetermineOperation\(|ValidateFromToOperation\(|NewTracerFromContext\(|NewMetricFactoryFromContext\(|NewHeaderIDFromContext\(|EnsureConfigFromEnvVars\(|WithTimeout\(|GracefulShutdown|MongoConnection|PostgresConnection|RedisConnection|ZapWithTraceLogger|FieldObfuscator|LogLevel|NoneLogger|WithFields\(|InitializeLogger\b" .
+
+# Check for v3 patterns that changed signature or semantics in v4
+rg -n "uuid\.Must\(uuid\.NewV7|GenerateUUIDv7\(\)" . --type go # should now return (uuid.UUID, error)
+rg -n "Token\.Valid\b" . --type go # renamed to Token.SignatureValid
+rg -n "\"code\":\s*\"[0-9]" . --type go # ErrorResponse.Code is now int, not string
+
+# Check for added or newly available v4 packages
+rg -n "commons/circuitbreaker|commons/assert|commons/safe|commons/security|commons/jwt|commons/backoff|commons/pointers|commons/cron|commons/errgroup|commons/secretsmanager|commons/tenant-manager" . --type go
+```
diff --git a/Makefile b/Makefile
index 2ac5525d..14563eb1 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,8 @@
-# Define the root directory of the project
-LIB_COMMONS := $(shell pwd)
+# Default target when running bare `make`
+.DEFAULT_GOAL := help
+
+# Define the root directory of the project (resolves correctly even with make -f)
+LIB_COMMONS := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
# Include shared color definitions and utility functions
include $(LIB_COMMONS)/commons/shell/makefile_colors.mk
@@ -13,9 +16,82 @@ define print_title
@echo "------------------------------------------"
endef
-# Include test targets
-MK_DIR := $(abspath mk)
-include $(MK_DIR)/tests.mk
+# ------------------------------------------------------
+# Test configuration for lib-commons
+# ------------------------------------------------------
+
+# Integration test filter
+# RUN: specific test name pattern (e.g., TestIntegration_FeatureName)
+# PKG: specific package to test (e.g., ./commons/...)
+# Usage: make test-integration RUN=TestIntegration_FeatureName
+# make test-integration PKG=./commons/...
+RUN ?=
+PKG ?=
+
+# Computed run pattern: uses RUN if set, otherwise defaults to '^TestIntegration'
+ifeq ($(RUN),)
+ RUN_PATTERN := ^TestIntegration
+else
+ RUN_PATTERN := $(RUN)
+endif
+
+# Low-resource mode for limited machines (sets -p=1 -parallel=1, disables -race)
+# Usage: make test LOW_RESOURCE=1
+# make test-unit LOW_RESOURCE=1
+# make test-integration LOW_RESOURCE=1
+# make coverage-unit LOW_RESOURCE=1
+# make coverage-integration LOW_RESOURCE=1
+LOW_RESOURCE ?= 0
+
+# Computed flags for low-resource mode
+ifeq ($(LOW_RESOURCE),1)
+ LOW_RES_P_FLAG := -p 1
+ LOW_RES_PARALLEL_FLAG := -parallel 1
+ LOW_RES_RACE_FLAG :=
+else
+ LOW_RES_P_FLAG :=
+ LOW_RES_PARALLEL_FLAG :=
+ LOW_RES_RACE_FLAG := -race
+endif
+
+# macOS ld64 workaround: newer ld emits noisy LC_DYSYMTAB warnings when linking test binaries with -race.
+# If available, prefer Apple's classic linker to silence them.
+UNAME_S := $(shell uname -s)
+ifeq ($(UNAME_S),Darwin)
+ # Prefer classic mode to suppress LC_DYSYMTAB warnings on macOS.
+ # Set DISABLE_OSX_LINKER_WORKAROUND=1 to disable this behavior.
+ ifneq ($(DISABLE_OSX_LINKER_WORKAROUND),1)
+ GO_TEST_LDFLAGS := -ldflags="-linkmode=external -extldflags=-ld_classic"
+ else
+ GO_TEST_LDFLAGS :=
+ endif
+else
+ GO_TEST_LDFLAGS :=
+endif
+
+# ------------------------------------------------------
+# Test tooling configuration
+# ------------------------------------------------------
+
+# Pinned tool versions for reproducibility (update as needed)
+GOTESTSUM_VERSION ?= v1.12.0
+GOSEC_VERSION ?= v2.22.4
+GOLANGCI_LINT_VERSION ?= v2.1.6
+
+TEST_REPORTS_DIR ?= ./reports
+GOTESTSUM = $(shell command -v gotestsum 2>/dev/null)
+RETRY_ON_FAIL ?= 0
+
+.PHONY: tools tools-gotestsum
+tools: tools-gotestsum ## Install helpful dev/test tools
+
+tools-gotestsum:
+ @if [ -z "$(GOTESTSUM)" ]; then \
+ echo "Installing gotestsum..."; \
+ GO111MODULE=on go install gotest.tools/gotestsum@$(GOTESTSUM_VERSION); \
+ else \
+ echo "gotestsum already installed: $(GOTESTSUM)"; \
+ fi
#-------------------------------------------------------
# Help Command
@@ -30,13 +106,14 @@ help:
@echo ""
@echo "Core Commands:"
@echo " make help - Display this help message"
- @echo " make test - Run all tests"
+ @echo " make test - Run unit tests (without integration)"
+ @echo " make ci - Run the local fix + verify pipeline"
@echo " make build - Build all packages"
@echo " make clean - Clean all build artifacts"
@echo ""
@echo ""
@echo "Test Suite Commands:"
- @echo " make test-unit - Run unit tests"
+ @echo " make test-unit - Run unit tests (LOW_RESOURCE=1 supported)"
@echo " make test-integration - Run integration tests with testcontainers (RUN=, LOW_RESOURCE=1)"
@echo " make test-all - Run all tests (unit + integration)"
@echo ""
@@ -52,10 +129,12 @@ help:
@echo ""
@echo ""
@echo "Code Quality Commands:"
- @echo " make lint - Run linting on all packages"
+ @echo " make lint - Run linting on all packages (read-only check)"
+ @echo " make lint-fix - Run linting with auto-fix on all packages"
@echo " make format - Format code in all packages"
@echo " make tidy - Clean dependencies"
@echo " make check-tests - Verify test coverage for packages"
+ @echo " make vet - Run go vet on all packages"
@echo " make sec - Run security checks using gosec"
@echo " make sec SARIF=1 - Run security checks with SARIF output"
@echo ""
@@ -86,34 +165,333 @@ build:
.PHONY: clean
clean:
$(call print_title,Cleaning build artifacts)
- @rm -rf ./bin ./dist ./reports coverage.out coverage.html
+ @rm -rf ./bin ./dist $(TEST_REPORTS_DIR) coverage.out coverage.html gosec-report.sarif
@go clean -cache -testcache
@echo "$(GREEN)$(BOLD)[ok]$(NC) All build artifacts cleaned$(GREEN) ✔️$(NC)"
+.PHONY: ci
+ci:
+ $(call print_title,Running local CI preflight pipeline)
+ @printf "This target normalizes the worktree before verification.\n"
+ $(MAKE) lint-fix
+ $(MAKE) format
+ $(MAKE) tidy
+ $(MAKE) check-tests
+ $(MAKE) sec
+ $(MAKE) vet
+ $(MAKE) test-unit
+ $(MAKE) test-integration
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) Local CI pipeline completed successfully$(GREEN) ✔️$(NC)"
+
+#-------------------------------------------------------
+# Core Test Commands
+#-------------------------------------------------------
+
+.PHONY: test
+test:
+ $(call print_title,Running all tests)
+ $(call check_command,go,"Install Go from https://golang.org/doc/install")
+ @set -e; mkdir -p $(TEST_REPORTS_DIR); \
+ if [ -n "$(GOTESTSUM)" ]; then \
+ echo "Running tests with gotestsum"; \
+ gotestsum --format testname -- -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) ./... || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying tests once..."; \
+ gotestsum --format testname -- -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) ./...; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ else \
+ go test -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) ./... || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying tests once..."; \
+ go test -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) ./...; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ fi
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) All tests passed$(GREEN) ✔️$(NC)"
+
+#-------------------------------------------------------
+# Test Suite Aliases
+#-------------------------------------------------------
+
+# Unit tests (excluding integration tests)
+.PHONY: test-unit
+test-unit:
+ $(call print_title,Running Go unit tests)
+ $(call check_command,go,"Install Go from https://golang.org/doc/install")
+ @set -e; mkdir -p $(TEST_REPORTS_DIR); \
+ pkgs=$$(go list ./... | grep -v '/tests'); \
+ if [ -z "$$pkgs" ]; then \
+ echo "No unit test packages found"; \
+ else \
+ if [ -n "$(GOTESTSUM)" ]; then \
+ echo "Running unit tests with gotestsum"; \
+ gotestsum --format testname -- -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying unit tests once..."; \
+ gotestsum --format testname -- -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ else \
+ go test -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying unit tests once..."; \
+ go test -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ fi; \
+ fi
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) Unit tests passed$(GREEN) ✔️$(NC)"
+
+# Integration tests with testcontainers (no coverage)
+# These tests use the `integration` build tag and testcontainers-go to spin up
+# ephemeral containers. No external Docker stack is required.
+#
+# Requirements:
+# - Test files must follow the naming convention: *_integration_test.go
+# - Test functions must start with TestIntegration_ (e.g., TestIntegration_MyFeature_Works)
+.PHONY: test-integration
+test-integration:
+ $(call print_title,Running integration tests with testcontainers)
+ $(call check_command,go,"Install Go from https://golang.org/doc/install")
+ $(call check_command,docker,"Install Docker from https://docs.docker.com/get-docker/")
+ @set -e; mkdir -p $(TEST_REPORTS_DIR); \
+ if [ -n "$(PKG)" ]; then \
+ echo "Using specified package: $(PKG)"; \
+ pkgs=$$(go list $(PKG) 2>/dev/null | tr '\n' ' '); \
+ else \
+ echo "Finding packages with *_integration_test.go files..."; \
+ dirs=$$(find . -name '*_integration_test.go' -not -path './vendor/*' -exec dirname {} \; 2>/dev/null | sort -u | tr '\n' ' '); \
+ pkgs=$$(if [ -n "$$dirs" ]; then go list $$dirs 2>/dev/null | tr '\n' ' '; fi); \
+ fi; \
+ if [ -z "$$pkgs" ]; then \
+ echo "No integration test packages found"; \
+ else \
+ echo "Packages: $$pkgs"; \
+ echo "Running packages sequentially (-p=1) to avoid Docker container conflicts"; \
+ if [ "$(LOW_RESOURCE)" = "1" ]; then \
+ echo "LOW_RESOURCE mode: -parallel=1, race detector disabled"; \
+ fi; \
+ if [ -n "$(GOTESTSUM)" ]; then \
+ echo "Running integration tests with gotestsum"; \
+ gotestsum --format testname -- \
+ -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying integration tests once..."; \
+ gotestsum --format testname -- \
+ -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ else \
+ go test -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying integration tests once..."; \
+ go test -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ fi; \
+ fi
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) Integration tests passed$(GREEN) ✔️$(NC)"
+
+# Run all tests (unit + integration)
+.PHONY: test-all
+test-all:
+ $(call print_title,Running all tests (unit + integration))
+ $(call print_title,Running unit tests)
+ $(MAKE) test-unit
+ $(call print_title,Running integration tests)
+ $(MAKE) test-integration
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) All tests passed$(GREEN) ✔️$(NC)"
+
+#-------------------------------------------------------
+# Coverage Commands
+#-------------------------------------------------------
+
+# Unit tests with coverage (uses covermode=atomic)
+# Supports PKG parameter to filter packages (e.g., PKG=./commons/...)
+# Supports .ignorecoverunit file to exclude patterns from coverage stats
+.PHONY: coverage-unit
+coverage-unit:
+ $(call print_title,Running Go unit tests with coverage)
+ $(call check_command,go,"Install Go from https://golang.org/doc/install")
+ @set -e; mkdir -p $(TEST_REPORTS_DIR); \
+ if [ -n "$(PKG)" ]; then \
+ echo "Using specified package: $(PKG)"; \
+ pkgs=$$(go list $(PKG) 2>/dev/null | grep -v '/tests' | tr '\n' ' '); \
+ else \
+ pkgs=$$(go list ./... | grep -v '/tests'); \
+ fi; \
+ if [ -z "$$pkgs" ]; then \
+ echo "No unit test packages found"; \
+ else \
+ echo "Packages: $$pkgs"; \
+ if [ -n "$(GOTESTSUM)" ]; then \
+ echo "Running unit tests with gotestsum (coverage enabled)"; \
+ gotestsum --format testname -- -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/unit_coverage.out $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying unit tests once..."; \
+ gotestsum --format testname -- -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/unit_coverage.out $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ else \
+ go test -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/unit_coverage.out $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying unit tests once..."; \
+ go test -tags=unit -v $(LOW_RES_P_FLAG) $(LOW_RES_RACE_FLAG) $(LOW_RES_PARALLEL_FLAG) -count=1 $(GO_TEST_LDFLAGS) -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/unit_coverage.out $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ fi; \
+ if [ -f .ignorecoverunit ]; then \
+ echo "Filtering coverage with .ignorecoverunit patterns..."; \
+ patterns=$$(grep -v '^#' .ignorecoverunit | grep -v '^$$' | tr '\n' '|' | sed 's/|$$//'); \
+ if [ -n "$$patterns" ]; then \
+ regex_patterns=$$(echo "$$patterns" | sed 's/[][(){}+?^$$\\|]/\\&/g' | sed 's/\./\\./g' | sed 's/\*/.*/g'); \
+ head -1 $(TEST_REPORTS_DIR)/unit_coverage.out > $(TEST_REPORTS_DIR)/unit_coverage_filtered.out; \
+ tail -n +2 $(TEST_REPORTS_DIR)/unit_coverage.out | grep -vE "$$regex_patterns" >> $(TEST_REPORTS_DIR)/unit_coverage_filtered.out || true; \
+ mv $(TEST_REPORTS_DIR)/unit_coverage_filtered.out $(TEST_REPORTS_DIR)/unit_coverage.out; \
+ echo "Excluded patterns: $$patterns"; \
+ fi; \
+ fi; \
+ echo "----------------------------------------"; \
+ go tool cover -func=$(TEST_REPORTS_DIR)/unit_coverage.out | grep total | awk '{print "Total coverage: " $$3}'; \
+ echo "----------------------------------------"; \
+ fi
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) Unit coverage report generated$(GREEN) ✔️$(NC)"
+
+# Integration tests with testcontainers (with coverage, uses covermode=atomic)
+.PHONY: coverage-integration
+coverage-integration:
+ $(call print_title,Running integration tests with testcontainers (coverage enabled))
+ $(call check_command,go,"Install Go from https://golang.org/doc/install")
+ $(call check_command,docker,"Install Docker from https://docs.docker.com/get-docker/")
+ @set -e; mkdir -p $(TEST_REPORTS_DIR); \
+ if [ -n "$(PKG)" ]; then \
+ echo "Using specified package: $(PKG)"; \
+ pkgs=$$(go list $(PKG) 2>/dev/null | tr '\n' ' '); \
+ else \
+ echo "Finding packages with *_integration_test.go files..."; \
+ dirs=$$(find . -name '*_integration_test.go' -not -path './vendor/*' -exec dirname {} \; 2>/dev/null | sort -u | tr '\n' ' '); \
+ pkgs=$$(if [ -n "$$dirs" ]; then go list $$dirs 2>/dev/null | tr '\n' ' '; fi); \
+ fi; \
+ if [ -z "$$pkgs" ]; then \
+ echo "No integration test packages found"; \
+ else \
+ echo "Packages: $$pkgs"; \
+ echo "Running packages sequentially (-p=1) to avoid Docker container conflicts"; \
+ if [ "$(LOW_RESOURCE)" = "1" ]; then \
+ echo "LOW_RESOURCE mode: -parallel=1, race detector disabled"; \
+ fi; \
+ if [ -n "$(GOTESTSUM)" ]; then \
+ echo "Running testcontainers integration tests with gotestsum (coverage enabled)"; \
+ gotestsum --format testname -- \
+ -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/integration_coverage.out \
+ $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying integration tests once..."; \
+ gotestsum --format testname -- \
+ -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/integration_coverage.out \
+ $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ else \
+ go test -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/integration_coverage.out \
+ $$pkgs || { \
+ if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
+ echo "Retrying integration tests once..."; \
+ go test -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
+ -p 1 $(LOW_RES_PARALLEL_FLAG) \
+ -run '$(RUN_PATTERN)' -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/integration_coverage.out \
+ $$pkgs; \
+ else \
+ exit 1; \
+ fi; \
+ }; \
+ fi; \
+ echo "----------------------------------------"; \
+ go tool cover -func=$(TEST_REPORTS_DIR)/integration_coverage.out | grep total | awk '{print "Total coverage: " $$3}'; \
+ echo "----------------------------------------"; \
+ fi
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) Integration coverage report generated$(GREEN) ✔️$(NC)"
+
+# Run all coverage targets
+.PHONY: coverage
+coverage:
+ $(call print_title,Running all coverage targets)
+ $(MAKE) coverage-unit
+ $(MAKE) coverage-integration
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) All coverage reports generated$(GREEN) ✔️$(NC)"
+
#-------------------------------------------------------
# Code Quality Commands
#-------------------------------------------------------
.PHONY: lint
lint:
- $(call print_title,Running linters on all packages)
- $(call check_command,golangci-lint,"go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest")
- @out=$$(golangci-lint run --fix ./... 2>&1); \
+ $(call print_title,Running linters on all packages (read-only))
+ $(call check_command,golangci-lint,"go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION)")
+ @out=$$(golangci-lint run ./... 2>&1); \
out_err=$$?; \
- perf_out=$$(perfsprint ./... 2>&1); \
- perf_err=$$?; \
+ if command -v perfsprint >/dev/null 2>&1; then \
+ perf_out=$$(perfsprint ./... 2>&1); \
+ perf_err=$$?; \
+ else \
+ perf_out=""; \
+ perf_err=0; \
+ echo "Note: perfsprint not installed, skipping performance checks (go install github.com/catenacyber/perfsprint@latest)"; \
+ fi; \
echo "$$out"; \
- echo "$$perf_out"; \
+ if [ -n "$$perf_out" ]; then echo "$$perf_out"; fi; \
if [ $$out_err -ne 0 ]; then \
- echo -e "\n$(BOLD)$(RED)An error has occurred during the lint process: \n $$out\n"; \
+ printf "\n%s\n" "$(BOLD)$(RED)An error has occurred during the lint process:$(NC)"; \
+ printf "%s\n" "$$out"; \
exit 1; \
fi; \
if [ $$perf_err -ne 0 ]; then \
- echo -e "\n$(BOLD)$(RED)An error has occurred during the performance check: \n $$perf_out\n"; \
+ printf "\n%s\n" "$(BOLD)$(RED)An error has occurred during the performance check:$(NC)"; \
+ printf "%s\n" "$$perf_out"; \
exit 1; \
fi
@echo "$(GREEN)$(BOLD)[ok]$(NC) Lint and performance checks passed successfully$(GREEN) ✔️$(NC)"
+.PHONY: lint-fix
+lint-fix:
+ $(call print_title,Running linters with auto-fix on all packages)
+ $(call check_command,golangci-lint,"go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION)")
+ @golangci-lint run --fix ./...
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) Lint auto-fix completed$(GREEN) ✔️$(NC)"
+
.PHONY: format
format:
$(call print_title,Formatting code in all packages)
@@ -132,6 +510,13 @@ check-tests:
fi
@echo "$(GREEN)$(BOLD)[ok]$(NC) Test coverage verification completed$(GREEN) ✔️$(NC)"
+.PHONY: vet
+vet:
+ $(call print_title,Running go vet on all packages)
+ $(call check_command,go,"Install Go from https://golang.org/doc/install")
+ go vet ./...
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) go vet completed successfully$(GREEN) ✔️$(NC)"
+
#-------------------------------------------------------
# Git Hook Commands
#-------------------------------------------------------
@@ -139,21 +524,37 @@ check-tests:
.PHONY: setup-git-hooks
setup-git-hooks:
$(call print_title,Installing and configuring git hooks)
- @find .githooks -type f -exec cp {} .git/hooks \;
- @chmod +x .git/hooks/*
+ @hooks_dir=$$(git rev-parse --git-path hooks); \
+ if [ ! -d .githooks ]; then \
+ echo "No .githooks directory found, skipping"; \
+ exit 0; \
+ fi; \
+ mkdir -p "$$hooks_dir"; \
+ for hook_dir in .githooks/*/; do \
+ if [ -d "$$hook_dir" ]; then \
+ for FILE in "$$hook_dir"*; do \
+ if [ -f "$$FILE" ]; then \
+ hook_name=$$(basename "$$FILE"); \
+ cp "$$FILE" "$$hooks_dir/$$hook_name"; \
+ chmod +x "$$hooks_dir/$$hook_name"; \
+ fi; \
+ done; \
+ fi; \
+ done
@echo "$(GREEN)$(BOLD)[ok]$(NC) All hooks installed and updated$(GREEN) ✔️$(NC)"
.PHONY: check-hooks
check-hooks:
$(call print_title,Verifying git hooks installation status)
- @err=0; \
+ @hooks_dir=$$(git rev-parse --git-path hooks); \
+ err=0; \
for hook_dir in .githooks/*; do \
if [ -d "$$hook_dir" ]; then \
for FILE in "$$hook_dir"/*; do \
if [ -f "$$FILE" ]; then \
f=$$(basename -- $$hook_dir)/$$(basename -- $$FILE); \
hook_name=$$(basename -- $$FILE); \
- FILE2=.git/hooks/$$hook_name; \
+ FILE2=$$hooks_dir/$$hook_name; \
if [ -f "$$FILE2" ]; then \
if cmp -s "$$FILE" "$$FILE2"; then \
echo "$(GREEN)$(BOLD)[ok]$(NC) Hook file $$f installed and updated$(GREEN) ✔️$(NC)"; \
@@ -170,7 +571,7 @@ check-hooks:
fi; \
done; \
if [ $$err -ne 0 ]; then \
- echo -e "\nRun $(BOLD)make setup-git-hooks$(NC) to setup your development environment, then try again.\n"; \
+ printf "\nRun %smake setup-git-hooks%s to setup your development environment, then try again.\n\n" "$(BOLD)" "$(NC)"; \
exit 1; \
else \
echo "$(GREEN)$(BOLD)[ok]$(NC) All hooks are properly installed$(GREEN) ✔️$(NC)"; \
@@ -181,8 +582,21 @@ check-envs:
$(call print_title,Checking git hooks and environment files for security issues)
$(MAKE) check-hooks
@echo "Checking for exposed secrets in environment files..."
- @if grep -rq "SECRET.*=" --include=".env" .; then \
- echo "$(RED)Warning: Secrets found in environment files. Make sure these are not committed to the repository.$(NC)"; \
+ @found=0; \
+ for pattern in '.env' '.env.*' '*.env'; do \
+ files=$$(find . -name "$$pattern" \
+ -not -name '*.example' -not -name '*.sample' -not -name '*.template' \
+ -not -path './vendor/*' -not -path './.git/*' 2>/dev/null); \
+ if [ -n "$$files" ]; then \
+ if echo "$$files" | xargs grep -iqE '^[[:space:]]*(export[[:space:]]+)?[A-Z0-9_]*(SECRET|PASSWORD|TOKEN|API_KEY|PRIVATE_KEY|CREDENTIAL|AWS_ACCESS_KEY|DB_PASS)[A-Z0-9_]*=[[:space:]]*[^#[:space:]]' 2>/dev/null; then \
+ echo "$(RED)Warning: Potential secrets found in environment files:$(NC)"; \
+ echo "$$files" | xargs grep -ilE '^[[:space:]]*(export[[:space:]]+)?[A-Z0-9_]*(SECRET|PASSWORD|TOKEN|API_KEY|PRIVATE_KEY|CREDENTIAL|AWS_ACCESS_KEY|DB_PASS)[A-Z0-9_]*=[[:space:]]*[^#[:space:]]' 2>/dev/null; \
+ found=1; \
+ fi; \
+ fi; \
+ done; \
+ if [ $$found -ne 0 ]; then \
+ echo "$(RED)Make sure these files are in .gitignore and not committed to the repository.$(NC)"; \
exit 1; \
else \
echo "$(GREEN)No exposed secrets found in environment files$(GREEN) ✔️$(NC)"; \
@@ -209,7 +623,7 @@ sec:
$(call print_title,Running security checks using gosec)
@if ! command -v gosec >/dev/null 2>&1; then \
echo "Installing gosec..."; \
- go install github.com/securego/gosec/v2/cmd/gosec@latest; \
+ go install github.com/securego/gosec/v2/cmd/gosec@$(GOSEC_VERSION); \
fi
@if find . -name "*.go" -type f -not -path './vendor/*' | grep -q .; then \
echo "Running security checks on all packages..."; \
@@ -218,7 +632,7 @@ sec:
if gosec -fmt sarif -out gosec-report.sarif ./...; then \
echo "$(GREEN)$(BOLD)[ok]$(NC) SARIF report generated: gosec-report.sarif$(GREEN) ✔️$(NC)"; \
else \
- echo -e "\n$(BOLD)$(RED)Security issues found by gosec. Please address them before proceeding.$(NC)\n"; \
+ printf "\n%s%sSecurity issues found by gosec. Please address them before proceeding.%s\n\n" "$(BOLD)" "$(RED)" "$(NC)"; \
echo "SARIF report with details: gosec-report.sarif"; \
exit 1; \
fi; \
@@ -226,7 +640,7 @@ sec:
if gosec ./...; then \
echo "$(GREEN)$(BOLD)[ok]$(NC) Security checks completed$(GREEN) ✔️$(NC)"; \
else \
- echo -e "\n$(BOLD)$(RED)Security issues found by gosec. Please address them before proceeding.$(NC)\n"; \
+ printf "\n%s%sSecurity issues found by gosec. Please address them before proceeding.%s\n\n" "$(BOLD)" "$(RED)" "$(NC)"; \
exit 1; \
fi; \
fi; \
@@ -242,5 +656,5 @@ sec:
goreleaser:
$(call print_title,Creating release snapshot with goreleaser)
$(call check_command,goreleaser,"go install github.com/goreleaser/goreleaser@latest")
- goreleaser release --snapshot --skip-publish --clean
- @echo "$(GREEN)$(BOLD)[ok]$(NC) Release snapshot created successfully$(GREEN) ✔️$(NC)"
\ No newline at end of file
+ goreleaser release --snapshot --skip=publish --clean
+ @echo "$(GREEN)$(BOLD)[ok]$(NC) Release snapshot created successfully$(GREEN) ✔️$(NC)"
diff --git a/README.md b/README.md
index eddee557..6e9137d4 100644
--- a/README.md
+++ b/README.md
@@ -1,248 +1,217 @@
# lib-commons
-A comprehensive Go library providing common utilities and components for building robust microservices and applications in the Lerian Studio ecosystem.
+`lib-commons` is Lerian's shared Go toolkit for service primitives, connectors, observability, and runtime safety.
-## Overview
+The current major API surface is **v4**. If you are migrating from older `lib-commons` code, see `MIGRATION_MAP.md`.
-`lib-commons` is a utility library that provides a collection of reusable components and helpers for Go applications. It includes standardized implementations for database connections, message queuing, logging, context management, error handling, transaction processing, and more.
+---
-## Features
+**Migrating from older packages?**
+Use `MIGRATION_MAP.md` as the canonical map for renamed, redesigned, or removed APIs in the unified `lib-commons` line.
-### Core Components
+---
-- **App Management**: Framework for managing application lifecycle and runtime (`app.go`)
-- **Context Utilities**: Enhanced context management with support for logging, tracing, and header IDs (`context.go`)
-- **Error Handling**: Standardized business error handling and responses (`errors.go`)
+## Requirements
-### Database Connectors
+- Go `1.25.7` or newer
-- **PostgreSQL**: Connection management, migrations, and utilities for PostgreSQL databases
-- **MongoDB**: Connection management and utilities for MongoDB
-- **Redis**: Client implementation and utilities for Redis
+## Installation
-### Messaging
+```bash
+go get github.com/LerianStudio/lib-commons/v4
+```
-- **RabbitMQ**: Client implementation and utilities for RabbitMQ
+## What is in this library
+
+### Core (`commons`)
+
+- `app.go`: `Launcher` for concurrent app lifecycle management with `NewLauncher(opts...)` and `RunApp` options
+- `context.go`: request-scoped logger/tracer/metrics/header-id tracking via `ContextWith*` helpers, safe timeout with `WithTimeoutSafe`, span attribute propagation
+- `errors.go`: standardized business error mapping with `ValidateBusinessError`
+- `utils.go`: UUID generation (`GenerateUUIDv7` returns error), struct-to-JSON, map merging, CPU/memory metrics, internal service detection
+- `stringUtils.go`: accent removal, case conversion, UUID placeholder replacement, SHA-256 hashing, server address validation
+- `time.go`: date/time validation, range checking, parsing with end-of-day support
+- `os.go`: environment variable helpers (`GetenvOrDefault`, `GetenvBoolOrDefault`, `GetenvIntOrDefault`), struct population from env tags via `SetConfigFromEnvVars`
+- `commons/constants`: shared constants for datasource status, errors, headers, metadata, pagination, transactions, OTEL attributes, obfuscation values, and `SanitizeMetricLabel` utility
+
+### Observability and logging
+
+- `commons/opentelemetry`: telemetry bootstrap (`NewTelemetry`), propagation (HTTP/gRPC/queue), span helpers, redaction (`Redactor` with `RedactionRule` patterns), struct-to-attribute conversion
+- `commons/opentelemetry/metrics`: fluent metrics factory (`NewMetricsFactory`, `NewNopFactory`) with Counter/Gauge/Histogram builders, explicit error returns, convenience recorders for accounts/transactions
+- `commons/log`: v2 logging interface (`Logger` with `Log`/`With`/`WithGroup`/`Enabled`/`Sync`), typed `Field` constructors (`String`, `Int`, `Bool`, `Err`, `Any`), `GoLogger` with CWE-117 log-injection prevention, sanitizer (`SafeError`, `SanitizeExternalResponse`)
+- `commons/zap`: zap adapter for `commons/log` with OTEL bridge, `Config`-based construction via `New()`, direct zap convenience methods (`Debug`/`Info`/`Warn`/`Error`), underlying access via `Raw()` and `Level()`
+
+### Data and messaging connectors
+
+- `commons/postgres`: `Config`-based constructor (`New`), `Resolver(ctx)` for dbresolver access, `Primary()` for raw `*sql.DB`, `NewMigrator` for schema migrations, backoff-based lazy-connect
+- `commons/mongo`: `Config`-based client with functional options (`NewClient`), URI builder (`BuildURI`), `Client(ctx)`/`ResolveClient(ctx)` for access, `EnsureIndexes` (variadic), TLS support, credential clearing
+- `commons/redis`: topology-based `Config` (standalone/sentinel/cluster), GCP IAM auth with token refresh, distributed locking via `LockManager` interface (`NewRedisLockManager`, `LockHandle`), `SetPackageLogger` for diagnostics, TLS defaults to a TLS1.2 minimum floor with `AllowLegacyMinVersion` as an explicit temporary compatibility override
+- `commons/rabbitmq`: connection/channel/health helpers for AMQP with `*Context()` variants, `HealthCheck() (bool, error)`, `Close()`/`CloseContext()`, confirmable publisher with broker acks and auto-recovery, DLQ topology utilities, and health-check hardening (`AllowInsecureHealthCheck`, `HealthCheckAllowedHosts`, `RequireHealthCheckAllowedHosts`)
+
+### HTTP and server utilities
+
+- `commons/net/http`: Fiber HTTP helpers -- response (`Respond`/`RespondStatus`/`RespondError`/`RenderError`), health (`Ping`/`HealthWithDependencies`), SSRF-protected reverse proxy (`ServeReverseProxy` with `ReverseProxyPolicy`), pagination (offset/opaque cursor/timestamp cursor/sort cursor), validation (`ParseBodyAndValidate`/`ValidateStruct`/`ValidateSortDirection`/`ValidateLimit`), context/ownership (`ParseAndVerifyTenantScopedID`/`ParseAndVerifyResourceScopedID`), middleware (`WithHTTPLogging`/`WithGrpcLogging`/`WithCORS`/`WithBasicAuth`/`NewTelemetryMiddleware`), `FiberErrorHandler`
+- `commons/net/http/ratelimit`: Redis-backed distributed rate limiting middleware for Fiber — `New(conn, opts...)` returns a `*RateLimiter` (nil when disabled, nil-safe for pass-through), `WithDefaultRateLimit(conn, opts...)` as a one-liner that wires `New` + `DefaultTier` into a ready-to-use `fiber.Handler`, fixed-window counter via atomic Lua script (INCR + PEXPIRE), `WithRateLimit(tier)` for static tiers, `WithDynamicRateLimit(TierFunc)` for per-request tier selection, `MethodTierSelector` for write-vs-read split, preset tiers (`DefaultTier` / `AggressiveTier` / `RelaxedTier`) configurable via env vars, identity extractors (`IdentityFromIP` / `IdentityFromHeader` / `IdentityFromIPAndHeader` — uses `#` separator to avoid conflict with IPv6 colons), fail-open/fail-closed policy, `WithOnLimited` callback, and standard `X-RateLimit-*` / `Retry-After` headers; also exports `RedisStorage` (`NewRedisStorage`) for use with third-party Fiber middleware
+- `commons/server`: `ServerManager`-based graceful shutdown with `WithHTTPServer`/`WithGRPCServer`/`WithShutdownChannel`/`WithShutdownTimeout`/`WithShutdownHook`, `StartWithGracefulShutdown()`/`StartWithGracefulShutdownWithError()`, `ServersStarted()` for test coordination
+
+### Resilience and safety
+
+- `commons/circuitbreaker`: `Manager` interface with error-returning constructors (`NewManager`), config validation, preset configs (`DefaultConfig`/`AggressiveConfig`/`ConservativeConfig`/`HTTPServiceConfig`/`DatabaseConfig`), health checker (`NewHealthCheckerWithValidation`), metrics via `WithMetricsFactory`
+- `commons/backoff`: exponential backoff with jitter (`ExponentialWithJitter`) and context-aware sleep (`WaitContext`)
+- `commons/errgroup`: error-group concurrency with panic recovery (`WithContext`, `Go`, `Wait`), configurable logger via `SetLogger`
+- `commons/runtime`: panic recovery (`RecoverAndLog`/`RecoverAndCrash`/`RecoverWithPolicy` with `*WithContext` variants), safe goroutines (`SafeGo`/`SafeGoWithContext`/`SafeGoWithContextAndComponent`), panic metrics (`InitPanicMetrics`), span recording (`RecordPanicToSpan`), error reporter (`SetErrorReporter`/`GetErrorReporter`), production mode (`SetProductionMode`/`IsProductionMode`)
+- `commons/assert`: production-safe assertions (`New` + `That`/`NotNil`/`NotEmpty`/`NoError`/`Never`/`Halt`), assertion metrics (`InitAssertionMetrics`), domain predicates (`Positive`/`ValidUUID`/`ValidAmount`/`DebitsEqualCredits`/`TransactionCanTransitionTo`/`BalanceSufficientForRelease` and more)
+- `commons/safe`: panic-safe math (`Divide`/`DivideRound`/`Percentage` on `decimal.Decimal`, `DivideFloat64`), regex with caching (`Compile`/`MatchString`/`FindString`), slices (`First`/`Last`/`At` with `*OrDefault` variants)
+- `commons/security`: sensitive field detection (`IsSensitiveField`), default field lists (`DefaultSensitiveFields`/`DefaultSensitiveFieldsMap`)
+
+### Domain and support packages
+
+- `commons/transaction`: intent-based transaction planning (`BuildIntentPlan`), balance eligibility validation (`ValidateBalanceEligibility`), posting flow (`ApplyPosting`), operation resolution (`ResolveOperation`), typed domain errors (`NewDomainError`)
+- `commons/outbox`: transactional outbox contracts, dispatcher, sanitizer, and PostgreSQL adapters for schema-per-tenant or column-per-tenant models (schema resolver requires tenant context by default; column migration uses composite key `(tenant_id, id)`)
+- `commons/crypto`: hashing (`GenerateHash`) and symmetric encryption (`InitializeCipher`/`Encrypt`/`Decrypt`) with credential-safe `fmt` output (`String()`/`GoString()` redact secrets)
+- `commons/jwt`: HS256/384/512 JWT signing (`Sign`), signature verification (`Parse`), combined signature + time-claim validation (`ParseAndValidate`), standalone time-claim validation (`ValidateTimeClaims`/`ValidateTimeClaimsAt`)
+- `commons/license`: license validation with functional options (`New(opts...)`, `WithLogger`), handler management (`SetHandler`), termination (`Terminate`/`TerminateWithError`/`TerminateSafe`)
+- `commons/pointers`: pointer conversion helpers (`String`, `Bool`, `Time`, `Int`, `Int64`, `Float64`)
+- `commons/cron`: cron expression parser (`Parse`) and scheduler (`Schedule.Next`)
+- `commons/secretsmanager`: AWS Secrets Manager M2M credential retrieval via `GetM2MCredentials`, typed retrieval errors, and the `SecretsManagerClient` test seam
+
+### Multi-tenant packages
+
+- `commons/tenant-manager/core`: shared tenant types, context helpers (`ContextWithTenantID`, `GetTenantIDFromContext`), and tenant-manager error contracts
+- `commons/tenant-manager/cache`: exported tenant-config cache contract (`ConfigCache`), `ErrCacheMiss`, and in-memory cache implementation used by the HTTP client
+- `commons/tenant-manager/client`: Tenant Manager HTTP client with circuit breaker, cache options (`WithCache`, `WithCacheTTL`, `WithSkipCache`), cache invalidation, and response hardening
+- `commons/tenant-manager/consumer`: dynamic multi-tenant queue consumer lifecycle management with tenant discovery, sync, retry, and per-tenant handlers
+- `commons/tenant-manager/middleware`: Fiber middleware for tenant extraction, upstream auth assertion checks, and tenant-scoped DB resolution
+- `commons/tenant-manager/postgres`: tenant-scoped PostgreSQL connection manager with LRU eviction, async settings revalidation, and pool controls
+- `commons/tenant-manager/mongo`: tenant-scoped MongoDB connection manager with LRU eviction and idle-timeout controls
+- `commons/tenant-manager/rabbitmq`: tenant-scoped RabbitMQ connection manager with soft connection-pool limits and eviction
+- `commons/tenant-manager/s3`: tenant-prefixed S3/object-storage key helpers with delimiter validation
+- `commons/tenant-manager/valkey`: tenant-prefixed Redis/Valkey key and pattern helpers with delimiter validation
+
+### Build and shell
+
+- `commons/shell/`: Makefile include helpers (`makefile_colors.mk`, `makefile_utils.mk`), shell scripts (`colors.sh`, `ascii.sh`), ASCII art (`logo.txt`)
+
+## Minimal v4 usage
+
+```go
+import (
+ "context"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+)
+
+func bootstrap() error {
+ logger := log.NewNop()
+
+ tl, err := opentelemetry.NewTelemetry(opentelemetry.TelemetryConfig{
+ LibraryName: "my-service",
+ ServiceName: "my-service-api",
+ ServiceVersion: "2.0.0",
+ DeploymentEnv: "local",
+ CollectorExporterEndpoint: "localhost:4317",
+ EnableTelemetry: false, // Set to true when collector is available
+ InsecureExporter: true,
+ Logger: logger,
+ })
+ if err != nil {
+ return err
+ }
+ defer tl.ShutdownTelemetry()
+
+ tl.ApplyGlobals()
+
+ _ = context.Background()
+
+ return nil
+}
+```
-### Observability
+## Environment Variables
-- **Logging**: Pluggable logging interface with multiple implementations
-- **Logging Obfuscation**: Dynamic environment variable to obfuscate specific fields from the request payload logging
- - `SECURE_LOG_FIELDS=password,apiKey`
-- **OpenTelemetry**: Integrated tracing, metrics, and logs through OpenTelemetry
-- **Zap**: Integration with Uber's Zap logging library
+The following environment variables are recognized by lib-commons:
-### Utilities
+| Variable | Type | Default | Package | Description |
+| :--- | :--- | :--- | :--- | :--- |
+| `VERSION` | `string` | `"NO-VERSION"` | `commons` | Application version, printed at startup by `InitLocalEnvConfig` |
+| `ENV_NAME` | `string` | `"local"` | `commons` | Environment name; when `"local"`, a `.env` file is loaded automatically |
+| `ENV` | `string` | _(none)_ | `commons/assert` | When set to `"production"`, stack traces are omitted from assertion failures |
+| `GO_ENV` | `string` | _(none)_ | `commons/assert` | Fallback production check (same behavior as `ENV`) |
+| `LOG_LEVEL` | `string` | `"debug"` (dev/local) / `"info"` (other) | `commons/zap` | Log level override (`debug`, `info`, `warn`, `error`); `Config.Level` takes precedence if set |
+| `LOG_ENCODING` | `string` | `"console"` (dev/local) / `"json"` (other) | `commons/zap` | Log output format: `"json"` for structured JSON, `"console"` for human-readable colored output |
+| `LOG_OBFUSCATION_DISABLED` | `bool` | `false` | `commons/net/http` | Set to `"true"` to disable sensitive-field obfuscation in HTTP access logs (**not recommended in production**) |
+| `METRICS_COLLECTION_INTERVAL` | `duration` | `"5s"` | `commons/net/http` | Background system-metrics collection interval (Go duration format, e.g. `"10s"`, `"1m"`) |
+| `ACCESS_CONTROL_ALLOW_CREDENTIALS` | `bool` | `"false"` | `commons/net/http` | CORS `Access-Control-Allow-Credentials` header value |
+| `ACCESS_CONTROL_ALLOW_ORIGIN` | `string` | `"*"` | `commons/net/http` | CORS `Access-Control-Allow-Origin` header value |
+| `ACCESS_CONTROL_ALLOW_METHODS` | `string` | `"POST, GET, OPTIONS, PUT, DELETE, PATCH"` | `commons/net/http` | CORS `Access-Control-Allow-Methods` header value |
+| `ACCESS_CONTROL_ALLOW_HEADERS` | `string` | `"Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization"` | `commons/net/http` | CORS `Access-Control-Allow-Headers` header value |
+| `ACCESS_CONTROL_EXPOSE_HEADERS` | `string` | `""` | `commons/net/http` | CORS `Access-Control-Expose-Headers` header value |
+| `RATE_LIMIT_ENABLED` | `bool` | `"true"` | `commons/net/http/ratelimit` | Set to `"false"` to disable rate limiting globally; `New` returns nil and all requests pass through |
+| `RATE_LIMIT_MAX` | `int` | `500` | `commons/net/http/ratelimit` | Maximum requests per window for `DefaultTier` |
+| `RATE_LIMIT_WINDOW_SEC` | `int` | `60` | `commons/net/http/ratelimit` | Window duration in seconds for `DefaultTier` |
+| `AGGRESSIVE_RATE_LIMIT_MAX` | `int` | `100` | `commons/net/http/ratelimit` | Maximum requests per window for `AggressiveTier` |
+| `AGGRESSIVE_RATE_LIMIT_WINDOW_SEC` | `int` | `60` | `commons/net/http/ratelimit` | Window duration in seconds for `AggressiveTier` |
+| `RELAXED_RATE_LIMIT_MAX` | `int` | `1000` | `commons/net/http/ratelimit` | Maximum requests per window for `RelaxedTier` |
+| `RELAXED_RATE_LIMIT_WINDOW_SEC` | `int` | `60` | `commons/net/http/ratelimit` | Window duration in seconds for `RelaxedTier` |
+| `RATE_LIMIT_REDIS_TIMEOUT_MS` | `int` | `500` | `commons/net/http/ratelimit` | Timeout in milliseconds for Redis operations; exceeded requests follow fail-open/fail-closed policy |
-- **String Utilities**: Common string manipulation functions
-- **Type Conversion**: Safe type conversion utilities
-- **Time Helpers**: Date and time manipulation functions
-- **OS Utilities**: Operating system related utilities
-- **Pointer Utilities**: Helper functions for pointer type operations
-- **Transaction Processing**: Utilities for financial transaction processing and validation
+Additionally, `commons.SetConfigFromEnvVars` populates any struct using `env:"VAR_NAME"` field tags, supporting `string`, `bool`, and integer types. Consuming applications define their own variable names through these tags.
-## Getting Started
+## Development commands
-### Prerequisites
+### Core
-- Go 1.23.2 or higher
+- `make build` -- build all packages
+- `make ci` -- run the local fix + verify pipeline (`lint-fix`, `format`, `tidy`, `check-tests`, `sec`, `vet`, `test-unit`, `test-integration`)
+- `make clean` -- clean build artifacts and caches
+- `make tidy` -- clean dependencies (`go mod tidy`)
+- `make format` -- format code with gofmt
+- `make help` -- display all available commands
-### Installation
+### Testing
-```bash
-go get github.com/LerianStudio/lib-commons/v2
-```
+- `make test` -- run unit tests (uses gotestsum if available)
+- `make test-unit` -- run unit tests excluding integration
+- `make test-integration` -- run integration tests with testcontainers (requires Docker)
+- `make test-all` -- run all tests (unit + integration)
+
+### Coverage
+
+- `make coverage-unit` -- unit tests with coverage report (respects `.ignorecoverunit`)
+- `make coverage-integration` -- integration tests with coverage
+- `make coverage` -- run all coverage targets
+
+### Code quality
+
+- `make lint` -- run lint checks (read-only)
+- `make lint-fix` -- auto-fix lint issues
+- `make vet` -- run `go vet` on all packages
+- `make sec` -- run security checks using gosec (`make sec SARIF=1` for SARIF output)
+- `make check-tests` -- verify test coverage for packages
+
+### Test flags
+
+- `LOW_RESOURCE=1` -- reduces parallelism and disables race detector for constrained machines
+- `RETRY_ON_FAIL=1` -- retries failed tests once
+- `RUN=` -- filter integration tests by name pattern
+- `PKG=` -- filter to specific package(s)
+
+### Git hooks
+
+- `make setup-git-hooks` -- install and configure git hooks
+- `make check-hooks` -- verify git hooks installation
+- `make check-envs` -- check hooks + environment file security
+
+### Tooling and release
+
+- `make tools` -- install test tools (gotestsum)
+- `make goreleaser` -- create release snapshot
+
+## Project Rules
-## API Reference
-
-### Core Components
-
-#### Application Management (`commons`)
-
-| Method | Description |
-| ---------------------------------- | -------------------------------------------------------- |
-| `NewLauncher(...LauncherOption)` | Creates a new application launcher with provided options |
-| `WithLogger(logger)` | LauncherOption that adds a logger to the launcher |
-| `RunApp(name, app)` | LauncherOption that registers an application to run |
-| `Launcher.Add(appName, app)` | Registers an application to run |
-| `Launcher.Run()` | Runs all registered applications in goroutines. |
-
-#### Context Utilities (`commons`)
-
-| Method | Description |
-| -------------------------------------- | ------------------------------- |
-| `ContextWithLogger(ctx, logger)` | Returns a context with logger |
-| `NewLoggerFromContext(ctx)` | Extracts logger from context |
-| `ContextWithTracer(ctx, tracer)` | Returns a context with tracer |
-| `NewTracerFromContext(ctx)` | Extracts tracer from context |
-| `ContextWithHeaderID(ctx, headerID)` | Returns a context with headerID |
-| `NewHeaderIDFromContext(ctx)` | Extracts headerID from context |
-
-#### Error Handling (`commons`)
-
-| Method | Description |
-| --------------------------------------------------- | ---------------------------------------------------------------------------- |
-| `ValidateBusinessError(err, entityType, ...args)` | Maps domain errors to business responses with appropriate codes and messages |
-| `Response.Error()` | Returns the error message from a Response |
-
-### Database Connectors
-
-#### PostgreSQL (`commons/postgres`)
-
-| Method | Description |
-| --------------------------------------------- | -------------------------------------------------------- |
-| `PostgresConnection.Connect()` | Establishes connection to PostgreSQL primary and replica |
-| `PostgresConnection.GetDB()` | Returns the database connection |
-| `PostgresConnection.MigrateUp(sourceDir)` | Runs database migrations |
-| `PostgresConnection.MigrateDown(sourceDir)` | Reverts database migrations |
-| `GetPagination(page, pageSize)` | Gets pagination parameters for SQL queries |
-
-#### MongoDB (`commons/mongo`)
-
-| Method | Description |
-| --------------------------------------- | --------------------------------- |
-| `MongoConnection.Connect(ctx)` | Establishes connection to MongoDB |
-| `MongoConnection.GetDB(ctx)` | Returns the MongoDB client |
-| `MongoConnection.EnsureIndexes(ctx, collection, index)` | Ensures an index exists (idempotent). If the collection does not exist, MongoDB will create it automatically during index creation. |
-
-#### Redis (`commons/redis`)
-
-| Method | Description |
-| ---------------------------------------------------- | ------------------------------- |
-| `RedisConnection.Connect()` | Establishes connection to Redis |
-| `RedisConnection.GetClient()` | Returns the Redis client |
-| `RedisConnection.Set(ctx, key, value, expiration)` | Sets a key-value pair in Redis |
-| `RedisConnection.Get(ctx, key)` | Gets a value from Redis by key |
-| `RedisConnection.Del(ctx, keys...)` | Deletes keys from Redis |
-
-### Messaging
-
-#### RabbitMQ (`commons/rabbitmq`)
-
-| Method | Description |
-| ------------------------------------------------------------- | ---------------------------------- |
-| `RabbitMQConnection.Connect()` | Establishes connection to RabbitMQ |
-| `RabbitMQConnection.GetChannel()` | Returns a RabbitMQ channel |
-| `RabbitMQConnection.DeclareQueue(name)` | Declares a queue |
-| `RabbitMQConnection.DeclareExchange(name, kind)` | Declares an exchange |
-| `RabbitMQConnection.QueueBind(queue, exchange, routingKey)` | Binds a queue to an exchange |
-| `RabbitMQConnection.Publish(exchange, routingKey, body)` | Publishes a message |
-| `RabbitMQConnection.Consume(queue, consumer)` | Consumes messages from a queue |
-
-### Observability
-
-#### Logging (`commons/log`)
-
-| Method | Description |
-| --------------------------------------- | ---------------------------------------------- |
-| `Info(args...)` | Logs info level message |
-| `Infof(format, args...)` | Logs formatted info level message |
-| `Error(args...)` | Logs error level message |
-| `Errorf(format, args...)` | Logs formatted error level message |
-| `Warn(args...)` | Logs warning level message |
-| `Warnf(format, args...)` | Logs formatted warning level message |
-| `Debug(args...)` | Logs debug level message |
-| `Debugf(format, args...)` | Logs formatted debug level message |
-| `Fatal(args...)` | Logs fatal level message and exits |
-| `Fatalf(format, args...)` | Logs formatted fatal level message and exits |
-| `WithFields(fields...)` | Returns a logger with additional fields |
-| `WithDefaultMessageTemplate(message)` | Returns a logger with default message template |
-| `Sync()` | Flushes any buffered log entries |
-
-#### Zap Integration (`commons/zap`)
-
-| Method | Description |
-| ------------------------------------------ | ------------------------------------------- |
-| `NewZapLogger(config)` | Creates a new Zap logger |
-| `ZapLoggerAdapter.Info(args...)` | Logs info level message using Zap |
-| `ZapLoggerAdapter.Error(args...)` | Logs error level message using Zap |
-| `ZapLoggerAdapter.WithFields(fields...)` | Returns a Zap logger with additional fields |
-
-#### OpenTelemetry (`commons/opentelemetry`)
-
-| Method | Description |
-| ----------------------------------------- | --------------------------------------------------------------- |
-| `Telemetry.InitializeTelemetry(logger)` | Initializes OpenTelemetry with trace, metric, and log providers |
-| `Telemetry.ShutdownTelemetry()` | Shuts down OpenTelemetry providers |
-| `Telemetry.GetTracer()` | Returns a tracer from the provider |
-| `Telemetry.GetMeter()` | Returns a meter from the provider |
-| `Telemetry.GetLogger()` | Returns a logger from the provider |
-| `Telemetry.StartSpan(ctx, name)` | Starts a new trace span |
-| `Telemetry.EndSpan(span, err)` | Ends a trace span with optional error |
-
-### Utilities
-
-#### String Utilities (`commons`)
-
-| Method | Description |
-| --------------------------------------- | ---------------------------------------- |
-| `IsNilOrEmpty(s)` | Checks if string pointer is nil or empty |
-| `TruncateString(s, maxLen)` | Truncates string to maximum length |
-| `MaskEmail(email)` | Masks email address for privacy |
-| `MaskLastDigits(value, digitsToShow)` | Masks all but last digits of a string |
-| `StringToObject(s, obj)` | Converts JSON string to object |
-| `ObjectToString(obj)` | Converts object to JSON string |
-
-#### OS Utilities (`commons`)
-
-| Method | Description |
-| ------------------------- | -------------------------------------------- |
-| `GetEnv(key, fallback)` | Gets environment variable with fallback |
-| `MustGetEnv(key)` | Gets required environment variable or panics |
-| `LoadEnvFile(file)` | Loads environment variables from file |
-| `GetMemUsage()` | Gets current memory usage statistics |
-| `GetCPUUsage()` | Gets current CPU usage statistics |
-
-#### Time Utilities (`commons`)
-
-| Method | Description |
-| ------------------------------ | ------------------------------------ |
-| `FormatTime(t, layout)` | Formats time according to layout |
-| `ParseTime(s, layout)` | Parses time from string using layout |
-| `GetCurrentTime()` | Gets current time in UTC |
-| `TimeBetween(t, start, end)` | Checks if time is between two times |
-
-#### Pointer Utilities (`commons/pointers`)
-
-| Method | Description |
-| -------------------- | ------------------------------------------------------ |
-| `ToString(s)` | Creates string pointer from string |
-| `ToInt(i)` | Creates int pointer from int |
-| `ToBool(b)` | Creates bool pointer from bool |
-| `FromStringPtr(s)` | Gets string from string pointer with safe nil handling |
-| `FromIntPtr(i)` | Gets int from int pointer with safe nil handling |
-| `FromBoolPtr(b)` | Gets bool from bool pointer with safe nil handling |
-
-#### Transaction Processing (`commons/transaction`)
-
-| Method | Description |
-| ------------------------------------- | ------------------------------------------------------ |
-| `ValidateTransactionRequest(req)` | Validates transaction request against business rules |
-| `ValidateAccountBalances(accounts)` | Validates account balances for transaction processing |
-| `ValidateAssetCode(code)` | Validates asset code existence and status |
-| `ValidateAccountStatuses(accounts)` | Validates account statuses for transaction eligibility |
-
-#### Shell Utilities (`commons/shell`)
-
-| Method | Description |
-| ----------------------------------------------- | ----------------------------------------- |
-| `ExecuteCommand(command)` | Executes shell command and returns output |
-| `ExecuteCommandWithTimeout(command, timeout)` | Executes shell command with timeout |
-| `ExecuteCommandInBackground(command)` | Executes shell command in background |
-
-#### Network Utilities (`commons/net`)
-
-| Method | Description |
-| -------------------------- | -------------------------------------- |
-| `ValidateURL(url)` | Validates URL format and accessibility |
-| `GetLocalIP()` | Gets local IP address |
-| `IsPortOpen(host, port)` | Checks if port is open on host |
-| `GetFreePort()` | Gets a free port on local machine |
-
-## Contributing
-
-Please read the contributing guidelines before submitting pull requests.
+For coding standards, architecture patterns, testing requirements, and development guidelines, see [`docs/PROJECT_RULES.md`](docs/PROJECT_RULES.md).
## License
-This project is licensed under the terms found in the LICENSE file in the root directory.
+This project is licensed under the terms in `LICENSE`.
diff --git a/REVIEW.md b/REVIEW.md
new file mode 100644
index 00000000..513ecbd6
--- /dev/null
+++ b/REVIEW.md
@@ -0,0 +1,388 @@
+# Review Findings
+
+Generated from 54 reviewer-agent runs (6 reviewers x 9 slices). Empty severity buckets are omitted. Similar findings are intentionally preserved when multiple reviewer lenses surfaced them independently.
+
+## 1. Observability + Metrics
+
+### Critical
+- [nil-safety] `references/lib-commons/commons/opentelemetry/metrics/metrics.go:105`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:119`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:133`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:179`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:214`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:251`, `references/lib-commons/commons/opentelemetry/metrics/account.go:10`, `references/lib-commons/commons/opentelemetry/metrics/transaction.go:10`, `references/lib-commons/commons/opentelemetry/metrics/operation_routes.go:10`, `references/lib-commons/commons/opentelemetry/metrics/transaction_routes.go:10`, `references/lib-commons/commons/opentelemetry/metrics/system.go:25`, `references/lib-commons/commons/opentelemetry/metrics/system.go:35` - exported `*MetricsFactory` methods are not nil-safe and can panic on nil receivers.
+- [nil-safety] `references/lib-commons/commons/opentelemetry/metrics/builders.go:29`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:47`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:63`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:74`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:87`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:105`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:125`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:144`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:162`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:178` - nil builder receivers panic before the intended `ErrNil*` guard can run.
+
+### High
+- [code] `references/lib-commons/commons/opentelemetry/otel.go:134`, `references/lib-commons/commons/opentelemetry/otel.go:139`, `references/lib-commons/commons/opentelemetry/otel.go:144`, `references/lib-commons/commons/opentelemetry/otel.go:153` - `NewTelemetry` allocates exporters/providers incrementally but does not roll back already-created resources if a later step fails.
+- [code] `references/lib-commons/commons/opentelemetry/metrics/metrics.go:180`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:191`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:215`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:226` - counter and gauge caching is keyed only by metric name, so later callers can silently get the wrong description/unit metadata.
+- [business] `references/lib-commons/commons/opentelemetry/obfuscation.go:122`, `references/lib-commons/commons/opentelemetry/obfuscation.go:125`, `references/lib-commons/commons/opentelemetry/obfuscation.go:128`, `references/lib-commons/commons/opentelemetry/obfuscation.go:132` - `PathPattern`-only redaction rules are not truly path-only; if `FieldPattern` is empty, matching falls back to `security.IsSensitiveField`, so custom path-scoped rules for non-default-sensitive keys silently do not apply.
+- [business] `references/lib-commons/commons/opentelemetry/metrics/builders.go:63`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:68` - `CounterBuilder.Add` accepts negative values, violating monotonic counter semantics.
+- [business] `references/lib-commons/commons/opentelemetry/metrics/metrics.go:162`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:163`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:164`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:169`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:170` - default histogram bucket selection prioritizes `transaction` over `latency`/`duration`/`time`, so names like `transaction.processing.latency` get the wrong bucket strategy.
+- [security] `references/lib-commons/commons/opentelemetry/otel.go:366`, `references/lib-commons/commons/opentelemetry/otel.go:384`, `references/lib-commons/commons/opentelemetry/otel.go:385` - unsanitized `err.Error()` content and `span.RecordError(err)` are exported directly into spans, bypassing redaction.
+- [test] `references/lib-commons/commons/opentelemetry/obfuscation_test.go:979`, `references/lib-commons/commons/opentelemetry/obfuscation_test.go:986` - `TestObfuscateStruct_FieldWithDotsInKey` has no real assertion.
+- [test] `references/lib-commons/commons/opentelemetry/otel_test.go:927`, `references/lib-commons/commons/opentelemetry/otel_test.go:938` - processor tests start spans but never inspect exported attributes, so the behaviors they claim to test are not actually validated.
+- [test] `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1088`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1118`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1146`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1175`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1179`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1209`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1213`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1235`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1239`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1265`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1270`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1275` - several concurrency tests silently discard returned errors or return early on failure.
+- [nil-safety] `references/lib-commons/commons/opentelemetry/otel.go:172`, `references/lib-commons/commons/opentelemetry/otel.go:181`, `references/lib-commons/commons/opentelemetry/otel.go:182`, `references/lib-commons/commons/opentelemetry/otel.go:183`, `references/lib-commons/commons/opentelemetry/otel.go:184` - `ApplyGlobals` only rejects a nil `Telemetry` pointer, not a zero-value or partially initialized `Telemetry`, so it can poison global OTEL state.
+- [nil-safety] `references/lib-commons/commons/opentelemetry/otel.go:362`, `references/lib-commons/commons/opentelemetry/otel.go:366`, `references/lib-commons/commons/opentelemetry/otel.go:371`, `references/lib-commons/commons/opentelemetry/otel.go:375`, `references/lib-commons/commons/opentelemetry/otel.go:380`, `references/lib-commons/commons/opentelemetry/otel.go:384`, `references/lib-commons/commons/opentelemetry/otel.go:385`, `references/lib-commons/commons/opentelemetry/otel.go:390`, `references/lib-commons/commons/opentelemetry/otel.go:400` - span helpers use `span == nil` on an interface and can still panic on typed-nil spans.
+- [consequences] `references/lib-commons/commons/opentelemetry/otel.go:172`, `references/lib-commons/commons/opentelemetry/otel.go:184`, `references/lib-commons/commons/opentelemetry/otel.go:498`, `references/lib-commons/commons/opentelemetry/otel.go:507` - propagation helpers are hard-wired to the global propagator, so `TelemetryConfig.Propagator` only takes effect if callers also mutate globals.
+- [consequences] `references/lib-commons/commons/opentelemetry/otel.go:639`, `references/lib-commons/commons/opentelemetry/otel.go:646`, `references/lib-commons/commons/opentelemetry/otel.go:647` - `ExtractTraceContextFromQueueHeaders` only accepts string values and drops valid upstream headers represented as `[]byte` or typed AMQP values.
+- [consequences] `references/lib-commons/commons/opentelemetry/obfuscation.go:59`, `references/lib-commons/commons/opentelemetry/obfuscation.go:64`, `references/lib-commons/commons/opentelemetry/obfuscation.go:104`, `references/lib-commons/commons/opentelemetry/otel.go:92` - if default redactor construction fails, `NewDefaultRedactor()` returns a redactor with no compiled rules instead of failing closed, so sensitive fields may be exported.
+
+### Medium
+- [code] `references/lib-commons/commons/opentelemetry/otel.go:423`, `references/lib-commons/commons/opentelemetry/otel.go:428`, `references/lib-commons/commons/opentelemetry/otel.go:429`, `references/lib-commons/commons/opentelemetry/otel.go:470` - `BuildAttributesFromValue` round-trips through JSON without `UseNumber`, so integers become `float64` and large values lose precision.
+- [code] `references/lib-commons/commons/opentelemetry/otel.go:464`, `references/lib-commons/commons/opentelemetry/otel.go:465`, `references/lib-commons/commons/opentelemetry/otel.go:466` - sanitization happens before byte truncation, so truncation can split a multibyte rune and reintroduce invalid UTF-8.
+- [code] `references/lib-commons/commons/opentelemetry/metrics/metrics.go:252`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:263`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:287`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:295`, `references/lib-commons/commons/opentelemetry/metrics/metrics.go:341` - histogram cache keys sort bucket boundaries, but instrument creation keeps caller order, so semantically different configs collide.
+- [business] `references/lib-commons/commons/opentelemetry/otel.go:423`, `references/lib-commons/commons/opentelemetry/otel.go:428`, `references/lib-commons/commons/opentelemetry/otel.go:429`, `references/lib-commons/commons/opentelemetry/otel.go:470` - trace attributes can carry incorrect business values because numeric precision is lost during JSON flattening.
+- [business] `references/lib-commons/commons/opentelemetry/metrics/system.go:25`, `references/lib-commons/commons/opentelemetry/metrics/system.go:31`, `references/lib-commons/commons/opentelemetry/metrics/system.go:35`, `references/lib-commons/commons/opentelemetry/metrics/system.go:41` - percentage helpers accept any integer and do not validate the 0..100 range.
+- [security] `references/lib-commons/commons/opentelemetry/metrics/builders.go:29`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:47`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:87`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:105`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:144`, `references/lib-commons/commons/opentelemetry/metrics/builders.go:162` - metric builders accept arbitrary caller-supplied labels/attributes with no sanitization or cardinality guard.
+- [security] `references/lib-commons/commons/opentelemetry/otel.go:125`, `references/lib-commons/commons/opentelemetry/otel.go:126`, `references/lib-commons/commons/opentelemetry/otel.go:127`, `references/lib-commons/commons/opentelemetry/otel.go:266`, `references/lib-commons/commons/opentelemetry/otel.go:275`, `references/lib-commons/commons/opentelemetry/otel.go:284` - plaintext OTLP export is allowed in non-dev environments with only a warning instead of failing closed.
+- [test] `references/lib-commons/commons/opentelemetry/otel_test.go:805`, `references/lib-commons/commons/opentelemetry/otel_test.go:818`, `references/lib-commons/commons/opentelemetry/otel_test.go:831` - tests only assert `NotPanics` and do not verify emitted events, recorded errors, or span status.
+- [test] `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1104`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1195`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1254`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1218`, `references/lib-commons/commons/opentelemetry/metrics/v2_test.go:1280` - several concurrency tests mostly equate success with “no race/no panic” and have weak postconditions.
+- [consequences] `references/lib-commons/commons/opentelemetry/otel.go:423`, `references/lib-commons/commons/opentelemetry/otel.go:429`, `references/lib-commons/commons/opentelemetry/otel.go:470` - precision loss in attribute flattening can misalign dashboards and queries that expect exact IDs and counters.
+- [consequences] `references/lib-commons/commons/opentelemetry/otel.go:434`, `references/lib-commons/commons/opentelemetry/otel.go:460`, `references/lib-commons/commons/opentelemetry/otel.go:469`, `references/lib-commons/commons/opentelemetry/otel.go:479` - top-level scalars can emit an empty attribute key and top-level slices can emit keys like `.0`.
+- [consequences] `references/lib-commons/commons/opentelemetry/otel.go:134`, `references/lib-commons/commons/opentelemetry/otel.go:139`, `references/lib-commons/commons/opentelemetry/otel.go:144`, `references/lib-commons/commons/opentelemetry/otel.go:158` - failed `NewTelemetry` calls do not clean up partially created exporters, so retries can accumulate orphaned resources.
+
+### Low
+- [code] `references/lib-commons/commons/opentelemetry/otel.go:459`, `references/lib-commons/commons/opentelemetry/otel.go:460` - flattening a top-level slice with an empty prefix produces keys like `.0`.
+- [security] `references/lib-commons/commons/opentelemetry/otel.go:483`, `references/lib-commons/commons/opentelemetry/otel.go:494` - `SetSpanAttributeForParam` writes raw request parameter values into span attributes without sensitivity checks.
+- [test] `references/lib-commons/commons/opentelemetry/v2_test.go:166` - `TestHandleSpanHelpers_NoPanicsOnNil` bundles multiple helper behaviors into a single no-panic test, reducing failure isolation.
+- [consequences] `references/lib-commons/commons/opentelemetry/otel.go:379`, `references/lib-commons/commons/opentelemetry/otel.go:384` - `HandleSpanError` can emit malformed status descriptions like `": ..."` when message is empty.
+
+## 2. HTTP Surface + Server Lifecycle
+
+### Critical
+- [nil-safety] `references/lib-commons/commons/net/http/proxy.go:119` - `ServeReverseProxy` checks `req != nil` but not `req.URL != nil`, so `&http.Request{}` can panic.
+- [nil-safety] `references/lib-commons/commons/net/http/withTelemetry.go:85`, `references/lib-commons/commons/net/http/withTelemetry.go:164` - middleware dereferences `effectiveTelemetry.TracerProvider` directly, so a partially initialized telemetry instance crashes the first request.
+
+### High
+- [code] `references/lib-commons/commons/server/shutdown.go:181`, `references/lib-commons/commons/server/shutdown.go:334`, `references/lib-commons/commons/server/shutdown.go:345` - `StartWithGracefulShutdownWithError()` logs startup failures but still returns `nil`.
+- [code] `references/lib-commons/commons/net/http/withTelemetry.go:262`, `references/lib-commons/commons/net/http/withTelemetry.go:309`, `references/lib-commons/commons/server/shutdown.go:395` - telemetry middleware starts a process-global metrics collector that is not stopped before telemetry shutdown.
+- [code] `references/lib-commons/commons/server/shutdown.go:395`, `references/lib-commons/commons/server/shutdown.go:402` - shutdown order is inverted for gRPC, so telemetry is torn down before in-flight RPCs finish.
+- [code] `references/lib-commons/commons/net/http/health.go:92`, `references/lib-commons/commons/net/http/health.go:123` - dependencies with a circuit breaker but empty `ServiceName` are silently treated as healthy.
+- [business] `references/lib-commons/commons/server/shutdown.go:181`, `references/lib-commons/commons/server/shutdown.go:246`, `references/lib-commons/commons/server/shutdown.go:271`, `references/lib-commons/commons/server/shutdown.go:331` - `StartWithGracefulShutdownWithError()` cannot distinguish clean shutdown from bind/listen failure.
+- [business] `references/lib-commons/commons/net/http/health.go:87`, `references/lib-commons/commons/net/http/health.go:92`, `references/lib-commons/commons/net/http/health.go:118`, `references/lib-commons/commons/net/http/health.go:124` - `HealthWithDependencies` false-greens misconfigured dependencies when `ServiceName` is missing.
+- [business] `references/lib-commons/commons/net/http/pagination.go:133`, `references/lib-commons/commons/net/http/pagination.go:159` - `EncodeTimestampCursor` accepts `uuid.Nil` even though `DecodeTimestampCursor` rejects it.
+- [business] `references/lib-commons/commons/net/http/pagination.go:216`, `references/lib-commons/commons/net/http/pagination.go:244`, `references/lib-commons/commons/net/http/pagination.go:248` - `EncodeSortCursor` can emit cursors that `DecodeSortCursor` later rejects.
+- [test] `references/lib-commons/commons/net/http/proxy_test.go:794`, `references/lib-commons/commons/net/http/proxy_test.go:897`, `references/lib-commons/commons/net/http/proxy.go:280` - SSRF/DNS rebinding coverage is shallow and misses key `validateResolvedIPs` branches.
+- [test] `references/lib-commons/commons/net/http/withLogging_test.go:229`, `references/lib-commons/commons/net/http/withLogging_test.go:246`, `references/lib-commons/commons/net/http/withLogging_test.go:282` - logging middleware tests never inject/capture a logger or assert logged fields/body obfuscation.
+- [nil-safety] `references/lib-commons/commons/net/http/health.go:92`, `references/lib-commons/commons/net/http/health.go:93`, `references/lib-commons/commons/net/http/health.go:94`, `references/lib-commons/commons/net/http/health.go:103` - interface-nil checks on `CircuitBreaker` miss typed-nil managers and can panic.
+- [nil-safety] `references/lib-commons/commons/net/http/context.go:323`, `references/lib-commons/commons/net/http/context.go:327`, `references/lib-commons/commons/net/http/context.go:336`, `references/lib-commons/commons/net/http/context.go:340`, `references/lib-commons/commons/net/http/context.go:345`, `references/lib-commons/commons/net/http/context.go:349`, `references/lib-commons/commons/net/http/context.go:355`, `references/lib-commons/commons/net/http/context.go:359` - span helpers rely on `span == nil` and can still panic on typed-nil spans.
+- [nil-safety] `references/lib-commons/commons/server/shutdown.go:152`, `references/lib-commons/commons/server/shutdown.go:153` - `ServersStarted()` is not nil-safe; nil receivers panic and zero-value managers can return a nil channel that blocks forever.
+- [consequences] `references/lib-commons/commons/net/http/withTelemetry.go:33`, `references/lib-commons/commons/net/http/withTelemetry.go:249`, `references/lib-commons/commons/net/http/withTelemetry.go:263`, `references/lib-commons/commons/net/http/withTelemetry.go:279`, `references/lib-commons/commons/server/shutdown.go:395` - host-metrics collection is process-global and can leak a collector goroutine / publish against stale telemetry after shutdown.
+- [consequences] `references/lib-commons/commons/net/http/withTelemetry.go:252`, `references/lib-commons/commons/net/http/withTelemetry.go:263`, `references/lib-commons/commons/server/shutdown.go:76`, `references/lib-commons/commons/server/shutdown.go:87`, `references/lib-commons/commons/server/shutdown.go:99` - once the process-global collector starts, later telemetry instances never bind their own meter provider.
+- [consequences] `references/lib-commons/commons/server/shutdown.go:181`, `references/lib-commons/commons/server/shutdown.go:192`, `references/lib-commons/commons/server/shutdown.go:246`, `references/lib-commons/commons/server/shutdown.go:271`, `references/lib-commons/commons/server/shutdown.go:283`, `references/lib-commons/commons/server/shutdown.go:334` - startup/listen failures are logged but not returned to embedders/tests/orchestrators.
+
+### Medium
+- [code] `references/lib-commons/commons/net/http/pagination.go:27`, `references/lib-commons/commons/net/http/pagination.go:38`, `references/lib-commons/commons/net/http/pagination.go:47` - `ParsePagination` documentation says invalid values are coerced to defaults, but malformed numerics actually return errors.
+- [code] `references/lib-commons/commons/net/http/withTelemetry.go:33`, `references/lib-commons/commons/net/http/withTelemetry.go:240` - metrics collector is managed through package-level singleton state, reducing composability and test isolation.
+- [code] `references/lib-commons/commons/net/http/health.go:84`, `references/lib-commons/commons/net/http/health.go:124` - dependency statuses are keyed only by name without validation for empty or duplicate names.
+- [business] `references/lib-commons/commons/net/http/withLogging.go:286` - middleware only echoes a correlation ID if it generated it, not when the client supplied a valid request ID.
+- [business] `references/lib-commons/commons/net/http/pagination.go:27`, `references/lib-commons/commons/net/http/pagination.go:38`, `references/lib-commons/commons/net/http/pagination.go:47` - comment/behavior mismatch can push callers into the wrong error-handling path.
+- [security] `references/lib-commons/commons/net/http/withCORS.go:15`, `references/lib-commons/commons/net/http/withCORS.go:46`, `references/lib-commons/commons/net/http/withCORS.go:66`, `references/lib-commons/commons/net/http/withCORS.go:83` - `WithCORS` defaults `Access-Control-Allow-Origin` to `*` when no trusted origins are configured.
+- [security] `references/lib-commons/commons/net/http/handler.go:52`, `references/lib-commons/commons/net/http/handler.go:61`, `references/lib-commons/commons/net/http/handler.go:67` - `ExtractTokenFromHeader` accepts non-`Bearer` authorization headers and can return the auth scheme itself as a token fallback.
+- [security] `references/lib-commons/commons/net/http/withLogging.go:82`, `references/lib-commons/commons/net/http/withLogging.go:124`, `references/lib-commons/commons/net/http/withLogging.go:224` - raw `Referer` is logged without sanitization.
+- [security] `references/lib-commons/commons/net/http/health.go:33`, `references/lib-commons/commons/net/http/health.go:84`, `references/lib-commons/commons/net/http/health.go:127` - health responses expose dependency names, breaker state, and counters that aid reconnaissance.
+- [test] `references/lib-commons/commons/net/http/handler_test.go:19`, `references/lib-commons/commons/net/http/handler_test.go:26` - `File()` tests are brittle and barely verify served content or missing-file behavior.
+- [test] `references/lib-commons/commons/net/http/withTelemetry_test.go:35` - test setup mutates global OTEL state and does not restore it.
+- [test] `references/lib-commons/commons/server/shutdown_integration_test.go:337` - in-flight shutdown test relies on a fixed sleep and is timing-sensitive.
+- [test] `references/lib-commons/commons/net/http/health_integration_test.go:428` - circuit recovery is validated with a fixed sleep instead of polling.
+- [test] `references/lib-commons/commons/net/http/error_test.go:577` - method-not-allowed test accepts either `404` or `405`, weakening regression detection.
+- [nil-safety] `references/lib-commons/commons/net/http/withTelemetry.go:168`, `references/lib-commons/commons/net/http/withTelemetry.go:177`, `references/lib-commons/commons/net/http/withTelemetry.go:192` - gRPC interceptor assumes `info *grpc.UnaryServerInfo` is always non-nil.
+- [consequences] `references/lib-commons/commons/server/shutdown.go:395`, `references/lib-commons/commons/server/shutdown.go:402`, `references/lib-commons/commons/net/http/withTelemetry.go:177`, `references/lib-commons/commons/net/http/withTelemetry.go:178` - telemetry can be torn down before `grpc.Server.GracefulStop()` drains requests, losing final spans/metrics.
+- [consequences] `references/lib-commons/commons/net/http/withTelemetry.go:71`, `references/lib-commons/commons/net/http/withTelemetry.go:101`, `references/lib-commons/commons/net/http/withTelemetry.go:240`, `references/lib-commons/commons/net/http/withTelemetry.go:323` - `excludedRoutes` are ignored when `WithTelemetry` is called on a nil receiver with an explicit telemetry argument.
+
+### Low
+- [code] `references/lib-commons/commons/net/http/handler.go:61`, `references/lib-commons/commons/net/http/handler.go:63` - `ExtractTokenFromHeader` uses `strings.Split` and permissively accepts malformed authorization headers like `Bearer token extra`.
+- [business] `references/lib-commons/commons/net/http/handler.go:61`, `references/lib-commons/commons/net/http/handler.go:64` - bearer-token parsing is less tolerant than common implementations for flexible whitespace.
+- [security] `references/lib-commons/commons/net/http/handler.go:23` - `Version` publicly exposes the exact deployed version.
+
+## 3. Tenant Manager Domain
+
+### Critical
+- [security] `references/lib-commons/commons/tenant-manager/middleware/tenant.go:116`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:129`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:147`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:336`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:340`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:350` - unverified JWT claims are used to choose tenant databases, enabling cross-tenant DB resolution if another auth path merely sets `c.Locals("user_id")`.
+- [nil-safety] `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:278`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:805`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:1012` - `Register` accepts a nil `HandlerFunc`, which later panics on first message delivery.
+- [nil-safety] `references/lib-commons/commons/tenant-manager/client/client.go:130`, `references/lib-commons/commons/tenant-manager/client/client.go:281`, `references/lib-commons/commons/tenant-manager/client/client.go:367`, `references/lib-commons/commons/tenant-manager/client/client.go:487`, `references/lib-commons/commons/tenant-manager/cache/memory.go:61`, `references/lib-commons/commons/tenant-manager/cache/memory.go:87`, `references/lib-commons/commons/tenant-manager/cache/memory.go:104`, `references/lib-commons/commons/tenant-manager/cache/memory.go:114` - `WithCache` accepts typed-nil caches and later panics on method calls.
+- [nil-safety] `references/lib-commons/commons/tenant-manager/postgres/manager.go:826`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:944` - `CreateDirectConnection` dereferences a nil `*core.PostgreSQLConfig`.
+
+### High
+- [code] `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:214`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:1091`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:1145`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:763` - requests can spawn long-lived background consumers for unknown/suspended tenants before tenant resolution succeeds.
+- [code] `references/lib-commons/commons/tenant-manager/client/client.go:323`, `references/lib-commons/commons/tenant-manager/client/client.go:337`, `references/lib-commons/commons/tenant-manager/client/client.go:345` - 403 handling only returns `*core.TenantSuspendedError` when the response body contains a parseable JSON `status`, otherwise it degrades to a generic error.
+- [business] `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:214`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:219`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:1102`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:1128`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:1145` - middleware can start consumers for nonexistent, purged, or unauthorized tenants.
+- [business] `references/lib-commons/commons/tenant-manager/rabbitmq/manager.go:185`, `references/lib-commons/commons/tenant-manager/rabbitmq/manager.go:190`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:869`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:876`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:883` - tenant-manager RabbitMQ connection creation wraps suspension/purge errors as generic retryable failures, causing infinite reconnect loops.
+- [business] `references/lib-commons/commons/tenant-manager/middleware/tenant.go:173`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:189`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:207`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:223`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:479`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:495`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:504` - `TenantMiddleware` and `MultiPoolMiddleware` map the same domain errors to different HTTP status codes.
+- [security] `references/lib-commons/commons/tenant-manager/rabbitmq/manager.go:201`, `references/lib-commons/commons/tenant-manager/rabbitmq/manager.go:205`, `references/lib-commons/commons/tenant-manager/rabbitmq/manager.go:398`, `references/lib-commons/commons/tenant-manager/rabbitmq/manager.go:403` - RabbitMQ connections are hard-wired to plaintext `amqp://` with no TLS/`amqps` path.
+- [security] `references/lib-commons/commons/tenant-manager/client/client.go:147`, `references/lib-commons/commons/tenant-manager/client/client.go:161`, `references/lib-commons/commons/tenant-manager/client/client.go:172`, `references/lib-commons/commons/tenant-manager/client/client.go:433`, `references/lib-commons/commons/tenant-manager/client/client.go:547` - tenant-manager client accepts any URL scheme/host and permits `http://`, so tenant credentials can be fetched over cleartext transport.
+- [test] `references/lib-commons/commons/tenant-manager/middleware/tenant.go:116`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:156`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:173`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:207`, `references/lib-commons/commons/tenant-manager/middleware/tenant_test.go:190` - middleware tests miss fail-closed auth enforcement, invalid `tenantId` format, suspended-tenant mapping, and PG/Mongo resolution failures.
+- [test] `references/lib-commons/commons/tenant-manager/client/client.go:276`, `references/lib-commons/commons/tenant-manager/client/client.go:361`, `references/lib-commons/commons/tenant-manager/client/client.go:480`, `references/lib-commons/commons/tenant-manager/client/client_test.go:152` - client cache tests miss cache-hit, malformed cached JSON, `WithSkipCache`, invalidation, and `Close` paths.
+- [consequences] `references/lib-commons/commons/tenant-manager/client/client.go:323`, `references/lib-commons/commons/tenant-manager/client/client.go:337`, `references/lib-commons/commons/tenant-manager/client/client.go:345`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:381`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:386`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:488` - degraded 403 handling means suspended/purged tenants can be misclassified as generic connection failures and surfaced as 5xx/503.
+- [consequences] `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:111`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:218`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:275`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:282`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:392` - `WithCrossModuleInjection` promises resolution for all registered routes, but only injects PostgreSQL after matched-route PG resolution.
+
+### Medium
+- [code] `references/lib-commons/commons/tenant-manager/postgres/manager.go:633`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:646`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:878`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:896`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:900` - removing tenant `connectionSettings` does not restore defaults; existing pools keep stale limits until recreated.
+- [code] `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:253`, `references/lib-commons/commons/tenant-manager/client/client.go:183`, `references/lib-commons/commons/tenant-manager/cache/memory.go:47`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:1174` - internal fallback `pmClient` allocates an `InMemoryCache` cleanup goroutine that `MultiTenantConsumer.Close` never stops.
+- [business] `references/lib-commons/commons/tenant-manager/core/errors.go:15`, `references/lib-commons/commons/tenant-manager/client/client.go:323`, `references/lib-commons/commons/tenant-manager/client/client.go:337`, `references/lib-commons/commons/tenant-manager/client/client.go:345` - `ErrTenantServiceAccessDenied` is documented as the 403 sentinel but is never actually returned or wrapped.
+- [security] `references/lib-commons/commons/tenant-manager/postgres/manager.go:827`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:829`, `references/lib-commons/commons/tenant-manager/postgres/manager.go:843` - PostgreSQL DSNs default to `sslmode=prefer`, allowing silent non-TLS downgrade.
+- [security] `references/lib-commons/commons/tenant-manager/core/types.go:17`, `references/lib-commons/commons/tenant-manager/core/types.go:29`, `references/lib-commons/commons/tenant-manager/core/types.go:42`, `references/lib-commons/commons/tenant-manager/client/client.go:366`, `references/lib-commons/commons/tenant-manager/client/client.go:367` - full tenant configs, including plaintext DB and RabbitMQ passwords, are cached wholesale for the default 1h TTL.
+- [test] `references/lib-commons/commons/tenant-manager/client/client_test.go:423`, `references/lib-commons/commons/tenant-manager/client/client_test.go:462` - half-open circuit-breaker tests rely on `time.Sleep(cbTimeout + 10*time.Millisecond)` and are timing-sensitive.
+- [test] `references/lib-commons/commons/tenant-manager/consumer/multi_tenant_test.go:535` - lazy sync test waits a fixed `3 * syncInterval` instead of polling.
+- [test] `references/lib-commons/commons/tenant-manager/postgres/manager_test.go:1033`, `references/lib-commons/commons/tenant-manager/postgres/manager_test.go:1191`, `references/lib-commons/commons/tenant-manager/postgres/manager_test.go:1249` - async revalidation tests infer goroutine completion with fixed sleeps.
+- [test] `references/lib-commons/commons/tenant-manager/middleware/tenant_test.go:207`, `references/lib-commons/commons/tenant-manager/middleware/tenant_test.go:232`, `references/lib-commons/commons/tenant-manager/middleware/tenant_test.go:262` - unauthorized-path assertions only check status code plus a generic `Unauthorized` substring instead of structured payload.
+- [consequences] `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:417`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:427`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:434`, `references/lib-commons/commons/tenant-manager/core/context.go:108` - cross-module resolution failures are only logged and then dropped, so downstream code later fails with `ErrTenantContextRequired` and loses the real cause.
+- [consequences] `references/lib-commons/commons/tenant-manager/middleware/tenant.go:116`, `references/lib-commons/commons/tenant-manager/middleware/tenant.go:238`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:336` - both middleware variants hard-code upstream auth to `c.Locals("user_id")`, making integration brittle with alternative auth middleware.
+
+### Low
+- [code] `references/lib-commons/commons/tenant-manager/client/client.go:287`, `references/lib-commons/commons/tenant-manager/client/client.go:296`, `references/lib-commons/commons/tenant-manager/client/client.go:301` - corrupt cached tenant config JSON is logged and refetched, but the bad cache entry is left in place.
+- [code] `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:66`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:299`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:302` - route selection is “first prefix wins” instead of longest-prefix matching.
+- [business] `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:456`, `references/lib-commons/commons/tenant-manager/consumer/multi_tenant.go:540` - `identifyNewTenants` repeatedly logs known-but-not-yet-started lazy tenants as newly discovered.
+- [test] `references/lib-commons/commons/tenant-manager/cache/memory_test.go:224`, `references/lib-commons/commons/tenant-manager/cache/memory_test.go:226`, `references/lib-commons/commons/tenant-manager/cache/memory_test.go:228` - concurrent cache test discards returned errors.
+- [test] `references/lib-commons/commons/tenant-manager/client/client_test.go:417`, `references/lib-commons/commons/tenant-manager/client/client_test.go:456`, `references/lib-commons/commons/tenant-manager/client/client_test.go:634`, `references/lib-commons/commons/tenant-manager/client/client_test.go:635`, `references/lib-commons/commons/tenant-manager/client/client_test.go:636` - several circuit-breaker setup calls intentionally ignore returned errors.
+- [consequences] `references/lib-commons/commons/tenant-manager/core/errors.go:13`, `references/lib-commons/commons/tenant-manager/client/client.go:329`, `references/lib-commons/commons/tenant-manager/client/client.go:345`, `references/lib-commons/commons/tenant-manager/middleware/multi_pool.go:495` - `ErrTenantServiceAccessDenied` is effectively dead contract surface.
+
+## 4. Messaging + Outbox
+
+### Critical
+- [consequences] `references/lib-commons/commons/outbox/postgres/schema_resolver.go:164`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:167`, `references/lib-commons/commons/outbox/dispatcher.go:461`, `references/lib-commons/commons/outbox/dispatcher.go:481`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:118` - `DiscoverTenants()` can inject a default tenant schema that is absent, and `ApplyTenant()` then drives unqualified queries against `public.outbox_events`, causing cross-tenant reads/writes.
+
+### High
+- [code] `references/lib-commons/commons/outbox/postgres/schema_resolver.go:141` - `DiscoverTenants` enumerates every UUID-shaped schema without checking whether it actually contains the outbox table.
+- [code] `references/lib-commons/commons/outbox/postgres/schema_resolver.go:110`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:164` - discovered “default tenant” dispatch cycles can run against the connection’s default `search_path` instead of the configured schema.
+- [code] `references/lib-commons/commons/rabbitmq/rabbitmq.go:925` - `AllowInsecureHealthCheck` disables host allowlist enforcement even when basic-auth credentials are attached.
+- [business] `references/lib-commons/commons/rabbitmq/rabbitmq.go:211`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:222`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:245`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:255` - reconnect failures leave stale `Connected`/`Connection`/`Channel` state visible after a failed reconnect attempt.
+- [business] `references/lib-commons/commons/rabbitmq/publisher.go:724`, `references/lib-commons/commons/rabbitmq/publisher.go:756`, `references/lib-commons/commons/rabbitmq/publisher.go:813` - `Reconnect` restores the channel but never resets publisher health to `HealthStateConnected`.
+- [security] `references/lib-commons/commons/rabbitmq/rabbitmq.go:79`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:552`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:557`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:922`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:940` - health-check client allows any `HealthCheckURL` host when no allowlist is configured and strict mode is off, leaving SSRF open by default.
+- [test] `references/lib-commons/commons/outbox/postgres/repository.go:617`, `references/lib-commons/commons/outbox/postgres/repository_integration_test.go:240` - `ListFailedForRetry` has no direct tests for the core retry-selection query semantics.
+- [test] `references/lib-commons/commons/outbox/postgres/column_resolver.go:120`, `references/lib-commons/commons/outbox/postgres/column_resolver.go:131`, `references/lib-commons/commons/outbox/postgres/column_resolver_test.go:56`, `references/lib-commons/commons/outbox/postgres/repository_integration_test.go:429` - tenant discovery cache-miss, `singleflight`, and timeout behavior are effectively untested.
+- [test] `references/lib-commons/commons/rabbitmq/publisher.go:606`, `references/lib-commons/commons/rabbitmq/publisher.go:611`, `references/lib-commons/commons/rabbitmq/publisher_test.go:221`, `references/lib-commons/commons/rabbitmq/publisher_test.go:678` - timeout/cancel tests assert only the returned error and do not verify the critical invalidation side effect.
+- [nil-safety] `references/lib-commons/commons/rabbitmq/rabbitmq.go:837`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:209`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:371`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:543`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:814` - `logger()` only checks interface-nil and can return a typed-nil logger that later panics.
+- [consequences] `references/lib-commons/commons/outbox/postgres/schema_resolver.go:110`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:112`, `references/lib-commons/commons/outbox/postgres/repository.go:1243`, `references/lib-commons/commons/outbox/postgres/repository.go:1278` - combining `WithAllowEmptyTenant()` with `WithDefaultTenantID(...)` routes default-tenant repository calls to `public`.
+- [consequences] `references/lib-commons/commons/rabbitmq/publisher.go:606`, `references/lib-commons/commons/rabbitmq/publisher.go:611`, `references/lib-commons/commons/rabbitmq/publisher.go:580`, `references/lib-commons/commons/rabbitmq/publisher.go:588` - one confirm timeout or canceled publish context permanently closes the publisher unless the caller rebuilds it.
+
+### Medium
+- [code] `references/lib-commons/commons/rabbitmq/rabbitmq.go:209`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:213`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:371`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:543` - context-aware API drops caller context for operational logging by hardcoding `context.Background()`.
+- [business] `references/lib-commons/commons/outbox/tenant.go:35`, `references/lib-commons/commons/outbox/tenant.go:50`, `references/lib-commons/commons/outbox/tenant.go:59`, `references/lib-commons/commons/outbox/tenant.go:67` - whitespace-wrapped tenant IDs are silently discarded instead of trimmed or rejected.
+- [security] `references/lib-commons/commons/rabbitmq/dlq.go:15`, `references/lib-commons/commons/rabbitmq/dlq.go:100`, `references/lib-commons/commons/rabbitmq/dlq.go:106`, `references/lib-commons/commons/rabbitmq/dlq.go:107`, `references/lib-commons/commons/rabbitmq/dlq.go:160`, `references/lib-commons/commons/rabbitmq/dlq.go:171` - default DLQ topology uses `#` with no TTL or max-length cap, allowing indefinite poison-message retention.
+- [test] `references/lib-commons/commons/rabbitmq/rabbitmq_integration_test.go:102`, `references/lib-commons/commons/rabbitmq/rabbitmq_integration_test.go:122`, `references/lib-commons/commons/rabbitmq/rabbitmq_integration_test.go:151`, `references/lib-commons/commons/rabbitmq/rabbitmq_integration_test.go:172`, `references/lib-commons/commons/rabbitmq/trace_propagation_integration_test.go:86`, `references/lib-commons/commons/rabbitmq/trace_propagation_integration_test.go:188`, `references/lib-commons/commons/rabbitmq/trace_propagation_integration_test.go:260`, `references/lib-commons/commons/rabbitmq/trace_propagation_integration_test.go:327`, `references/lib-commons/commons/rabbitmq/trace_propagation_integration_test.go:344`, `references/lib-commons/commons/rabbitmq/trace_propagation_integration_test.go:409` - multiple integration tests ignore teardown errors.
+- [test] `references/lib-commons/commons/outbox/event_test.go:33`, `references/lib-commons/commons/outbox/event_test.go:37`, `references/lib-commons/commons/outbox/event_test.go:42`, `references/lib-commons/commons/outbox/event_test.go:47`, `references/lib-commons/commons/outbox/event_test.go:58`, `references/lib-commons/commons/outbox/event_test.go:63` - many validation branches are packed into one test and rely on substring matching.
+- [test] `references/lib-commons/commons/rabbitmq/rabbitmq_test.go:696`, `references/lib-commons/commons/rabbitmq/rabbitmq_test.go:713`, `references/lib-commons/commons/rabbitmq/rabbitmq_test.go:731`, `references/lib-commons/commons/rabbitmq/rabbitmq_test.go:766` - health-check error-path tests use only generic `assert.Error` / `assert.False` assertions.
+- [consequences] `references/lib-commons/commons/rabbitmq/publisher.go:756`, `references/lib-commons/commons/rabbitmq/publisher.go:765`, `references/lib-commons/commons/rabbitmq/publisher.go:814` - `Reconnect()` never restores `health` to `HealthStateConnected`, so health probes can keep treating a recovered publisher as unhealthy.
+
+### Low
+- [security] `references/lib-commons/commons/outbox/postgres/schema_resolver.go:36`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:40`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:102`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:107`, `references/lib-commons/commons/outbox/postgres/schema_resolver.go:110` - `WithAllowEmptyTenant` makes empty tenant ID a silent no-op and can accidentally reuse an active `search_path`.
+- [test] `references/lib-commons/commons/outbox/postgres/repository_integration_test.go:231` - non-priority fixture event is intentionally ignored, so the test only proves the positive match.
+- [test] `references/lib-commons/commons/rabbitmq/trace_propagation_integration_test.go:482` - multiple-message trace test hard-codes FIFO ordering instead of focusing only on trace propagation.
+- [consequences] `references/lib-commons/commons/rabbitmq/rabbitmq.go:151`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:177`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:211`, `references/lib-commons/commons/rabbitmq/rabbitmq.go:255` - `Connect()` opens a new AMQP connection/channel before checking whether an existing live connection is already installed.
+
+## 5. Data Connectors
+
+### High
+- [code] `references/lib-commons/commons/redis/redis.go:426`, `references/lib-commons/commons/redis/redis.go:432`, `references/lib-commons/commons/redis/redis.go:451` - reconnect logic closes the current client before replacement is created and pinged, so a failed reconnect can discard a healthy client and turn recovery into outage.
+- [code] `references/lib-commons/commons/redis/lock.go:372`, `references/lib-commons/commons/redis/lock.go:374`, `references/lib-commons/commons/redis/lock.go:378` - `TryLock` treats any error containing `failed to acquire lock` as normal contention, masking real infrastructure faults.
+- [business] `references/lib-commons/commons/postgres/postgres.go:760`, `references/lib-commons/commons/postgres/postgres.go:850`, `references/lib-commons/commons/postgres/postgres.go:857` - missing migration files are treated as a warning and `Migrator.Up()` returns `nil`, allowing services to boot against unmigrated schemas.
+- [business] `references/lib-commons/commons/redis/lock.go:299`, `references/lib-commons/commons/redis/lock.go:310`, `references/lib-commons/commons/redis/lock.go:319` - `WithLockOptions()` unlocks with the caller context; if it is already canceled, unlock fails and the method still returns success while the lock remains held until TTL expiry.
+- [business] `references/lib-commons/commons/redis/redis.go:911`, `references/lib-commons/commons/redis/redis.go:830`, `references/lib-commons/commons/redis/redis.go:1047` - `AllowLegacyMinVersion=true` is accepted and logged as retained, but runtime TLS construction still forces TLS 1.2 unless exactly TLS 1.3.
+- [test] `references/lib-commons/commons/redis/resilience_integration_test.go:195`, `references/lib-commons/commons/redis/resilience_integration_test.go:223`, `references/lib-commons/commons/backoff/backoff.go:83` - Redis backoff resilience test is nondeterministic because full jitter can legitimately produce repeated zero delays.
+- [test] `references/lib-commons/commons/postgres/resilience_integration_test.go:208`, `references/lib-commons/commons/postgres/resilience_integration_test.go:236`, `references/lib-commons/commons/backoff/backoff.go:83` - Postgres backoff resilience test has the same full-jitter flake vector.
+- [test] `references/lib-commons/commons/mongo/mongo.go:358`, `references/lib-commons/commons/mongo/mongo_integration_test.go:181` - Mongo reconnect-storm protection in `ResolveClient` is effectively untested.
+- [consequences] `references/lib-commons/commons/postgres/postgres.go:760`, `references/lib-commons/commons/postgres/postgres.go:763`, `references/lib-commons/commons/postgres/postgres.go:850`, `references/lib-commons/commons/postgres/postgres.go:857` - missing migrations become warn-and-skip behavior across consuming services.
+- [consequences] `references/lib-commons/commons/postgres/postgres.go:359`, `references/lib-commons/commons/postgres/postgres.go:630`, `references/lib-commons/commons/postgres/postgres.go:679`, `references/lib-commons/commons/postgres/postgres.go:693` - `SanitizedError` wrappers drop unwrap semantics, so `errors.Is` / `errors.As` stop matching driver/network causes.
+- [consequences] `references/lib-commons/commons/redis/redis.go:811`, `references/lib-commons/commons/postgres/postgres.go:834`, `references/lib-commons/commons/redis/redis.go:1047`, `references/lib-commons/commons/redis/redis.go:1052` - explicit legacy TLS compatibility claims do not match actual runtime behavior, breaking integrations that rely on them.
+
+### Medium
+- [code] `references/lib-commons/commons/postgres/postgres.go:841`, `references/lib-commons/commons/postgres/postgres.go:850`, `references/lib-commons/commons/postgres/postgres.go:860` - `migrate.Migrate` created by `migrate.NewWithDatabaseInstance` is never closed.
+- [code] `references/lib-commons/commons/redis/redis.go:176`, `references/lib-commons/commons/redis/redis.go:378`, `references/lib-commons/commons/redis/redis.go:393` - `Status` / `IsConnected` expose a cached connected flag instead of probing real liveness.
+- [code] `references/lib-commons/commons/redis/lock_interface.go:26`, `references/lib-commons/commons/redis/lock_interface.go:45`, `references/lib-commons/commons/redis/lock_interface.go:61` - exported `LockManager` abstraction increases API surface with little demonstrated production value.
+- [business] `references/lib-commons/commons/mongo/connection_string.go:122`, `references/lib-commons/commons/mongo/connection_string.go:128` - `BuildURI()` turns username-only auth into `user:@`, changing semantics for external-auth flows.
+- [security] `references/lib-commons/commons/mongo/mongo.go:272`, `references/lib-commons/commons/mongo/mongo.go:274`, `references/lib-commons/commons/mongo/mongo.go:276`, `references/lib-commons/commons/mongo/mongo.go:283`, `references/lib-commons/commons/mongo/mongo.go:288`, `references/lib-commons/commons/mongo/mongo.go:290` - Mongo connection and ping failures are logged/returned with raw driver errors, which may include URI or auth details.
+- [security] `references/lib-commons/commons/redis/redis.go:120`, `references/lib-commons/commons/redis/redis.go:123`, `references/lib-commons/commons/redis/redis.go:811`, `references/lib-commons/commons/redis/redis.go:830`, `references/lib-commons/commons/redis/redis.go:834`, `references/lib-commons/commons/redis/redis.go:900`, `references/lib-commons/commons/redis/redis.go:911`, `references/lib-commons/commons/redis/redis.go:912` - Redis explicitly allows TLS versions below 1.2 when `AllowLegacyMinVersion=true`.
+- [test] `references/lib-commons/commons/mongo/mongo_test.go:312`, `references/lib-commons/commons/mongo/mongo.go:256` - config propagation test only verifies captured options, not that they were applied.
+- [test] `references/lib-commons/commons/postgres/postgres_test.go:1416`, `references/lib-commons/commons/postgres/postgres.go:175` - `TestValidateDSN` misses malformed URL cases.
+- [test] `references/lib-commons/commons/postgres/postgres_test.go:1448`, `references/lib-commons/commons/postgres/postgres.go:191` - insecure DSN warning test only asserts “does not panic”.
+- [consequences] `references/lib-commons/commons/redis/lock.go:366`, `references/lib-commons/commons/redis/lock.go:372`, `references/lib-commons/commons/redis/lock.go:376`, `references/lib-commons/commons/redis/lock.go:378` - `TryLock` collapses true contention and backend/quorum failures into the same `(nil, false, nil)` outcome.
+- [consequences] `references/lib-commons/commons/mongo/connection_string.go:111`, `references/lib-commons/commons/mongo/connection_string.go:114`, `references/lib-commons/commons/mongo/connection_string.go:119` - `BuildURI` blindly concatenates raw IPv6 literals and can emit invalid Mongo URIs.
+
+### Low
+- [code] `references/lib-commons/commons/mongo/connection_string.go:34`, `references/lib-commons/commons/mongo/connection_string.go:111` - `BuildURI` claims canonical validation but intentionally defers host validation downstream.
+- [security] `references/lib-commons/commons/postgres/postgres.go:151`, `references/lib-commons/commons/postgres/postgres.go:181`, `references/lib-commons/commons/postgres/postgres.go:184`, `references/lib-commons/commons/postgres/postgres.go:191`, `references/lib-commons/commons/postgres/postgres.go:319`, `references/lib-commons/commons/postgres/postgres.go:320` - Postgres allows `sslmode=disable` with only a warning.
+- [security] `references/lib-commons/commons/mongo/mongo.go:91`, `references/lib-commons/commons/mongo/mongo.go:104`, `references/lib-commons/commons/mongo/mongo.go:263`, `references/lib-commons/commons/mongo/mongo.go:269`, `references/lib-commons/commons/mongo/mongo.go:295`, `references/lib-commons/commons/mongo/mongo.go:297` - Mongo connects without TLS whenever the URI/TLS config does not force it, only warning afterward.
+- [security] `references/lib-commons/commons/redis/redis.go:475`, `references/lib-commons/commons/redis/redis.go:476`, `references/lib-commons/commons/redis/redis.go:955`, `references/lib-commons/commons/redis/redis.go:965` - Redis allows non-TLS operation for non-GCP-IAM modes with only a warning.
+- [test] `references/lib-commons/commons/redis/lock_test.go:650`, `references/lib-commons/commons/redis/lock.go:274` - tracing/context propagation test for `WithLock` only checks that callback context is non-nil.
+- [consequences] `references/lib-commons/commons/mongo/mongo.go:660`, `references/lib-commons/commons/mongo/mongo.go:661`, `references/lib-commons/commons/mongo/mongo.go:662` - TLS detection for warning suppression is case-sensitive and can emit misleading warnings.
+
+## 6. Resilience + Execution Safety
+
+### Critical
+- [nil-safety] `references/lib-commons/commons/circuitbreaker/manager.go:145`, `references/lib-commons/commons/circuitbreaker/types.go:117` - `Execute` forwards `fn` without a nil guard, so nil callbacks panic.
+- [nil-safety] `references/lib-commons/commons/backoff/backoff.go:106` - `WaitContext` calls `ctx.Done()` unconditionally and panics on nil context.
+
+### High
+- [code] `references/lib-commons/commons/circuitbreaker/manager.go:307`, `references/lib-commons/commons/circuitbreaker/manager.go:310`, `references/lib-commons/commons/circuitbreaker/types.go:168` - listener timeout is ineffective because derived context is never passed to `OnStateChange` and the listener interface has no context parameter.
+- [code] `references/lib-commons/commons/runtime/tracing.go:72`, `references/lib-commons/commons/runtime/tracing.go:84`, `references/lib-commons/commons/runtime/tracing.go:95` - panic tracing writes raw panic values and full stack traces into span events with no redaction/size cap.
+- [code] `references/lib-commons/commons/circuitbreaker/types.go:64`, `references/lib-commons/commons/circuitbreaker/types.go:69`, `references/lib-commons/commons/circuitbreaker/types.go:73` - `Config.Validate` does not reject negative `Interval` or `Timeout` values.
+- [business] `references/lib-commons/commons/circuitbreaker/types.go:35`, `references/lib-commons/commons/circuitbreaker/manager.go:206`, `references/lib-commons/commons/circuitbreaker/healthchecker.go:159`, `references/lib-commons/commons/circuitbreaker/healthchecker.go:236` - `IsHealthy` is documented as “not open” but implemented as “closed only”, so half-open breakers look unhealthy and can be reset prematurely.
+- [business] `references/lib-commons/commons/circuitbreaker/manager.go:283`, `references/lib-commons/commons/circuitbreaker/manager.go:307` - listener timeout comments/behavior do not match reality.
+- [security] `references/lib-commons/commons/runtime/tracing.go:69-75`, `references/lib-commons/commons/runtime/tracing.go:84-87` - recovered panics are written into OTEL as raw `panic.value`, full `panic.stack`, and `RecordError(...)` payloads.
+- [test] `references/lib-commons/commons/runtime/metrics.go:51`, `references/lib-commons/commons/runtime/metrics.go:86`, `references/lib-commons/commons/runtime/metrics.go:100` - panic-metrics init/reset/recording paths are effectively untested.
+- [test] `references/lib-commons/commons/circuitbreaker/manager.go:154`, `references/lib-commons/commons/circuitbreaker/manager.go:156`, `references/lib-commons/commons/circuitbreaker/manager.go:158` - no test covers half-open `ErrTooManyRequests` rejection or its metric label.
+- [test] `references/lib-commons/commons/circuitbreaker/types.go:73` - config validation lacks negative tests for `MinRequests > 0` with `FailureRatio <= 0`.
+- [nil-safety] `references/lib-commons/commons/circuitbreaker/manager.go:244`, `references/lib-commons/commons/circuitbreaker/manager.go:310` - `RegisterStateChangeListener` accepts typed-nil listeners and can later panic during notification.
+- [nil-safety] `references/lib-commons/commons/runtime/error_reporter.go:149`, `references/lib-commons/commons/runtime/error_reporter.go:170` - typed-nil `error` values can reintroduce panic risk inside panic-reporting code.
+- [nil-safety] `references/lib-commons/commons/errgroup/errgroup.go:61`, `references/lib-commons/commons/errgroup/errgroup.go:90` - `Go` and `Wait` assume non-nil `*Group` and panic on nil receivers.
+- [consequences] `references/lib-commons/commons/circuitbreaker/manager.go:103`, `references/lib-commons/commons/circuitbreaker/manager.go:108`, `references/lib-commons/commons/circuitbreaker/manager.go:120`, `references/lib-commons/commons/circuitbreaker/manager.go:128` - `GetOrCreate` keys breakers only by `serviceName`, so later calls with different config silently reuse stale breaker settings.
+- [consequences] `references/lib-commons/commons/runtime/error_reporter.go:108`, `references/lib-commons/commons/runtime/error_reporter.go:120`, `references/lib-commons/commons/runtime/recover.go:53`, `references/lib-commons/commons/runtime/recover.go:86`, `references/lib-commons/commons/runtime/recover.go:139`, `references/lib-commons/commons/runtime/recover.go:216`, `references/lib-commons/commons/runtime/tracing.go:73`, `references/lib-commons/commons/runtime/tracing.go:74`, `references/lib-commons/commons/runtime/tracing.go:84`, `references/lib-commons/commons/circuitbreaker/manager.go:287`, `references/lib-commons/commons/circuitbreaker/healthchecker.go:99`, `references/lib-commons/commons/errgroup/errgroup.go:64` - `SetProductionMode(true)` redacts the external error-reporter path but not panic logs/spans in recovery flows.
+
+### Medium
+- [code] `references/lib-commons/commons/assert/predicates.go:316`, `references/lib-commons/commons/assert/predicates.go:318`, `references/lib-commons/commons/assert/predicates.go:333` - `TransactionOperationsMatch` checks subset inclusion, but its name/doc imply full matching.
+- [code] `references/lib-commons/commons/assert/assert.go:309`, `references/lib-commons/commons/assert/assert.go:311`, `references/lib-commons/commons/assert/assert.go:315` - assertion failures are emitted as a single multiline string instead of structured fields.
+- [business] `references/lib-commons/commons/circuitbreaker/types.go:62` - `Config.Validate` accepts nonsensical negative durations.
+- [security] `references/lib-commons/commons/runtime/recover.go:156-167` - panic recovery logs raw panic values and full stack traces on every recovery path.
+- [security] `references/lib-commons/commons/assert/assert.go:141-155`, `references/lib-commons/commons/assert/assert.go:188-199`, `references/lib-commons/commons/assert/assert.go:230-243`, `references/lib-commons/commons/assert/assert.go:290-312` - assertion failures log caller-supplied key/value data, `err.Error()`, and stack traces by default, making secret/PII exposure easy.
+- [security] `references/lib-commons/commons/circuitbreaker/healthchecker.go:169-180`, `references/lib-commons/commons/circuitbreaker/healthchecker.go:244-253` - health-check failures are logged verbatim and may include connection strings, usernames, or hostnames.
+- [test] `references/lib-commons/commons/backoff/backoff.go:48`, `references/lib-commons/commons/backoff/backoff.go:50`, `references/lib-commons/commons/backoff/backoff.go:71`, `references/lib-commons/commons/backoff/backoff.go:73` - fallback path for crypto-rand failure is untested.
+- [test] `references/lib-commons/commons/circuitbreaker/types.go:113`, `references/lib-commons/commons/circuitbreaker/types.go:122`, `references/lib-commons/commons/circuitbreaker/types.go:131` - nil/uninitialized `CircuitBreaker` guard paths are uncovered.
+- [test] `references/lib-commons/commons/assert/assert_extended_test.go:294`, `references/lib-commons/commons/assert/assert_extended_test.go:305` - metric-recording test only proves “no panic” and never asserts that a metric was emitted.
+- [test] `references/lib-commons/commons/errgroup/errgroup_test.go:61`, `references/lib-commons/commons/errgroup/errgroup_test.go:63`, `references/lib-commons/commons/errgroup/errgroup_test.go:156`, `references/lib-commons/commons/errgroup/errgroup_test.go:158` - tests use `time.Sleep(50 * time.Millisecond)` to force goroutine ordering.
+- [test] `references/lib-commons/commons/assert/predicates_test.go:205`, `references/lib-commons/commons/assert/predicates_test.go:225`, `references/lib-commons/commons/assert/predicates_test.go:228` - `TestDateNotInFuture` depends on `time.Now()` and a 1 ms tolerance.
+- [nil-safety] `references/lib-commons/commons/runtime/goroutine.go:28`, `references/lib-commons/commons/runtime/goroutine.go:66` - `SafeGo` and `SafeGoWithContextAndComponent` invoke `fn` without validating it.
+- [nil-safety] `references/lib-commons/commons/circuitbreaker/manager.go:74` - `NewManager` executes each `ManagerOption` blindly, so a nil option panics during construction.
+- [consequences] `references/lib-commons/commons/circuitbreaker/manager.go:287`, `references/lib-commons/commons/circuitbreaker/manager.go:307`, `references/lib-commons/commons/circuitbreaker/manager.go:310`, `references/lib-commons/commons/circuitbreaker/types.go:170` - slow/blocking listeners leak one goroutine per state transition because the advertised timeout is ineffective.
+- [consequences] `references/lib-commons/commons/circuitbreaker/healthchecker.go:161`, `references/lib-commons/commons/circuitbreaker/healthchecker.go:176`, `references/lib-commons/commons/circuitbreaker/manager.go:179`, `references/lib-commons/commons/circuitbreaker/manager.go:222` - health checker behavior depends on registration order and can probe forever against missing breakers.
+
+### Low
+- [business] `references/lib-commons/commons/safe/regex.go:119` - `FindString` comment says invalid patterns return empty string, but implementation returns `("", err)`.
+- [security] `references/lib-commons/commons/assert/assert.go:230-243` - stack-trace emission is opt-out rather than opt-in.
+- [test] `references/lib-commons/commons/assert/assert_extended_test.go:22`, `references/lib-commons/commons/assert/assert_extended_test.go:26` - helper panics on setup failure instead of failing the test normally.
+- [test] `references/lib-commons/commons/circuitbreaker/manager_test.go:354`, `references/lib-commons/commons/circuitbreaker/manager_test.go:368` - existing-breaker test only compares state and not instance identity.
+- [consequences] `references/lib-commons/commons/safe/regex.go:40`, `references/lib-commons/commons/safe/regex.go:41`, `references/lib-commons/commons/safe/regex.go:44` - once the regex cache reaches 1024 entries, adding one more pattern flushes the entire shared cache.
+
+## 7. Logging Stack
+
+### Critical
+- [nil-safety] `references/lib-commons/commons/zap/zap.go:166-167` - `(*Logger).Level()` dereferences `l.atomicLevel` without the nil-safe `must()` pattern used elsewhere.
+
+### High
+- [code] `references/lib-commons/commons/log/go_logger.go:135`, `references/lib-commons/commons/log/go_logger.go:145` - `GoLogger` only sanitizes plain `string`, `error`, and `fmt.Stringer`; composite values passed through `log.Any(...)` can still emit raw newlines and forge multi-line entries.
+- [code] `references/lib-commons/commons/zap/injector.go:114`, `references/lib-commons/commons/zap/injector.go:133`, `references/lib-commons/commons/zap/zap.go:44`, `references/lib-commons/commons/zap/zap.go:141` - console encoding permits raw newline messages and bypasses single-entry-per-line assumptions in non-JSON mode.
+- [business] `references/lib-commons/commons/log/go_logger.go:135`, `references/lib-commons/commons/log/go_logger.go:145-155` - `GoLogger`’s injection protection is incomplete for non-string composite values.
+- [security] `references/lib-commons/commons/log/go_logger.go:129`, `references/lib-commons/commons/log/go_logger.go:135`, `references/lib-commons/commons/log/go_logger.go:145` - stdlib logger never consults `commons/security` for key-based redaction, so sensitive fields are emitted verbatim.
+- [security] `references/lib-commons/commons/zap/zap.go:45`, `references/lib-commons/commons/zap/zap.go:221`, `references/lib-commons/commons/zap/zap.go:224` - zap adapter converts all fields with unconditional `zap.Any` and performs no sensitive-field masking.
+- [test] `references/lib-commons/commons/zap/zap_test.go:457` - `TestWithGroupNamespacesFields` never asserts the namespaced field structure.
+- [test] `references/lib-commons/commons/zap/zap.go:107` - panic-recovery branch inside `Sync` is untested.
+- [nil-safety] `references/lib-commons/commons/log/go_logger.go:149-152` - typed-nil `error` or `fmt.Stringer` values can panic when `sanitizeFieldValue` calls `Error()` / `String()`.
+- [nil-safety] `references/lib-commons/commons/log/sanitizer.go:11-24` - `SafeError` only checks `logger == nil`, so a typed-nil `Logger` interface can still panic.
+- [consequences] `references/lib-commons/commons/log/go_logger.go:135`, `references/lib-commons/commons/log/go_logger.go:145`, `references/lib-commons/commons/log/log.go:88` - backend swap does not preserve the same single-line hygiene for `Any` payloads containing nested strings.
+
+### Medium
+- [code] `references/lib-commons/commons/zap/zap.go:83`, `references/lib-commons/commons/log/go_logger.go:82` - `WithGroup("")` has backend-dependent semantics between stdlib and zap implementations.
+- [business] `references/lib-commons/commons/zap/zap.go:83-87`, `references/lib-commons/commons/log/go_logger.go:74-84` - grouped logging behavior changes depending on the backend behind the same `commons/log.Logger` interface.
+- [security] `references/lib-commons/commons/log/go_logger.go:135`, `references/lib-commons/commons/log/go_logger.go:145`, `references/lib-commons/commons/log/go_logger.go:154` - log-injection hardening is incomplete for composite values.
+- [security] `references/lib-commons/commons/log/sanitizer.go:10`, `references/lib-commons/commons/log/sanitizer.go:23`, `references/lib-commons/commons/log/sanitizer.go:28` - `SafeError` depends on a caller-supplied `production` boolean, so one misuse can leak raw upstream error strings.
+- [test] `references/lib-commons/commons/zap/zap_test.go:159`, `references/lib-commons/commons/zap/zap_test.go:182`, `references/lib-commons/commons/zap/zap_test.go:197`, `references/lib-commons/commons/zap/zap_test.go:209`, `references/lib-commons/commons/zap/zap_test.go:220`, `references/lib-commons/commons/zap/zap_test.go:231`, `references/lib-commons/commons/zap/zap_test.go:247`, `references/lib-commons/commons/zap/zap_test.go:265`, `references/lib-commons/commons/zap/zap_test.go:403`, `references/lib-commons/commons/zap/zap_test.go:404` - several tests silently discard returned errors.
+- [test] `references/lib-commons/commons/log/sanitizer_test.go:35` - `TestSafeError_NilGuards` asserts only `NotPanics`.
+- [test] `references/lib-commons/commons/security/sensitive_fields_test.go:435` - concurrent-access test proves only liveness, not correctness of returned values.
+- [consequences] `references/lib-commons/commons/zap/zap.go:83`, `references/lib-commons/commons/zap/zap.go:221`, `references/lib-commons/commons/log/go_logger.go:82`, `references/lib-commons/commons/log/go_logger.go:130` - zap path forwards empty group names and empty field keys that stdlib path drops, creating schema drift for ingestion pipelines.
+- [consequences] `references/lib-commons/commons/zap/zap.go:65`, `references/lib-commons/commons/log/go_logger.go:31`, `references/lib-commons/commons/log/log.go:48`, `references/lib-commons/commons/log/log.go:67` - unknown log levels diverge by backend: stdlib suppresses them while zap downgrades them to `info`.
+
+### Low
+- [code] `references/lib-commons/commons/zap/zap.go:56`, `references/lib-commons/commons/log/go_logger.go:31`, `references/lib-commons/commons/log/log.go:48` - unknown `log.Level` values behave inconsistently between implementations.
+- [business] `references/lib-commons/commons/log/log.go:67-79` - `ParseLevel` lowercases input but does not trim surrounding whitespace.
+- [security] `references/lib-commons/commons/security/sensitive_fields.go:12` - default sensitive-field catalog misses common PII keys like `email`, `phone`, and address-style fields.
+- [test] `references/lib-commons/commons/log/log_test.go:120` - source-text scan test is brittle and implementation-coupled.
+- [test] `references/lib-commons/commons/security/sensitive_fields_test.go:223` - exact field-count assertion makes list evolution noisy.
+- [test] `references/lib-commons/commons/zap/injector_test.go:57` - constant-value assertion tests an implementation detail rather than observable behavior.
+- [test] `references/lib-commons/commons/zap/zap_test.go:100` - `TestSyncReturnsErrorFromUnderlyingLogger` is misleadingly named because it asserts `NoError`.
+
+## 8. Domain + Security Utilities
+
+### Critical
+- [nil-safety] `references/lib-commons/commons/license/manager.go:63` - `New(opts ...ManagerOption)` calls each option without guarding against nil function values.
+- [nil-safety] `references/lib-commons/commons/jwt/jwt.go:258` - `Token.ValidateTimeClaims()` is a value-receiver method on `Token`, so calling it through a nil `*Token` panics before entering the body.
+- [nil-safety] `references/lib-commons/commons/jwt/jwt.go:264` - `Token.ValidateTimeClaimsAt()` has the same nil-pointer panic surface.
+- [nil-safety] `references/lib-commons/commons/crypto/crypto.go:120` - `Encrypt` only checks `c.Cipher == nil`, missing typed-nil `cipher.AEAD` values.
+- [nil-safety] `references/lib-commons/commons/crypto/crypto.go:150` - `Decrypt` has the same typed-nil interface panic risk.
+- [nil-safety] `references/lib-commons/commons/secretsmanager/m2m.go:127` - `GetM2MCredentials` only checks interface-nil client and can still panic on typed-nil implementations.
+- [consequences] `references/lib-commons/commons/transaction/validations.go:263`, `references/lib-commons/commons/transaction/validations.go:268`, `references/lib-commons/commons/transaction/validations.go:209`, `references/lib-commons/commons/transaction/validations.go:219` - planner/applicator contract is internally broken for pending destination cancellations, which resolve to a debit that `applyDebit` rejects for `StatusCanceled`.
+
+### High
+- [code] `references/lib-commons/commons/jwt/jwt.go:274` - token expiry check uses `now.After(exp)`, so a token is still valid at the exact expiration instant.
+- [code] `references/lib-commons/commons/transaction/validations.go:77` - `ValidateBalanceEligibility` never compares `posting.Amount` with source balance availability / hold state.
+- [code] `references/lib-commons/commons/secretsmanager/m2m.go:131`, `references/lib-commons/commons/secretsmanager/m2m.go:198` - path segment validation checks only emptiness, so embedded `/` lets callers escape the intended secret namespace.
+- [business] `references/lib-commons/commons/jwt/jwt.go:273-276` - `exp` semantics are off by one at the exact expiry instant.
+- [business] `references/lib-commons/commons/transaction/validations.go:71-94`, `references/lib-commons/commons/transaction/validations.go:241-248` - balance eligibility never checks whether sources can actually cover the posting amount, so preflight validation can succeed and `ApplyPosting` can still fail for insufficient funds.
+- [business] `references/lib-commons/commons/secretsmanager/m2m.go:131-145`, `references/lib-commons/commons/secretsmanager/m2m.go:192-199` - secret path segments are concatenated without trimming or rejecting embedded `/`.
+- [security] `references/lib-commons/commons/secretsmanager/m2m.go:131-145`, `references/lib-commons/commons/secretsmanager/m2m.go:192-198` - path traversal through secret path building can retrieve the wrong tenant/service secret.
+- [security] `references/lib-commons/commons/license/manager.go:35-40`, `references/lib-commons/commons/license/manager.go:57-60`, `references/lib-commons/commons/license/manager.go:87-112` - default license-failure behavior is fail-open; `DefaultHandler` only records an assertion and does not stop execution.
+- [security] `references/lib-commons/commons/transaction/validations.go:72-121`, `references/lib-commons/commons/transaction/validations.go:146-167`, `references/lib-commons/commons/transaction/transaction.go:109-126` - transaction validation never checks `OrganizationID` or `LedgerID`, so callers can assemble postings across unrelated ledgers/tenants as long as asset and allow flags match.
+- [test] `references/lib-commons/commons/jwt/jwt.go:110`, `references/lib-commons/commons/jwt/jwt.go:116` - `ParseAndValidate` has no direct integration test locking down combined parse + time-claim behavior.
+- [test] `references/lib-commons/commons/crypto/crypto.go:172`, `references/lib-commons/commons/crypto/crypto_test.go:230`, `references/lib-commons/commons/crypto/crypto_test.go:304` - `Decrypt` auth-failure path is not tested with tampered ciphertext or wrong key.
+- [test] `references/lib-commons/commons/secretsmanager/m2m.go:131`, `references/lib-commons/commons/secretsmanager/m2m.go:135`, `references/lib-commons/commons/secretsmanager/m2m.go:139`, `references/lib-commons/commons/secretsmanager/m2m_test.go:393` - input-validation tests cover empty strings only, not whitespace-only values.
+- [consequences] `references/lib-commons/commons/transaction/validations.go:96`, `references/lib-commons/commons/transaction/validations.go:106`, `references/lib-commons/commons/transaction/validations.go:110`, `references/lib-commons/commons/transaction/validations.go:115`, `references/lib-commons/commons/transaction/validations.go:263`, `references/lib-commons/commons/transaction/validations.go:268` - destination validation is hard-coded as receiver-only even when canceled pending destinations are debits.
+- [consequences] `references/lib-commons/commons/transaction/validations.go:77`, `references/lib-commons/commons/transaction/validations.go:87`, `references/lib-commons/commons/transaction/validations.go:124`, `references/lib-commons/commons/transaction/validations.go:141`, `references/lib-commons/commons/transaction/validations.go:242`, `references/lib-commons/commons/transaction/validations.go:247` - `ValidateBalanceEligibility` and `ApplyPosting` disagree on liquidity requirements, increasing late-stage failure risk.
+- [consequences] `references/lib-commons/commons/license/manager.go:82`, `references/lib-commons/commons/license/manager.go:87`, `references/lib-commons/commons/license/manager.go:101`, `references/lib-commons/commons/license/manager.go:108` - `Terminate` can fail open on nil or zero-value managers and has no error channel.
+
+### Medium
+- [code] `references/lib-commons/commons/transaction/validations.go:78`, `references/lib-commons/commons/transaction/validations.go:97` - balance eligibility lookup is keyed only by `BalanceID` and does not verify that resolved balances belong to the posting target account.
+- [code] `references/lib-commons/commons/crypto/crypto.go:75`, `references/lib-commons/commons/crypto/crypto.go:109` - `InitializeCipher` accepts 16/24/32-byte AES keys, but docs describe encryption as requiring a 32-byte key.
+- [code] `references/lib-commons/commons/secretsmanager/m2m.go:156`, `references/lib-commons/commons/secretsmanager/m2m.go:164` - nil/binary/non-string secret payloads are misclassified as JSON unmarshal failures.
+- [code] `references/lib-commons/commons/license/manager.go:117`, `references/lib-commons/commons/license/manager.go:123` - `TerminateWithError` docs promise `ErrLicenseValidationFailed` regardless of initialization state, but nil receiver returns `ErrManagerNotInitialized`.
+- [business] `references/lib-commons/commons/transaction/validations.go:77-80`, `references/lib-commons/commons/transaction/validations.go:96-99`, `references/lib-commons/commons/transaction/validations.go:151-157` - ownership validation is skipped during eligibility precheck, so it can approve a plan that later fails in `ApplyPosting`.
+- [business] `references/lib-commons/commons/secretsmanager/m2m.go:156-166` - binary secrets are treated as malformed JSON instead of unsupported/alternate-format secrets.
+- [security] `references/lib-commons/commons/jwt/jwt.go:272-289`, `references/lib-commons/commons/jwt/jwt.go:304-321` - malformed `exp`, `nbf`, or `iat` values fail open because unsupported types/parse errors simply skip validation.
+- [security] `references/lib-commons/commons/jwt/jwt.go:69-103`, `references/lib-commons/commons/jwt/jwt.go:196-226`, `references/lib-commons/commons/crypto/crypto.go:62-73` - cryptographic operations accept empty secrets and turn misconfiguration into weak-but-valid auth/signing behavior.
+- [security] `references/lib-commons/commons/secretsmanager/m2m.go:165`, `references/lib-commons/commons/secretsmanager/m2m.go:179`, `references/lib-commons/commons/secretsmanager/m2m.go:205-216` - returned errors include the full secret path and leak tenant/service naming metadata.
+- [test] `references/lib-commons/commons/crypto/crypto.go:62`, `references/lib-commons/commons/crypto/crypto_test.go:32`, `references/lib-commons/commons/crypto/crypto_test.go:73` - `GenerateHash` lacks known-vector assertions and only checks length/consistency.
+- [test] `references/lib-commons/commons/transaction/transaction_test.go:786`, `references/lib-commons/commons/transaction/transaction_test.go:796`, `references/lib-commons/commons/transaction/transaction_test.go:809`, `references/lib-commons/commons/transaction/transaction_test.go:817`, `references/lib-commons/commons/transaction/transaction_test.go:826`, `references/lib-commons/commons/transaction/transaction_test.go:845`, `references/lib-commons/commons/transaction/transaction_test.go:854`, `references/lib-commons/commons/transaction/transaction_test.go:866` - several tests ignore `decimal.NewFromString` errors during setup.
+- [test] `references/lib-commons/commons/jwt/jwt.go:274`, `references/lib-commons/commons/jwt/jwt.go:280`, `references/lib-commons/commons/jwt/jwt.go:286`, `references/lib-commons/commons/jwt/jwt_test.go:316`, `references/lib-commons/commons/jwt/jwt_test.go:331` - exact equality boundaries for `exp == now`, `nbf == now`, `iat == now` are not tested.
+- [consequences] `references/lib-commons/commons/license/manager.go:117`, `references/lib-commons/commons/license/manager.go:118`, `references/lib-commons/commons/license/manager.go:122`, `references/lib-commons/commons/license/manager.go:124` - nil-receiver `TerminateWithError` does not satisfy the documented `errors.Is(err, ErrLicenseValidationFailed)` contract.
+- [consequences] `references/lib-commons/commons/jwt/jwt.go:272`, `references/lib-commons/commons/jwt/jwt.go:300`, `references/lib-commons/commons/jwt/jwt.go:310`, `references/lib-commons/commons/jwt/jwt.go:320` - exported time-claim validators only recognize `float64` and `json.Number`, so `int` / `int64` claims in in-memory `MapClaims` are silently skipped.
+
+### Low
+- [code] `references/lib-commons/commons/crypto/crypto.go:62` - `GenerateHash` silently returns `""` for nil receiver/input instead of failing loudly like the rest of the type.
+- [security] `references/lib-commons/commons/license/manager.go:127-133`, `references/lib-commons/commons/license/manager.go:153-158` - warning logs include raw `reason` strings and can leak customer/license details.
+- [test] `references/lib-commons/commons/license/manager_test.go:94` - uninitialized-manager test only asserts no panic, not observable outcome.
+- [consequences] `references/lib-commons/commons/transaction/validations.go:298`, `references/lib-commons/commons/transaction/validations.go:317`, `references/lib-commons/commons/transaction/validations.go:354` - allocation field paths omit whether the failing side was source or destination.
+
+## 9. Shared Primitives + Constants
+
+### Critical
+- [nil-safety] `references/lib-commons/commons/os.go:104`, `references/lib-commons/commons/os.go:106`, `references/lib-commons/commons/os.go:111`, `references/lib-commons/commons/os.go:117` - `SetConfigFromEnvVars` can panic on nil interface, typed-nil pointer, or pointer-to-non-struct instead of returning an error.
+- [nil-safety] `references/lib-commons/commons/context.go:46`, `references/lib-commons/commons/utils.go:192`, `references/lib-commons/commons/utils.go:211` - `NewLoggerFromContext` calls `ctx.Value(...)` without guarding `ctx == nil`, so nil contexts can panic directly or via `GetCPUUsage` / `GetMemUsage`.
+- [nil-safety] `references/lib-commons/commons/app.go:43`, `references/lib-commons/commons/app.go:44` - `WithLogger` option blindly assigns through `l.Logger`, so invoking it with a nil launcher panics.
+- [nil-safety] `references/lib-commons/commons/app.go:52`, `references/lib-commons/commons/app.go:53`, `references/lib-commons/commons/app.go:55` - `RunApp` option appends to launcher state through a nil receiver and can panic.
+- [consequences] `references/lib-commons/commons/cron/cron.go:50`, `references/lib-commons/commons/cron/cron.go:121` - package advertises standard 5-field cron but enforces day-of-month and day-of-week with AND instead of OR, so imported schedules can silently run far less often or never.
+
+### High
+- [code] `references/lib-commons/commons/cron/cron.go:121` - standard day-of-month/day-of-week cron semantics are implemented as AND, not OR.
+- [code] `references/lib-commons/commons/cron/cron.go:113` - `Next` hard-limits its search to 366 days, so valid sparse schedules like leap-day jobs can return `ErrNoMatch`.
+- [code] `references/lib-commons/commons/errors.go:35`, `references/lib-commons/commons/errors.go:73` - `ValidateBusinessError` uses exact error identity instead of `errors.Is`, so wrapped sentinels bypass mapping.
+- [code] `references/lib-commons/commons/os.go:79`, `references/lib-commons/commons/os.go:97` - `InitLocalEnvConfig` returns `nil` outside `ENV_NAME=local`.
+- [code] `references/lib-commons/commons/utils.go:191`, `references/lib-commons/commons/utils.go:204`, `references/lib-commons/commons/utils.go:210`, `references/lib-commons/commons/utils.go:222` - `GetCPUUsage` and `GetMemUsage` dereference `factory` unconditionally.
+- [business] `references/lib-commons/commons/context.go:144`, `references/lib-commons/commons/context.go:191` - `NewTrackingFromContext` generates a fresh UUID whenever `HeaderID` is absent, so two extractions from the same request context can yield different correlation IDs.
+- [business] `references/lib-commons/commons/errors.go:35` - wrapped business errors leak through untranslated because mapping is not `errors.Is`-aware.
+- [business] `references/lib-commons/commons/os.go:72` - DI/provider-style `InitLocalEnvConfig` returns `nil` outside local runs.
+- [business] `references/lib-commons/commons/cron/cron.go:121` - cron `0 0 1 * 1` will run only when the 1st is Monday, not on either condition.
+- [business] `references/lib-commons/commons/cron/cron.go:113` - leap-day schedules can return `ErrNoMatch` even though they are valid.
+- [test] `references/lib-commons/commons/utils.go:181`, `references/lib-commons/commons/utils.go:191`, `references/lib-commons/commons/utils.go:210` - `Syscmd.ExecCmd`, `GetCPUUsage`, and `GetMemUsage` have no test coverage.
+- [consequences] `references/lib-commons/commons/cron/cron.go:32`, `references/lib-commons/commons/cron/cron.go:85` - rejecting day-of-week `7` breaks compatibility with many cron producers.
+- [consequences] `references/lib-commons/commons/cron/cron.go:113` - sparse but valid schedules can be misclassified as no-match.
+- [consequences] `references/lib-commons/commons/errors.go:35`, `references/lib-commons/commons/errors.go:73` - wrapped sentinels stop yielding structured business errors to downstream HTTP/API consumers.
+- [consequences] `references/lib-commons/commons/os.go:79`, `references/lib-commons/commons/os.go:97` - DI consumers can receive nil `*LocalEnvConfig` and fail at startup or first dereference.
+- [consequences] `references/lib-commons/commons/utils.go:191`, `references/lib-commons/commons/utils.go:210` - optional metrics dependencies become panic paths instead of safe degradation.
+
+### Medium
+- [code] `references/lib-commons/commons/os.go:104`, `references/lib-commons/commons/os.go:106`, `references/lib-commons/commons/os.go:117` - `SetConfigFromEnvVars` assumes a non-nil pointer to a struct and is fragile for callers.
+- [code] `references/lib-commons/commons/utils.go:63` - `SafeIntToUint64` converts negative inputs to `1`, which is a surprising semantic default.
+- [code] `references/lib-commons/commons/stringUtils.go:19`, `references/lib-commons/commons/stringUtils.go:181` - `ValidateServerAddress` does not validate port range and rejects valid IPv6 host:port forms.
+- [security] `references/lib-commons/commons/os.go:32-56`, `references/lib-commons/commons/os.go:119-126` - malformed env vars silently fall back to `false` / `0` and can quietly disable protections.
+- [security] `references/lib-commons/commons/errors.go:79-85` - `ValidateBusinessError` appends raw `args` into externally returned business error messages.
+- [security] `references/lib-commons/commons/utils.go:180-187` - `Syscmd.ExecCmd` exposes an arbitrary process execution primitive with no allowlist or validation.
+- [test] `references/lib-commons/commons/context_test.go:58`, `references/lib-commons/commons/context_test.go:80` - time-based assertions around `time.Until(...)` are scheduler-sensitive.
+- [test] `references/lib-commons/commons/os.go:72`, `references/lib-commons/commons/os_test.go:192` - `ENV_NAME=local` branches and `sync.Once` behavior are untested.
+- [test] `references/lib-commons/commons/context.go:76`, `references/lib-commons/commons/context.go:90`, `references/lib-commons/commons/context.go:104`, `references/lib-commons/commons/context.go:118`, `references/lib-commons/commons/context.go:280` - nil-safe branches for several context helpers are not covered.
+- [test] `references/lib-commons/commons/cron/cron.go:233` - malformed range parsing is only partially exercised.
+- [nil-safety] `references/lib-commons/commons/context.go:247`, `references/lib-commons/commons/context.go:249` - `ContextWithSpanAttributes(nil)` with no attrs returns nil instead of normalizing to `context.Background()`.
+- [consequences] `references/lib-commons/commons/os.go:104`, `references/lib-commons/commons/os.go:117` - configuration mistakes become panics in bootstrap/DI code paths.
+- [consequences] `references/lib-commons/commons/context.go:247` - nil context can leak downstream when no attributes are provided.
+
+### Low
+- [code] `references/lib-commons/commons/app.go:71` - `Add` docstring says it runs an application in a goroutine, but it only registers the app.
+- [code] `references/lib-commons/commons/app.go:108`, `references/lib-commons/commons/app.go:118` - `Run` / `RunWithError` comments describe behavior that the implementation cannot provide when logger is nil.
+- [security] `references/lib-commons/commons/context.go:244-260` - `ContextWithSpanAttributes` accepts arbitrary request-wide span attributes with no filtering.
+- [test] `references/lib-commons/commons/pointers/pointers_test.go:42` - `Float64()` lacks a direct unit test.
+- [test] `references/lib-commons/commons/app.go:110` - `Run()` wrapper itself is untested; coverage only hits `RunWithError()`.
+- [test] `references/lib-commons/commons/pointers/pointers.go:26` - `Float64()` is the only exported pointer helper without a corresponding test.
diff --git a/commons/app.go b/commons/app.go
index 9675c692..85187d47 100644
--- a/commons/app.go
+++ b/commons/app.go
@@ -1,19 +1,31 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package commons
import (
+ "context"
"errors"
+ "fmt"
+ "strings"
"sync"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
)
// ErrLoggerNil is returned when the Logger is nil and cannot proceed.
var ErrLoggerNil = errors.New("logger is nil")
+var (
+ // ErrNilLauncher is returned when a launcher method is called on a nil receiver.
+ ErrNilLauncher = errors.New("launcher is nil")
+ // ErrEmptyApp is returned when an app name is empty or whitespace.
+ ErrEmptyApp = errors.New("app name is empty")
+ // ErrNilApp is returned when a nil app instance is provided.
+ ErrNilApp = errors.New("app is nil")
+ // ErrConfigFailed is returned when launcher option application collected errors.
+ ErrConfigFailed = errors.New("launcher configuration failed")
+)
+
// App represents an application that will run as a deployable component.
// It's an entrypoint at main.go.
// RedisRepository provides an interface for redis.
@@ -27,75 +39,153 @@ type App interface {
type LauncherOption func(l *Launcher)
// WithLogger adds a log.Logger component to launcher.
+// If the launcher is nil the option is a no-op, preventing panics when
+// option closures are invoked on a nil receiver.
func WithLogger(logger log.Logger) LauncherOption {
return func(l *Launcher) {
+ if l == nil {
+ return
+ }
+
l.Logger = logger
}
}
-// RunApp start all process registered before to the launcher.
+// RunApp registers an application with the launcher.
+// If registration fails, the error is collected and surfaced when RunWithError is called.
+// If the launcher is nil the option is a no-op, preventing panics when
+// option closures are invoked on a nil receiver.
func RunApp(name string, app App) LauncherOption {
return func(l *Launcher) {
- l.Add(name, app)
+ if l == nil {
+ return
+ }
+
+ if err := l.Add(name, app); err != nil {
+ l.configErrors = append(l.configErrors, fmt.Errorf("add app %q: %w", name, err))
+
+ if l.Logger != nil {
+ l.Logger.Log(context.Background(), log.LevelError, "launcher add app error", log.Err(err))
+ }
+ }
}
}
// Launcher manages apps.
type Launcher struct {
- Logger log.Logger
- apps map[string]App
- wg *sync.WaitGroup
- Verbose bool
+ Logger log.Logger
+ apps map[string]App
+ wg *sync.WaitGroup
+ configErrors []error
+ Verbose bool
}
-// Add runs an application in a goroutine.
-func (l *Launcher) Add(appName string, a App) *Launcher {
+// Add registers an application under the given name for later execution.
+func (l *Launcher) Add(appName string, a App) error {
+ if l == nil {
+ asserter := assert.New(context.Background(), nil, "launcher", "Add")
+ _ = asserter.Never(context.Background(), "launcher receiver is nil")
+
+ return ErrNilLauncher
+ }
+
+ if l.apps == nil {
+ l.apps = make(map[string]App)
+ }
+
+ if l.wg == nil {
+ l.wg = new(sync.WaitGroup)
+ }
+
+ if strings.TrimSpace(appName) == "" {
+ asserter := assert.New(context.Background(), l.Logger, "launcher", "Add")
+ _ = asserter.Never(context.Background(), "app name must not be empty")
+
+ return ErrEmptyApp
+ }
+
+ if a == nil {
+ asserter := assert.New(context.Background(), l.Logger, "launcher", "Add")
+ _ = asserter.Never(context.Background(), "app must not be nil", "app_name", appName)
+
+ return ErrNilApp
+ }
+
l.apps[appName] = a
- return l
+
+ return nil
}
-// Run every application registered before with Run method.
-// Maintains backward compatibility - logs error internally if Logger is nil.
-// For explicit error handling, use RunWithError instead.
+// Run executes every application previously registered via Add.
+// Maintains backward compatibility — logs errors internally when Logger is
+// available. For explicit error handling, use RunWithError instead.
func (l *Launcher) Run() {
if err := l.RunWithError(); err != nil {
if l.Logger != nil {
- l.Logger.Errorf("Launcher error: %v", err)
+ l.Logger.Log(context.Background(), log.LevelError, "launcher error", log.Err(err))
}
}
}
-// RunWithError runs all applications and returns an error if Logger is nil.
-// Use this method when you need explicit error handling for launcher initialization.
+// RunWithError runs all registered applications and returns an error if the
+// launcher is nil, if Logger is nil, or if configuration errors were collected
+// during option application. Safe to call on a Launcher created without
+// NewLauncher (fields are lazy-initialized).
func (l *Launcher) RunWithError() error {
+ if l == nil {
+ return ErrNilLauncher
+ }
+
if l.Logger == nil {
return ErrLoggerNil
}
- count := len(l.apps)
- l.wg.Add(count)
+ // Lazy-init guards: safe to use even if constructed without NewLauncher.
+ if l.wg == nil {
+ l.wg = new(sync.WaitGroup)
+ }
- l.Logger.Infof("Starting %d app(s)\n", count)
+ if l.apps == nil {
+ l.apps = make(map[string]App)
+ }
- for name, app := range l.apps {
- go func(name string, app App) {
- defer l.wg.Done()
+ // Surface any errors collected during option application.
+ if len(l.configErrors) > 0 {
+ return errors.Join(append([]error{ErrConfigFailed}, l.configErrors...)...)
+ }
- l.Logger.Info("--")
- l.Logger.Infof("Launcher: App \u001b[33m(%s)\u001b[0m starting\n", name)
+ count := len(l.apps)
+ l.wg.Add(count)
- if err := app.Run(l); err != nil {
- l.Logger.Infof("Launcher: App (%s) error:", name)
- l.Logger.Infof("\u001b[31m%s\u001b[0m", err)
- }
+ l.Logger.Log(context.Background(), log.LevelInfo, "starting apps", log.Int("count", count))
- l.Logger.Infof("Launcher: App (%s) finished\n", name)
- }(name, app)
+ for name, app := range l.apps {
+ nameCopy := name
+ appCopy := app
+
+ runtime.SafeGoWithContextAndComponent(
+ context.Background(),
+ l.Logger,
+ "launcher",
+ "run_app_"+nameCopy,
+ runtime.KeepRunning,
+ func(_ context.Context) {
+ defer l.wg.Done()
+
+ l.Logger.Log(context.Background(), log.LevelInfo, "app starting", log.String("app", nameCopy))
+
+ if err := appCopy.Run(l); err != nil {
+ l.Logger.Log(context.Background(), log.LevelError, "app error", log.String("app", nameCopy), log.Err(err))
+ }
+
+ l.Logger.Log(context.Background(), log.LevelInfo, "app finished", log.String("app", nameCopy))
+ },
+ )
}
l.wg.Wait()
- l.Logger.Info("Launcher: Terminated")
+ l.Logger.Log(context.Background(), log.LevelInfo, "launcher terminated")
return nil
}
diff --git a/commons/app_test.go b/commons/app_test.go
new file mode 100644
index 00000000..f635f3ef
--- /dev/null
+++ b/commons/app_test.go
@@ -0,0 +1,164 @@
+//go:build unit
+
+package commons
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// stubApp is a minimal App implementation for testing.
+type stubApp struct {
+ err error
+}
+
+func (s *stubApp) Run(_ *Launcher) error {
+ return s.err
+}
+
+func TestNewLauncher(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher()
+ require.NotNil(t, l)
+ assert.True(t, l.Verbose)
+ assert.NotNil(t, l.apps)
+}
+
+func TestLauncher_Add(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var l *Launcher
+ err := l.Add("app", &stubApp{})
+ assert.ErrorIs(t, err, ErrNilLauncher)
+ })
+
+ t.Run("nil_app", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher()
+ err := l.Add("app", nil)
+ assert.ErrorIs(t, err, ErrNilApp)
+ })
+
+ t.Run("empty_name", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher()
+ err := l.Add("", &stubApp{})
+ assert.ErrorIs(t, err, ErrEmptyApp)
+ })
+
+ t.Run("whitespace_name", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher()
+ err := l.Add(" ", &stubApp{})
+ assert.ErrorIs(t, err, ErrEmptyApp)
+ })
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher()
+ err := l.Add("myapp", &stubApp{})
+ assert.NoError(t, err)
+ })
+}
+
+func TestRunAppOption(t *testing.T) {
+ t.Parallel()
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher()
+ opt := RunApp("myapp", &stubApp{})
+ opt(l)
+ assert.Empty(t, l.configErrors)
+ })
+
+ t.Run("failure_nil_app", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher(WithLogger(&log.NopLogger{}))
+ opt := RunApp("myapp", nil)
+ opt(l)
+ assert.NotEmpty(t, l.configErrors)
+ })
+}
+
+func TestWithLoggerOption_NilLauncher(t *testing.T) {
+ t.Parallel()
+
+ // WithLogger option applied to nil launcher must not panic.
+ opt := WithLogger(&log.NopLogger{})
+ assert.NotPanics(t, func() { opt(nil) })
+}
+
+func TestRunAppOption_NilLauncher(t *testing.T) {
+ t.Parallel()
+
+ // RunApp option applied to nil launcher must not panic.
+ opt := RunApp("myapp", &stubApp{})
+ assert.NotPanics(t, func() { opt(nil) })
+}
+
+func TestWithLoggerOption(t *testing.T) {
+ t.Parallel()
+
+ logger := &log.NopLogger{}
+ l := NewLauncher(WithLogger(logger))
+ assert.Equal(t, logger, l.Logger)
+}
+
+func TestRunWithError(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_logger_returns_ErrLoggerNil", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher()
+ err := l.RunWithError()
+ assert.ErrorIs(t, err, ErrLoggerNil)
+ })
+
+ t.Run("config_errors_surface", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher(WithLogger(&log.NopLogger{}))
+ l.configErrors = append(l.configErrors, errors.New("bad config"))
+
+ err := l.RunWithError()
+ assert.ErrorIs(t, err, ErrConfigFailed)
+ })
+
+ t.Run("no_apps_finishes", func(t *testing.T) {
+ t.Parallel()
+
+ l := NewLauncher(WithLogger(&log.NopLogger{}))
+ err := l.RunWithError()
+ assert.NoError(t, err)
+ })
+
+ t.Run("app_run_error_is_handled_gracefully", func(t *testing.T) {
+ t.Parallel()
+
+ sentinel := errors.New("boom")
+
+ l := NewLauncher(WithLogger(&log.NopLogger{}))
+ require.NoError(t, l.Add("failing", &stubApp{err: sentinel}))
+
+ // RunWithError launches apps in goroutines; app errors are logged
+ // but not propagated, so the launcher completes without error.
+ err := l.RunWithError()
+ assert.NoError(t, err)
+ })
+}
diff --git a/commons/assert/assert.go b/commons/assert/assert.go
new file mode 100644
index 00000000..b1500ab6
--- /dev/null
+++ b/commons/assert/assert.go
@@ -0,0 +1,512 @@
+package assert
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "reflect"
+ goruntime "runtime"
+ "runtime/debug"
+ "strconv"
+ "strings"
+ "sync"
+
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/trace"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+)
+
+// Logger defines the minimal logging interface required by assertions.
+// This interface is satisfied by commons/log.Logger.
+type Logger interface {
+ Log(ctx context.Context, level log.Level, msg string, fields ...log.Field)
+}
+
+// Asserter evaluates invariants and emits telemetry on failure.
+type Asserter struct {
+ ctx context.Context
+ logger Logger
+ component string
+ operation string
+}
+
+// ErrAssertionFailed is the sentinel error for failed assertions.
+var ErrAssertionFailed = errors.New("assertion failed")
+
+// AssertionError represents a failed assertion with rich context.
+type AssertionError struct {
+ Assertion string
+ Message string
+ Component string
+ Operation string
+ Details string
+}
+
+// Error returns the formatted assertion failure message.
+func (entry *AssertionError) Error() string {
+ if entry == nil {
+ return ErrAssertionFailed.Error()
+ }
+
+ if entry.Details == "" {
+ return "assertion failed: " + entry.Message
+ }
+
+ return "assertion failed: " + entry.Message + "\n" + entry.Details
+}
+
+// Unwrap returns the sentinel assertion error for errors.Is.
+func (entry *AssertionError) Unwrap() error {
+ return ErrAssertionFailed
+}
+
+// New creates an Asserter with context, logging, and labels.
+// component and operation are used for telemetry labeling.
+//
+//nolint:contextcheck // Intentionally creates a fallback context when nil is passed
+func New(ctx context.Context, logger Logger, component, operation string) *Asserter {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ return &Asserter{
+ ctx: ctx,
+ logger: logger,
+ component: component,
+ operation: operation,
+ }
+}
+
+// That returns an error if ok is false. Use for general-purpose assertions.
+//
+// Example:
+//
+// if err := asserter.That(ctx, len(items) > 0, "items must not be empty", "count", len(items)); err != nil {
+// return err
+// }
+func (asserter *Asserter) That(ctx context.Context, ok bool, msg string, kv ...any) error {
+ if ok {
+ return nil
+ }
+
+ return asserter.fail(ctx, "That", msg, kv...)
+}
+
+// NotNil returns an error if v is nil. This function correctly handles both untyped nil
+// and typed nil (nil interface values with concrete types).
+//
+// Example:
+//
+// if err := asserter.NotNil(ctx, config, "config must be initialized"); err != nil {
+// return err
+// }
+func (asserter *Asserter) NotNil(ctx context.Context, v any, msg string, kv ...any) error {
+ if !isNil(v) {
+ return nil
+ }
+
+ return asserter.fail(ctx, "NotNil", msg, kv...)
+}
+
+// NotEmpty returns an error if s is an empty string.
+//
+// Example:
+//
+// if err := asserter.NotEmpty(ctx, userID, "userID must be provided"); err != nil {
+// return err
+// }
+func (asserter *Asserter) NotEmpty(ctx context.Context, s, msg string, kv ...any) error {
+ if s != "" {
+ return nil
+ }
+
+ return asserter.fail(ctx, "NotEmpty", msg, kv...)
+}
+
+// NoError returns an error if err is not nil. The error message and type are
+// automatically included in the assertion context for debugging.
+//
+// Example:
+//
+// if err := asserter.NoError(ctx, err, "compute must succeed", "input", input); err != nil {
+// return err
+// }
+func (asserter *Asserter) NoError(ctx context.Context, err error, msg string, kv ...any) error {
+ if err == nil {
+ return nil
+ }
+
+ // Prepend error and error_type to key-value pairs for richer debugging
+ // errorKVPairs: 2 pairs added (error + error_type), each pair = 2 elements
+ const errorKVPairs = 4
+
+ kvWithError := make([]any, 0, len(kv)+errorKVPairs)
+ kvWithError = append(kvWithError, "error", err.Error())
+ kvWithError = append(kvWithError, "error_type", fmt.Sprintf("%T", err))
+ kvWithError = append(kvWithError, kv...)
+
+ return asserter.fail(ctx, "NoError", msg, kvWithError...)
+}
+
+// Never always returns an error. Use for code paths that should be unreachable.
+//
+// Example:
+//
+// return asserter.Never(ctx, "unhandled status", "status", status)
+func (asserter *Asserter) Never(ctx context.Context, msg string, kv ...any) error {
+ return asserter.fail(ctx, "Never", msg, kv...)
+}
+
+// Halt terminates the current goroutine if err is not nil.
+// Use this after a failed assertion in goroutines to prevent further execution.
+func (asserter *Asserter) Halt(err error) {
+ if err != nil {
+ goruntime.Goexit()
+ }
+}
+
+const maxValueLength = 200 // Truncate values longer than this
+
+// truncateValue truncates long values for logging safety.
+// This prevents log bloat and reduces risk of sensitive data exposure.
+func truncateValue(v any) string {
+ s := fmt.Sprintf("%v", v)
+ if len(s) <= maxValueLength {
+ return s
+ }
+
+ return s[:maxValueLength] + "... (truncated " + strconv.Itoa(len(s)-maxValueLength) + " chars)"
+}
+
+func (asserter *Asserter) fail(ctx context.Context, assertion, msg string, kv ...any) error {
+ ctx, logger, component, operation := asserter.values(ctx)
+ contextPairs := withContextPairs(assertion, component, operation, kv)
+ details := formatKeyValueLines(contextPairs)
+
+ stack := []byte(nil)
+ if shouldIncludeStack() {
+ stack = debug.Stack()
+ }
+
+ // Emit structured fields for log aggregation; fall back to single-string format
+ // when no logger is available (stderr path) or for the stack trace supplement.
+ logAssertionStructured(logger, assertion, component, operation, msg, details)
+
+ if len(stack) > 0 && logger != nil {
+ logger.Log(context.Background(), log.LevelError, "assertion stack trace",
+ log.String("assertion_type", assertion),
+ log.String("stack_trace", string(stack)),
+ )
+ }
+
+ recordAssertionObservability(ctx, assertion, msg, stack, component, operation)
+
+ return &AssertionError{
+ Assertion: assertion,
+ Message: msg,
+ Component: component,
+ Operation: operation,
+ Details: details,
+ }
+}
+
+func (asserter *Asserter) values(ctx context.Context) (context.Context, Logger, string, string) {
+ if asserter == nil {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ return ctx, nil, "", ""
+ }
+
+ if ctx == nil {
+ ctx = asserter.ctx
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ return ctx, asserter.logger, asserter.component, asserter.operation
+}
+
+// shouldIncludeStack controls whether assertion failures include a stack trace.
+//
+// Stack traces are opt-out: they are included by default and suppressed when
+// production mode is detected. This is intentional because during development
+// and testing, stack traces are invaluable for debugging assertion failures,
+// while in production they add noise and may expose internal paths.
+//
+// To disable stack traces in production, use either:
+// - runtime.SetProductionMode(true) during application startup (preferred)
+// - Set ENV=production or GO_ENV=production environment variables
+func shouldIncludeStack() bool {
+ // Primary check: use runtime.IsProductionMode() which is explicitly
+ // set during application startup via runtime.SetProductionMode(true).
+ if runtime.IsProductionMode() {
+ return false
+ }
+
+ // Fallback: check environment variables for cases where production mode
+ // has not been explicitly configured via the runtime package.
+ env := strings.TrimSpace(os.Getenv("ENV"))
+ goEnv := strings.TrimSpace(os.Getenv("GO_ENV"))
+
+ return !strings.EqualFold(env, "production") && !strings.EqualFold(goEnv, "production")
+}
+
+// contextPairsCapacity is the capacity for the fixed context pairs (assertion, component, operation).
+const contextPairsCapacity = 6
+
+func withContextPairs(assertion, component, operation string, kv []any) []any {
+ contextPairs := make([]any, 0, len(kv)+contextPairsCapacity)
+ contextPairs = append(contextPairs, "assertion", assertion)
+
+ if component != "" {
+ contextPairs = append(contextPairs, "component", component)
+ }
+
+ if operation != "" {
+ contextPairs = append(contextPairs, "operation", operation)
+ }
+
+ contextPairs = append(contextPairs, kv...)
+
+ return contextPairs
+}
+
+func formatKeyValueLines(kv []any) string {
+ if len(kv) == 0 {
+ return ""
+ }
+
+ var sb strings.Builder
+
+ for i := 0; i < len(kv); i += 2 {
+ if i > 0 {
+ sb.WriteString("\n")
+ }
+
+ var value any
+ if i+1 < len(kv) {
+ value = kv[i+1]
+ } else {
+ value = "MISSING_VALUE"
+ }
+
+ fmt.Fprintf(&sb, " %v=%v", kv[i], truncateValue(value))
+ }
+
+ return sb.String()
+}
+
+// logAssertionStructured emits assertion failure as individual structured log fields
+// for better searchability in log aggregation systems (Loki, Elasticsearch, etc.).
+func logAssertionStructured(logger Logger, assertion, component, operation, msg, details string) {
+ if logger == nil {
+ // Fall back to stderr for emergency visibility
+ fmt.Fprintln(os.Stderr, "ASSERTION FAILED: "+msg)
+
+ return
+ }
+
+ fields := []log.Field{
+ log.String("assertion_type", assertion),
+ log.String("message", msg),
+ }
+
+ if component != "" {
+ fields = append(fields, log.String("component", component))
+ }
+
+ if operation != "" {
+ fields = append(fields, log.String("operation", operation))
+ }
+
+ if details != "" {
+ fields = append(fields, log.String("details", details))
+ }
+
+ logger.Log(context.Background(), log.LevelError, "ASSERTION FAILED", fields...)
+}
+
+// logAssertion is kept for backward compatibility with code paths that only
+// have a pre-formatted message string (e.g. when logger is nil and we write to stderr).
+func logAssertion(logger Logger, message string) {
+ if logger != nil {
+ logger.Log(context.Background(), log.LevelError, message)
+ return
+ }
+
+ fmt.Fprintln(os.Stderr, message)
+}
+
+// isNil checks if a value is nil, handling both untyped nil and typed nil
+// (nil interface values with concrete types).
+func isNil(v any) bool {
+ if v == nil {
+ return true
+ }
+
+ rv := reflect.ValueOf(v)
+ switch rv.Kind() {
+ case reflect.Pointer, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
+ return rv.IsNil()
+ default:
+ return false
+ }
+}
+
+// AssertionSpanEventName is the event name used when recording assertion failures on spans.
+const AssertionSpanEventName = constant.EventAssertionFailed
+
+// AssertionMetrics provides assertion-related metrics using OpenTelemetry.
+// It wraps lib-commons' MetricsFactory for consistent metric handling.
+type AssertionMetrics struct {
+ factory *metrics.MetricsFactory
+}
+
+// assertionFailedMetric defines the metric for counting failed assertions.
+var assertionFailedMetric = metrics.Metric{
+ Name: constant.MetricAssertionFailedTotal,
+ Unit: "1",
+ Description: "Total number of failed assertions",
+}
+
+var (
+ assertionMetricsInstance *AssertionMetrics
+ assertionMetricsMu sync.RWMutex
+)
+
+// InitAssertionMetrics initializes assertion metrics with the provided MetricsFactory.
+// This should be called once during application startup after telemetry is initialized.
+func InitAssertionMetrics(factory *metrics.MetricsFactory) {
+ assertionMetricsMu.Lock()
+ defer assertionMetricsMu.Unlock()
+
+ if factory == nil {
+ return
+ }
+
+ if assertionMetricsInstance != nil {
+ return
+ }
+
+ assertionMetricsInstance = &AssertionMetrics{factory: factory}
+}
+
+// GetAssertionMetrics returns the singleton AssertionMetrics instance.
+// Returns nil if InitAssertionMetrics has not been called.
+func GetAssertionMetrics() *AssertionMetrics {
+ assertionMetricsMu.RLock()
+ defer assertionMetricsMu.RUnlock()
+
+ return assertionMetricsInstance
+}
+
+// ResetAssertionMetrics clears the assertion metrics singleton (useful for tests).
+func ResetAssertionMetrics() {
+ assertionMetricsMu.Lock()
+ defer assertionMetricsMu.Unlock()
+
+ assertionMetricsInstance = nil
+}
+
+// RecordAssertionFailed increments the assertion_failed_total counter with labels.
+// If metrics are not initialized, this is a no-op.
+func (am *AssertionMetrics) RecordAssertionFailed(
+ ctx context.Context,
+ component, operation, assertion string,
+) {
+ if am == nil || am.factory == nil {
+ return
+ }
+
+ counter, err := am.factory.Counter(assertionFailedMetric)
+ if err != nil {
+ logAssertion(nil, fmt.Sprintf("failed to create assertion metric counter: %v", err))
+ return
+ }
+
+ err = counter.
+ WithLabels(map[string]string{
+ "component": constant.SanitizeMetricLabel(component),
+ "operation": constant.SanitizeMetricLabel(operation),
+ "assertion": constant.SanitizeMetricLabel(assertion),
+ }).
+ AddOne(ctx)
+ if err != nil {
+ logAssertion(nil, fmt.Sprintf("failed to record assertion metric: %v", err))
+ return
+ }
+}
+
+func recordAssertionMetric(ctx context.Context, component, operation, assertion string) {
+ am := GetAssertionMetrics()
+ if am != nil {
+ am.RecordAssertionFailed(ctx, component, operation, assertion)
+ }
+}
+
+func recordAssertionObservability(
+ ctx context.Context,
+ assertion, message string,
+ stack []byte,
+ component, operation string,
+) {
+ recordAssertionMetric(ctx, component, operation, assertion)
+ recordAssertionToSpan(ctx, assertion, message, stack, component, operation)
+}
+
+func recordAssertionToSpan(
+ ctx context.Context,
+ assertion, message string,
+ stack []byte,
+ component, operation string,
+) {
+ span := trace.SpanFromContext(ctx)
+ if !span.IsRecording() {
+ return
+ }
+
+ attrs := []attribute.KeyValue{
+ attribute.String("assertion.name", assertion),
+ attribute.String("assertion.message", message),
+ }
+
+ if component != "" {
+ attrs = append(attrs, attribute.String("assertion.component", component))
+ }
+
+ if operation != "" {
+ attrs = append(attrs, attribute.String("assertion.operation", operation))
+ }
+
+ if len(stack) > 0 {
+ attrs = append(attrs, attribute.String("assertion.stack", string(stack)))
+ }
+
+ span.AddEvent(AssertionSpanEventName, trace.WithAttributes(attrs...))
+ span.RecordError(fmt.Errorf("%w: %s", ErrAssertionFailed, message))
+ span.SetStatus(codes.Error, assertionStatusMessage(component, operation))
+}
+
+func assertionStatusMessage(component, operation string) string {
+ switch {
+ case component != "" && operation != "":
+ return fmt.Sprintf("assertion failed in %s/%s", component, operation)
+ case component != "":
+ return "assertion failed in " + component
+ case operation != "":
+ return "assertion failed in " + operation
+ default:
+ return "assertion failed"
+ }
+}
diff --git a/commons/assert/assert_extended_test.go b/commons/assert/assert_extended_test.go
new file mode 100644
index 00000000..4dc801c7
--- /dev/null
+++ b/commons/assert/assert_extended_test.go
@@ -0,0 +1,603 @@
+//go:build unit
+
+package assert
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/metric/noop"
+ tracesdk "go.opentelemetry.io/otel/sdk/trace"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+)
+
+func newTestMetricsFactory(t *testing.T) *metrics.MetricsFactory {
+ t.Helper()
+
+ meter := noop.NewMeterProvider().Meter("test")
+ factory, err := metrics.NewMetricsFactory(meter, &libLog.NopLogger{})
+ require.NoError(t, err, "newTestMetricsFactory failed")
+
+ return factory
+}
+
+// --- AssertionError Tests ---
+
+func TestAssertionError_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var entry *AssertionError
+ msg := entry.Error()
+ require.Equal(t, ErrAssertionFailed.Error(), msg)
+}
+
+func TestAssertionError_WithoutDetails(t *testing.T) {
+ t.Parallel()
+
+ entry := &AssertionError{
+ Assertion: "That",
+ Message: "some message",
+ Component: "comp",
+ Operation: "op",
+ Details: "",
+ }
+
+ msg := entry.Error()
+ require.Equal(t, "assertion failed: some message", msg)
+}
+
+func TestAssertionError_WithDetails(t *testing.T) {
+ t.Parallel()
+
+ entry := &AssertionError{
+ Assertion: "NotNil",
+ Message: "value required",
+ Component: "comp",
+ Operation: "op",
+ Details: " key=value",
+ }
+
+ msg := entry.Error()
+ require.Contains(t, msg, "assertion failed: value required")
+ require.Contains(t, msg, "key=value")
+}
+
+func TestAssertionError_Unwrap(t *testing.T) {
+ t.Parallel()
+
+ entry := &AssertionError{Message: "test"}
+ require.ErrorIs(t, entry, ErrAssertionFailed)
+}
+
+// --- Halt Tests ---
+
+func TestHalt_NilError_NoEffect(t *testing.T) {
+ t.Parallel()
+
+ asserter := New(context.Background(), nil, "test", "halt")
+ // Halt with nil error should be a no-op, no panic or goexit.
+ asserter.Halt(nil)
+}
+
+// --- truncateValue Tests ---
+
+func TestTruncateValue_ShortValue(t *testing.T) {
+ t.Parallel()
+
+ result := truncateValue("hello")
+ require.Equal(t, "hello", result)
+}
+
+func TestTruncateValue_ExactMaxLength(t *testing.T) {
+ t.Parallel()
+
+ val := strings.Repeat("a", maxValueLength)
+ result := truncateValue(val)
+ require.Equal(t, val, result)
+}
+
+func TestTruncateValue_LongValue(t *testing.T) {
+ t.Parallel()
+
+ val := strings.Repeat("b", maxValueLength+50)
+ result := truncateValue(val)
+ require.Len(t, result, maxValueLength+len("... (truncated 50 chars)"))
+ require.Contains(t, result, "... (truncated 50 chars)")
+}
+
+func TestTruncateValue_NonStringType(t *testing.T) {
+ t.Parallel()
+
+ result := truncateValue(42)
+ require.Equal(t, "42", result)
+}
+
+// --- values Tests ---
+
+func TestValues_NilAsserter(t *testing.T) {
+ t.Parallel()
+
+ var asserter *Asserter
+ ctx, logger, component, operation := asserter.values(context.Background())
+ require.NotNil(t, ctx)
+ require.Nil(t, logger)
+ require.Empty(t, component)
+ require.Empty(t, operation)
+}
+
+func TestValues_NilAsserterNilCtx(t *testing.T) {
+ t.Parallel()
+
+ var asserter *Asserter
+ //nolint:staticcheck // intentionally passing nil ctx
+ ctx, _, _, _ := asserter.values(nil)
+ require.NotNil(t, ctx)
+}
+
+func TestValues_WithAsserterNilCtx(t *testing.T) {
+ t.Parallel()
+
+ logger := &testLogger{}
+ asserter := New(context.Background(), logger, "comp", "op")
+ //nolint:staticcheck // intentionally passing nil ctx
+ ctx, l, c, o := asserter.values(nil)
+ require.NotNil(t, ctx)
+ require.Equal(t, logger, l)
+ require.Equal(t, "comp", c)
+ require.Equal(t, "op", o)
+}
+
+func TestValues_BothNilFallsToBackground(t *testing.T) {
+ t.Parallel()
+
+ asserter := &Asserter{
+ ctx: nil,
+ logger: nil,
+ component: "",
+ operation: "",
+ }
+ //nolint:staticcheck // intentionally passing nil ctx
+ ctx, _, _, _ := asserter.values(nil)
+ require.NotNil(t, ctx)
+}
+
+// --- SanitizeMetricLabel Tests ---
+
+func TestSanitizeMetricLabel_ShortLabel(t *testing.T) {
+ t.Parallel()
+
+ result := constant.SanitizeMetricLabel("short")
+ require.Equal(t, "short", result)
+}
+
+func TestSanitizeMetricLabel_ExactMaxLength(t *testing.T) {
+ t.Parallel()
+
+ val := strings.Repeat("x", constant.MaxMetricLabelLength)
+ result := constant.SanitizeMetricLabel(val)
+ require.Equal(t, val, result)
+}
+
+func TestSanitizeMetricLabel_TruncatesLongLabel(t *testing.T) {
+ t.Parallel()
+
+ val := strings.Repeat("y", constant.MaxMetricLabelLength+20)
+ result := constant.SanitizeMetricLabel(val)
+ require.Len(t, result, constant.MaxMetricLabelLength)
+ require.Equal(t, strings.Repeat("y", constant.MaxMetricLabelLength), result)
+}
+
+// --- assertionStatusMessage Tests ---
+
+func TestAssertionStatusMessage_ComponentAndOperation(t *testing.T) {
+ t.Parallel()
+
+ msg := assertionStatusMessage("comp", "op")
+ require.Equal(t, "assertion failed in comp/op", msg)
+}
+
+func TestAssertionStatusMessage_ComponentOnly(t *testing.T) {
+ t.Parallel()
+
+ msg := assertionStatusMessage("comp", "")
+ require.Equal(t, "assertion failed in comp", msg)
+}
+
+func TestAssertionStatusMessage_OperationOnly(t *testing.T) {
+ t.Parallel()
+
+ msg := assertionStatusMessage("", "op")
+ require.Equal(t, "assertion failed in op", msg)
+}
+
+func TestAssertionStatusMessage_Neither(t *testing.T) {
+ t.Parallel()
+
+ msg := assertionStatusMessage("", "")
+ require.Equal(t, "assertion failed", msg)
+}
+
+// --- InitAssertionMetrics / ResetAssertionMetrics / GetAssertionMetrics Tests ---
+
+func TestInitAssertionMetrics_NilFactory(t *testing.T) {
+ // Not parallel - modifies global state.
+ ResetAssertionMetrics()
+ defer ResetAssertionMetrics()
+
+ InitAssertionMetrics(nil)
+ require.Nil(t, GetAssertionMetrics())
+}
+
+func TestInitAssertionMetrics_ValidFactory(t *testing.T) {
+ // Not parallel - modifies global state.
+ ResetAssertionMetrics()
+ defer ResetAssertionMetrics()
+
+ factory := newTestMetricsFactory(t)
+ InitAssertionMetrics(factory)
+
+ am := GetAssertionMetrics()
+ require.NotNil(t, am)
+ require.Equal(t, factory, am.factory)
+}
+
+func TestInitAssertionMetrics_DoubleInit_NoOverwrite(t *testing.T) {
+ // Not parallel - modifies global state.
+ ResetAssertionMetrics()
+ defer ResetAssertionMetrics()
+
+ factory1 := newTestMetricsFactory(t)
+ factory2 := newTestMetricsFactory(t)
+
+ InitAssertionMetrics(factory1)
+ InitAssertionMetrics(factory2)
+
+ am := GetAssertionMetrics()
+ require.NotNil(t, am)
+ require.Equal(t, factory1, am.factory, "second init should not overwrite")
+}
+
+func TestResetAssertionMetrics(t *testing.T) {
+ // Not parallel - modifies global state.
+ factory := newTestMetricsFactory(t)
+ InitAssertionMetrics(factory)
+
+ ResetAssertionMetrics()
+ require.Nil(t, GetAssertionMetrics())
+}
+
+// --- RecordAssertionFailed Tests ---
+
+func TestRecordAssertionFailed_NilMetrics(t *testing.T) {
+ t.Parallel()
+
+ // Should be a no-op, no panic.
+ var am *AssertionMetrics
+ am.RecordAssertionFailed(context.Background(), "comp", "op", "That")
+}
+
+func TestRecordAssertionFailed_NilFactory(t *testing.T) {
+ t.Parallel()
+
+ am := &AssertionMetrics{factory: nil}
+ // Should be a no-op, no panic.
+ am.RecordAssertionFailed(context.Background(), "comp", "op", "That")
+}
+
+func TestRecordAssertionFailed_WithFactory(t *testing.T) {
+ // Not parallel - modifies global state.
+ ResetAssertionMetrics()
+ defer ResetAssertionMetrics()
+
+ factory := newTestMetricsFactory(t)
+ InitAssertionMetrics(factory)
+
+ am := GetAssertionMetrics()
+ require.NotNil(t, am)
+ // Should not panic.
+ am.RecordAssertionFailed(context.Background(), "comp", "op", "That")
+}
+
+// --- recordAssertionMetric Tests ---
+
+func TestRecordAssertionMetric_NoMetricsInitialized(t *testing.T) {
+ // Not parallel - modifies global state.
+ ResetAssertionMetrics()
+ defer ResetAssertionMetrics()
+
+ // Should be a no-op, no panic.
+ recordAssertionMetric(context.Background(), "comp", "op", "That")
+}
+
+func TestRecordAssertionMetric_WithMetrics(t *testing.T) {
+ // Not parallel - modifies global state.
+ ResetAssertionMetrics()
+ defer ResetAssertionMetrics()
+
+ factory := newTestMetricsFactory(t)
+ InitAssertionMetrics(factory)
+
+ // Should not panic.
+ recordAssertionMetric(context.Background(), "comp", "op", "NotNil")
+}
+
+// --- recordAssertionToSpan Tests ---
+
+func TestRecordAssertionToSpan_NoSpanInContext(t *testing.T) {
+ t.Parallel()
+
+ // Background context has a no-op span, which is not recording.
+ // Should be a no-op, no panic.
+ recordAssertionToSpan(context.Background(), "That", "test message", nil, "comp", "op")
+}
+
+func TestRecordAssertionToSpan_WithRecordingSpan(t *testing.T) {
+ t.Parallel()
+
+ tp := tracesdk.NewTracerProvider()
+ tracer := tp.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+ defer span.End()
+
+ // Should record event and error on the span, no panic.
+ recordAssertionToSpan(ctx, "NotNil", "value is nil", nil, "comp", "op")
+}
+
+func TestRecordAssertionToSpan_WithStack(t *testing.T) {
+ t.Parallel()
+
+ tp := tracesdk.NewTracerProvider()
+ tracer := tp.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+ defer span.End()
+
+ stack := []byte("goroutine 1:\n main.go:10")
+ recordAssertionToSpan(ctx, "That", "condition false", stack, "comp", "op")
+}
+
+func TestRecordAssertionToSpan_EmptyComponentAndOperation(t *testing.T) {
+ t.Parallel()
+
+ tp := tracesdk.NewTracerProvider()
+ tracer := tp.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+ defer span.End()
+
+ recordAssertionToSpan(ctx, "Never", "unreachable", nil, "", "")
+}
+
+// --- logAssertion Tests ---
+
+func TestLogAssertion_WithNilLogger(t *testing.T) {
+ t.Parallel()
+
+ // Writes to stderr, should not panic.
+ logAssertion(nil, "test message for stderr")
+}
+
+func TestLogAssertion_WithLogger(t *testing.T) {
+ t.Parallel()
+
+ logger := &testLogger{}
+ logAssertion(logger, "test message for logger")
+ require.Len(t, logger.messages, 1)
+ require.Equal(t, "test message for logger", logger.messages[0])
+}
+
+// --- New Tests ---
+
+func TestNew_NilContext(t *testing.T) {
+ t.Parallel()
+
+ //nolint:staticcheck // intentionally passing nil ctx
+ asserter := New(nil, nil, "comp", "op")
+ require.NotNil(t, asserter)
+ require.NotNil(t, asserter.ctx)
+}
+
+func TestNew_WithAllFields(t *testing.T) {
+ t.Parallel()
+
+ logger := &testLogger{}
+ ctx := context.Background()
+ asserter := New(ctx, logger, "comp", "op")
+ require.Equal(t, ctx, asserter.ctx)
+ require.Equal(t, logger, asserter.logger)
+ require.Equal(t, "comp", asserter.component)
+ require.Equal(t, "op", asserter.operation)
+}
+
+// --- formatKeyValueLines Tests ---
+
+func TestFormatKeyValueLines_Empty(t *testing.T) {
+ t.Parallel()
+
+ result := formatKeyValueLines(nil)
+ require.Empty(t, result)
+}
+
+func TestFormatKeyValueLines_SinglePair(t *testing.T) {
+ t.Parallel()
+
+ result := formatKeyValueLines([]any{"key", "value"})
+ require.Equal(t, " key=value", result)
+}
+
+func TestFormatKeyValueLines_MultiplePairs(t *testing.T) {
+ t.Parallel()
+
+ result := formatKeyValueLines([]any{"k1", "v1", "k2", "v2"})
+ require.Contains(t, result, "k1=v1")
+ require.Contains(t, result, "k2=v2")
+}
+
+func TestFormatKeyValueLines_OddCount(t *testing.T) {
+ t.Parallel()
+
+ result := formatKeyValueLines([]any{"k1", "v1", "orphan"})
+ require.Contains(t, result, "k1=v1")
+ require.Contains(t, result, "orphan=MISSING_VALUE")
+}
+
+// --- recordAssertionObservability Tests ---
+
+func TestRecordAssertionObservability_NoMetricsNoSpan(t *testing.T) {
+ // Not parallel - modifies global state.
+ ResetAssertionMetrics()
+ defer ResetAssertionMetrics()
+
+ // Should not panic.
+ recordAssertionObservability(context.Background(), "That", "test", nil, "comp", "op")
+}
+
+// --- isNil Tests ---
+
+func TestIsNil_UntypedNil(t *testing.T) {
+ t.Parallel()
+ require.True(t, isNil(nil))
+}
+
+func TestIsNil_TypedNilPointer(t *testing.T) {
+ t.Parallel()
+
+ var p *int
+ // A typed-nil pointer stored in an interface{} should be detected as nil.
+ require.True(t, isNil(p), "typed nil pointer should be nil")
+}
+
+func TestIsNil_NonNilInt(t *testing.T) {
+ t.Parallel()
+ require.False(t, isNil(42))
+}
+
+func TestIsNil_NonNilString(t *testing.T) {
+ t.Parallel()
+ require.False(t, isNil("hello"))
+}
+
+func TestIsNil_NonNilStruct(t *testing.T) {
+ t.Parallel()
+
+ type s struct{}
+ require.False(t, isNil(s{}))
+}
+
+// --- shouldIncludeStack Tests ---
+
+func TestShouldIncludeStack_NonProduction(t *testing.T) {
+ // Not parallel - uses t.Setenv and depends on runtime global state.
+ t.Setenv("ENV", "development")
+ t.Setenv("GO_ENV", "")
+
+ require.True(t, shouldIncludeStack())
+}
+
+func TestShouldIncludeStack_ProductionENV(t *testing.T) {
+ // Not parallel - uses t.Setenv and depends on runtime global state.
+ t.Setenv("ENV", "production")
+ t.Setenv("GO_ENV", "")
+
+ require.False(t, shouldIncludeStack())
+}
+
+func TestShouldIncludeStack_ProductionGOENV(t *testing.T) {
+ // Not parallel - uses t.Setenv and depends on runtime global state.
+ t.Setenv("ENV", "")
+ t.Setenv("GO_ENV", "production")
+
+ require.False(t, shouldIncludeStack())
+}
+
+func TestShouldIncludeStack_ProductionCaseInsensitive(t *testing.T) {
+ // Not parallel - uses t.Setenv and depends on runtime global state.
+ t.Setenv("ENV", "Production")
+ t.Setenv("GO_ENV", "")
+
+ require.False(t, shouldIncludeStack())
+}
+
+func TestShouldIncludeStack_RuntimeProductionMode(t *testing.T) {
+ // Not parallel - modifies global state.
+ t.Setenv("ENV", "")
+ t.Setenv("GO_ENV", "")
+
+ runtime.SetProductionMode(true)
+ defer runtime.SetProductionMode(false)
+
+ require.False(t, shouldIncludeStack(), "should suppress stacks when runtime.IsProductionMode() is true")
+}
+
+func TestShouldIncludeStack_RuntimeProductionModeOverridesEnv(t *testing.T) {
+ // Not parallel - modifies global state.
+ // Even though env vars say non-production, runtime mode takes priority.
+ t.Setenv("ENV", "development")
+ t.Setenv("GO_ENV", "development")
+
+ runtime.SetProductionMode(true)
+ defer runtime.SetProductionMode(false)
+
+ require.False(t, shouldIncludeStack(), "runtime production mode should override env vars")
+}
+
+func TestShouldIncludeStack_EnvFallbackWhenRuntimeNotSet(t *testing.T) {
+ // Not parallel - modifies global state.
+ runtime.SetProductionMode(false)
+ defer runtime.SetProductionMode(false)
+
+ t.Setenv("ENV", "production")
+ t.Setenv("GO_ENV", "")
+
+ require.False(t, shouldIncludeStack(), "env var fallback should still detect production")
+}
+
+func TestShouldIncludeStack_NonProductionWhenBothDisabled(t *testing.T) {
+ // Not parallel - modifies global state.
+ runtime.SetProductionMode(false)
+ defer runtime.SetProductionMode(false)
+
+ t.Setenv("ENV", "development")
+ t.Setenv("GO_ENV", "")
+
+ require.True(t, shouldIncludeStack(), "should include stacks in non-production mode")
+}
+
+// --- withContextPairs Tests ---
+
+func TestWithContextPairs_AllFields(t *testing.T) {
+ t.Parallel()
+
+ result := withContextPairs("That", "comp", "op", []any{"k1", "v1"})
+ // Should contain: assertion, That, component, comp, operation, op, k1, v1
+ require.Len(t, result, 8)
+}
+
+func TestWithContextPairs_EmptyComponent(t *testing.T) {
+ t.Parallel()
+
+ result := withContextPairs("NotNil", "", "op", nil)
+ // Should contain: assertion, NotNil, operation, op
+ require.Len(t, result, 4)
+}
+
+func TestWithContextPairs_EmptyOperation(t *testing.T) {
+ t.Parallel()
+
+ result := withContextPairs("NotNil", "comp", "", nil)
+ // Should contain: assertion, NotNil, component, comp
+ require.Len(t, result, 4)
+}
+
+func TestWithContextPairs_BothEmpty(t *testing.T) {
+ t.Parallel()
+
+ result := withContextPairs("Never", "", "", nil)
+ // Should contain: assertion, Never
+ require.Len(t, result, 2)
+}
diff --git a/commons/assert/assert_test.go b/commons/assert/assert_test.go
new file mode 100644
index 00000000..dd592d30
--- /dev/null
+++ b/commons/assert/assert_test.go
@@ -0,0 +1,596 @@
+//go:build unit
+
+package assert
+
+import (
+ "context"
+ "errors"
+ "math"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/shopspring/decimal"
+ "github.com/stretchr/testify/require"
+)
+
+// errTest is a test error for assertions.
+var errTest = errors.New("test error")
+
+// errSpecificTest is a specific test error for assertions.
+var errSpecificTest = errors.New("specific test error")
+
+type testLogger struct {
+ messages []string
+}
+
+func (l *testLogger) Log(_ context.Context, _ log.Level, msg string, _ ...log.Field) {
+ l.messages = append(l.messages, msg)
+}
+
+func newTestAsserter(logger Logger) *Asserter {
+ return New(context.Background(), logger, "test-component", "test-operation")
+}
+
+func newTestAsserterWithLogger() (*Asserter, *testLogger) {
+ logger := &testLogger{}
+ return newTestAsserter(logger), logger
+}
+
+// TestThat_Pass verifies That returns nil when condition is true.
+func TestThat_Pass(t *testing.T) {
+ t.Parallel()
+
+ a := newTestAsserter(nil)
+ require.NoError(t, a.That(context.Background(), true, "should not fail"))
+}
+
+// TestThat_Fail verifies That returns an error when condition is false.
+func TestThat_Fail(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.That(context.Background(), false, "should fail")
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrAssertionFailed)
+}
+
+// TestThat_ErrorMessage verifies the error message contains the expected content.
+func TestThat_ErrorMessage(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.That(context.Background(), false, "test message", "key1", "value1", "key2", 42)
+ require.Error(t, err)
+ msg := err.Error()
+ require.Contains(t, msg, "assertion failed:")
+ require.Contains(t, msg, "test message")
+ require.Contains(t, msg, "assertion=That")
+ require.Contains(t, msg, "key1=value1")
+ require.Contains(t, msg, "key2=42")
+}
+
+// TestThat_LogIncludesStackTrace verifies stack trace is logged in non-production.
+// With structured logging, the stack trace is emitted as a separate log call
+// with message "assertion stack trace" and a stack_trace field.
+func TestThat_LogIncludesStackTrace(t *testing.T) {
+ t.Setenv("ENV", "")
+ t.Setenv("GO_ENV", "")
+
+ a, logger := newTestAsserterWithLogger()
+ err := a.That(context.Background(), false, "test message", "key1", "value1")
+ require.Error(t, err)
+ require.NotEmpty(t, logger.messages)
+ // First log is the structured assertion failure
+ require.Contains(t, logger.messages[0], "ASSERTION FAILED")
+ // Second log is the stack trace (emitted separately for structured logging)
+ require.Len(t, logger.messages, 2, "should have assertion failure + stack trace logs")
+ require.Contains(t, logger.messages[1], "assertion stack trace")
+}
+
+// TestNotNil_Pass verifies NotNil returns nil for non-nil values.
+func TestNotNil_Pass(t *testing.T) {
+ t.Parallel()
+
+ asserter := newTestAsserter(nil)
+ require.NoError(t, asserter.NotNil(context.Background(), "hello", "string should not be nil"))
+ require.NoError(t, asserter.NotNil(context.Background(), 42, "int should not be nil"))
+
+ x := new(int)
+ require.NoError(t, asserter.NotNil(context.Background(), x, "pointer should not be nil"))
+
+ s := []int{1, 2, 3}
+ require.NoError(t, asserter.NotNil(context.Background(), s, "slice should not be nil"))
+
+ m := map[string]int{"a": 1}
+ require.NoError(t, asserter.NotNil(context.Background(), m, "map should not be nil"))
+}
+
+// TestNotNil_Fail verifies NotNil returns an error for nil values.
+func TestNotNil_Fail(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.NotNil(context.Background(), nil, "should fail for nil")
+ require.Error(t, err)
+}
+
+// TestNotNil_TypedNil verifies NotNil correctly handles typed nil.
+// A typed nil is when an interface holds a nil pointer of a concrete type.
+func TestNotNil_TypedNil(t *testing.T) {
+ t.Parallel()
+
+ asserter, _ := newTestAsserterWithLogger()
+
+ var ptr *int
+
+ var iface any = ptr // typed nil: interface is not nil, but value is
+
+ err := asserter.NotNil(context.Background(), iface, "should fail for typed nil")
+ require.Error(t, err)
+}
+
+// TestNotNil_TypedNilSlice verifies NotNil handles typed nil slices.
+func TestNotNil_TypedNilSlice(t *testing.T) {
+ t.Parallel()
+
+ asserter, _ := newTestAsserterWithLogger()
+
+ var s []int
+
+ var iface any = s
+
+ err := asserter.NotNil(context.Background(), iface, "should fail for typed nil slice")
+ require.Error(t, err)
+}
+
+// TestNotNil_TypedNilMap verifies NotNil handles typed nil maps.
+func TestNotNil_TypedNilMap(t *testing.T) {
+ t.Parallel()
+
+ asserter, _ := newTestAsserterWithLogger()
+
+ var m map[string]int
+
+ var iface any = m
+
+ err := asserter.NotNil(context.Background(), iface, "should fail for typed nil map")
+ require.Error(t, err)
+}
+
+// TestNotNil_TypedNilChan verifies NotNil handles typed nil channels.
+func TestNotNil_TypedNilChan(t *testing.T) {
+ t.Parallel()
+
+ asserter, _ := newTestAsserterWithLogger()
+
+ var ch chan int
+
+ var iface any = ch
+
+ err := asserter.NotNil(context.Background(), iface, "should fail for typed nil channel")
+ require.Error(t, err)
+}
+
+// TestNotNil_TypedNilFunc verifies NotNil handles typed nil functions.
+func TestNotNil_TypedNilFunc(t *testing.T) {
+ t.Parallel()
+
+ asserter, _ := newTestAsserterWithLogger()
+
+ var fn func()
+
+ var iface any = fn
+
+ err := asserter.NotNil(context.Background(), iface, "should fail for typed nil function")
+ require.Error(t, err)
+}
+
+// TestNotEmpty_Pass verifies NotEmpty returns nil for non-empty strings.
+func TestNotEmpty_Pass(t *testing.T) {
+ t.Parallel()
+
+ a := newTestAsserter(nil)
+ require.NoError(t, a.NotEmpty(context.Background(), "hello", "should not fail"))
+ require.NoError(t, a.NotEmpty(context.Background(), " ", "whitespace is not empty"))
+}
+
+// TestNotEmpty_Fail verifies NotEmpty returns an error for empty strings.
+func TestNotEmpty_Fail(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.NotEmpty(context.Background(), "", "should fail for empty string")
+ require.Error(t, err)
+}
+
+// TestNoError_Pass verifies NoError returns nil when error is nil.
+func TestNoError_Pass(t *testing.T) {
+ t.Parallel()
+
+ a := newTestAsserter(nil)
+ require.NoError(t, a.NoError(context.Background(), nil, "should not fail"))
+}
+
+// TestNoError_Fail verifies NoError returns an error when error is not nil.
+func TestNoError_Fail(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.NoError(context.Background(), errTest, "should fail")
+ require.Error(t, err)
+}
+
+// TestNoError_MessageContainsError verifies the error message and type are included.
+func TestNoError_MessageContainsError(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.NoError(
+ context.Background(),
+ errSpecificTest,
+ "operation failed",
+ "context_key",
+ "context_value",
+ )
+ require.Error(t, err)
+ msg := err.Error()
+ require.Contains(t, msg, "assertion failed:")
+ require.Contains(t, msg, "operation failed")
+ require.Contains(t, msg, "error=specific test error")
+ require.Contains(t, msg, "error_type=*errors.errorString")
+ require.Contains(t, msg, "context_key=context_value")
+}
+
+// TestNever_AlwaysFails verifies Never always returns an error.
+func TestNever_AlwaysFails(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.Never(context.Background(), "unreachable code reached")
+ require.Error(t, err)
+}
+
+// TestNever_ErrorMessage verifies Never includes message and context.
+func TestNever_ErrorMessage(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.Never(context.Background(), "unreachable", "state", "invalid")
+ require.Error(t, err)
+ msg := err.Error()
+ require.Contains(t, msg, "assertion failed:")
+ require.Contains(t, msg, "unreachable")
+ require.Contains(t, msg, "state=invalid")
+}
+
+// TestOddKeyValuePairs verifies handling of odd number of key-value pairs.
+func TestOddKeyValuePairs(t *testing.T) {
+ t.Parallel()
+
+ a, _ := newTestAsserterWithLogger()
+ err := a.That(context.Background(), false, "test", "key1", "value1", "key2")
+ require.Error(t, err)
+ msg := err.Error()
+ require.Contains(t, msg, "key1=value1")
+ require.Contains(t, msg, "key2=MISSING_VALUE")
+}
+
+// TestPositive tests the Positive predicate.
+func TestPositive(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ n int64
+ expected bool
+ }{
+ {"positive", 1, true},
+ {"large positive", 1000000, true},
+ {"max int64", math.MaxInt64, true},
+ {"zero", 0, false},
+ {"negative", -1, false},
+ {"large negative", -1000000, false},
+ {"min int64", math.MinInt64, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, Positive(tt.n))
+ })
+ }
+}
+
+// TestNonNegative tests the NonNegative predicate.
+func TestNonNegative(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ n int64
+ expected bool
+ }{
+ {"positive", 1, true},
+ {"max int64", math.MaxInt64, true},
+ {"zero", 0, true},
+ {"negative", -1, false},
+ {"large negative", -1000000, false},
+ {"min int64", math.MinInt64, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, NonNegative(tt.n))
+ })
+ }
+}
+
+// TestNotZero tests the NotZero predicate.
+func TestNotZero(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ n int64
+ expected bool
+ }{
+ {"positive", 1, true},
+ {"negative", -1, true},
+ {"zero", 0, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, NotZero(tt.n))
+ })
+ }
+}
+
+// TestInRange tests the InRange predicate.
+func TestInRange(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ n int64
+ min int64
+ max int64
+ expected bool
+ }{
+ {"in range", 5, 1, 10, true},
+ {"at min", 1, 1, 10, true},
+ {"at max", 10, 1, 10, true},
+ {"below range", 0, 1, 10, false},
+ {"above range", 11, 1, 10, false},
+ {"inverted range", 5, 10, 1, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, InRange(tt.n, tt.min, tt.max))
+ })
+ }
+}
+
+// TestValidUUID tests ValidUUID predicate.
+func TestValidUUID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ uuid string
+ expected bool
+ }{
+ {"valid UUID", "123e4567-e89b-12d3-a456-426614174000", true},
+ {"valid UUID without hyphens", "123e4567e89b12d3a456426614174000", true},
+ {"empty string", "", false},
+ {"invalid format", "not-a-uuid", false},
+ {"too short", "123e4567-e89b-12d3-a456-42661417400", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, ValidUUID(tt.uuid))
+ })
+ }
+}
+
+// TestValidAmount tests ValidAmount predicate.
+func TestValidAmount(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ amount decimal.Decimal
+ expected bool
+ }{
+ {"zero", decimal.Zero, true},
+ {"max positive exponent", decimal.New(1, 18), true},
+ {"min negative exponent", decimal.New(1, -18), true},
+ {"too large exponent", decimal.New(1, 19), false},
+ {"too small exponent", decimal.New(1, -19), false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, ValidAmount(tt.amount))
+ })
+ }
+}
+
+// TestValidScale tests ValidScale predicate.
+func TestValidScale(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ scale int
+ expected bool
+ }{
+ {"min scale", 0, true},
+ {"max scale", 18, true},
+ {"negative scale", -1, false},
+ {"too large scale", 19, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, ValidScale(tt.scale))
+ })
+ }
+}
+
+// TestPositiveDecimal tests PositiveDecimal predicate.
+func TestPositiveDecimal(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ amount decimal.Decimal
+ expected bool
+ }{
+ {"positive", decimal.NewFromFloat(1.23), true},
+ {"zero", decimal.Zero, false},
+ {"negative", decimal.NewFromFloat(-1.23), false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, PositiveDecimal(tt.amount))
+ })
+ }
+}
+
+// TestNonNegativeDecimal tests NonNegativeDecimal predicate.
+func TestNonNegativeDecimal(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ amount decimal.Decimal
+ expected bool
+ }{
+ {"positive", decimal.NewFromFloat(1.23), true},
+ {"zero", decimal.Zero, true},
+ {"negative", decimal.NewFromFloat(-1.23), false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, NonNegativeDecimal(tt.amount))
+ })
+ }
+}
+
+// TestValidPort tests ValidPort predicate.
+func TestValidPort(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ port string
+ expected bool
+ }{
+ {"valid port", "5432", true},
+ {"min port", "1", true},
+ {"max port", "65535", true},
+ {"zero port", "0", false},
+ {"negative", "-1", false},
+ {"too large", "65536", false},
+ {"non-numeric", "abc", false},
+ {"empty", "", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, ValidPort(tt.port))
+ })
+ }
+}
+
+// TestValidSSLMode tests ValidSSLMode predicate.
+func TestValidSSLMode(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ mode string
+ expected bool
+ }{
+ {"empty", "", true},
+ {"disable", "disable", true},
+ {"allow", "allow", true},
+ {"prefer", "prefer", true},
+ {"require", "require", true},
+ {"verify-ca", "verify-ca", true},
+ {"verify-full", "verify-full", true},
+ {"invalid", "invalid", false},
+ {"uppercase", "DISABLE", false},
+ {"with spaces", " disable ", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, ValidSSLMode(tt.mode))
+ })
+ }
+}
+
+// TestPositiveInt tests PositiveInt predicate.
+func TestPositiveInt(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ n int
+ expected bool
+ }{
+ {"positive", 1, true},
+ {"zero", 0, false},
+ {"negative", -1, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, PositiveInt(tt.n))
+ })
+ }
+}
+
+// TestInRangeInt tests InRangeInt predicate.
+func TestInRangeInt(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ n int
+ min int
+ max int
+ expected bool
+ }{
+ {"in range", 5, 1, 10, true},
+ {"at min", 1, 1, 10, true},
+ {"at max", 10, 1, 10, true},
+ {"below range", 0, 1, 10, false},
+ {"above range", 11, 1, 10, false},
+ {"inverted range", 5, 10, 1, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, InRangeInt(tt.n, tt.min, tt.max))
+ })
+ }
+}
diff --git a/commons/assert/benchmark_test.go b/commons/assert/benchmark_test.go
new file mode 100644
index 00000000..2713fc5e
--- /dev/null
+++ b/commons/assert/benchmark_test.go
@@ -0,0 +1,158 @@
+//go:build unit
+
+package assert
+
+import (
+ "context"
+ "testing"
+
+ "github.com/shopspring/decimal"
+)
+
+// Benchmarks verify assertions are lightweight enough for always-on usage.
+// Target: < 100ns for hot path (condition is true), zero allocations.
+
+// --- Core Assertion Benchmarks (Hot Path) ---
+
+func BenchmarkThat_True(b *testing.B) {
+ asserter := New(context.Background(), nil, "", "")
+ for i := 0; i < b.N; i++ {
+ _ = asserter.That(context.Background(), true, "benchmark test")
+ }
+}
+
+func BenchmarkThat_TrueWithContext(b *testing.B) {
+ asserter := New(context.Background(), nil, "", "")
+ for i := 0; i < b.N; i++ {
+ _ = asserter.That(
+ context.Background(),
+ true,
+ "benchmark test",
+ "key1",
+ "value1",
+ "key2",
+ 42,
+ )
+ }
+}
+
+func BenchmarkNotNil_NonNil(b *testing.B) {
+ asserter := New(context.Background(), nil, "", "")
+
+ v := "test"
+
+ for i := 0; i < b.N; i++ {
+ _ = asserter.NotNil(context.Background(), v, "benchmark test")
+ }
+}
+
+func BenchmarkNotNil_NonNilPointer(b *testing.B) {
+ asserter := New(context.Background(), nil, "", "")
+
+ x := 42
+ ptr := &x
+
+ for i := 0; i < b.N; i++ {
+ _ = asserter.NotNil(context.Background(), ptr, "benchmark test")
+ }
+}
+
+func BenchmarkNotEmpty_NonEmpty(b *testing.B) {
+ asserter := New(context.Background(), nil, "", "")
+
+ s := "test"
+
+ for i := 0; i < b.N; i++ {
+ _ = asserter.NotEmpty(context.Background(), s, "benchmark test")
+ }
+}
+
+func BenchmarkNoError_NilError(b *testing.B) {
+ asserter := New(context.Background(), nil, "", "")
+ for i := 0; i < b.N; i++ {
+ _ = asserter.NoError(context.Background(), nil, "benchmark test")
+ }
+}
+
+// --- Predicate Benchmarks ---
+
+func BenchmarkPositive(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ Positive(int64(i + 1))
+ }
+}
+
+func BenchmarkNonNegative(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ NonNegative(int64(i))
+ }
+}
+
+func BenchmarkInRange(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ InRange(5, 0, 10)
+ }
+}
+
+func BenchmarkValidUUID(b *testing.B) {
+ uuid := "123e4567-e89b-12d3-a456-426614174000"
+ for i := 0; i < b.N; i++ {
+ ValidUUID(uuid)
+ }
+}
+
+func BenchmarkValidAmount(b *testing.B) {
+ amount := decimal.NewFromFloat(1234.56)
+ for i := 0; i < b.N; i++ {
+ ValidAmount(amount)
+ }
+}
+
+func BenchmarkPositiveDecimal(b *testing.B) {
+ amount := decimal.NewFromFloat(1234.56)
+ for i := 0; i < b.N; i++ {
+ PositiveDecimal(amount)
+ }
+}
+
+func BenchmarkValidScale(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ ValidScale(8)
+ }
+}
+
+// --- Helper Function Benchmarks ---
+
+func BenchmarkIsNil_NonNil(b *testing.B) {
+ v := "test"
+ for i := 0; i < b.N; i++ {
+ isNil(v)
+ }
+}
+
+func BenchmarkIsNil_TypedNilPointer(b *testing.B) {
+ var ptr *int
+ for i := 0; i < b.N; i++ {
+ isNil(ptr)
+ }
+}
+
+// --- Combined Usage Benchmarks ---
+
+// BenchmarkTypicalAssertion simulates a typical assertion pattern.
+func BenchmarkTypicalAssertion(b *testing.B) {
+ asserter := New(context.Background(), nil, "", "")
+ id := "123e4567-e89b-12d3-a456-426614174000"
+ amount := decimal.NewFromFloat(100.50)
+
+ for i := 0; i < b.N; i++ {
+ _ = asserter.That(context.Background(), ValidUUID(id), "invalid id", "id", id)
+ _ = asserter.That(
+ context.Background(),
+ PositiveDecimal(amount),
+ "invalid amount",
+ "amount",
+ amount,
+ )
+ }
+}
diff --git a/commons/assert/doc.go b/commons/assert/doc.go
new file mode 100644
index 00000000..6f14e68a
--- /dev/null
+++ b/commons/assert/doc.go
@@ -0,0 +1,172 @@
+// Package assert provides always-on runtime assertions for detecting programming bugs.
+//
+// Unlike test assertions, these assertions are intended to remain enabled in production
+// code. They are designed for detecting invariant violations, programming errors, and
+// impossible states - NOT for input validation or expected error conditions.
+//
+// # Design Philosophy
+//
+// Assertions are for catching bugs, not for handling user input:
+//
+// - Use assertions for conditions that should NEVER be false if the code is correct
+// - Use error returns for conditions that CAN legitimately fail (I/O, user input, etc.)
+// - Assertions return errors so callers can stop execution immediately
+//
+// Good assertion usage:
+//
+// a := assert.New(ctx, logger, "transaction", "create")
+// if err := a.NotNil(ctx, config, "config must be loaded before server starts"); err != nil {
+// return err
+// }
+// if err := a.That(ctx, len(items) > 0, "processItems called with empty slice"); err != nil {
+// return err
+// }
+//
+// Bad assertion usage (use error returns instead):
+//
+// // DON'T: User input validation
+// _ = a.That(ctx, email != "", "email is required") // Use validation errors
+//
+// // DON'T: I/O that can fail
+// // _, err := file.Read(buf)
+// // _ = a.NoError(ctx, err, "file must read") // Use proper error handling
+//
+// # Core Assertion Methods
+//
+// The package provides five core assertion methods on Asserter:
+//
+// a.That(ctx context.Context, ok bool, msg string, kv ...any) error
+// Returns an error if ok is false. General-purpose assertion.
+//
+// a.NotNil(ctx context.Context, v any, msg string, kv ...any) error
+// Returns an error if v is nil. Handles both untyped nil and typed nil (nil interface
+// values with concrete types).
+//
+// a.NotEmpty(ctx context.Context, s string, msg string, kv ...any) error
+// Returns an error if s is an empty string.
+//
+// a.NoError(ctx context.Context, err error, msg string, kv ...any) error
+// Returns an error if err is not nil. Automatically includes the error in context.
+//
+// a.Never(ctx context.Context, msg string, kv ...any) error
+// Always returns an error. Use for unreachable code paths.
+//
+// # Key-Value Context
+//
+// All assertion methods accept optional key-value pairs to provide context
+// in logs and errors:
+//
+// if err := a.That(ctx, balance >= 0, "balance must not be negative",
+// "account_id", accountID,
+// "balance", balance,
+// ); err != nil {
+// return err
+// }
+//
+// The error message will include:
+//
+// assertion failed: balance must not be negative
+// assertion=That
+// account_id=550e8400-e29b-41d4-a716-446655440000
+// balance=-100
+//
+// Odd numbers of key-value arguments are handled gracefully with a "MISSING_VALUE" marker.
+//
+// # Domain Predicates
+//
+// The package includes predicate functions for common domain validations:
+//
+// // Numeric predicates (int64)
+// assert.Positive(n int64) bool // n > 0
+// assert.NonNegative(n int64) bool // n >= 0
+// assert.NotZero(n int64) bool // n != 0
+// assert.InRange(n, minVal, maxVal int64) bool // minVal <= n <= maxVal
+//
+// // Numeric predicates (int)
+// assert.PositiveInt(n int) bool // n > 0
+// assert.InRangeInt(n, minVal, maxVal int) bool // minVal <= n <= maxVal
+//
+// // String predicates
+// assert.ValidUUID(s string) bool // valid UUID format
+//
+// // Financial predicates (using shopspring/decimal)
+// assert.ValidAmount(amount decimal.Decimal) bool // exponent in [-18, 18]
+// assert.ValidScale(scale int) bool // scale in [0, 18]
+// assert.PositiveDecimal(amount decimal.Decimal) bool // amount > 0
+// assert.NonNegativeDecimal(amount decimal.Decimal) bool // amount >= 0
+//
+// Use predicates with Asserter:
+//
+// if err := a.That(ctx, assert.Positive(count), "count must be positive", "count", count); err != nil {
+// return err
+// }
+// if err := a.That(ctx, assert.ValidUUID(id), "invalid UUID", "id", id); err != nil {
+// return err
+// }
+//
+// # Usage Examples
+//
+// Pre-conditions (validate inputs at function entry):
+//
+// func ProcessTransaction(ctx context.Context, tx *Transaction) error {
+// a := assert.New(ctx, logger, "transaction", "process")
+// if err := a.NotNil(ctx, tx, "transaction must not be nil"); err != nil {
+// return err
+// }
+// if err := a.NotEmpty(ctx, tx.ID, "transaction must have ID", "tx", tx); err != nil {
+// return err
+// }
+// // ... rest of function
+// }
+//
+// Post-conditions (validate outputs before return):
+//
+// func CreateAccount(ctx context.Context, name string) (*Account, error) {
+// a := assert.New(ctx, logger, "account", "create")
+// acc := &Account{ID: uuid.New(), Name: name}
+// if err := a.NotEmpty(ctx, acc.ID.String(), "created account must have ID"); err != nil {
+// return nil, err
+// }
+// return acc, nil
+// }
+//
+// Unreachable code:
+//
+// switch status {
+// case Active:
+// return handleActive()
+// case Inactive:
+// return handleInactive()
+// case Deleted:
+// return handleDeleted()
+// default:
+// return a.Never(ctx, "unhandled status", "status", status)
+// }
+//
+// # Goroutine Halting
+//
+// In goroutines, use Halt to stop execution after a failed assertion:
+//
+// go func() {
+// a := assert.New(ctx, logger, "transaction", "sync")
+// if err := a.That(ctx, ready, "sync not ready"); err != nil {
+// a.Halt(err)
+// }
+// // ... rest of goroutine
+// }()
+//
+// # Observability Integration
+//
+// Failed assertions emit telemetry signals:
+//
+// 1. Metrics: Records assertion_failed_total with component/operation/assertion labels.
+// Initialize with InitAssertionMetrics(factory).
+//
+// 2. Tracing: Records assertion.failed span events (with stack traces in non-prod).
+// Automatically uses the span from the context.
+//
+// # Stack Traces
+//
+// Stack traces are included in logs and trace events only in non-production
+// environments (ENV != production and GO_ENV != production).
+package assert
diff --git a/commons/assert/predicates.go b/commons/assert/predicates.go
new file mode 100644
index 00000000..fd3d5724
--- /dev/null
+++ b/commons/assert/predicates.go
@@ -0,0 +1,345 @@
+package assert
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ txn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/google/uuid"
+ "github.com/shopspring/decimal"
+)
+
+// Positive returns true if n > 0.
+//
+// Example:
+//
+// a.That(ctx, assert.Positive(count), "count must be positive", "count", count)
+func Positive(n int64) bool {
+ return n > 0
+}
+
+// NonNegative returns true if n >= 0.
+//
+// Example:
+//
+// a.That(ctx, assert.NonNegative(balance), "balance must not be negative", "balance", balance)
+func NonNegative(n int64) bool {
+ return n >= 0
+}
+
+// NotZero returns true if n != 0.
+//
+// Example:
+//
+// a.That(ctx, assert.NotZero(divisor), "divisor must not be zero", "divisor", divisor)
+func NotZero(n int64) bool {
+ return n != 0
+}
+
+// InRange returns true if min <= n <= max.
+//
+// Note: If min > max (inverted range), always returns false. This is fail-safe
+// behavior - callers should ensure min <= max for correct results.
+//
+// Example:
+//
+// a.That(ctx, assert.InRange(page, 1, 1000), "page out of range", "page", page)
+func InRange(n, minVal, maxVal int64) bool {
+ return n >= minVal && n <= maxVal
+}
+
+// ValidUUID returns true if s is a valid UUID string.
+//
+// Note: Accepts both canonical (with hyphens) and non-canonical (without hyphens)
+// UUID formats per RFC 4122. Empty strings return false.
+//
+// Example:
+//
+// a.That(ctx, assert.ValidUUID(id), "invalid UUID format", "id", id)
+func ValidUUID(s string) bool {
+ if s == "" {
+ return false
+ }
+
+ _, err := uuid.Parse(s)
+
+ return err == nil
+}
+
+// ValidAmount returns true if the decimal's exponent is within reasonable bounds.
+// The exponent must be in the range [-18, 18] to align with supported precision
+// for financial calculations (scale up to 18 decimal places).
+//
+// Note: This validates exponent bounds only, not coefficient size. For user-facing
+// validation, consider additional bounds checks on the coefficient.
+//
+// Example:
+//
+// a.That(ctx, assert.ValidAmount(amount), "amount has invalid precision", "amount", amount)
+func ValidAmount(amount decimal.Decimal) bool {
+ exp := amount.Exponent()
+ return exp >= -18 && exp <= 18
+}
+
+// ValidScale returns true if scale is in the range [0, 18].
+// Scale represents the number of decimal places for financial amounts.
+//
+// Example:
+//
+// a.That(ctx, assert.ValidScale(scale), "invalid scale", "scale", scale)
+func ValidScale(scale int) bool {
+ return scale >= 0 && scale <= 18
+}
+
+// PositiveDecimal returns true if amount > 0.
+//
+// Example:
+//
+// a.That(ctx, assert.PositiveDecimal(price), "price must be positive", "price", price)
+func PositiveDecimal(amount decimal.Decimal) bool {
+ return amount.IsPositive()
+}
+
+// NonNegativeDecimal returns true if amount >= 0.
+//
+// Example:
+//
+// a.That(ctx, assert.NonNegativeDecimal(balance), "balance must not be negative", "balance", balance)
+func NonNegativeDecimal(amount decimal.Decimal) bool {
+ return !amount.IsNegative()
+}
+
+// ValidPort returns true if port is a valid network port number (1-65535).
+// The port must be a numeric string representing a value in the valid range.
+//
+// Note: Port 0 is invalid for configuration purposes (it's used for dynamic allocation).
+// Empty strings, non-numeric values, and out-of-range values return false.
+//
+// Example:
+//
+// a.That(ctx, assert.ValidPort(cfg.DBPort), "DB_PORT must be valid port", "port", cfg.DBPort)
+func ValidPort(port string) bool {
+ if port == "" {
+ return false
+ }
+
+ p, err := strconv.Atoi(port)
+ if err != nil {
+ return false
+ }
+
+ return p > 0 && p <= 65535
+}
+
+// validSSLModes contains the valid PostgreSQL SSL modes.
+// Package-level for zero-allocation lookups in ValidSSLMode.
+var validSSLModes = map[string]bool{
+ "": true, // Empty uses PostgreSQL default
+ "disable": true,
+ "allow": true,
+ "prefer": true,
+ "require": true,
+ "verify-ca": true,
+ "verify-full": true,
+}
+
+// ValidSSLMode returns true if mode is a valid PostgreSQL SSL mode.
+// Valid modes are: disable, allow, prefer, require, verify-ca, verify-full.
+// Empty string is also valid (uses PostgreSQL default).
+//
+// Note: SSL modes are case-sensitive per PostgreSQL documentation.
+// Unknown modes will cause connection failures.
+//
+// Example:
+//
+// a.That(ctx, assert.ValidSSLMode(cfg.DBSSLMode), "DB_SSLMODE invalid", "mode", cfg.DBSSLMode)
+func ValidSSLMode(mode string) bool {
+ return validSSLModes[mode]
+}
+
+// PositiveInt returns true if n > 0.
+// This is the int variant of Positive (which uses int64).
+//
+// Example:
+//
+// a.That(ctx, assert.PositiveInt(cfg.MaxWorkers), "MAX_WORKERS must be positive", "value", cfg.MaxWorkers)
+func PositiveInt(n int) bool {
+ return n > 0
+}
+
+// InRangeInt returns true if min <= n <= max.
+// This is the int variant of InRange (which uses int64).
+//
+// Note: If min > max (inverted range), always returns false. This is fail-safe
+// behavior - callers should ensure min <= max for correct results.
+//
+// Example:
+//
+// a.That(ctx, assert.InRangeInt(cfg.PoolSize, 1, 100), "POOL_SIZE out of range", "value", cfg.PoolSize)
+func InRangeInt(n, minVal, maxVal int) bool {
+ return n >= minVal && n <= maxVal
+}
+
+// DebitsEqualCredits returns true if debits and credits are exactly equal.
+// This validates the fundamental double-entry accounting invariant:
+// for every transaction, total debits MUST equal total credits.
+//
+// Note: Uses decimal.Equal() for exact comparison without floating point issues.
+// Even a tiny difference indicates a bug in amount calculation.
+//
+// Example:
+//
+// a.That(ctx, assert.DebitsEqualCredits(debitTotal, creditTotal),
+// "double-entry violation: debits must equal credits",
+// "debits", debitTotal, "credits", creditTotal)
+func DebitsEqualCredits(debits, credits decimal.Decimal) bool {
+ return debits.Equal(credits)
+}
+
+// NonZeroTotals returns true if both debits and credits are non-zero.
+// A transaction with zero totals is meaningless and indicates a bug.
+//
+// Example:
+//
+// a.That(ctx, assert.NonZeroTotals(debitTotal, creditTotal),
+// "transaction totals must be non-zero",
+// "debits", debitTotal, "credits", creditTotal)
+func NonZeroTotals(debits, credits decimal.Decimal) bool {
+ return !debits.IsZero() && !credits.IsZero()
+}
+
+// validTransactionStatuses contains valid transaction status values.
+// Package-level for zero-allocation lookups.
+var validTransactionStatuses = map[string]bool{
+ txn.CREATED: true,
+ txn.APPROVED: true,
+ txn.PENDING: true,
+ txn.CANCELED: true,
+ txn.NOTED: true,
+}
+
+// ValidTransactionStatus returns true if status is a valid transaction status.
+// Valid statuses are: CREATED, APPROVED, PENDING, CANCELED, NOTED.
+//
+// Note: Statuses are case-sensitive and must match exactly.
+//
+// Example:
+//
+// a.That(ctx, assert.ValidTransactionStatus(tran.Status.Code),
+// "invalid transaction status",
+// "status", tran.Status.Code)
+func ValidTransactionStatus(status string) bool {
+ return validTransactionStatuses[status]
+}
+
+// validTransitions defines the allowed state machine transitions.
+// Key: current state, Value: set of valid target states.
+// Only PENDING transactions can be committed (APPROVED) or canceled (CANCELED).
+var validTransitions = map[string]map[string]bool{
+ txn.PENDING: {
+ txn.APPROVED: true,
+ txn.CANCELED: true,
+ },
+ // CREATED, APPROVED, CANCELED, NOTED are terminal states - no forward transitions
+}
+
+// TransactionCanTransitionTo returns true if transitioning from current to target is valid.
+// The transaction state machine only allows: PENDING -> APPROVED or PENDING -> CANCELED.
+//
+// Note: This is for forward transitions only. Revert is a separate operation.
+//
+// Example:
+//
+// a.That(ctx, assert.TransactionCanTransitionTo(current, next),
+// "invalid status transition",
+// "current", current,
+// "next", next)
+func TransactionCanTransitionTo(current, target string) bool {
+ validTargets, exists := validTransitions[current]
+ if !exists {
+ return false
+ }
+
+ return validTargets[target]
+}
+
+// TransactionCanBeReverted returns true if a transaction can be reverted.
+// The transaction can only be reverted if:
+// - Status is APPROVED
+// - It has no parent transaction (i.e., it is not a reversal of another transaction)
+//
+// This ensures only original transactions can be reverted, not reversals.
+func TransactionCanBeReverted(status string, hasParent bool) bool {
+ if status != txn.APPROVED {
+ return false
+ }
+
+ return !hasParent
+}
+
+// BalanceSufficientForRelease returns true if the available on-hold balance
+// is sufficient to release the specified amount.
+func BalanceSufficientForRelease(onHold, releaseAmount decimal.Decimal) bool {
+ if onHold.IsNegative() || releaseAmount.IsNegative() {
+ return false
+ }
+
+ return onHold.GreaterThanOrEqual(releaseAmount)
+}
+
+// DateNotInFuture returns true if the date is not in the future (i.e., <= now).
+// Zero time is considered valid (returns true).
+func DateNotInFuture(date time.Time) bool {
+ if date.IsZero() {
+ return true
+ }
+
+ return !date.After(time.Now().UTC())
+}
+
+// DateAfter returns true if date is strictly after reference time.
+func DateAfter(date, reference time.Time) bool {
+ return date.After(reference)
+}
+
+// BalanceIsZero returns true if both available and onHold balances are exactly zero.
+func BalanceIsZero(available, onHold decimal.Decimal) bool {
+ return available.IsZero() && onHold.IsZero()
+}
+
+// TransactionHasOperations returns true if the transaction has operations.
+func TransactionHasOperations(operations []string) bool {
+ return len(operations) > 0
+}
+
+// TransactionOperationsContain returns true if every element in operations is
+// contained in the allowed set (i.e. operations is a subset of allowed).
+// Both empty operations and empty allowed return false.
+func TransactionOperationsContain(operations, allowed []string) bool {
+ if len(operations) == 0 || len(allowed) == 0 {
+ return false
+ }
+
+ allowedSet := make(map[string]struct{}, len(allowed))
+ for _, op := range allowed {
+ allowedSet[strings.TrimSpace(op)] = struct{}{}
+ }
+
+ for _, op := range operations {
+ if _, ok := allowedSet[strings.TrimSpace(op)]; !ok {
+ return false
+ }
+ }
+
+ return true
+}
+
+// TransactionOperationsMatch is a deprecated alias for TransactionOperationsContain.
+// It checks subset containment: every operation must be in the allowed set.
+//
+// Deprecated: Use TransactionOperationsContain instead. The name "Match" implied
+// full bidirectional equality, but the behavior is subset containment.
+func TransactionOperationsMatch(operations, allowed []string) bool {
+ return TransactionOperationsContain(operations, allowed)
+}
diff --git a/commons/assert/predicates_test.go b/commons/assert/predicates_test.go
new file mode 100644
index 00000000..66bf0c26
--- /dev/null
+++ b/commons/assert/predicates_test.go
@@ -0,0 +1,340 @@
+//go:build unit
+
+package assert
+
+import (
+ "testing"
+ "time"
+
+ "github.com/shopspring/decimal"
+ "github.com/stretchr/testify/require"
+)
+
+// TestDebitsEqualCredits tests the DebitsEqualCredits predicate for double-entry accounting.
+func TestDebitsEqualCredits(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ debits decimal.Decimal
+ credits decimal.Decimal
+ expected bool
+ }{
+ {"equal positive amounts", decimal.NewFromInt(100), decimal.NewFromInt(100), true},
+ {"equal with decimals", decimal.NewFromFloat(123.45), decimal.NewFromFloat(123.45), true},
+ {"equal zero", decimal.Zero, decimal.Zero, true},
+ {"debits greater", decimal.NewFromInt(100), decimal.NewFromInt(99), false},
+ {"credits greater", decimal.NewFromInt(99), decimal.NewFromInt(100), false},
+ {"tiny difference", decimal.NewFromFloat(100.001), decimal.NewFromFloat(100.002), false},
+ {"large equal", decimal.NewFromInt(1000000000), decimal.NewFromInt(1000000000), true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, DebitsEqualCredits(tt.debits, tt.credits))
+ })
+ }
+}
+
+// TestNonZeroTotals tests the NonZeroTotals predicate for transaction validation.
+func TestNonZeroTotals(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ debits decimal.Decimal
+ credits decimal.Decimal
+ expected bool
+ }{
+ {"both positive", decimal.NewFromInt(100), decimal.NewFromInt(100), true},
+ {"both zero", decimal.Zero, decimal.Zero, false},
+ {"debits zero", decimal.Zero, decimal.NewFromInt(100), false},
+ {"credits zero", decimal.NewFromInt(100), decimal.Zero, false},
+ {"small positive", decimal.NewFromFloat(0.01), decimal.NewFromFloat(0.01), true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, NonZeroTotals(tt.debits, tt.credits))
+ })
+ }
+}
+
+// TestValidTransactionStatus tests the ValidTransactionStatus predicate.
+func TestValidTransactionStatus(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ status string
+ expected bool
+ }{
+ {"CREATED valid", "CREATED", true},
+ {"APPROVED valid", "APPROVED", true},
+ {"PENDING valid", "PENDING", true},
+ {"CANCELED valid", "CANCELED", true},
+ {"NOTED valid", "NOTED", true},
+ {"empty invalid", "", false},
+ {"lowercase invalid", "pending", false},
+ {"unknown invalid", "UNKNOWN", false},
+ {"partial invalid", "APPROV", false},
+ {"with spaces invalid", " PENDING ", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, ValidTransactionStatus(tt.status))
+ })
+ }
+}
+
+// TestTransactionCanTransitionTo tests the TransactionCanTransitionTo predicate.
+func TestTransactionCanTransitionTo(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ current string
+ target string
+ expected bool
+ }{
+ // Valid transitions from PENDING
+ {"PENDING to APPROVED", "PENDING", "APPROVED", true},
+ {"PENDING to CANCELED", "PENDING", "CANCELED", true},
+ // Invalid transitions from PENDING
+ {"PENDING to CREATED", "PENDING", "CREATED", false},
+ {"PENDING to PENDING", "PENDING", "PENDING", false},
+ // Invalid transitions from APPROVED (terminal state for forward)
+ {"APPROVED to CANCELED", "APPROVED", "CANCELED", false},
+ {"APPROVED to PENDING", "APPROVED", "PENDING", false},
+ {"APPROVED to CREATED", "APPROVED", "CREATED", false},
+ // Invalid transitions from CANCELED (terminal state)
+ {"CANCELED to APPROVED", "CANCELED", "APPROVED", false},
+ {"CANCELED to PENDING", "CANCELED", "PENDING", false},
+ // Invalid transitions from CREATED
+ {"CREATED to APPROVED", "CREATED", "APPROVED", false},
+ {"CREATED to CANCELED", "CREATED", "CANCELED", false},
+ // Invalid statuses
+ {"invalid current", "INVALID", "APPROVED", false},
+ {"invalid target", "PENDING", "INVALID", false},
+ {"both invalid", "INVALID", "UNKNOWN", false},
+ {"empty current", "", "APPROVED", false},
+ {"empty target", "PENDING", "", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, TransactionCanTransitionTo(tt.current, tt.target))
+ })
+ }
+}
+
+// TestTransactionCanBeReverted tests the TransactionCanBeReverted predicate.
+func TestTransactionCanBeReverted(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ status string
+ hasParent bool
+ expected bool
+ }{
+ {"APPROVED without parent can revert", "APPROVED", false, true},
+ {"APPROVED with parent cannot revert", "APPROVED", true, false},
+ {"PENDING cannot revert", "PENDING", false, false},
+ {"CANCELED cannot revert", "CANCELED", false, false},
+ {"CREATED cannot revert", "CREATED", false, false},
+ {"NOTED cannot revert", "NOTED", false, false},
+ {"invalid status cannot revert", "INVALID", false, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, TransactionCanBeReverted(tt.status, tt.hasParent))
+ })
+ }
+}
+
+// TestBalanceSufficientForRelease tests the BalanceSufficientForRelease predicate.
+func TestBalanceSufficientForRelease(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ onHold decimal.Decimal
+ releaseAmount decimal.Decimal
+ expected bool
+ }{
+ {"sufficient onHold", decimal.NewFromInt(100), decimal.NewFromInt(50), true},
+ {"exactly sufficient", decimal.NewFromInt(100), decimal.NewFromInt(100), true},
+ {"insufficient onHold", decimal.NewFromInt(50), decimal.NewFromInt(100), false},
+ {"zero onHold zero release", decimal.Zero, decimal.Zero, true},
+ {"zero onHold positive release", decimal.Zero, decimal.NewFromInt(1), false},
+ {
+ "decimal precision sufficient",
+ decimal.NewFromFloat(100.50),
+ decimal.NewFromFloat(100.49),
+ true,
+ },
+ {
+ "decimal precision insufficient",
+ decimal.NewFromFloat(100.49),
+ decimal.NewFromFloat(100.50),
+ false,
+ },
+ {"negative onHold always fails", decimal.NewFromInt(-10), decimal.NewFromInt(5), false},
+ {"negative releaseAmount always fails", decimal.NewFromInt(100), decimal.NewFromInt(-5), false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, BalanceSufficientForRelease(tt.onHold, tt.releaseAmount))
+ })
+ }
+}
+
+// TestDateNotInFuture tests the DateNotInFuture predicate.
+func TestDateNotInFuture(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now()
+
+ tests := []struct {
+ name string
+ date time.Time
+ expected bool
+ }{
+ {"past date valid", now.Add(-24 * time.Hour), true},
+ {"recent past valid", now.Add(-time.Second), true},
+ {"one second ago valid", now.Add(-time.Second), true},
+ {"one second future invalid", now.Add(time.Second), false},
+ {"one hour future invalid", now.Add(time.Hour), false},
+ {"far future invalid", now.Add(365 * 24 * time.Hour), false},
+ {"zero time valid", time.Time{}, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := DateNotInFuture(tt.date)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestDateAfter tests the DateAfter predicate.
+func TestDateAfter(t *testing.T) {
+ t.Parallel()
+
+ base := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
+
+ tests := []struct {
+ name string
+ date time.Time
+ reference time.Time
+ expected bool
+ }{
+ {"date after reference", base.Add(24 * time.Hour), base, true},
+ {"date equal to reference", base, base, false},
+ {"date before reference", base.Add(-24 * time.Hour), base, false},
+ {"date one second after", base.Add(time.Second), base, true},
+ {"date one second before", base.Add(-time.Second), base, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, DateAfter(tt.date, tt.reference))
+ })
+ }
+}
+
+// TestBalanceIsZero tests the BalanceIsZero predicate.
+func TestBalanceIsZero(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ available decimal.Decimal
+ onHold decimal.Decimal
+ expected bool
+ }{
+ {"both zero", decimal.Zero, decimal.Zero, true},
+ {"available non-zero", decimal.NewFromInt(1), decimal.Zero, false},
+ {"onHold non-zero", decimal.Zero, decimal.NewFromInt(1), false},
+ {"both non-zero", decimal.NewFromInt(1), decimal.NewFromInt(1), false},
+ {"tiny available", decimal.NewFromFloat(0.001), decimal.Zero, false},
+ {"negative available still not zero", decimal.NewFromInt(-1), decimal.Zero, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, BalanceIsZero(tt.available, tt.onHold))
+ })
+ }
+}
+
+// TestTransactionHasOperations tests the TransactionHasOperations predicate.
+func TestTransactionHasOperations(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ ops []string
+ expectedOK bool
+ }{
+ {"has operations", []string{"CREDIT"}, true},
+ {"empty operations", nil, false},
+ {"empty slice", []string{}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expectedOK, TransactionHasOperations(tt.ops))
+ })
+ }
+}
+
+// TestTransactionOperationsContain tests the TransactionOperationsContain predicate.
+func TestTransactionOperationsContain(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ ops []string
+ allowed []string
+ expected bool
+ }{
+ {"match single", []string{"CREDIT"}, []string{"CREDIT", "DEBIT"}, true},
+ {"match multiple", []string{"CREDIT", "DEBIT"}, []string{"CREDIT", "DEBIT"}, true},
+ {"mismatch", []string{"TRANSFER"}, []string{"CREDIT", "DEBIT"}, false},
+ {"empty operations", []string{}, []string{"CREDIT"}, false},
+ {"empty allowed", []string{"CREDIT"}, []string{}, false},
+ {"whitespace tolerant", []string{" CREDIT "}, []string{"CREDIT"}, true},
+ {"whitespace mismatch", []string{" CREDIT "}, []string{"DEBIT"}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, TransactionOperationsContain(tt.ops, tt.allowed))
+ })
+ }
+}
+
+// TestTransactionOperationsMatch_DeprecatedAlias verifies the deprecated alias delegates correctly.
+func TestTransactionOperationsMatch_DeprecatedAlias(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, TransactionOperationsMatch([]string{"CREDIT"}, []string{"CREDIT", "DEBIT"}))
+ require.False(t, TransactionOperationsMatch([]string{"TRANSFER"}, []string{"CREDIT", "DEBIT"}))
+}
diff --git a/commons/backoff/backoff.go b/commons/backoff/backoff.go
new file mode 100644
index 00000000..dc5b8257
--- /dev/null
+++ b/commons/backoff/backoff.go
@@ -0,0 +1,114 @@
+package backoff
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/big"
+ mrand "math/rand/v2"
+ "time"
+)
+
+const maxShift = 62
+
+// Exponential calculates exponential delay based on attempt number.
+// The delay is calculated as base * 2^attempt with overflow protection.
+// Negative attempts are treated as 0.
+func Exponential(base time.Duration, attempt int) time.Duration {
+ if base <= 0 {
+ return 0
+ }
+
+ if attempt < 0 {
+ attempt = 0
+ } else if attempt > maxShift {
+ attempt = maxShift
+ }
+
+ multiplier := int64(1 << attempt)
+
+ baseInt := int64(base)
+ if baseInt > math.MaxInt64/multiplier {
+ return time.Duration(math.MaxInt64)
+ }
+
+ return time.Duration(baseInt * multiplier)
+}
+
+// FullJitter returns a random duration in the range [0, delay).
+// Uses crypto/rand for secure randomness, falling back to math/rand if crypto fails.
+// Returns 0 for zero or negative delays.
+func FullJitter(delay time.Duration) time.Duration {
+ if delay <= 0 {
+ return 0
+ }
+
+ n, err := rand.Int(rand.Reader, big.NewInt(int64(delay)))
+ if err != nil {
+ return time.Duration(cryptoFallbackRand(int64(delay)))
+ }
+
+ return time.Duration(n.Int64())
+}
+
+// fallbackDivisor is used when crypto/rand fails completely.
+const fallbackDivisor = 2
+
+// cryptoFallbackRand provides a fallback random number generator when crypto/rand fails.
+// It uses a defense-in-depth strategy with two fallback layers:
+// - Layer 1: Attempt to seed a math/rand PRNG via crypto/rand. Even though
+// FullJitter's crypto/rand.Int already failed, rand.Read uses a different
+// code path (raw bytes vs big.Int) and may succeed independently.
+// - Layer 2: If even seeding fails, return a deterministic midpoint
+// (maxValue / 2) to provide a reasonable jitter value without blocking.
+//
+// This ensures backoff jitter never stalls, even under severe entropy exhaustion.
+func cryptoFallbackRand(maxValue int64) int64 {
+ var seed [8]byte
+
+ _, err := rand.Read(seed[:])
+ if err != nil {
+ return maxValue / fallbackDivisor
+ }
+
+ rng := mrand.New(
+ mrand.NewPCG(binary.LittleEndian.Uint64(seed[:]), 0),
+ ) // #nosec G404 -- Fallback when crypto/rand fails
+
+ return rng.Int64N(maxValue)
+}
+
+// ExponentialWithJitter combines exponential backoff with full jitter.
+// Returns a random duration in [0, base * 2^attempt).
+// This implements the "Full Jitter" strategy recommended by AWS.
+func ExponentialWithJitter(base time.Duration, attempt int) time.Duration {
+ exponentialDelay := Exponential(base, attempt)
+
+ return FullJitter(exponentialDelay)
+}
+
+// WaitContext sleeps for the specified duration but respects context cancellation.
+// Returns nil if the sleep completes, or an error if the context is cancelled.
+// Returns the context error for zero or negative durations if the context is already cancelled.
+// A nil context is normalized to context.Background().
+func WaitContext(ctx context.Context, duration time.Duration) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if duration <= 0 {
+ return ctx.Err()
+ }
+
+ timer := time.NewTimer(duration)
+ defer timer.Stop()
+
+ select {
+ case <-timer.C:
+ return nil
+ case <-ctx.Done():
+ return fmt.Errorf("context done: %w", ctx.Err())
+ }
+}
diff --git a/commons/backoff/backoff_test.go b/commons/backoff/backoff_test.go
new file mode 100644
index 00000000..bb3a36ad
--- /dev/null
+++ b/commons/backoff/backoff_test.go
@@ -0,0 +1,450 @@
+//go:build unit
+
+package backoff
+
+import (
+ "context"
+ "math"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestExponential(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ base time.Duration
+ attempt int
+ expected time.Duration
+ }{
+ {
+ name: "attempt 0 returns base",
+ base: 100 * time.Millisecond,
+ attempt: 0,
+ expected: 100 * time.Millisecond,
+ },
+ {
+ name: "attempt 1 doubles base",
+ base: 100 * time.Millisecond,
+ attempt: 1,
+ expected: 200 * time.Millisecond,
+ },
+ {
+ name: "attempt 2 quadruples base",
+ base: 100 * time.Millisecond,
+ attempt: 2,
+ expected: 400 * time.Millisecond,
+ },
+ {
+ name: "attempt 3 is 8x base",
+ base: 100 * time.Millisecond,
+ attempt: 3,
+ expected: 800 * time.Millisecond,
+ },
+ {
+ name: "attempt 10 is 1024x base",
+ base: 1 * time.Millisecond,
+ attempt: 10,
+ expected: 1024 * time.Millisecond,
+ },
+ {
+ name: "negative attempt treated as 0",
+ base: 100 * time.Millisecond,
+ attempt: -5,
+ expected: 100 * time.Millisecond,
+ },
+ {
+ name: "zero base returns 0",
+ base: 0,
+ attempt: 5,
+ expected: 0,
+ },
+ {
+ name: "negative base returns 0",
+ base: -100 * time.Millisecond,
+ attempt: 5,
+ expected: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := Exponential(tt.base, tt.attempt)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestExponential_OverflowProtection(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ attempt int
+ }{
+ {"attempt 62 (max allowed)", 62},
+ {"attempt 63 clamped to 62", 63},
+ {"attempt 100 clamped to 62", 100},
+ {"attempt 1000 clamped to 62", 1000},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := Exponential(1*time.Nanosecond, tt.attempt)
+ expected := Exponential(1*time.Nanosecond, 62)
+ assert.Equal(t, expected, result)
+ assert.NotPanics(t, func() {
+ _ = Exponential(time.Second, tt.attempt)
+ })
+ })
+ }
+}
+
+func TestExponential_MultiplicationOverflow(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ base time.Duration
+ attempt int
+ }{
+ {
+ name: "hour base with attempt 40 overflows",
+ base: time.Hour,
+ attempt: 40,
+ },
+ {
+ name: "hour base with attempt 62 overflows",
+ base: time.Hour,
+ attempt: 62,
+ },
+ {
+ name: "second base with attempt 50 overflows",
+ base: time.Second,
+ attempt: 50,
+ },
+ {
+ name: "large base with moderate attempt overflows",
+ base: 24 * time.Hour,
+ attempt: 30,
+ },
+ {
+ name: "max int64 nanoseconds base with attempt 1 overflows",
+ base: time.Duration(math.MaxInt64/2 + 1),
+ attempt: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := Exponential(tt.base, tt.attempt)
+ assert.Equal(t, time.Duration(math.MaxInt64), result,
+ "overflow should clamp to math.MaxInt64")
+ })
+ }
+}
+
+func TestExponential_MultiplicationBoundary(t *testing.T) {
+ t.Parallel()
+
+ t.Run("just below overflow threshold remains exact", func(t *testing.T) {
+ t.Parallel()
+
+ // 1 nanosecond * 2^40 = 1,099,511,627,776 ns (~18 min) -- no overflow
+ result := Exponential(1*time.Nanosecond, 40)
+ expected := time.Duration(int64(1) << 40)
+ assert.Equal(t, expected, result)
+ })
+
+ t.Run("1 nanosecond base never overflows at max shift", func(t *testing.T) {
+ t.Parallel()
+
+ // 1 ns * 2^62 = 4,611,686,018,427,387,904 ns (~146 years) -- fits int64
+ result := Exponential(1*time.Nanosecond, 62)
+ expected := time.Duration(int64(1) << 62)
+ assert.Equal(t, expected, result)
+ })
+
+ t.Run("2 nanoseconds base overflows at max shift", func(t *testing.T) {
+ t.Parallel()
+
+ // 2 ns * 2^62 would be 2^63 which overflows int64
+ result := Exponential(2*time.Nanosecond, 62)
+ assert.Equal(t, time.Duration(math.MaxInt64), result)
+ })
+
+ t.Run("result is always positive", func(t *testing.T) {
+ t.Parallel()
+
+ // Ensure no wraparound to negative values
+ largeValues := []struct {
+ base time.Duration
+ attempt int
+ }{
+ {time.Hour, 40},
+ {time.Minute, 50},
+ {time.Second, 55},
+ {time.Millisecond, 60},
+ {24 * time.Hour, 62},
+ }
+
+ for _, v := range largeValues {
+ result := Exponential(v.base, v.attempt)
+ assert.Positive(t, int64(result),
+ "Exponential(%v, %d) should never be negative", v.base, v.attempt)
+ }
+ })
+}
+
+func TestFullJitter(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ delay time.Duration
+ }{
+ {"100ms delay", 100 * time.Millisecond},
+ {"1s delay", 1 * time.Second},
+ {"10s delay", 10 * time.Second},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ for range 100 {
+ result := FullJitter(tt.delay)
+ assert.GreaterOrEqual(t, result, time.Duration(0))
+ assert.Less(t, result, tt.delay)
+ }
+ })
+ }
+}
+
+func TestFullJitter_EdgeCases(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ delay time.Duration
+ expected time.Duration
+ }{
+ {"zero delay returns 0", 0, 0},
+ {"negative delay returns 0", -100 * time.Millisecond, 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := FullJitter(tt.delay)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestFullJitter_Distribution(t *testing.T) {
+ t.Parallel()
+
+ const iterations = 1000
+
+ delay := 100 * time.Millisecond
+
+ var sum time.Duration
+
+ for range iterations {
+ sum += FullJitter(delay)
+ }
+
+ avg := sum / iterations
+ expectedMid := delay / 2
+ tolerance := delay / 5
+
+ assert.InDelta(t, int64(expectedMid), int64(avg), float64(tolerance),
+ "average should be roughly half the delay (expected ~%v, got %v)", expectedMid, avg)
+}
+
+func TestExponentialWithJitter(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ base time.Duration
+ attempt int
+ }{
+ {"attempt 0", 100 * time.Millisecond, 0},
+ {"attempt 1", 100 * time.Millisecond, 1},
+ {"attempt 5", 100 * time.Millisecond, 5},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ maxDelay := Exponential(tt.base, tt.attempt)
+
+ for range 50 {
+ result := ExponentialWithJitter(tt.base, tt.attempt)
+ assert.GreaterOrEqual(t, result, time.Duration(0))
+ assert.Less(t, result, maxDelay)
+ }
+ })
+ }
+}
+
+func TestExponentialWithJitter_EdgeCases(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ base time.Duration
+ attempt int
+ expected time.Duration
+ }{
+ {"zero base returns 0", 0, 5, 0},
+ {"negative base returns 0", -100 * time.Millisecond, 5, 0},
+ {"negative attempt treated as 0", 100 * time.Millisecond, -5, 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ if tt.expected == 0 && tt.base > 0 {
+ maxDelay := Exponential(tt.base, 0)
+
+ for range 50 {
+ result := ExponentialWithJitter(tt.base, tt.attempt)
+ assert.GreaterOrEqual(t, result, time.Duration(0))
+ assert.Less(t, result, maxDelay)
+ }
+ } else {
+ result := ExponentialWithJitter(tt.base, tt.attempt)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestWaitContext(t *testing.T) {
+ t.Parallel()
+
+ t.Run("completes sleep successfully", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ start := time.Now()
+ err := WaitContext(ctx, 50*time.Millisecond)
+ elapsed := time.Since(start)
+
+ require.NoError(t, err)
+ assert.GreaterOrEqual(t, elapsed, 50*time.Millisecond)
+ })
+
+ t.Run("respects context cancellation", func(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ go func() {
+ time.Sleep(20 * time.Millisecond)
+ cancel()
+ }()
+
+ start := time.Now()
+ err := WaitContext(ctx, 1*time.Second)
+ elapsed := time.Since(start)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, context.Canceled)
+ assert.Less(t, elapsed, 500*time.Millisecond)
+ })
+
+ t.Run("respects context deadline", func(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
+ defer cancel()
+
+ err := WaitContext(ctx, 1*time.Second)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, context.DeadlineExceeded)
+ })
+
+ t.Run("zero duration returns immediately", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ start := time.Now()
+ err := WaitContext(ctx, 0)
+ elapsed := time.Since(start)
+
+ require.NoError(t, err)
+ assert.Less(t, elapsed, 200*time.Millisecond)
+ })
+
+ t.Run("negative duration returns immediately", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ start := time.Now()
+ err := WaitContext(ctx, -100*time.Millisecond)
+ elapsed := time.Since(start)
+
+ require.NoError(t, err)
+ assert.Less(t, elapsed, 200*time.Millisecond)
+ })
+
+ t.Run("zero duration with cancelled context returns error", func(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err := WaitContext(ctx, 0)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, context.Canceled)
+ })
+
+ t.Run("already cancelled context returns immediately", func(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ start := time.Now()
+ err := WaitContext(ctx, 1*time.Second)
+ elapsed := time.Since(start)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, context.Canceled)
+ assert.Less(t, elapsed, 200*time.Millisecond)
+ })
+}
+
+func TestCryptoFallbackRand(t *testing.T) {
+ t.Parallel()
+
+ t.Run("returns value in range", func(t *testing.T) {
+ t.Parallel()
+
+ const maxValue = 1000
+
+ for range 100 {
+ result := cryptoFallbackRand(maxValue)
+ assert.GreaterOrEqual(t, result, int64(0))
+ assert.Less(t, result, int64(maxValue))
+ }
+ })
+}
diff --git a/commons/backoff/doc.go b/commons/backoff/doc.go
new file mode 100644
index 00000000..32347c44
--- /dev/null
+++ b/commons/backoff/doc.go
@@ -0,0 +1,5 @@
+// Package backoff provides retry delay helpers with exponential growth and jitter.
+//
+// Use ExponentialWithJitter for retry loops and WaitContext to wait while
+// respecting cancellation and deadlines.
+package backoff
diff --git a/commons/circuitbreaker/config.go b/commons/circuitbreaker/config.go
index a2de4da5..02eb0b34 100644
--- a/commons/circuitbreaker/config.go
+++ b/commons/circuitbreaker/config.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package circuitbreaker
import "time"
diff --git a/commons/circuitbreaker/doc.go b/commons/circuitbreaker/doc.go
new file mode 100644
index 00000000..125de6e8
--- /dev/null
+++ b/commons/circuitbreaker/doc.go
@@ -0,0 +1,9 @@
+// Package circuitbreaker provides service-level circuit breaker orchestration
+// and health-check-driven recovery helpers.
+//
+// Use NewManager to create and manage per-service breakers, then run calls through
+// Manager.Execute so failures are tracked consistently across callers.
+//
+// Optional health-check integration can automatically reset breakers after
+// downstream services recover.
+package circuitbreaker
diff --git a/commons/circuitbreaker/fallback_example_test.go b/commons/circuitbreaker/fallback_example_test.go
new file mode 100644
index 00000000..a3bb5c5f
--- /dev/null
+++ b/commons/circuitbreaker/fallback_example_test.go
@@ -0,0 +1,58 @@
+//go:build unit
+
+package circuitbreaker_test
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/circuitbreaker"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+func ExampleManager_Execute_fallbackOnOpen() {
+ mgr, err := circuitbreaker.NewManager(&log.NopLogger{})
+ if err != nil {
+ return
+ }
+
+ _, err = mgr.GetOrCreate("core-ledger", circuitbreaker.Config{
+ MaxRequests: 1,
+ Interval: time.Minute,
+ Timeout: time.Second,
+ ConsecutiveFailures: 1,
+ })
+ if err != nil {
+ return
+ }
+
+ _, firstErr := mgr.Execute("core-ledger", func() (any, error) {
+ return nil, errors.New("upstream timeout")
+ })
+
+ _, secondErr := mgr.Execute("core-ledger", func() (any, error) {
+ return "ok", nil
+ })
+
+ fallback := "primary"
+ if secondErr != nil {
+ fallback = "cached-response"
+ }
+
+ fmt.Println(firstErr != nil)
+ fmt.Println(mgr.GetState("core-ledger") == circuitbreaker.StateOpen)
+ if secondErr != nil {
+ fmt.Println(strings.Contains(secondErr.Error(), "currently unavailable"))
+ } else {
+ fmt.Println(false)
+ }
+ fmt.Println(fallback)
+
+ // Output:
+ // true
+ // true
+ // true
+ // cached-response
+}
diff --git a/commons/circuitbreaker/healthchecker.go b/commons/circuitbreaker/healthchecker.go
index a38264c7..79ccbd60 100644
--- a/commons/circuitbreaker/healthchecker.go
+++ b/commons/circuitbreaker/healthchecker.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package circuitbreaker
import (
@@ -11,10 +7,13 @@ import (
"sync"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
)
var (
+ // ErrNilManager is returned when a nil manager is passed to NewHealthCheckerWithValidation.
+ ErrNilManager = errors.New("circuitbreaker: manager must not be nil")
// ErrInvalidHealthCheckInterval indicates that the health check interval must be positive
ErrInvalidHealthCheckInterval = errors.New("circuitbreaker: health check interval must be positive")
// ErrInvalidHealthCheckTimeout indicates that the health check timeout must be positive
@@ -32,6 +31,8 @@ type healthChecker struct {
immediateCheck chan string // Channel to trigger immediate health check for a service
wg sync.WaitGroup
mu sync.RWMutex
+ stopOnce sync.Once
+ started bool
}
// NewHealthCheckerWithValidation creates a new health checker with validation.
@@ -39,6 +40,14 @@ type healthChecker struct {
// interval: how often to run health checks
// checkTimeout: timeout for each individual health check operation
func NewHealthCheckerWithValidation(manager Manager, interval, checkTimeout time.Duration, logger log.Logger) (HealthChecker, error) {
+ if manager == nil {
+ return nil, ErrNilManager
+ }
+
+ if logger == nil {
+ return nil, ErrNilLogger
+ }
+
if interval <= 0 {
return nil, ErrInvalidHealthCheckInterval
}
@@ -58,45 +67,59 @@ func NewHealthCheckerWithValidation(manager Manager, interval, checkTimeout time
}, nil
}
-// Deprecated: Use NewHealthCheckerWithValidation instead for proper error handling.
-// NewHealthChecker creates a new health checker.
-// interval: how often to run health checks
-// checkTimeout: timeout for each individual health check operation
-func NewHealthChecker(manager Manager, interval, checkTimeout time.Duration, logger log.Logger) HealthChecker {
- hc, err := NewHealthCheckerWithValidation(manager, interval, checkTimeout, logger)
- if err != nil {
- panic(err.Error())
- }
-
- return hc
-}
-
// Register adds a service to health check
func (hc *healthChecker) Register(serviceName string, healthCheckFn HealthCheckFunc) {
+ if healthCheckFn == nil {
+ hc.logger.Log(context.Background(), log.LevelWarn, "attempted to register nil health check function", log.String("service", serviceName))
+ return
+ }
+
hc.mu.Lock()
defer hc.mu.Unlock()
hc.services[serviceName] = healthCheckFn
- hc.logger.Infof("Registered health check for service: %s", serviceName)
+ hc.logger.Log(context.Background(), log.LevelInfo, "registered health check for service", log.String("service", serviceName))
}
// Start begins the health check loop
func (hc *healthChecker) Start() {
- hc.wg.Add(1)
+ hc.mu.Lock()
- go hc.healthCheckLoop()
+ if hc.started {
+ hc.mu.Unlock()
+ hc.logger.Log(context.Background(), log.LevelWarn, "health checker already started, ignoring duplicate Start() call")
- hc.logger.Infof("Health checker started - checking services every %v", hc.interval)
+ return
+ }
+
+ hc.started = true
+ hc.wg.Add(1)
+ hc.mu.Unlock()
+
+ runtime.SafeGoWithContextAndComponent(
+ context.Background(),
+ hc.logger,
+ "circuitbreaker",
+ "health_check_loop",
+ runtime.KeepRunning,
+ func(ctx context.Context) {
+ hc.healthCheckLoop(ctx)
+ },
+ )
+
+ hc.logger.Log(context.Background(), log.LevelInfo, "health checker started", log.String("interval", hc.interval.String()))
}
// Stop gracefully stops the health checker
func (hc *healthChecker) Stop() {
- close(hc.stopChan)
+ hc.stopOnce.Do(func() {
+ close(hc.stopChan)
+ })
hc.wg.Wait()
- hc.logger.Info("Health checker stopped")
+ hc.logger.Log(context.Background(), log.LevelInfo, "Health checker stopped")
}
-func (hc *healthChecker) healthCheckLoop() {
+func (hc *healthChecker) healthCheckLoop(ctx context.Context) {
defer hc.wg.Done()
ticker := time.NewTicker(hc.interval)
@@ -110,8 +133,10 @@ func (hc *healthChecker) healthCheckLoop() {
hc.performHealthChecks()
case serviceName := <-hc.immediateCheck:
// Immediate health check for a specific service
- hc.logger.Debugf("Triggering immediate health check for service: %s", serviceName)
+ hc.logger.Log(context.Background(), log.LevelDebug, "triggering immediate health check", log.String("service", serviceName))
hc.checkServiceHealth(serviceName)
+ case <-ctx.Done():
+ return
case <-hc.stopChan:
return
}
@@ -126,7 +151,7 @@ func (hc *healthChecker) performHealthChecks() {
hc.mu.RUnlock()
- hc.logger.Debug("Performing health checks on registered services...")
+ hc.logger.Log(context.Background(), log.LevelDebug, "performing health checks on registered services")
unhealthyCount := 0
recoveredCount := 0
@@ -139,7 +164,7 @@ func (hc *healthChecker) performHealthChecks() {
unhealthyCount++
- hc.logger.Infof("Attempting to heal service: %s (circuit breaker is open)", serviceName)
+ hc.logger.Log(context.Background(), log.LevelInfo, "attempting to heal service", log.String("service", serviceName), log.String("reason", "circuit breaker open"))
ctx, cancel := context.WithTimeout(context.Background(), hc.checkTimeout)
err := healthCheckFn(ctx)
@@ -147,19 +172,19 @@ func (hc *healthChecker) performHealthChecks() {
cancel()
if err == nil {
- hc.logger.Infof("Service %s recovered - resetting circuit breaker", serviceName)
+ hc.logger.Log(context.Background(), log.LevelInfo, "service recovered, resetting circuit breaker", log.String("service", serviceName))
hc.manager.Reset(serviceName)
recoveredCount++
} else {
- hc.logger.Warnf("Service %s still unhealthy: %v - will retry in %v", serviceName, err, hc.interval)
+ hc.logger.Log(context.Background(), log.LevelWarn, "service still unhealthy", log.String("service", serviceName), log.Err(err), log.String("retry_in", hc.interval.String()))
}
}
if unhealthyCount > 0 {
- hc.logger.Infof("Health check complete: %d services needed healing, %d recovered", unhealthyCount, recoveredCount)
+ hc.logger.Log(context.Background(), log.LevelInfo, "health check complete", log.Int("unhealthy", unhealthyCount), log.Int("recovered", recoveredCount))
} else {
- hc.logger.Debug("All services healthy")
+ hc.logger.Log(context.Background(), log.LevelDebug, "all services healthy")
}
}
@@ -178,21 +203,23 @@ func (hc *healthChecker) GetHealthStatus() map[string]string {
return status
}
-// OnStateChange implements StateChangeListener interface
-// This is called when a circuit breaker changes state
-func (hc *healthChecker) OnStateChange(serviceName string, from State, to State) {
- hc.logger.Debugf("Health checker notified of state change for %s: %s -> %s", serviceName, from, to)
+// OnStateChange implements StateChangeListener interface.
+// This is called when a circuit breaker changes state.
+// The provided context carries a deadline; the health checker uses it for logging
+// but schedules checks independently.
+func (hc *healthChecker) OnStateChange(_ context.Context, serviceName string, from State, to State) {
+ hc.logger.Log(context.Background(), log.LevelDebug, "health checker notified of state change", log.String("service", serviceName), log.String("from", string(from)), log.String("to", string(to)))
// If circuit just opened, trigger immediate health check
if to == StateOpen {
- hc.logger.Infof("Circuit breaker opened for %s - scheduling immediate health check", serviceName)
+ hc.logger.Log(context.Background(), log.LevelInfo, "circuit breaker opened, scheduling immediate health check", log.String("service", serviceName))
// Non-blocking send to avoid deadlock
select {
case hc.immediateCheck <- serviceName:
- hc.logger.Debugf("Immediate health check scheduled for %s", serviceName)
+ hc.logger.Log(context.Background(), log.LevelDebug, "immediate health check scheduled", log.String("service", serviceName))
default:
- hc.logger.Warnf("Immediate health check channel full for %s, will check on next interval", serviceName)
+ hc.logger.Log(context.Background(), log.LevelWarn, "immediate health check channel full, will check on next interval", log.String("service", serviceName))
}
}
}
@@ -204,17 +231,17 @@ func (hc *healthChecker) checkServiceHealth(serviceName string) {
hc.mu.RUnlock()
if !exists {
- hc.logger.Warnf("No health check function registered for service: %s", serviceName)
+ hc.logger.Log(context.Background(), log.LevelWarn, "no health check function registered", log.String("service", serviceName))
return
}
// Skip if circuit breaker is already healthy
if hc.manager.IsHealthy(serviceName) {
- hc.logger.Debugf("Service %s is already healthy, skipping check", serviceName)
+ hc.logger.Log(context.Background(), log.LevelDebug, "service already healthy, skipping check", log.String("service", serviceName))
return
}
- hc.logger.Infof("Attempting to heal service: %s (circuit breaker is open)", serviceName)
+ hc.logger.Log(context.Background(), log.LevelInfo, "attempting to heal service", log.String("service", serviceName), log.String("reason", "circuit breaker open"))
ctx, cancel := context.WithTimeout(context.Background(), hc.checkTimeout)
err := healthCheckFn(ctx)
@@ -222,9 +249,9 @@ func (hc *healthChecker) checkServiceHealth(serviceName string) {
cancel()
if err == nil {
- hc.logger.Infof("Service %s recovered - resetting circuit breaker", serviceName)
+ hc.logger.Log(context.Background(), log.LevelInfo, "service recovered, resetting circuit breaker", log.String("service", serviceName))
hc.manager.Reset(serviceName)
} else {
- hc.logger.Warnf("Service %s still unhealthy: %v - will retry in %v", serviceName, err, hc.interval)
+ hc.logger.Log(context.Background(), log.LevelWarn, "service still unhealthy", log.String("service", serviceName), log.Err(err), log.String("retry_in", hc.interval.String()))
}
}
diff --git a/commons/circuitbreaker/healthchecker_test.go b/commons/circuitbreaker/healthchecker_test.go
index 5a344b73..f4a36934 100644
--- a/commons/circuitbreaker/healthchecker_test.go
+++ b/commons/circuitbreaker/healthchecker_test.go
@@ -1,21 +1,22 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package circuitbreaker
import (
+ "context"
"errors"
"testing"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestNewHealthCheckerWithValidation_Success(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
hc, err := NewHealthCheckerWithValidation(manager, 1*time.Second, 500*time.Millisecond, logger)
@@ -24,8 +25,9 @@ func TestNewHealthCheckerWithValidation_Success(t *testing.T) {
}
func TestNewHealthCheckerWithValidation_InvalidInterval(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
hc, err := NewHealthCheckerWithValidation(manager, 0, 500*time.Millisecond, logger)
@@ -35,8 +37,9 @@ func TestNewHealthCheckerWithValidation_InvalidInterval(t *testing.T) {
}
func TestNewHealthCheckerWithValidation_NegativeInterval(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
hc, err := NewHealthCheckerWithValidation(manager, -1*time.Second, 500*time.Millisecond, logger)
@@ -46,8 +49,9 @@ func TestNewHealthCheckerWithValidation_NegativeInterval(t *testing.T) {
}
func TestNewHealthCheckerWithValidation_InvalidTimeout(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
hc, err := NewHealthCheckerWithValidation(manager, 1*time.Second, 0, logger)
@@ -57,8 +61,9 @@ func TestNewHealthCheckerWithValidation_InvalidTimeout(t *testing.T) {
}
func TestNewHealthCheckerWithValidation_NegativeTimeout(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
hc, err := NewHealthCheckerWithValidation(manager, 1*time.Second, -500*time.Millisecond, logger)
@@ -67,63 +72,410 @@ func TestNewHealthCheckerWithValidation_NegativeTimeout(t *testing.T) {
assert.True(t, errors.Is(err, ErrInvalidHealthCheckTimeout))
}
-func TestNewHealthChecker_PanicOnInvalidInterval(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+func TestNewHealthCheckerWithValidation_NilManager(t *testing.T) {
+ logger := &log.NopLogger{}
+
+ hc, err := NewHealthCheckerWithValidation(nil, 1*time.Second, 500*time.Millisecond, logger)
+
+ assert.Nil(t, hc)
+ assert.Error(t, err)
+ assert.True(t, errors.Is(err, ErrNilManager))
+}
+
+func TestNewHealthCheckerWithValidation_NilLogger(t *testing.T) {
+ manager, err := NewManager(&log.NopLogger{})
+ require.NoError(t, err)
+
+ hc, err := NewHealthCheckerWithValidation(manager, 1*time.Second, 500*time.Millisecond, nil)
+
+ assert.Nil(t, hc)
+ assert.Error(t, err)
+ assert.True(t, errors.Is(err, ErrNilLogger))
+}
+
+// --- Helper to create a test healthChecker ---
- assert.Panics(t, func() {
- NewHealthChecker(manager, 0, 500*time.Millisecond, logger)
+func newTestHealthChecker(t *testing.T) (HealthChecker, Manager) {
+ t.Helper()
+
+ logger := &log.NopLogger{}
+ mgr, err := NewManager(logger)
+ require.NoError(t, err)
+
+ hc, err := NewHealthCheckerWithValidation(mgr, 50*time.Millisecond, 100*time.Millisecond, logger)
+ require.NoError(t, err)
+
+ return hc, mgr
+}
+
+func TestRegister_NilHealthCheckFunction(t *testing.T) {
+ hc, _ := newTestHealthChecker(t)
+
+ // Should not panic, should be a no-op (logs warning)
+ assert.NotPanics(t, func() {
+ hc.Register("svc", nil)
})
+
+ // Service should not appear in health status
+ status := hc.GetHealthStatus()
+ _, exists := status["svc"]
+ assert.False(t, exists)
}
-func TestNewHealthChecker_PanicOnInvalidTimeout(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+func TestRegister_ValidFunction(t *testing.T) {
+ hc, mgr := newTestHealthChecker(t)
+
+ cfg := DefaultConfig()
+ _, err := mgr.GetOrCreate("my-svc", cfg)
+ require.NoError(t, err)
- assert.Panics(t, func() {
- NewHealthChecker(manager, 1*time.Second, 0, logger)
+ hc.Register("my-svc", func(ctx context.Context) error {
+ return nil
})
+
+ status := hc.GetHealthStatus()
+ _, exists := status["my-svc"]
+ assert.True(t, exists)
}
-func TestNewHealthChecker_Success(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+func TestStart_DuplicateIsNoop(t *testing.T) {
+ hc, _ := newTestHealthChecker(t)
- hc := NewHealthChecker(manager, 1*time.Second, 500*time.Millisecond, logger)
+ // First start
+ hc.Start()
- assert.NotNil(t, hc)
+ // Second start should be a no-op, not panic
+ assert.NotPanics(t, func() {
+ hc.Start()
+ })
+
+ hc.Stop()
}
-func TestNewHealthCheckerWithValidation_NilManager(t *testing.T) {
- // Note: The current implementation does not validate nil manager.
- // This test documents the current behavior: a nil manager is accepted
- // and will cause a panic later when methods like IsHealthy() are called.
- // This is acceptable because:
- // 1. Manager is required for the health checker to function
- // 2. The caller is responsible for providing valid dependencies
- // 3. Adding nil validation would be a behavior change
- logger := &log.NoneLogger{}
+func TestStop(t *testing.T) {
+ hc, _ := newTestHealthChecker(t)
- hc, err := NewHealthCheckerWithValidation(nil, 1*time.Second, 500*time.Millisecond, logger)
+ hc.Start()
- // Current behavior: nil manager is accepted (no validation)
- assert.NoError(t, err)
- assert.NotNil(t, hc)
+ // Stop should complete without hanging
+ done := make(chan struct{})
+ go func() {
+ hc.Stop()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ // success
+ case <-time.After(2 * time.Second):
+ t.Fatal("Stop() did not return in time")
+ }
}
-func TestNewHealthCheckerWithValidation_NilLogger(t *testing.T) {
- // Note: The current implementation does not validate nil logger.
- // This test documents the current behavior: a nil logger is accepted
- // and will cause a panic later when logging methods are called.
- // This is acceptable because:
- // 1. Logger is required for proper operation
- // 2. The caller is responsible for providing valid dependencies
- // 3. Adding nil validation would be a behavior change
- manager := NewManager(&log.NoneLogger{})
+func TestGetHealthStatus(t *testing.T) {
+ hc, mgr := newTestHealthChecker(t)
- hc, err := NewHealthCheckerWithValidation(manager, 1*time.Second, 500*time.Millisecond, nil)
+ cfg := DefaultConfig()
- // Current behavior: nil logger is accepted (no validation)
- assert.NoError(t, err)
- assert.NotNil(t, hc)
+ _, err := mgr.GetOrCreate("svc-a", cfg)
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("svc-b", cfg)
+ require.NoError(t, err)
+
+ hc.Register("svc-a", func(ctx context.Context) error { return nil })
+ hc.Register("svc-b", func(ctx context.Context) error { return nil })
+
+ status := hc.GetHealthStatus()
+ assert.Equal(t, string(StateClosed), status["svc-a"])
+ assert.Equal(t, string(StateClosed), status["svc-b"])
+}
+
+func TestOnStateChange_OpenTriggersImmediateCheck(t *testing.T) {
+ hc, _ := newTestHealthChecker(t)
+
+ // Access the internal immediateCheck channel
+ hcInternal := hc.(*healthChecker)
+
+ hc.(*healthChecker).OnStateChange(context.Background(), "test-svc", StateClosed, StateOpen)
+
+ // Should have sent a message to immediateCheck channel
+ select {
+ case svc := <-hcInternal.immediateCheck:
+ assert.Equal(t, "test-svc", svc)
+ case <-time.After(1 * time.Second):
+ t.Fatal("Expected immediate check to be scheduled")
+ }
+}
+
+func TestOnStateChange_NonOpenDoesNotTrigger(t *testing.T) {
+ hc, _ := newTestHealthChecker(t)
+
+ hcInternal := hc.(*healthChecker)
+
+ hc.(*healthChecker).OnStateChange(context.Background(), "test-svc", StateOpen, StateClosed)
+
+ // Should NOT have sent a message
+ select {
+ case <-hcInternal.immediateCheck:
+ t.Fatal("Should not trigger immediate check for non-Open state")
+ case <-time.After(50 * time.Millisecond):
+ // expected
+ }
+}
+
+func TestCheckServiceHealth_NonExistentService(t *testing.T) {
+ hc, _ := newTestHealthChecker(t)
+
+ hcInternal := hc.(*healthChecker)
+
+ // Should not panic when service is not registered
+ assert.NotPanics(t, func() {
+ hcInternal.checkServiceHealth("non-existent")
+ })
+}
+
+func TestCheckServiceHealth_AlreadyHealthy(t *testing.T) {
+ hc, mgr := newTestHealthChecker(t)
+
+ cfg := DefaultConfig()
+ _, err := mgr.GetOrCreate("healthy-svc", cfg)
+ require.NoError(t, err)
+
+ called := false
+ hc.Register("healthy-svc", func(ctx context.Context) error {
+ called = true
+ return nil
+ })
+
+ hcInternal := hc.(*healthChecker)
+ hcInternal.checkServiceHealth("healthy-svc")
+
+ // Health check function should NOT be called since service is healthy
+ assert.False(t, called)
+}
+
+func TestCheckServiceHealth_SuccessfulRecovery(t *testing.T) {
+ logger := &log.NopLogger{}
+ mgr, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 1 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ _, err = mgr.GetOrCreate("recover-svc", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker
+ for i := 0; i < 3; i++ {
+ _, _ = mgr.Execute("recover-svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+ assert.Equal(t, StateOpen, mgr.GetState("recover-svc"))
+
+ hc, err := NewHealthCheckerWithValidation(mgr, 50*time.Millisecond, 100*time.Millisecond, logger)
+ require.NoError(t, err)
+
+ hc.Register("recover-svc", func(ctx context.Context) error {
+ return nil // healthy
+ })
+
+ hcInternal := hc.(*healthChecker)
+ hcInternal.checkServiceHealth("recover-svc")
+
+ // Should have reset the breaker
+ assert.Equal(t, StateClosed, mgr.GetState("recover-svc"))
+}
+
+func TestCheckServiceHealth_FailedRecovery(t *testing.T) {
+ logger := &log.NopLogger{}
+ mgr, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 1 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ _, err = mgr.GetOrCreate("fail-svc", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker
+ for i := 0; i < 3; i++ {
+ _, _ = mgr.Execute("fail-svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+ assert.Equal(t, StateOpen, mgr.GetState("fail-svc"))
+
+ hc, err := NewHealthCheckerWithValidation(mgr, 50*time.Millisecond, 100*time.Millisecond, logger)
+ require.NoError(t, err)
+
+ hc.Register("fail-svc", func(ctx context.Context) error {
+ return errors.New("still down")
+ })
+
+ hcInternal := hc.(*healthChecker)
+ hcInternal.checkServiceHealth("fail-svc")
+
+ // Breaker should remain open
+ assert.Equal(t, StateOpen, mgr.GetState("fail-svc"))
+}
+
+func TestPerformHealthChecks_MixedServices(t *testing.T) {
+ logger := &log.NopLogger{}
+ mgr, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 1 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ // Create two services
+ _, err = mgr.GetOrCreate("healthy-svc", cfg)
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("unhealthy-svc", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker on unhealthy-svc only
+ for i := 0; i < 3; i++ {
+ _, _ = mgr.Execute("unhealthy-svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+ assert.Equal(t, StateOpen, mgr.GetState("unhealthy-svc"))
+ assert.Equal(t, StateClosed, mgr.GetState("healthy-svc"))
+
+ hc, err := NewHealthCheckerWithValidation(mgr, 50*time.Millisecond, 100*time.Millisecond, logger)
+ require.NoError(t, err)
+
+ healthyChecked := false
+ hc.Register("healthy-svc", func(ctx context.Context) error {
+ healthyChecked = true
+ return nil
+ })
+
+ hc.Register("unhealthy-svc", func(ctx context.Context) error {
+ return nil // simulate recovery
+ })
+
+ hcInternal := hc.(*healthChecker)
+ hcInternal.performHealthChecks()
+
+ // Healthy service should be skipped (its health check func not called)
+ assert.False(t, healthyChecked)
+
+ // Unhealthy service should have recovered
+ assert.Equal(t, StateClosed, mgr.GetState("unhealthy-svc"))
+}
+
+func TestPerformHealthChecks_UnhealthyStaysUnhealthy(t *testing.T) {
+ logger := &log.NopLogger{}
+ mgr, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 1 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ _, err = mgr.GetOrCreate("still-down", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker
+ for i := 0; i < 3; i++ {
+ _, _ = mgr.Execute("still-down", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+
+ hc, err := NewHealthCheckerWithValidation(mgr, 50*time.Millisecond, 100*time.Millisecond, logger)
+ require.NoError(t, err)
+
+ hc.Register("still-down", func(ctx context.Context) error {
+ return errors.New("nope")
+ })
+
+ hcInternal := hc.(*healthChecker)
+ hcInternal.performHealthChecks()
+
+ // Should remain open
+ assert.Equal(t, StateOpen, mgr.GetState("still-down"))
+}
+
+func TestHealthCheckLoop_PeriodicChecks(t *testing.T) {
+ logger := &log.NopLogger{}
+ mgr, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 1 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ _, err = mgr.GetOrCreate("periodic-svc", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker
+ for i := 0; i < 3; i++ {
+ _, _ = mgr.Execute("periodic-svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+ assert.Equal(t, StateOpen, mgr.GetState("periodic-svc"))
+
+ hc, err := NewHealthCheckerWithValidation(mgr, 50*time.Millisecond, 100*time.Millisecond, logger)
+ require.NoError(t, err)
+
+ hc.Register("periodic-svc", func(ctx context.Context) error {
+ return nil // recovery succeeds
+ })
+
+ hc.Start()
+ defer hc.Stop()
+
+ // Poll until the periodic health check fires and recovers the breaker
+ require.Eventually(t, func() bool {
+ return mgr.GetState("periodic-svc") == StateClosed
+ }, 2*time.Second, 50*time.Millisecond, "periodic health check should recover the breaker")
+}
+
+func TestOnStateChange_ImmediateCheckChannelFull(t *testing.T) {
+ hc, _ := newTestHealthChecker(t)
+ hcInternal := hc.(*healthChecker)
+
+ // Fill the immediateCheck channel (capacity 10)
+ for i := 0; i < 10; i++ {
+ hcInternal.immediateCheck <- "fill"
+ }
+
+ // This should not block or panic — it logs a warning instead
+ assert.NotPanics(t, func() {
+ hcInternal.OnStateChange(context.Background(), "overflow-svc", StateClosed, StateOpen)
+ })
}
diff --git a/commons/circuitbreaker/manager.go b/commons/circuitbreaker/manager.go
index abed4f47..6db902e3 100644
--- a/commons/circuitbreaker/manager.go
+++ b/commons/circuitbreaker/manager.go
@@ -1,42 +1,132 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package circuitbreaker
import (
+ "context"
+ "errors"
"fmt"
+ "reflect"
"sync"
+ "time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+ "github.com/LerianStudio/lib-commons/v4/commons/safe"
"github.com/sony/gobreaker"
)
+// stateChangeListenerTimeout limits how long a state change listener notification
+// can run before the context is cancelled.
+const stateChangeListenerTimeout = 10 * time.Second
+
type manager struct {
- breakers map[string]*gobreaker.CircuitBreaker
- configs map[string]Config // Store configs for safe reset
- listeners []StateChangeListener
- mu sync.RWMutex
- logger log.Logger
+ breakers map[string]*gobreaker.CircuitBreaker
+ configs map[string]Config // Store configs for safe reset
+ listeners []StateChangeListener
+ mu sync.RWMutex
+ logger log.Logger
+ metricsFactory *metrics.MetricsFactory
+ stateCounter *metrics.CounterBuilder
+ execCounter *metrics.CounterBuilder
+}
+
+// ManagerOption configures optional behaviour on a circuit breaker manager.
+type ManagerOption func(*manager)
+
+// WithMetricsFactory attaches a MetricsFactory so the manager emits
+// circuit_breaker_state_transitions_total and circuit_breaker_executions_total
+// counters automatically. When nil, metrics are silently skipped.
+func WithMetricsFactory(f *metrics.MetricsFactory) ManagerOption {
+ return func(m *manager) {
+ m.metricsFactory = f
+ }
+}
+
+// stateTransitionMetric defines the counter for circuit breaker state transitions.
+var stateTransitionMetric = metrics.Metric{
+ Name: "circuit_breaker_state_transitions_total",
+ Unit: "1",
+ Description: "Total number of circuit breaker state transitions",
}
-// NewManager creates a new circuit breaker manager
-func NewManager(logger log.Logger) Manager {
- return &manager{
+// executionMetric defines the counter for circuit breaker executions.
+var executionMetric = metrics.Metric{
+ Name: "circuit_breaker_executions_total",
+ Unit: "1",
+ Description: "Total number of circuit breaker executions",
+}
+
+// NewManager creates a new circuit breaker manager.
+// Returns an error if logger is nil (including typed-nil interface values).
+func NewManager(logger log.Logger, opts ...ManagerOption) (Manager, error) {
+ if isNilLogger(logger) {
+ return nil, ErrNilLogger
+ }
+
+ m := &manager{
breakers: make(map[string]*gobreaker.CircuitBreaker),
configs: make(map[string]Config),
listeners: make([]StateChangeListener, 0),
logger: logger,
}
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(m)
+ }
+ }
+
+ m.initMetricCounters()
+
+ return m, nil
+}
+
+func (m *manager) initMetricCounters() {
+ if m.metricsFactory == nil {
+ return
+ }
+
+ stateCounter, err := m.metricsFactory.Counter(stateTransitionMetric)
+ if err != nil {
+ m.logger.Log(context.Background(), log.LevelWarn, "failed to create state transition metric counter", log.Err(err))
+ } else {
+ m.stateCounter = stateCounter
+ }
+
+ execCounter, err := m.metricsFactory.Counter(executionMetric)
+ if err != nil {
+ m.logger.Log(context.Background(), log.LevelWarn, "failed to create execution metric counter", log.Err(err))
+ } else {
+ m.execCounter = execCounter
+ }
}
-func (m *manager) GetOrCreate(serviceName string, config Config) CircuitBreaker {
+// GetOrCreate returns an existing breaker or creates one for the service.
+// If a breaker already exists for the name with a different config, ErrConfigMismatch is returned.
+func (m *manager) GetOrCreate(serviceName string, config Config) (CircuitBreaker, error) {
m.mu.RLock()
breaker, exists := m.breakers[serviceName]
- m.mu.RUnlock()
if exists {
- return &circuitBreaker{breaker: breaker}
+ storedCfg := m.configs[serviceName]
+ m.mu.RUnlock()
+
+ if storedCfg != config {
+ return nil, fmt.Errorf(
+ "%w: service %q already registered with different settings",
+ ErrConfigMismatch,
+ serviceName,
+ )
+ }
+
+ return &circuitBreaker{breaker: breaker}, nil
+ }
+
+ m.mu.RUnlock()
+
+ if err := config.Validate(); err != nil {
+ return nil, fmt.Errorf("circuit breaker config for service %s: %w", serviceName, err)
}
m.mu.Lock()
@@ -44,36 +134,35 @@ func (m *manager) GetOrCreate(serviceName string, config Config) CircuitBreaker
// Double-check after acquiring write lock
if breaker, exists = m.breakers[serviceName]; exists {
- return &circuitBreaker{breaker: breaker}
- }
-
- // Create new circuit breaker with configuration
- settings := gobreaker.Settings{
- Name: fmt.Sprintf("service-%s", serviceName),
- MaxRequests: config.MaxRequests,
- Interval: config.Interval,
- Timeout: config.Timeout,
- ReadyToTrip: func(counts gobreaker.Counts) bool {
- failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
+ storedCfg := m.configs[serviceName]
+ if storedCfg != config {
+ return nil, fmt.Errorf(
+ "%w: service %q already registered with different settings",
+ ErrConfigMismatch,
+ serviceName,
+ )
+ }
- return counts.ConsecutiveFailures >= config.ConsecutiveFailures ||
- (counts.Requests >= config.MinRequests && failureRatio >= config.FailureRatio)
- },
- OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
- m.handleStateChange(serviceName, from, to)
- },
+ return &circuitBreaker{breaker: breaker}, nil
}
+ settings := m.buildSettings(serviceName, config)
+
breaker = gobreaker.NewCircuitBreaker(settings)
m.breakers[serviceName] = breaker
- m.configs[serviceName] = config // Store config for safe reset
+ m.configs[serviceName] = config
- m.logger.Infof("Created circuit breaker for service: %s", serviceName)
+ m.logger.Log(context.Background(), log.LevelInfo, "created circuit breaker", log.String("service", serviceName))
- return &circuitBreaker{breaker: breaker}
+ return &circuitBreaker{breaker: breaker}, nil
}
+// Execute runs fn through the named service breaker.
func (m *manager) Execute(serviceName string, fn func() (any, error)) (any, error) {
+ if fn == nil {
+ return nil, ErrNilCallback
+ }
+
m.mu.RLock()
breaker, exists := m.breakers[serviceName]
m.mu.RUnlock()
@@ -84,20 +173,32 @@ func (m *manager) Execute(serviceName string, fn func() (any, error)) (any, erro
result, err := breaker.Execute(fn)
if err != nil {
- if err == gobreaker.ErrOpenState {
- m.logger.Warnf("Circuit breaker [%s] is OPEN - request rejected immediately", serviceName)
+ if errors.Is(err, gobreaker.ErrOpenState) {
+ m.logger.Log(context.Background(), log.LevelWarn, "circuit breaker is OPEN, request rejected", log.String("service", serviceName))
+ m.recordExecution(serviceName, "rejected_open")
+
return nil, fmt.Errorf("service %s is currently unavailable (circuit breaker open): %w", serviceName, err)
}
- if err == gobreaker.ErrTooManyRequests {
- m.logger.Warnf("Circuit breaker [%s] is HALF-OPEN - too many test requests", serviceName)
+ if errors.Is(err, gobreaker.ErrTooManyRequests) {
+ m.logger.Log(context.Background(), log.LevelWarn, "circuit breaker is HALF-OPEN, too many test requests", log.String("service", serviceName))
+ m.recordExecution(serviceName, "rejected_half_open")
+
return nil, fmt.Errorf("service %s is recovering (too many requests): %w", serviceName, err)
}
+
+ // The wrapped function returned an error (not a breaker rejection)
+ m.recordExecution(serviceName, "error")
+
+ return result, err
}
+ m.recordExecution(serviceName, "success")
+
return result, err
}
+// GetState returns the current state for a service breaker.
func (m *manager) GetState(serviceName string) State {
m.mu.RLock()
breaker, exists := m.breakers[serviceName]
@@ -107,19 +208,10 @@ func (m *manager) GetState(serviceName string) State {
return StateUnknown
}
- state := breaker.State()
- switch state {
- case gobreaker.StateClosed:
- return StateClosed
- case gobreaker.StateOpen:
- return StateOpen
- case gobreaker.StateHalfOpen:
- return StateHalfOpen
- default:
- return StateUnknown
- }
+ return convertGobreakerState(breaker.State())
}
+// GetCounts returns current counters for a service breaker.
func (m *manager) GetCounts(serviceName string) Counts {
m.mu.RLock()
breaker, exists := m.breakers[serviceName]
@@ -140,60 +232,50 @@ func (m *manager) GetCounts(serviceName string) Counts {
}
}
+// IsHealthy reports whether the service breaker is in a healthy state.
+// Both Closed and HalfOpen states are considered healthy: Closed allows all traffic,
+// and HalfOpen allows limited probe traffic for recovery verification.
+// Open (rejecting all requests) and Unknown (unregistered breaker) are considered unhealthy.
func (m *manager) IsHealthy(serviceName string) bool {
state := m.GetState(serviceName)
- // Only CLOSED state is considered healthy
- // OPEN and HALF-OPEN both need health checker intervention
- isHealthy := state == StateClosed
- m.logger.Debugf("IsHealthy check: service=%s, state=%s, isHealthy=%v", serviceName, state, isHealthy)
+ // Closed and HalfOpen are healthy; Open and Unknown are unhealthy.
+ // HalfOpen is healthy because it allows probe traffic for recovery.
+ isHealthy := state != StateOpen && state != StateUnknown
+ m.logger.Log(context.Background(), log.LevelDebug, "health check result", log.String("service", serviceName), log.String("state", string(state)), log.Bool("healthy", isHealthy))
return isHealthy
}
+// Reset recreates the service breaker with its stored config.
func (m *manager) Reset(serviceName string) {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.breakers[serviceName]; exists {
- m.logger.Infof("Resetting circuit breaker for service: %s", serviceName)
+ m.logger.Log(context.Background(), log.LevelInfo, "resetting circuit breaker", log.String("service", serviceName))
- // Get stored config
config, configExists := m.configs[serviceName]
if !configExists {
- m.logger.Warnf("No stored config found for service %s, cannot recreate", serviceName)
+ m.logger.Log(context.Background(), log.LevelWarn, "no stored config found, cannot recreate circuit breaker", log.String("service", serviceName))
delete(m.breakers, serviceName)
return
}
- // Recreate circuit breaker with same configuration
- settings := gobreaker.Settings{
- Name: fmt.Sprintf("service-%s", serviceName),
- MaxRequests: config.MaxRequests,
- Interval: config.Interval,
- Timeout: config.Timeout,
- ReadyToTrip: func(counts gobreaker.Counts) bool {
- failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
-
- return counts.ConsecutiveFailures >= config.ConsecutiveFailures ||
- (counts.Requests >= config.MinRequests && failureRatio >= config.FailureRatio)
- },
- OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
- m.handleStateChange(serviceName, from, to)
- },
- }
+ settings := m.buildSettings(serviceName, config)
breaker := gobreaker.NewCircuitBreaker(settings)
m.breakers[serviceName] = breaker
- m.logger.Infof("Circuit breaker reset completed for service: %s", serviceName)
+ m.logger.Log(context.Background(), log.LevelInfo, "circuit breaker reset completed", log.String("service", serviceName))
}
}
-// RegisterStateChangeListener registers a listener for state change notifications
+// RegisterStateChangeListener registers a listener for state change notifications.
+// Both untyped nil and typed nil (e.g., (*MyListener)(nil)) are rejected.
func (m *manager) RegisterStateChangeListener(listener StateChangeListener) {
- if listener == nil {
- m.logger.Warnf("Attempted to register a nil state change listener")
+ if isNilListener(listener) {
+ m.logger.Log(context.Background(), log.LevelWarn, "attempted to register a nil state change listener")
return
}
@@ -202,57 +284,167 @@ func (m *manager) RegisterStateChangeListener(listener StateChangeListener) {
defer m.mu.Unlock()
m.listeners = append(m.listeners, listener)
- m.logger.Debugf("Registered state change listener (total: %d)", len(m.listeners))
+ m.logger.Log(context.Background(), log.LevelDebug, "registered state change listener", log.Int("total", len(m.listeners)))
+}
+
+// isNilLogger checks for both untyped nil and typed nil log.Logger values.
+// Mirrors the isNilListener pattern to prevent panics from typed-nil loggers.
+func isNilLogger(logger log.Logger) bool {
+ if logger == nil {
+ return true
+ }
+
+ v := reflect.ValueOf(logger)
+ if !v.IsValid() {
+ return true
+ }
+
+ switch v.Kind() {
+ case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Interface:
+ return v.IsNil()
+ default:
+ return false
+ }
+}
+
+// isNilListener checks for both untyped nil and typed nil interface values.
+// Handles all nilable kinds: pointers, slices, maps, channels, funcs, and interfaces.
+func isNilListener(listener StateChangeListener) bool {
+ if listener == nil {
+ return true
+ }
+
+ v := reflect.ValueOf(listener)
+ if !v.IsValid() {
+ return true
+ }
+
+ switch v.Kind() {
+ case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func, reflect.Interface:
+ return v.IsNil()
+ default:
+ return false
+ }
}
// handleStateChange processes state changes and notifies listeners
func (m *manager) handleStateChange(serviceName string, from gobreaker.State, to gobreaker.State) {
- // Log state change
- m.logger.Warnf("Circuit Breaker [%s] state changed: %s -> %s",
- serviceName, from.String(), to.String())
-
switch to {
case gobreaker.StateOpen:
- m.logger.Errorf("Circuit Breaker [%s] OPENED - service is unhealthy, requests will fast-fail", serviceName)
+ m.logger.Log(context.Background(), log.LevelError, "circuit breaker OPENED, requests will fast-fail", log.String("service", serviceName), log.String("from", from.String()))
case gobreaker.StateHalfOpen:
- m.logger.Infof("Circuit Breaker [%s] HALF-OPEN - testing service recovery", serviceName)
+ m.logger.Log(context.Background(), log.LevelInfo, "circuit breaker HALF-OPEN, testing service recovery", log.String("service", serviceName), log.String("from", from.String()))
case gobreaker.StateClosed:
- m.logger.Infof("Circuit Breaker [%s] CLOSED - service is healthy", serviceName)
+ m.logger.Log(context.Background(), log.LevelInfo, "circuit breaker CLOSED, service is healthy", log.String("service", serviceName), log.String("from", from.String()))
}
- // Notify listeners
+ // Record state transition metric
fromState := convertGobreakerState(from)
toState := convertGobreakerState(to)
+ m.recordStateTransition(serviceName, fromState, toState)
+
m.mu.RLock()
listeners := make([]StateChangeListener, len(m.listeners))
copy(listeners, m.listeners)
m.mu.RUnlock()
for _, listener := range listeners {
- // Notify in goroutine to avoid blocking circuit breaker operations
- go func(l StateChangeListener) {
- defer func() {
- if r := recover(); r != nil {
- m.logger.Errorf("Circuit breaker state change listener panic for service %s: %v", serviceName, r)
- }
- }()
+ // Notify in goroutine to avoid blocking circuit breaker operations.
+ // A timeout context prevents slow or stuck listeners from leaking goroutines.
+ listenerCopy := listener
+
+ runtime.SafeGoWithContextAndComponent(
+ context.Background(),
+ m.logger,
+ "circuitbreaker",
+ "state_change_listener_"+serviceName,
+ runtime.KeepRunning,
+ func(ctx context.Context) {
+ m.notifyStateChangeListener(ctx, listenerCopy, serviceName, fromState, toState)
+ },
+ )
+ }
+}
- l.OnStateChange(serviceName, fromState, toState)
- }(listener)
+func (m *manager) notifyStateChangeListener(
+ ctx context.Context,
+ listener StateChangeListener,
+ serviceName string,
+ fromState State,
+ toState State,
+) {
+ listenerCtx, listenerCancel := context.WithTimeout(ctx, stateChangeListenerTimeout)
+ defer listenerCancel()
+
+ listener.OnStateChange(listenerCtx, serviceName, fromState, toState)
+}
+
+// readyToTrip builds the trip function for gobreaker.Settings.
+func readyToTrip(config Config) func(counts gobreaker.Counts) bool {
+ return func(counts gobreaker.Counts) bool {
+ // Check consecutive failures (skip if threshold is 0 = disabled)
+ if config.ConsecutiveFailures > 0 && counts.ConsecutiveFailures >= config.ConsecutiveFailures {
+ return true
+ }
+
+ // Check failure ratio (skip if min requests is 0 = disabled)
+ if config.MinRequests > 0 && counts.Requests >= config.MinRequests {
+ failureRatio := safe.DivideFloat64OrZero(float64(counts.TotalFailures), float64(counts.Requests))
+ return failureRatio >= config.FailureRatio
+ }
+
+ return false
}
}
-// convertGobreakerState converts gobreaker.State to our State type
-func convertGobreakerState(state gobreaker.State) State {
- switch state {
- case gobreaker.StateClosed:
- return StateClosed
- case gobreaker.StateOpen:
- return StateOpen
- case gobreaker.StateHalfOpen:
- return StateHalfOpen
- default:
- return StateUnknown
+// buildSettings creates gobreaker.Settings from a Config for the given service.
+func (m *manager) buildSettings(serviceName string, config Config) gobreaker.Settings {
+ return gobreaker.Settings{
+ Name: "service-" + serviceName,
+ MaxRequests: config.MaxRequests,
+ Interval: config.Interval,
+ Timeout: config.Timeout,
+ ReadyToTrip: readyToTrip(config),
+ OnStateChange: func(name string, from gobreaker.State, to gobreaker.State) {
+ m.handleStateChange(serviceName, from, to)
+ },
+ }
+}
+
+// recordStateTransition increments the state transition counter.
+// No-op when metricsFactory is nil.
+func (m *manager) recordStateTransition(serviceName string, from, to State) {
+ if m.stateCounter == nil {
+ return
+ }
+
+ err := m.stateCounter.
+ WithLabels(map[string]string{
+ "service": constant.SanitizeMetricLabel(serviceName),
+ "from_state": string(from),
+ "to_state": string(to),
+ }).
+ AddOne(context.Background())
+ if err != nil {
+ m.logger.Log(context.Background(), log.LevelWarn, "failed to record state transition metric", log.Err(err))
+ }
+}
+
+// recordExecution increments the execution counter.
+// No-op when metricsFactory is nil.
+func (m *manager) recordExecution(serviceName, result string) {
+ if m.execCounter == nil {
+ return
+ }
+
+ err := m.execCounter.
+ WithLabels(map[string]string{
+ "service": constant.SanitizeMetricLabel(serviceName),
+ "result": result,
+ }).
+ AddOne(context.Background())
+ if err != nil {
+ m.logger.Log(context.Background(), log.LevelWarn, "failed to record execution metric", log.Err(err))
}
}
diff --git a/commons/circuitbreaker/manager_example_test.go b/commons/circuitbreaker/manager_example_test.go
new file mode 100644
index 00000000..36a06360
--- /dev/null
+++ b/commons/circuitbreaker/manager_example_test.go
@@ -0,0 +1,33 @@
+//go:build unit
+
+package circuitbreaker_test
+
+import (
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/circuitbreaker"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+func ExampleManager_Execute() {
+ mgr, err := circuitbreaker.NewManager(&log.NopLogger{})
+ if err != nil {
+ return
+ }
+
+ _, err = mgr.GetOrCreate("ledger-db", circuitbreaker.DefaultConfig())
+ if err != nil {
+ return
+ }
+
+ result, err := mgr.Execute("ledger-db", func() (any, error) {
+ return "ok", nil
+ })
+
+ fmt.Println(result, err == nil)
+ fmt.Println(mgr.GetState("ledger-db"))
+
+ // Output:
+ // ok true
+ // closed
+}
diff --git a/commons/circuitbreaker/manager_metrics_test.go b/commons/circuitbreaker/manager_metrics_test.go
new file mode 100644
index 00000000..9f551f9d
--- /dev/null
+++ b/commons/circuitbreaker/manager_metrics_test.go
@@ -0,0 +1,474 @@
+//go:build unit
+
+package circuitbreaker
+
+import (
+ "context"
+ "errors"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ sdkmetric "go.opentelemetry.io/otel/sdk/metric"
+ "go.opentelemetry.io/otel/sdk/metric/metricdata"
+)
+
+// ---------------------------------------------------------------------------
+// Test helpers
+// ---------------------------------------------------------------------------
+
+// newTestMetricsFactory creates a MetricsFactory backed by a real SDK meter
+// provider with a ManualReader, mirroring the pattern in metrics/v2_test.go.
+func newTestMetricsFactory(t *testing.T) (*metrics.MetricsFactory, *sdkmetric.ManualReader) {
+ t.Helper()
+
+ reader := sdkmetric.NewManualReader()
+ provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
+ meter := provider.Meter("test-circuitbreaker")
+
+ factory, err := metrics.NewMetricsFactory(meter, &log.NopLogger{})
+ require.NoError(t, err)
+
+ return factory, reader
+}
+
+// collectMetrics calls reader.Collect and returns the ResourceMetrics payload.
+func collectMetrics(t *testing.T, reader *sdkmetric.ManualReader) metricdata.ResourceMetrics {
+ t.Helper()
+
+ var rm metricdata.ResourceMetrics
+
+ err := reader.Collect(context.Background(), &rm)
+ require.NoError(t, err)
+
+ return rm
+}
+
+// findMetricByName walks the collected ResourceMetrics and returns the first
+// Metrics entry whose Name matches. Returns nil if not found.
+func findMetricByName(rm metricdata.ResourceMetrics, name string) *metricdata.Metrics {
+ for _, sm := range rm.ScopeMetrics {
+ for i := range sm.Metrics {
+ if sm.Metrics[i].Name == name {
+ return &sm.Metrics[i]
+ }
+ }
+ }
+
+ return nil
+}
+
+// sumDataPoints extracts data points from a Sum metric.
+func sumDataPoints(t *testing.T, m *metricdata.Metrics) []metricdata.DataPoint[int64] {
+ t.Helper()
+
+ sum, ok := m.Data.(metricdata.Sum[int64])
+ require.True(t, ok, "expected Sum[int64] data, got %T", m.Data)
+
+ return sum.DataPoints
+}
+
+// hasAttributeValue checks whether a data point's attribute set contains the key/value pair.
+func hasAttributeValue(dp metricdata.DataPoint[int64], key, value string) bool {
+ iter := dp.Attributes.Iter()
+ for iter.Next() {
+ kv := iter.Attribute()
+ if string(kv.Key) == key && kv.Value.AsString() == value {
+ return true
+ }
+ }
+
+ return false
+}
+
+// ---------------------------------------------------------------------------
+// Test: WithMetricsFactory(nil) — manager works, no metrics emitted, no panic
+// ---------------------------------------------------------------------------
+
+func TestMetrics_WithNilFactory_NoPanic(t *testing.T) {
+ t.Parallel()
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(nil))
+ require.NoError(t, err)
+
+ // Verify the metricsFactory field is nil on the concrete manager
+ m := mgr.(*manager)
+ assert.Nil(t, m.metricsFactory, "metricsFactory should be nil when WithMetricsFactory(nil) is used")
+
+ // Create a breaker and execute — must not panic even without metrics
+ _, err = mgr.GetOrCreate("no-metrics-svc", DefaultConfig())
+ require.NoError(t, err)
+
+ result, err := mgr.Execute("no-metrics-svc", func() (any, error) {
+ return "ok", nil
+ })
+ assert.NoError(t, err)
+ assert.Equal(t, "ok", result)
+
+ // Execute with error — recordExecution("error") must not panic
+ _, err = mgr.Execute("no-metrics-svc", func() (any, error) {
+ return nil, errors.New("boom")
+ })
+ assert.Error(t, err)
+}
+
+// ---------------------------------------------------------------------------
+// Test: WithMetricsFactory(factory) — option is applied to manager
+// ---------------------------------------------------------------------------
+
+func TestMetrics_WithFactory_Applied(t *testing.T) {
+ t.Parallel()
+
+ factory, _ := newTestMetricsFactory(t)
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(factory))
+ require.NoError(t, err)
+
+ m := mgr.(*manager)
+ assert.Same(t, factory, m.metricsFactory, "metricsFactory should be the factory passed via option")
+}
+
+// ---------------------------------------------------------------------------
+// Test: recordExecution — success path emits counter with result="success"
+// ---------------------------------------------------------------------------
+
+func TestMetrics_RecordExecution_Success(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestMetricsFactory(t)
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(factory))
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("exec-svc", DefaultConfig())
+ require.NoError(t, err)
+
+ // Successful execution
+ _, err = mgr.Execute("exec-svc", func() (any, error) {
+ return "ok", nil
+ })
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "circuit_breaker_executions_total")
+ require.NotNil(t, m, "execution metric must be recorded")
+
+ dps := sumDataPoints(t, m)
+ require.NotEmpty(t, dps)
+
+ // Find the data point with result="success"
+ found := false
+
+ for _, dp := range dps {
+ if hasAttributeValue(dp, "result", "success") && hasAttributeValue(dp, "service", "exec-svc") {
+ found = true
+ assert.Equal(t, int64(1), dp.Value, "one successful execution should record value 1")
+ }
+ }
+
+ assert.True(t, found, "expected a data point with result=success and service=exec-svc")
+}
+
+// ---------------------------------------------------------------------------
+// Test: recordExecution — error path emits counter with result="error"
+// ---------------------------------------------------------------------------
+
+func TestMetrics_RecordExecution_Error(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestMetricsFactory(t)
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(factory))
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("err-svc", DefaultConfig())
+ require.NoError(t, err)
+
+ // Failing execution (the wrapped function returns an error)
+ _, err = mgr.Execute("err-svc", func() (any, error) {
+ return nil, errors.New("service failure")
+ })
+ assert.Error(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "circuit_breaker_executions_total")
+ require.NotNil(t, m, "execution metric must be recorded on error path")
+
+ dps := sumDataPoints(t, m)
+
+ found := false
+
+ for _, dp := range dps {
+ if hasAttributeValue(dp, "result", "error") && hasAttributeValue(dp, "service", "err-svc") {
+ found = true
+ assert.Equal(t, int64(1), dp.Value)
+ }
+ }
+
+ assert.True(t, found, "expected a data point with result=error and service=err-svc")
+}
+
+// ---------------------------------------------------------------------------
+// Test: recordExecution — open-state rejection emits result="rejected_open"
+// ---------------------------------------------------------------------------
+
+func TestMetrics_RecordExecution_RejectedOpen(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestMetricsFactory(t)
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(factory))
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 5 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ _, err = mgr.GetOrCreate("reject-svc", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker open
+ for i := 0; i < 3; i++ {
+ _, _ = mgr.Execute("reject-svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+
+ require.Equal(t, StateOpen, mgr.GetState("reject-svc"))
+
+ // This call should be rejected by the open breaker
+ _, err = mgr.Execute("reject-svc", func() (any, error) {
+ return nil, nil
+ })
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "currently unavailable")
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "circuit_breaker_executions_total")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+
+ found := false
+
+ for _, dp := range dps {
+ if hasAttributeValue(dp, "result", "rejected_open") && hasAttributeValue(dp, "service", "reject-svc") {
+ found = true
+ assert.GreaterOrEqual(t, dp.Value, int64(1))
+ }
+ }
+
+ assert.True(t, found, "expected a data point with result=rejected_open and service=reject-svc")
+}
+
+// ---------------------------------------------------------------------------
+// Test: recordStateTransition — state change from closed → open emits metric
+// ---------------------------------------------------------------------------
+
+func TestMetrics_RecordStateTransition_ClosedToOpen(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestMetricsFactory(t)
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(factory))
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 5 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ _, err = mgr.GetOrCreate("state-svc", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker: consecutive failures → closed→open transition
+ for i := 0; i < 3; i++ {
+ _, _ = mgr.Execute("state-svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+
+ require.Equal(t, StateOpen, mgr.GetState("state-svc"))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "circuit_breaker_state_transitions_total")
+ require.NotNil(t, m, "state transition metric must be recorded")
+
+ dps := sumDataPoints(t, m)
+
+ found := false
+
+ for _, dp := range dps {
+ if hasAttributeValue(dp, "from_state", string(StateClosed)) &&
+ hasAttributeValue(dp, "to_state", string(StateOpen)) &&
+ hasAttributeValue(dp, "service", "state-svc") {
+ found = true
+ assert.GreaterOrEqual(t, dp.Value, int64(1))
+ }
+ }
+
+ assert.True(t, found, "expected state transition metric with from_state=closed, to_state=open, service=state-svc")
+}
+
+// ---------------------------------------------------------------------------
+// Test: recordStateTransition — direct call on manager struct (nil factory)
+// ---------------------------------------------------------------------------
+
+func TestMetrics_RecordStateTransition_NilFactory_Noop(t *testing.T) {
+ t.Parallel()
+
+ mgr, err := NewManager(&log.NopLogger{})
+ require.NoError(t, err)
+
+ m := mgr.(*manager)
+
+ // Direct call with nil metricsFactory — must be a no-op, no panic
+ assert.NotPanics(t, func() {
+ m.recordStateTransition("any-service", StateClosed, StateOpen)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Test: recordExecution — direct call on manager struct (nil factory)
+// ---------------------------------------------------------------------------
+
+func TestMetrics_RecordExecution_NilFactory_Noop(t *testing.T) {
+ t.Parallel()
+
+ mgr, err := NewManager(&log.NopLogger{})
+ require.NoError(t, err)
+
+ m := mgr.(*manager)
+
+ // Direct call with nil metricsFactory — must be a no-op, no panic
+ assert.NotPanics(t, func() {
+ m.recordExecution("any-service", "success")
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Test: SanitizeMetricLabel is applied — long service name > 64 chars
+// ---------------------------------------------------------------------------
+
+func TestMetrics_LongServiceName_Sanitized(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestMetricsFactory(t)
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(factory))
+ require.NoError(t, err)
+
+ // Create a service name that exceeds 64 characters
+ longName := strings.Repeat("a", 100)
+ require.Greater(t, len(longName), 64, "test precondition: service name must exceed 64 chars")
+
+ _, err = mgr.GetOrCreate(longName, DefaultConfig())
+ require.NoError(t, err)
+
+ _, err = mgr.Execute(longName, func() (any, error) {
+ return "ok", nil
+ })
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "circuit_breaker_executions_total")
+ require.NotNil(t, m, "execution metric must be recorded for long service name")
+
+ dps := sumDataPoints(t, m)
+ require.NotEmpty(t, dps)
+
+ // The service label must be truncated to 64 characters
+ truncatedName := longName[:64]
+
+ found := false
+
+ for _, dp := range dps {
+ if hasAttributeValue(dp, "service", truncatedName) {
+ found = true
+ }
+ }
+
+ assert.True(t, found, "service label should be truncated to 64 characters via SanitizeMetricLabel")
+}
+
+// ---------------------------------------------------------------------------
+// Test: Multiple executions accumulate correctly
+// ---------------------------------------------------------------------------
+
+func TestMetrics_MultipleExecutions_Accumulate(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestMetricsFactory(t)
+
+ mgr, err := NewManager(&log.NopLogger{}, WithMetricsFactory(factory))
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("accum-svc", DefaultConfig())
+ require.NoError(t, err)
+
+ // Run 3 successful and 2 failed executions
+ for i := 0; i < 3; i++ {
+ _, err = mgr.Execute("accum-svc", func() (any, error) {
+ return "ok", nil
+ })
+ require.NoError(t, err)
+ }
+
+ for i := 0; i < 2; i++ {
+ _, _ = mgr.Execute("accum-svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "circuit_breaker_executions_total")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+
+ var successVal, errorVal int64
+
+ for _, dp := range dps {
+ if hasAttributeValue(dp, "service", "accum-svc") {
+ if hasAttributeValue(dp, "result", "success") {
+ successVal = dp.Value
+ }
+
+ if hasAttributeValue(dp, "result", "error") {
+ errorVal = dp.Value
+ }
+ }
+ }
+
+ assert.Equal(t, int64(3), successVal, "3 successful executions should be recorded")
+ assert.Equal(t, int64(2), errorVal, "2 failed executions should be recorded")
+}
+
+// ---------------------------------------------------------------------------
+// Test: Metric definitions have correct names and units
+// ---------------------------------------------------------------------------
+
+func TestMetrics_MetricDefinitions(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "circuit_breaker_state_transitions_total", stateTransitionMetric.Name)
+ assert.Equal(t, "1", stateTransitionMetric.Unit)
+ assert.NotEmpty(t, stateTransitionMetric.Description)
+
+ assert.Equal(t, "circuit_breaker_executions_total", executionMetric.Name)
+ assert.Equal(t, "1", executionMetric.Unit)
+ assert.NotEmpty(t, executionMetric.Description)
+}
diff --git a/commons/circuitbreaker/manager_test.go b/commons/circuitbreaker/manager_test.go
index 685529d8..ff416edf 100644
--- a/commons/circuitbreaker/manager_test.go
+++ b/commons/circuitbreaker/manager_test.go
@@ -1,24 +1,27 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package circuitbreaker
import (
+ "context"
"errors"
"testing"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/sony/gobreaker"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestCircuitBreaker_InitialState(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
config := DefaultConfig()
- manager.GetOrCreate("test-service", config)
+ _, err = manager.GetOrCreate("test-service", config)
+ assert.NoError(t, err)
// Circuit breaker should start in closed state
assert.Equal(t, StateClosed, manager.GetState("test-service"))
@@ -26,8 +29,9 @@ func TestCircuitBreaker_InitialState(t *testing.T) {
}
func TestCircuitBreaker_OpenState(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
config := Config{
MaxRequests: 1,
@@ -38,7 +42,8 @@ func TestCircuitBreaker_OpenState(t *testing.T) {
MinRequests: 2,
}
- manager.GetOrCreate("test-service", config)
+ _, err = manager.GetOrCreate("test-service", config)
+ assert.NoError(t, err)
// Trigger failures to open circuit breaker
for i := 0; i < 5; i++ {
@@ -54,7 +59,7 @@ func TestCircuitBreaker_OpenState(t *testing.T) {
// Requests should fast-fail
start := time.Now()
- _, err := manager.Execute("test-service", func() (any, error) {
+ _, err = manager.Execute("test-service", func() (any, error) {
time.Sleep(5 * time.Second) // This should not execute
return nil, nil
})
@@ -66,11 +71,13 @@ func TestCircuitBreaker_OpenState(t *testing.T) {
}
func TestCircuitBreaker_SuccessfulExecution(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
config := DefaultConfig()
- manager.GetOrCreate("test-service", config)
+ _, err = manager.GetOrCreate("test-service", config)
+ assert.NoError(t, err)
result, err := manager.Execute("test-service", func() (any, error) {
return "success", nil
@@ -82,24 +89,28 @@ func TestCircuitBreaker_SuccessfulExecution(t *testing.T) {
}
func TestCircuitBreaker_GetCounts(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
config := DefaultConfig()
- manager.GetOrCreate("test-service", config)
+ _, err = manager.GetOrCreate("test-service", config)
+ assert.NoError(t, err)
// Execute some requests
for i := 0; i < 5; i++ {
- _, _ = manager.Execute("test-service", func() (any, error) {
+ _, err = manager.Execute("test-service", func() (any, error) {
return "success", nil
})
+ require.NoError(t, err)
}
// Trigger some failures
for i := 0; i < 3; i++ {
- _, _ = manager.Execute("test-service", func() (any, error) {
+ _, err = manager.Execute("test-service", func() (any, error) {
return nil, errors.New("failure")
})
+ require.Error(t, err)
}
counts := manager.GetCounts("test-service")
@@ -109,8 +120,9 @@ func TestCircuitBreaker_GetCounts(t *testing.T) {
}
func TestCircuitBreaker_Reset(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
config := Config{
MaxRequests: 1,
@@ -121,13 +133,15 @@ func TestCircuitBreaker_Reset(t *testing.T) {
MinRequests: 2,
}
- manager.GetOrCreate("test-service", config)
+ _, err = manager.GetOrCreate("test-service", config)
+ assert.NoError(t, err)
// Trigger failures to open circuit breaker
for i := 0; i < 5; i++ {
- _, _ = manager.Execute("test-service", func() (any, error) {
+ _, err = manager.Execute("test-service", func() (any, error) {
return nil, errors.New("service error")
})
+ require.Error(t, err)
}
// Circuit breaker should be open
@@ -149,14 +163,15 @@ func TestCircuitBreaker_Reset(t *testing.T) {
}
func TestCircuitBreaker_UnknownService(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
// Query non-existent service
assert.Equal(t, StateUnknown, manager.GetState("non-existent"))
// Execute on non-existent service should fail
- _, err := manager.Execute("non-existent", func() (any, error) {
+ _, err = manager.Execute("non-existent", func() (any, error) {
return "success", nil
})
@@ -185,8 +200,9 @@ func TestCircuitBreaker_ConfigPresets(t *testing.T) {
}
func TestCircuitBreaker_StateChangeListenerPanicRecovery(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
config := Config{
MaxRequests: 1,
@@ -229,13 +245,15 @@ func TestCircuitBreaker_StateChangeListenerPanicRecovery(t *testing.T) {
manager.RegisterStateChangeListener(secondNormalListener)
// Create circuit breaker
- manager.GetOrCreate("test-service", config)
+ _, err = manager.GetOrCreate("test-service", config)
+ assert.NoError(t, err)
// Trigger failures to open circuit breaker and trigger state change
for i := 0; i < 3; i++ {
- _, _ = manager.Execute("test-service", func() (any, error) {
+ _, err = manager.Execute("test-service", func() (any, error) {
return nil, errors.New("service error")
})
+ require.Error(t, err)
}
// Wait for all listeners to be called (with timeout)
@@ -268,8 +286,9 @@ func TestCircuitBreaker_StateChangeListenerPanicRecovery(t *testing.T) {
}
func TestCircuitBreaker_NilListenerRegistration(t *testing.T) {
- logger := &log.NoneLogger{}
- manager := NewManager(logger)
+ logger := &log.NopLogger{}
+ manager, err := NewManager(logger)
+ require.NoError(t, err)
// Attempt to register nil listener
manager.RegisterStateChangeListener(nil)
@@ -283,13 +302,15 @@ func TestCircuitBreaker_NilListenerRegistration(t *testing.T) {
FailureRatio: 0.5,
MinRequests: 2,
}
- manager.GetOrCreate("test-service", config)
+ _, err = manager.GetOrCreate("test-service", config)
+ assert.NoError(t, err)
// Trigger a state change to ensure system still works
for i := 0; i < 3; i++ {
- _, _ = manager.Execute("test-service", func() (any, error) {
+ _, err = manager.Execute("test-service", func() (any, error) {
return nil, errors.New("service error")
})
+ require.Error(t, err)
}
// Should successfully transition to open state
@@ -301,8 +322,209 @@ type mockStateChangeListener struct {
onStateChangeFn func(serviceName string, from State, to State)
}
-func (m *mockStateChangeListener) OnStateChange(serviceName string, from State, to State) {
+func (m *mockStateChangeListener) OnStateChange(_ context.Context, serviceName string, from State, to State) {
if m.onStateChangeFn != nil {
m.onStateChangeFn(serviceName, from, to)
}
}
+
+func TestNewManager_NilLogger(t *testing.T) {
+ m, err := NewManager(nil)
+ assert.Nil(t, m)
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilLogger)
+}
+
+func TestGetOrCreate_InvalidConfig(t *testing.T) {
+ logger := &log.NopLogger{}
+ m, err := NewManager(logger)
+ require.NoError(t, err)
+
+ // Both trip conditions zero → invalid
+ invalidCfg := Config{
+ ConsecutiveFailures: 0,
+ MinRequests: 0,
+ }
+
+ cb, err := m.GetOrCreate("bad-config-service", invalidCfg)
+ assert.Nil(t, cb)
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+}
+
+func TestGetOrCreate_ReturnExistingBreaker(t *testing.T) {
+ logger := &log.NopLogger{}
+ m, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := DefaultConfig()
+
+ cb1, err := m.GetOrCreate("my-service", cfg)
+ require.NoError(t, err)
+
+ cb2, err := m.GetOrCreate("my-service", cfg)
+ require.NoError(t, err)
+
+ // Both should return a valid breaker in the same state
+ assert.Equal(t, cb1.State(), cb2.State())
+}
+
+func TestExecute_OpenStateRejection(t *testing.T) {
+ logger := &log.NopLogger{}
+ m, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := Config{
+ MaxRequests: 1,
+ Interval: 100 * time.Millisecond,
+ Timeout: 200 * time.Millisecond,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+
+ _, err = m.GetOrCreate("svc", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker open by sending consecutive failures
+ for i := 0; i < 3; i++ {
+ _, _ = m.Execute("svc", func() (any, error) {
+ return nil, errors.New("fail")
+ })
+ }
+ assert.Equal(t, StateOpen, m.GetState("svc"))
+
+ // Poll until the breaker transitions to half-open (timeout is 200ms)
+ require.Eventually(t, func() bool {
+ return m.GetState("svc") == StateHalfOpen
+ }, 2*time.Second, 10*time.Millisecond, "breaker should transition to half-open after timeout")
+
+ // MaxRequests=1, so the first call in half-open is the probe.
+ // Make it fail so the breaker re-opens.
+ _, _ = m.Execute("svc", func() (any, error) {
+ return nil, errors.New("still failing")
+ })
+
+ // Poll again until the breaker transitions back to half-open
+ require.Eventually(t, func() bool {
+ return m.GetState("svc") == StateHalfOpen
+ }, 2*time.Second, 10*time.Millisecond, "breaker should transition to half-open after second timeout")
+
+ // In half-open the probe call (first) is allowed; make it fail to re-open
+ _, err = m.Execute("svc", func() (any, error) {
+ return nil, errors.New("probe fail")
+ })
+ // After the probe fails in half-open, breaker re-opens.
+ // The next call should be rejected with an open-state error.
+ _, err = m.Execute("svc", func() (any, error) {
+ return nil, nil
+ })
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "currently unavailable")
+}
+
+func TestGetCounts_NonExistentService(t *testing.T) {
+ logger := &log.NopLogger{}
+ m, err := NewManager(logger)
+ require.NoError(t, err)
+
+ counts := m.GetCounts("does-not-exist")
+ assert.Equal(t, Counts{}, counts)
+}
+
+func TestIsHealthy_NonExistentService(t *testing.T) {
+ logger := &log.NopLogger{}
+ m, err := NewManager(logger)
+ require.NoError(t, err)
+
+ // StateUnknown != StateClosed → not healthy
+ assert.False(t, m.IsHealthy("no-such-service"))
+}
+
+func TestReset_NonExistentService(t *testing.T) {
+ logger := &log.NopLogger{}
+ m, err := NewManager(logger)
+ require.NoError(t, err)
+
+ // Should be a no-op, not panic
+ assert.NotPanics(t, func() {
+ m.Reset("non-existent-service")
+ })
+}
+
+func TestCircuitBreaker_Wrapper_Execute(t *testing.T) {
+ logger := &log.NopLogger{}
+ m, err := NewManager(logger)
+ require.NoError(t, err)
+
+ cfg := DefaultConfig()
+ cb, err := m.GetOrCreate("wrapper-test", cfg)
+ require.NoError(t, err)
+
+ // Test Execute through the CircuitBreaker interface
+ result, err := cb.Execute(func() (any, error) {
+ return 42, nil
+ })
+ assert.NoError(t, err)
+ assert.Equal(t, 42, result)
+
+ // Test State
+ assert.Equal(t, StateClosed, cb.State())
+
+ // Test Counts
+ counts := cb.Counts()
+ assert.Equal(t, uint32(1), counts.Requests)
+ assert.Equal(t, uint32(1), counts.TotalSuccesses)
+}
+
+func TestReadyToTrip_ConsecutiveFailures(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 3,
+ MinRequests: 0,
+ }
+
+ tripFn := readyToTrip(cfg)
+
+ // Below threshold
+ assert.False(t, tripFn(gobreaker.Counts{ConsecutiveFailures: 2}))
+
+ // At threshold
+ assert.True(t, tripFn(gobreaker.Counts{ConsecutiveFailures: 3}))
+
+ // Above threshold
+ assert.True(t, tripFn(gobreaker.Counts{ConsecutiveFailures: 5}))
+}
+
+func TestReadyToTrip_FailureRatio(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 0,
+ MinRequests: 4,
+ FailureRatio: 0.5,
+ }
+
+ tripFn := readyToTrip(cfg)
+
+ // Not enough requests
+ assert.False(t, tripFn(gobreaker.Counts{Requests: 3, TotalFailures: 3}))
+
+ // Enough requests, below ratio
+ assert.False(t, tripFn(gobreaker.Counts{Requests: 4, TotalFailures: 1}))
+
+ // Enough requests, at ratio
+ assert.True(t, tripFn(gobreaker.Counts{Requests: 4, TotalFailures: 2}))
+
+ // Enough requests, above ratio
+ assert.True(t, tripFn(gobreaker.Counts{Requests: 4, TotalFailures: 3}))
+}
+
+func TestReadyToTrip_NeitherConditionMet(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 0,
+ MinRequests: 0,
+ }
+
+ tripFn := readyToTrip(cfg)
+
+ // Both conditions disabled → never trips
+ assert.False(t, tripFn(gobreaker.Counts{Requests: 100, TotalFailures: 100, ConsecutiveFailures: 100}))
+}
diff --git a/commons/circuitbreaker/types.go b/commons/circuitbreaker/types.go
index f0636613..6ace9548 100644
--- a/commons/circuitbreaker/types.go
+++ b/commons/circuitbreaker/types.go
@@ -1,20 +1,34 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package circuitbreaker
import (
"context"
+ "errors"
+ "fmt"
"time"
"github.com/sony/gobreaker"
)
+var (
+ // ErrInvalidConfig is returned when a Config has invalid or insufficient values.
+ ErrInvalidConfig = errors.New("circuitbreaker: invalid config")
+
+ // ErrNilLogger is returned when a nil logger is passed to NewManager.
+ ErrNilLogger = errors.New("circuitbreaker: logger must not be nil")
+
+ // ErrNilCallback is returned when a nil callback is passed to Execute.
+ ErrNilCallback = errors.New("circuitbreaker: callback must not be nil")
+
+ // ErrConfigMismatch is returned when GetOrCreate is called with a config that
+ // differs from the one stored for an existing breaker with the same name.
+ ErrConfigMismatch = errors.New("circuitbreaker: breaker already exists with different config")
+)
+
// Manager manages circuit breakers for external services
type Manager interface {
- // GetOrCreate returns existing circuit breaker or creates a new one
- GetOrCreate(serviceName string, config Config) CircuitBreaker
+ // GetOrCreate returns existing circuit breaker or creates a new one.
+ // Returns an error if the config is invalid.
+ GetOrCreate(serviceName string, config Config) (CircuitBreaker, error)
// Execute runs a function through the circuit breaker
Execute(serviceName string, fn func() (any, error)) (any, error)
@@ -45,21 +59,52 @@ type CircuitBreaker interface {
// Config holds circuit breaker configuration
type Config struct {
MaxRequests uint32 // Max requests in half-open state
- Interval time.Duration // Wait time before half-open retry
- Timeout time.Duration // Execution timeout
+ Interval time.Duration // Cyclic period of the closed state to clear internal counts
+ Timeout time.Duration // Period of the open state before becoming half-open
ConsecutiveFailures uint32 // Consecutive failures to trigger open state
FailureRatio float64 // Failure ratio to trigger open (e.g., 0.5 for 50%)
MinRequests uint32 // Min requests before checking ratio
}
+// Validate checks that the Config has valid values.
+// At least one trip condition (ConsecutiveFailures or MinRequests+FailureRatio) must be enabled.
+// Interval and Timeout must not be negative.
+func (c Config) Validate() error {
+ if c.ConsecutiveFailures == 0 && c.MinRequests == 0 {
+ return fmt.Errorf("%w: at least one trip condition must be set (ConsecutiveFailures > 0 or MinRequests > 0)", ErrInvalidConfig)
+ }
+
+ if c.FailureRatio < 0 || c.FailureRatio > 1 {
+ return fmt.Errorf("%w: FailureRatio must be between 0 and 1, got %f", ErrInvalidConfig, c.FailureRatio)
+ }
+
+ if c.MinRequests > 0 && c.FailureRatio <= 0 {
+ return fmt.Errorf("%w: FailureRatio must be > 0 when MinRequests > 0 (ratio-based trip is ineffective with FailureRatio=0)", ErrInvalidConfig)
+ }
+
+ if c.Interval < 0 {
+ return fmt.Errorf("%w: Interval must not be negative, got %v", ErrInvalidConfig, c.Interval)
+ }
+
+ if c.Timeout < 0 {
+ return fmt.Errorf("%w: Timeout must not be negative, got %v", ErrInvalidConfig, c.Timeout)
+ }
+
+ return nil
+}
+
// State represents circuit breaker state
type State string
const (
- StateClosed State = "closed"
- StateOpen State = "open"
+ // StateClosed allows requests to pass through normally.
+ StateClosed State = "closed"
+ // StateOpen rejects requests until the timeout elapses.
+ StateOpen State = "open"
+ // StateHalfOpen allows limited trial requests after an open period.
StateHalfOpen State = "half-open"
- StateUnknown State = "unknown"
+ // StateUnknown is returned when the underlying state cannot be mapped.
+ StateUnknown State = "unknown"
)
// Counts represents circuit breaker statistics
@@ -71,30 +116,42 @@ type Counts struct {
ConsecutiveFailures uint32
}
+// ErrNilCircuitBreaker is returned when a circuit breaker method is called on a nil or uninitialized instance.
+var ErrNilCircuitBreaker = errors.New("circuitbreaker: not initialized")
+
// circuitBreaker is the internal implementation wrapping gobreaker
type circuitBreaker struct {
breaker *gobreaker.CircuitBreaker
}
+// Execute runs fn through the underlying circuit breaker.
func (cb *circuitBreaker) Execute(fn func() (any, error)) (any, error) {
+ if cb == nil || cb.breaker == nil {
+ return nil, ErrNilCircuitBreaker
+ }
+
+ if fn == nil {
+ return nil, ErrNilCallback
+ }
+
return cb.breaker.Execute(fn)
}
+// State returns the current circuit breaker state.
func (cb *circuitBreaker) State() State {
- state := cb.breaker.State()
- switch state {
- case gobreaker.StateClosed:
- return StateClosed
- case gobreaker.StateOpen:
- return StateOpen
- case gobreaker.StateHalfOpen:
- return StateHalfOpen
- default:
+ if cb == nil || cb.breaker == nil {
return StateUnknown
}
+
+ return convertGobreakerState(cb.breaker.State())
}
+// Counts returns the current breaker counters.
func (cb *circuitBreaker) Counts() Counts {
+ if cb == nil || cb.breaker == nil {
+ return Counts{}
+ }
+
counts := cb.breaker.Counts()
return Counts{
@@ -127,8 +184,24 @@ type HealthChecker interface {
// HealthCheckFunc defines a function that checks service health
type HealthCheckFunc func(ctx context.Context) error
-// StateChangeListener is notified when circuit breaker state changes
+// StateChangeListener is notified when circuit breaker state changes.
type StateChangeListener interface {
- // OnStateChange is called when a circuit breaker changes state
- OnStateChange(serviceName string, from State, to State)
+ // OnStateChange is called when a circuit breaker changes state.
+ // The provided context carries a deadline derived from the listener timeout;
+ // implementations should respect ctx.Done() for cancellation.
+ OnStateChange(ctx context.Context, serviceName string, from State, to State)
+}
+
+// convertGobreakerState converts gobreaker.State to our State type.
+func convertGobreakerState(state gobreaker.State) State {
+ switch state {
+ case gobreaker.StateClosed:
+ return StateClosed
+ case gobreaker.StateOpen:
+ return StateOpen
+ case gobreaker.StateHalfOpen:
+ return StateHalfOpen
+ default:
+ return StateUnknown
+ }
}
diff --git a/commons/circuitbreaker/types_test.go b/commons/circuitbreaker/types_test.go
new file mode 100644
index 00000000..c6b00fab
--- /dev/null
+++ b/commons/circuitbreaker/types_test.go
@@ -0,0 +1,109 @@
+//go:build unit
+
+package circuitbreaker
+
+import (
+ "testing"
+
+ "github.com/sony/gobreaker"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestConfig_Validate_BothTripConditionsZero(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 0,
+ MinRequests: 0,
+ }
+
+ err := cfg.Validate()
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ assert.Contains(t, err.Error(), "at least one trip condition must be set")
+}
+
+func TestConfig_Validate_InvalidFailureRatio_Negative(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 5,
+ FailureRatio: -0.1,
+ }
+
+ err := cfg.Validate()
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ assert.Contains(t, err.Error(), "FailureRatio must be between 0 and 1")
+}
+
+func TestConfig_Validate_InvalidFailureRatio_GreaterThanOne(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 5,
+ FailureRatio: 1.1,
+ }
+
+ err := cfg.Validate()
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ assert.Contains(t, err.Error(), "FailureRatio must be between 0 and 1")
+}
+
+func TestConfig_Validate_ValidConfig(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 5,
+ FailureRatio: 0.5,
+ MinRequests: 10,
+ }
+
+ err := cfg.Validate()
+ assert.NoError(t, err)
+}
+
+func TestConfig_Validate_OnlyConsecutiveFailuresSet(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 3,
+ MinRequests: 0,
+ FailureRatio: 0,
+ }
+
+ err := cfg.Validate()
+ assert.NoError(t, err)
+}
+
+func TestConfig_Validate_MinRequestsAndFailureRatioSet(t *testing.T) {
+ cfg := Config{
+ ConsecutiveFailures: 0,
+ MinRequests: 10,
+ FailureRatio: 0.5,
+ }
+
+ err := cfg.Validate()
+ assert.NoError(t, err)
+}
+
+func TestConfig_Validate_BoundaryFailureRatio(t *testing.T) {
+ // FailureRatio = 0 is valid
+ cfg := Config{
+ ConsecutiveFailures: 5,
+ FailureRatio: 0,
+ }
+ assert.NoError(t, cfg.Validate())
+
+ // FailureRatio = 1 is valid
+ cfg.FailureRatio = 1
+ assert.NoError(t, cfg.Validate())
+}
+
+func TestConvertGobreakerState_Closed(t *testing.T) {
+ assert.Equal(t, StateClosed, convertGobreakerState(gobreaker.StateClosed))
+}
+
+func TestConvertGobreakerState_Open(t *testing.T) {
+ assert.Equal(t, StateOpen, convertGobreakerState(gobreaker.StateOpen))
+}
+
+func TestConvertGobreakerState_HalfOpen(t *testing.T) {
+ assert.Equal(t, StateHalfOpen, convertGobreakerState(gobreaker.StateHalfOpen))
+}
+
+func TestConvertGobreakerState_Unknown(t *testing.T) {
+ // Use an arbitrary value that doesn't map to any known state
+ assert.Equal(t, StateUnknown, convertGobreakerState(gobreaker.State(99)))
+}
diff --git a/commons/constants/datasource.go b/commons/constants/datasource.go
index c61810c5..ef343596 100644
--- a/commons/constants/datasource.go
+++ b/commons/constants/datasource.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
// DataSource Status
diff --git a/commons/constants/doc.go b/commons/constants/doc.go
new file mode 100644
index 00000000..9f5d36a6
--- /dev/null
+++ b/commons/constants/doc.go
@@ -0,0 +1,8 @@
+// Package constant provides shared constant values used across the library.
+//
+// The package name is singular for API compatibility, while the import path is
+// /commons/constants.
+//
+// Keep this package free of runtime behavior.
+// It is used by transport, telemetry, and logging helpers to avoid duplicated literals.
+package constant
diff --git a/commons/constants/errors.go b/commons/constants/errors.go
index f29f7cb1..d7e7f513 100644
--- a/commons/constants/errors.go
+++ b/commons/constants/errors.go
@@ -1,18 +1,51 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
import "errors"
+// Error code string constants — single source of truth for numeric codes
+// shared between sentinel errors (below) and domain ErrorCode types.
+const (
+ // CodeInsufficientFunds is the code for insufficient balance.
+ CodeInsufficientFunds = "0018"
+ // CodeAccountIneligibility is the code for account ineligibility.
+ CodeAccountIneligibility = "0019"
+ // CodeAccountStatusTransactionRestriction is the code for account status restrictions.
+ CodeAccountStatusTransactionRestriction = "0024"
+ // CodeAssetCodeNotFound is the code for missing asset.
+ CodeAssetCodeNotFound = "0034"
+ // CodeMetadataKeyLengthExceeded is the code for metadata key exceeding length limit.
+ CodeMetadataKeyLengthExceeded = "0050"
+ // CodeMetadataValueLengthExceeded is the code for metadata value exceeding length limit.
+ CodeMetadataValueLengthExceeded = "0051"
+ // CodeTransactionValueMismatch is the code for allocation vs total mismatch.
+ CodeTransactionValueMismatch = "0073"
+ // CodeTransactionAmbiguous is the code for ambiguous transaction routing.
+ CodeTransactionAmbiguous = "0090"
+ // CodeOverFlowInt64 is the code for int64 overflow.
+ CodeOverFlowInt64 = "0097"
+ // CodeOnHoldExternalAccount is the code for on-hold on external accounts.
+ CodeOnHoldExternalAccount = "0098"
+)
+
var (
- ErrInsufficientFunds = errors.New("0018")
- ErrAccountIneligibility = errors.New("0019")
- ErrAccountStatusTransactionRestriction = errors.New("0024")
- ErrAssetCodeNotFound = errors.New("0034")
- ErrTransactionValueMismatch = errors.New("0073")
- ErrTransactionAmbiguous = errors.New("0090")
- ErrOverFlowInt64 = errors.New("0097")
- ErrOnHoldExternalAccount = errors.New("0098")
+ // ErrInsufficientFunds maps to transaction error code 0018.
+ ErrInsufficientFunds = errors.New(CodeInsufficientFunds)
+ // ErrAccountIneligibility maps to transaction error code 0019.
+ ErrAccountIneligibility = errors.New(CodeAccountIneligibility)
+ // ErrAccountStatusTransactionRestriction maps to transaction error code 0024.
+ ErrAccountStatusTransactionRestriction = errors.New(CodeAccountStatusTransactionRestriction)
+ // ErrAssetCodeNotFound maps to transaction error code 0034.
+ ErrAssetCodeNotFound = errors.New(CodeAssetCodeNotFound)
+ // ErrMetadataKeyLengthExceeded maps to metadata error code 0050.
+ ErrMetadataKeyLengthExceeded = errors.New(CodeMetadataKeyLengthExceeded)
+ // ErrMetadataValueLengthExceeded maps to metadata error code 0051.
+ ErrMetadataValueLengthExceeded = errors.New(CodeMetadataValueLengthExceeded)
+ // ErrTransactionValueMismatch maps to transaction error code 0073.
+ ErrTransactionValueMismatch = errors.New(CodeTransactionValueMismatch)
+ // ErrTransactionAmbiguous maps to transaction error code 0090.
+ ErrTransactionAmbiguous = errors.New(CodeTransactionAmbiguous)
+ // ErrOverFlowInt64 maps to transaction error code 0097.
+ ErrOverFlowInt64 = errors.New(CodeOverFlowInt64)
+ // ErrOnHoldExternalAccount maps to transaction error code 0098.
+ ErrOnHoldExternalAccount = errors.New(CodeOnHoldExternalAccount)
)
diff --git a/commons/constants/headers.go b/commons/constants/headers.go
index 1dee9e21..2fdd9fae 100644
--- a/commons/constants/headers.go
+++ b/commons/constants/headers.go
@@ -1,29 +1,54 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
const (
- HeaderUserAgent = "User-Agent"
- HeaderRealIP = "X-Real-Ip"
- HeaderForwardedFor = "X-Forwarded-For"
+ // HeaderUserAgent is the HTTP User-Agent header key.
+ HeaderUserAgent = "User-Agent"
+ // HeaderRealIP is the de-facto upstream real client IP header key.
+ HeaderRealIP = "X-Real-Ip"
+ // HeaderForwardedFor is the X-Forwarded-For header key.
+ HeaderForwardedFor = "X-Forwarded-For"
+ // HeaderForwardedHost is the X-Forwarded-Host header key.
HeaderForwardedHost = "X-Forwarded-Host"
- HeaderHost = "Host"
- DSL = "dsl"
- FileExtension = ".gold"
- HeaderID = "X-Request-Id"
- HeaderTraceparent = "Traceparent"
- IdempotencyKey = "X-Idempotency"
- IdempotencyTTL = "X-TTL"
+ // HeaderHost is the Host header key.
+ HeaderHost = "Host"
+ // DSL is the file kind marker used for DSL resources.
+ DSL = "dsl"
+ // FileExtension is the default extension for DSL files.
+ FileExtension = ".gold"
+ // HeaderID is the request identifier header key.
+ HeaderID = "X-Request-Id"
+ // HeaderTraceparent is the W3C traceparent header key.
+ HeaderTraceparent = "Traceparent"
+ // IdempotencyKey is the idempotency key request header.
+ IdempotencyKey = "X-Idempotency"
+ // IdempotencyTTL is the idempotency record TTL header.
+ IdempotencyTTL = "X-TTL"
+ // IdempotencyReplayed signals whether a request was replayed.
IdempotencyReplayed = "X-Idempotency-Replayed"
- Authorization = "Authorization"
- Basic = "Basic"
- BasicAuth = "Basic Auth"
- WWWAuthenticate = "WWW-Authenticate"
+ // Authorization is the HTTP Authorization header key.
+ Authorization = "Authorization"
+ // Basic is the HTTP Basic auth scheme token.
+ Basic = "Basic"
+ // BasicAuth is the human-readable Basic auth label.
+ BasicAuth = "Basic Auth"
+ // WWWAuthenticate is the HTTP WWW-Authenticate header key.
+ WWWAuthenticate = "WWW-Authenticate"
+ // Bearer is the HTTP Bearer auth scheme token.
+ Bearer = "Bearer"
+
+ // HeaderReferer is the HTTP Referer header key.
+ HeaderReferer = "Referer"
+ // HeaderContentType is the HTTP Content-Type header key.
+ HeaderContentType = "Content-Type"
+ // HeaderTraceparentPascal is the PascalCase variant of the Traceparent header for gRPC metadata.
+ HeaderTraceparentPascal = "Traceparent"
+ // HeaderTracestatePascal is the PascalCase variant of the Tracestate header for gRPC metadata.
+ HeaderTracestatePascal = "Tracestate"
- // Rate Limit Headers
- RateLimitLimit = "X-RateLimit-Limit"
+ // RateLimitLimit is the header containing the configured request quota.
+ RateLimitLimit = "X-RateLimit-Limit"
+ // RateLimitRemaining is the header containing remaining requests in the current window.
RateLimitRemaining = "X-RateLimit-Remaining"
- RateLimitReset = "X-RateLimit-Reset"
+ // RateLimitReset is the header containing the reset time for the current window.
+ RateLimitReset = "X-RateLimit-Reset"
)
diff --git a/commons/constants/log.go b/commons/constants/log.go
index 8986588f..a0cfa85f 100644
--- a/commons/constants/log.go
+++ b/commons/constants/log.go
@@ -1,7 +1,4 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
+// LoggerDefaultSeparator is the default delimiter used in composed log messages.
const LoggerDefaultSeparator = " | "
diff --git a/commons/constants/metadata.go b/commons/constants/metadata.go
index cc8219ed..40623cbc 100644
--- a/commons/constants/metadata.go
+++ b/commons/constants/metadata.go
@@ -1,12 +1,12 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
const (
- MetadataID = "metadata_id"
- MetadataTraceparent = "traceparent"
- MetadataTracestate = "tracestate"
+ // MetadataID is the metadata key that carries the request context identifier.
+ MetadataID = "metadata_id"
+ // MetadataTraceparent is the metadata key for W3C traceparent.
+ MetadataTraceparent = "traceparent"
+ // MetadataTracestate is the metadata key for W3C tracestate.
+ MetadataTracestate = "tracestate"
+ // MetadataAuthorization is the metadata key for authorization propagation.
MetadataAuthorization = "authorization"
)
diff --git a/commons/constants/obfuscation.go b/commons/constants/obfuscation.go
index a318ec0a..a299ced1 100644
--- a/commons/constants/obfuscation.go
+++ b/commons/constants/obfuscation.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
const (
diff --git a/commons/constants/opentelemetry.go b/commons/constants/opentelemetry.go
index 4c196fcd..03d21417 100644
--- a/commons/constants/opentelemetry.go
+++ b/commons/constants/opentelemetry.go
@@ -1,7 +1,73 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
+// TelemetrySDKName identifies this library in OTEL telemetry resource attributes.
const TelemetrySDKName = "lib-commons/opentelemetry"
+
+// MaxMetricLabelLength is the maximum length for metric labels to prevent cardinality explosion.
+// Used by assert, runtime, and circuitbreaker packages for label sanitization.
+const MaxMetricLabelLength = 64
+
+// Telemetry attribute key prefixes.
+const (
+ // AttrPrefixAppRequest is the prefix for application request attributes.
+ AttrPrefixAppRequest = "app.request."
+ // AttrPrefixAssertion is the prefix for assertion event attributes.
+ AttrPrefixAssertion = "assertion."
+ // AttrPrefixPanic is the prefix for panic event attributes.
+ AttrPrefixPanic = "panic."
+)
+
+// Telemetry attribute keys for database connectors.
+const (
+ // AttrDBSystem is the OTEL semantic convention attribute key for the database system name.
+ AttrDBSystem = "db.system"
+ // AttrDBName is the OTEL semantic convention attribute key for the database name.
+ AttrDBName = "db.name"
+ // AttrDBMongoDBCollection is the OTEL semantic convention attribute key for the MongoDB collection.
+ AttrDBMongoDBCollection = "db.mongodb.collection"
+)
+
+// Database system identifiers used as values for AttrDBSystem.
+const (
+ // DBSystemPostgreSQL is the OTEL semantic convention value for PostgreSQL.
+ DBSystemPostgreSQL = "postgresql"
+ // DBSystemMongoDB is the OTEL semantic convention value for MongoDB.
+ DBSystemMongoDB = "mongodb"
+ // DBSystemRedis is the OTEL semantic convention value for Redis.
+ DBSystemRedis = "redis"
+ // DBSystemRabbitMQ is the OTEL semantic convention value for RabbitMQ.
+ DBSystemRabbitMQ = "rabbitmq"
+)
+
+// Telemetry metric names.
+const (
+ // MetricPanicRecoveredTotal is the counter metric for recovered panics.
+ MetricPanicRecoveredTotal = "panic_recovered_total"
+ // MetricAssertionFailedTotal is the counter metric for failed assertions.
+ MetricAssertionFailedTotal = "assertion_failed_total"
+)
+
+// Telemetry event names.
+const (
+ // EventAssertionFailed is the span event name for assertion failures.
+ EventAssertionFailed = "assertion.failed"
+ // EventPanicRecovered is the span event name for recovered panics.
+ EventPanicRecovered = "panic.recovered"
+)
+
+// SanitizeMetricLabel truncates a label value to MaxMetricLabelLength runes
+// to prevent metric cardinality explosion in OTEL backends.
+// Truncation is rune-aware to avoid splitting multibyte UTF-8 characters.
+func SanitizeMetricLabel(value string) string {
+ if len(value) <= MaxMetricLabelLength {
+ // Fast path: if byte length is within limit, rune length is too.
+ return value
+ }
+
+ runes := []rune(value)
+ if len(runes) > MaxMetricLabelLength {
+ return string(runes[:MaxMetricLabelLength])
+ }
+
+ return value
+}
diff --git a/commons/constants/opentelemetry_test.go b/commons/constants/opentelemetry_test.go
new file mode 100644
index 00000000..d9e6d659
--- /dev/null
+++ b/commons/constants/opentelemetry_test.go
@@ -0,0 +1,90 @@
+//go:build unit
+
+package constant
+
+import (
+ "strings"
+ "testing"
+ "unicode/utf8"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSanitizeMetricLabel(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {
+ name: "empty string returns empty",
+ input: "",
+ want: "",
+ },
+ {
+ name: "short string returned as-is",
+ input: "short",
+ want: "short",
+ },
+ {
+ name: "exactly 64 chars returned as-is",
+ input: strings.Repeat("x", 64),
+ want: strings.Repeat("x", 64),
+ },
+ {
+ name: "65 chars truncated to 64",
+ input: strings.Repeat("y", 65),
+ want: strings.Repeat("y", 64),
+ },
+ {
+ name: "100 chars truncated to 64",
+ input: strings.Repeat("z", 100),
+ want: strings.Repeat("z", 64),
+ },
+ {
+ name: "single character returned as-is",
+ input: "a",
+ want: "a",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got := SanitizeMetricLabel(tt.input)
+ assert.Equal(t, tt.want, got)
+ assert.LessOrEqual(t, len(got), MaxMetricLabelLength,
+ "result length must never exceed MaxMetricLabelLength")
+ })
+ }
+}
+
+func TestSanitizeMetricLabel_MultibyteSafety(t *testing.T) {
+ t.Parallel()
+
+ // Each emoji is 4 bytes but 1 rune. A 65-emoji string should truncate
+ // to 64 runes without splitting a codepoint.
+ emojis := strings.Repeat("\U0001F600", MaxMetricLabelLength+1) // 65 emojis
+ got := SanitizeMetricLabel(emojis)
+
+ assert.True(t, utf8.ValidString(got), "truncated string must be valid UTF-8")
+ assert.Equal(t, MaxMetricLabelLength, utf8.RuneCountInString(got),
+ "truncated string must have exactly MaxMetricLabelLength runes")
+
+ // Mixed multibyte: CJK characters (3 bytes each)
+ cjk := strings.Repeat("\u4e16", MaxMetricLabelLength+5) // 69 CJK chars
+ got = SanitizeMetricLabel(cjk)
+
+ assert.True(t, utf8.ValidString(got), "CJK truncated string must be valid UTF-8")
+ assert.Equal(t, MaxMetricLabelLength, utf8.RuneCountInString(got))
+}
+
+func TestMaxMetricLabelLength_Value(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, 64, MaxMetricLabelLength,
+ "MaxMetricLabelLength must be 64 to match OTEL cardinality safeguards")
+}
diff --git a/commons/constants/pagination.go b/commons/constants/pagination.go
index 65bf8b7c..981b0067 100644
--- a/commons/constants/pagination.go
+++ b/commons/constants/pagination.go
@@ -1,13 +1,19 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
-type Order string
+// Pagination defaults.
+const (
+ // DefaultLimit is the default number of items per page.
+ DefaultLimit = 20
+ // DefaultOffset is the default pagination offset.
+ DefaultOffset = 0
+ // MaxLimit is the maximum allowed items per page.
+ MaxLimit = 200
+)
-// Order is a type that represents the ordering of a list.
+// Sort direction constants (uppercase, used by HTTP APIs).
const (
- Asc Order = "asc"
- Desc Order = "desc"
+ // SortDirASC is the ascending sort direction for API responses.
+ SortDirASC = "ASC"
+ // SortDirDESC is the descending sort direction for API responses.
+ SortDirDESC = "DESC"
)
diff --git a/commons/constants/response.go b/commons/constants/response.go
new file mode 100644
index 00000000..c7c4a096
--- /dev/null
+++ b/commons/constants/response.go
@@ -0,0 +1,9 @@
+package constant
+
+const (
+ // DefaultErrorTitle is the fallback error title used in HTTP error responses
+ // when no specific title is provided.
+ DefaultErrorTitle = "request_failed"
+ // DefaultInternalErrorMessage is the fallback message for unclassified server errors.
+ DefaultInternalErrorMessage = "An internal error occurred"
+)
diff --git a/commons/constants/transaction.go b/commons/constants/transaction.go
index 8f2cbad6..f9677613 100644
--- a/commons/constants/transaction.go
+++ b/commons/constants/transaction.go
@@ -1,20 +1,28 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package constant
const (
+ // DefaultExternalAccountAliasPrefix prefixes aliases for external accounts.
DefaultExternalAccountAliasPrefix = "@external/"
- ExternalAccountType = "external"
+ // ExternalAccountType identifies external accounts.
+ ExternalAccountType = "external"
- DEBIT = "DEBIT"
- CREDIT = "CREDIT"
- ONHOLD = "ON_HOLD"
+ // DEBIT identifies debit operations.
+ DEBIT = "DEBIT"
+ // CREDIT identifies credit operations.
+ CREDIT = "CREDIT"
+ // ONHOLD identifies hold operations.
+ ONHOLD = "ON_HOLD"
+ // RELEASE identifies release operations.
RELEASE = "RELEASE"
- CREATED = "CREATED"
+ // CREATED identifies transaction intents created but not yet approved.
+ CREATED = "CREATED"
+ // APPROVED identifies transaction intents approved for processing.
APPROVED = "APPROVED"
- PENDING = "PENDING"
+ // PENDING identifies transaction intents currently being processed.
+ PENDING = "PENDING"
+ // CANCELED identifies transaction intents canceled or rolled back.
CANCELED = "CANCELED"
+ // NOTED identifies transaction intents that have been noted/acknowledged.
+ NOTED = "NOTED"
)
diff --git a/commons/context.go b/commons/context.go
index 25339d99..4e4dd545 100644
--- a/commons/context.go
+++ b/commons/context.go
@@ -1,17 +1,14 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package commons
import (
"context"
"errors"
"strings"
+ "sync"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry/metrics"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
"github.com/google/uuid"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
@@ -25,6 +22,7 @@ var ErrNilParentContext = errors.New("cannot create context from nil parent")
type customContextKey string
+// CustomContextKey is the context key used to store CustomContextKeyValue.
var CustomContextKey = customContextKey("custom_context")
// CustomContextKeyValue holds all request-scoped facilities we attach to context.
@@ -41,25 +39,51 @@ type CustomContextKeyValue struct {
// ---- Logger helpers ----
-// NewLoggerFromContext extract the Logger from "logger" value inside context
+// NewLoggerFromContext extract the Logger from "logger" value inside context.
+// A nil ctx is normalized to context.Background() so callers never trigger a nil-pointer dereference.
//
//nolint:ireturn
func NewLoggerFromContext(ctx context.Context) log.Logger {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
if customContext, ok := ctx.Value(CustomContextKey).(*CustomContextKeyValue); ok &&
customContext.Logger != nil {
return customContext.Logger
}
- return &log.NoneLogger{}
+ return &log.NopLogger{}
+}
+
+// cloneContextValues returns a shallow copy of the CustomContextKeyValue from ctx.
+// This prevents concurrent mutation of a shared struct when multiple goroutines
+// derive child contexts from the same parent.
+// The AttrBag slice is deep-copied to avoid aliasing the underlying array.
+func cloneContextValues(ctx context.Context) *CustomContextKeyValue {
+ existing, _ := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
+
+ clone := &CustomContextKeyValue{}
+ if existing != nil {
+ *clone = *existing
+
+ // Deep-copy the slice to avoid aliasing the backing array.
+ if len(existing.AttrBag) > 0 {
+ clone.AttrBag = make([]attribute.KeyValue, len(existing.AttrBag))
+ copy(clone.AttrBag, existing.AttrBag)
+ }
+ }
+
+ return clone
}
// ContextWithLogger returns a context within a Logger in "logger" value.
func ContextWithLogger(ctx context.Context, logger log.Logger) context.Context {
- values, _ := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
- if values == nil {
- values = &CustomContextKeyValue{}
+ if ctx == nil {
+ ctx = context.Background()
}
+ values := cloneContextValues(ctx)
values.Logger = logger
return context.WithValue(ctx, CustomContextKey, values)
@@ -67,26 +91,13 @@ func ContextWithLogger(ctx context.Context, logger log.Logger) context.Context {
// ---- Tracer helpers ----
-// Deprecated: use NewTrackingFromContext instead
-// NewTracerFromContext returns a new tracer from the context.
-//
-//nolint:ireturn
-func NewTracerFromContext(ctx context.Context) trace.Tracer {
- if customContext, ok := ctx.Value(CustomContextKey).(*CustomContextKeyValue); ok &&
- customContext.Tracer != nil {
- return customContext.Tracer
- }
-
- return otel.Tracer("default")
-}
-
// ContextWithTracer returns a context within a trace.Tracer in "tracer" value.
func ContextWithTracer(ctx context.Context, tracer trace.Tracer) context.Context {
- values, _ := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
- if values == nil {
- values = &CustomContextKeyValue{}
+ if ctx == nil {
+ ctx = context.Background()
}
+ values := cloneContextValues(ctx)
values.Tracer = tracer
return context.WithValue(ctx, CustomContextKey, values)
@@ -94,27 +105,13 @@ func ContextWithTracer(ctx context.Context, tracer trace.Tracer) context.Context
// ---- Metrics helpers ----
-// Deprecated: use NewTrackingFromContext instead
-//
-// NewMetricFactoryFromContext returns a new metric factory from the context.
-//
-//nolint:ireturn
-func NewMetricFactoryFromContext(ctx context.Context) *metrics.MetricsFactory {
- if customContext, ok := ctx.Value(CustomContextKey).(*CustomContextKeyValue); ok &&
- customContext.MetricFactory != nil {
- return customContext.MetricFactory
- }
-
- return metrics.NewMetricsFactory(otel.GetMeterProvider().Meter("default"), &log.NoneLogger{})
-}
-
// ContextWithMetricFactory returns a context within a MetricsFactory in "metricFactory" value.
func ContextWithMetricFactory(ctx context.Context, metricFactory *metrics.MetricsFactory) context.Context {
- values, _ := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
- if values == nil {
- values = &CustomContextKeyValue{}
+ if ctx == nil {
+ ctx = context.Background()
}
+ values := cloneContextValues(ctx)
values.MetricFactory = metricFactory
return context.WithValue(ctx, CustomContextKey, values)
@@ -124,32 +121,16 @@ func ContextWithMetricFactory(ctx context.Context, metricFactory *metrics.Metric
// ContextWithHeaderID returns a context within a HeaderID in "headerID" value.
func ContextWithHeaderID(ctx context.Context, headerID string) context.Context {
- values, _ := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
- if values == nil {
- values = &CustomContextKeyValue{}
+ if ctx == nil {
+ ctx = context.Background()
}
+ values := cloneContextValues(ctx)
values.HeaderID = headerID
return context.WithValue(ctx, CustomContextKey, values)
}
-// Deprecated: use NewTrackingFromContext instead
-//
-// NewHeaderIDFromContext returns a HeaderID from the context.
-func NewHeaderIDFromContext(ctx context.Context) string {
- customContext, ok := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
- if !ok {
- return uuid.New().String()
- }
-
- if customContext != nil && strings.TrimSpace(customContext.HeaderID) != "" {
- return customContext.HeaderID
- }
-
- return uuid.New().String()
-}
-
// ---- Tracking bundle (convenience) ----
// TrackingComponents represents the complete set of tracking components extracted from context.
@@ -166,7 +147,12 @@ type TrackingComponents struct {
//
//nolint:ireturn
func NewTrackingFromContext(ctx context.Context) (log.Logger, trace.Tracer, string, *metrics.MetricsFactory) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
components := extractTrackingComponents(ctx)
+
return components.Logger, components.Tracer, components.HeaderID, components.MetricFactory
}
@@ -192,7 +178,7 @@ func resolveLogger(logger log.Logger) log.Logger {
return logger
}
- return &log.NoneLogger{} // Null Object Pattern - always functional
+ return &log.NopLogger{} // Null Object Pattern - always functional
}
// resolveTracer ensures a valid tracer is always available using OpenTelemetry best practices.
@@ -207,6 +193,11 @@ func resolveTracer(tracer trace.Tracer) trace.Tracer {
// resolveHeaderID implements the correlation ID pattern with UUID fallback.
// Ensures every request has a unique identifier for distributed tracing.
+//
+// IMPORTANT: When no HeaderID is present in context, a new UUID is generated on
+// every call to NewTrackingFromContext. Ingress middleware (HTTP/gRPC) MUST persist
+// the generated ID back into context via ContextWithHeaderID so that downstream
+// extractions within the same request return a stable correlation ID.
func resolveHeaderID(headerID string) string {
if trimmed := strings.TrimSpace(headerID); trimmed != "" {
return trimmed
@@ -215,24 +206,46 @@ func resolveHeaderID(headerID string) string {
return uuid.New().String() // Generate unique correlation ID
}
+var (
+ defaultFactoryOnce sync.Once
+ defaultFactory *metrics.MetricsFactory
+)
+
+func getDefaultMetricsFactory() *metrics.MetricsFactory {
+ defaultFactoryOnce.Do(func() {
+ meter := otel.GetMeterProvider().Meter("commons.default")
+
+ f, err := metrics.NewMetricsFactory(meter, &log.NopLogger{})
+ if err != nil {
+ defaultFactory = metrics.NewNopFactory()
+ return
+ }
+
+ defaultFactory = f
+ })
+
+ return defaultFactory
+}
+
// resolveMetricFactory ensures a valid metrics factory is always available following the fail-safe pattern.
-// Provides a default factory when none exists, maintaining consistency with logger and tracer resolution.
+// Provides a cached default factory when none exists, initialized once via sync.Once.
+// Never returns nil: if factory creation fails, it falls back to a no-op factory.
func resolveMetricFactory(factory *metrics.MetricsFactory) *metrics.MetricsFactory {
if factory != nil {
return factory
}
- return metrics.NewMetricsFactory(otel.GetMeterProvider().Meter("commons.default"), &log.NoneLogger{})
+ return getDefaultMetricsFactory()
}
// newDefaultTrackingComponents creates a complete set of default components.
// Used when context extraction fails entirely - ensures system remains operational.
func newDefaultTrackingComponents() TrackingComponents {
return TrackingComponents{
- Logger: &log.NoneLogger{},
+ Logger: &log.NopLogger{},
Tracer: otel.Tracer("commons.default"),
HeaderID: uuid.New().String(),
- MetricFactory: metrics.NewMetricsFactory(otel.GetMeterProvider().Meter("commons.default"), &log.NoneLogger{}),
+ MetricFactory: resolveMetricFactory(nil),
}
}
@@ -242,15 +255,16 @@ func newDefaultTrackingComponents() TrackingComponents {
// Call this once at the ingress (HTTP/gRPC middleware) and avoid per-layer duplication.
// Example keys: tenant.id, enduser.id, request.route, region, plan.
func ContextWithSpanAttributes(ctx context.Context, kv ...attribute.KeyValue) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
if len(kv) == 0 {
return ctx
}
- values, _ := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
- if values == nil {
- values = &CustomContextKeyValue{}
- }
- // Append (preserve order; low-cost).
+ values := cloneContextValues(ctx)
+ // Append to the cloned (independent) slice.
values.AttrBag = append(values.AttrBag, kv...)
return context.WithValue(ctx, CustomContextKey, values)
@@ -258,6 +272,10 @@ func ContextWithSpanAttributes(ctx context.Context, kv ...attribute.KeyValue) co
// AttributesFromContext returns a shallow copy of the AttrBag slice, safe to reuse by processors.
func AttributesFromContext(ctx context.Context) []attribute.KeyValue {
+ if ctx == nil {
+ return nil
+ }
+
if values, ok := ctx.Value(CustomContextKey).(*CustomContextKeyValue); ok && values != nil && len(values.AttrBag) > 0 {
out := make([]attribute.KeyValue, len(values.AttrBag))
copy(out, values.AttrBag)
@@ -270,12 +288,14 @@ func AttributesFromContext(ctx context.Context) []attribute.KeyValue {
// ReplaceAttributes resets the current AttrBag with a new set (rarely needed; provided for completeness).
func ReplaceAttributes(ctx context.Context, kv ...attribute.KeyValue) context.Context {
- values, _ := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
- if values == nil {
- values = &CustomContextKeyValue{}
+ if ctx == nil {
+ ctx = context.Background()
}
- values.AttrBag = append(values.AttrBag[:0], kv...)
+ values := cloneContextValues(ctx)
+ // Replace with a fresh slice -- the clone already has an independent copy.
+ values.AttrBag = make([]attribute.KeyValue, len(kv))
+ copy(values.AttrBag, kv)
return context.WithValue(ctx, CustomContextKey, values)
}
@@ -285,10 +305,7 @@ func ReplaceAttributes(ctx context.Context, kv ...attribute.KeyValue) context.Co
// WithTimeoutSafe creates a context with the specified timeout, but respects
// any existing deadline in the parent context. Returns an error if parent is nil.
//
-// This is the safe alternative to WithTimeout that returns an error instead of panicking.
-// The "Safe" suffix is used here (instead of "WithError") because the function signature
-// returns three values (context, cancel, error) rather than wrapping an existing function.
-// Use WithTimeout for backward-compatible panic behavior.
+// The function returns three values (context, cancel, error) for explicit nil-parent error handling.
//
// Note: When the parent's deadline is shorter than the requested timeout, this function
// returns a cancellable context that inherits the parent's deadline rather than creating
@@ -302,50 +319,12 @@ func WithTimeoutSafe(parent context.Context, timeout time.Duration) (context.Con
timeUntilDeadline := time.Until(deadline)
if timeUntilDeadline < timeout {
- ctx, cancel := context.WithCancel(parent)
+ ctx, cancel := context.WithCancel(parent) // #nosec G118 -- cancel is intentionally returned to the caller for lifecycle management
return ctx, cancel, nil
}
}
- ctx, cancel := context.WithTimeout(parent, timeout)
+ ctx, cancel := context.WithTimeout(parent, timeout) // #nosec G118 -- cancel is intentionally returned to the caller for lifecycle management
return ctx, cancel, nil
}
-
-// Deprecated: Use WithTimeoutSafe instead for proper error handling.
-// WithTimeout panics on nil parent. Prefer WithTimeoutSafe for graceful error handling.
-//
-// WithTimeout creates a context with the specified timeout, but respects
-// any existing deadline in the parent context. If the parent context has
-// a deadline that would expire sooner than the requested timeout, the
-// parent's deadline is used instead.
-//
-// This prevents the common mistake of extending a context's deadline
-// beyond what the caller intended.
-//
-// Example:
-//
-// // Parent has 5s deadline, we request 10s -> gets 5s
-// ctx, cancel := commons.WithTimeout(parentCtx, 10*time.Second)
-// defer cancel()
-func WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
- if parent == nil {
- panic("cannot create context from nil parent")
- }
-
- // Check if parent already has a deadline
- if deadline, ok := parent.Deadline(); ok {
- // Calculate time until parent deadline
- timeUntilDeadline := time.Until(deadline)
-
- // Use the shorter of the two timeouts
- if timeUntilDeadline < timeout {
- // Parent deadline is sooner, just return a cancellable context
- // that respects the parent's deadline
- return context.WithCancel(parent)
- }
- }
-
- // Either parent has no deadline, or our timeout is shorter
- return context.WithTimeout(parent, timeout)
-}
diff --git a/commons/context_clone_test.go b/commons/context_clone_test.go
new file mode 100644
index 00000000..e64e6afe
--- /dev/null
+++ b/commons/context_clone_test.go
@@ -0,0 +1,169 @@
+//go:build unit
+
+package commons
+
+import (
+ "context"
+ "sync"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+)
+
+func TestCloneContextValues(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil context value returns empty non-nil struct", func(t *testing.T) {
+ t.Parallel()
+
+ // context.Background() has no CustomContextKey value.
+ clone := cloneContextValues(context.Background())
+
+ require.NotNil(t, clone)
+ assert.Empty(t, clone.HeaderID)
+ assert.Nil(t, clone.Logger)
+ assert.Nil(t, clone.Tracer)
+ assert.Nil(t, clone.MetricFactory)
+ assert.Nil(t, clone.AttrBag)
+ })
+
+ t.Run("context with wrong type returns empty non-nil struct", func(t *testing.T) {
+ t.Parallel()
+
+ // Store a string instead of *CustomContextKeyValue.
+ ctx := context.WithValue(context.Background(), CustomContextKey, "not-a-struct")
+ clone := cloneContextValues(ctx)
+
+ require.NotNil(t, clone)
+ assert.Empty(t, clone.HeaderID)
+ })
+
+ t.Run("preserves existing values", func(t *testing.T) {
+ t.Parallel()
+
+ nopLogger := &log.NopLogger{}
+ tracer := otel.Tracer("test-clone")
+
+ original := &CustomContextKeyValue{
+ HeaderID: "hdr-abc",
+ Logger: nopLogger,
+ Tracer: tracer,
+ }
+ ctx := context.WithValue(context.Background(), CustomContextKey, original)
+
+ clone := cloneContextValues(ctx)
+
+ require.NotNil(t, clone)
+ assert.Equal(t, "hdr-abc", clone.HeaderID)
+ assert.Equal(t, nopLogger, clone.Logger)
+ assert.Equal(t, tracer, clone.Tracer)
+ })
+
+ t.Run("deep-copies AttrBag so mutating clone does not affect original", func(t *testing.T) {
+ t.Parallel()
+
+ original := &CustomContextKeyValue{
+ HeaderID: "hdr-deep",
+ AttrBag: []attribute.KeyValue{
+ attribute.String("tenant.id", "t1"),
+ attribute.String("region", "us-east"),
+ },
+ }
+ ctx := context.WithValue(context.Background(), CustomContextKey, original)
+
+ clone := cloneContextValues(ctx)
+
+ // Verify initial equality.
+ require.Len(t, clone.AttrBag, 2)
+ assert.Equal(t, original.AttrBag, clone.AttrBag)
+
+ // Mutate the clone's AttrBag.
+ clone.AttrBag[0] = attribute.String("tenant.id", "MUTATED")
+ clone.AttrBag = append(clone.AttrBag, attribute.String("extra", "added"))
+
+ // Original must be unchanged.
+ assert.Equal(t, "t1", original.AttrBag[0].Value.AsString())
+ assert.Len(t, original.AttrBag, 2)
+ })
+
+ t.Run("empty AttrBag is shallow-copied without deep-copy allocation", func(t *testing.T) {
+ t.Parallel()
+
+ original := &CustomContextKeyValue{
+ HeaderID: "hdr-empty-bag",
+ AttrBag: []attribute.KeyValue{},
+ }
+ ctx := context.WithValue(context.Background(), CustomContextKey, original)
+
+ clone := cloneContextValues(ctx)
+
+ // The struct copy (*clone = *existing) propagates the empty slice.
+ // The deep-copy branch is skipped (len == 0), so the clone gets the
+ // original's empty-but-non-nil slice header. This is correct behavior:
+ // no allocation needed for an empty bag.
+ assert.Empty(t, clone.AttrBag)
+ assert.Equal(t, "hdr-empty-bag", clone.HeaderID)
+ })
+
+ t.Run("clone is independent — modifying clone fields does not affect original", func(t *testing.T) {
+ t.Parallel()
+
+ nopLogger := &log.NopLogger{}
+ original := &CustomContextKeyValue{
+ HeaderID: "hdr-independent",
+ Logger: nopLogger,
+ }
+ ctx := context.WithValue(context.Background(), CustomContextKey, original)
+
+ clone := cloneContextValues(ctx)
+ clone.HeaderID = "CHANGED"
+ clone.Logger = nil
+
+ // Original must remain intact.
+ assert.Equal(t, "hdr-independent", original.HeaderID)
+ assert.Equal(t, nopLogger, original.Logger)
+ })
+}
+
+func TestCloneContextValues_Concurrent(t *testing.T) {
+ t.Parallel()
+
+ // Two goroutines derive independent clones from the same parent context.
+ // They both mutate their clone's AttrBag without data races.
+ original := &CustomContextKeyValue{
+ HeaderID: "hdr-concurrent",
+ AttrBag: []attribute.KeyValue{
+ attribute.String("shared", "value"),
+ },
+ }
+ parentCtx := context.WithValue(context.Background(), CustomContextKey, original)
+
+ const goroutines = 50
+
+ var wg sync.WaitGroup
+
+ wg.Add(goroutines)
+
+ for i := range goroutines {
+ go func(id int) {
+ defer wg.Done()
+
+ clone := cloneContextValues(parentCtx)
+
+ // Each goroutine mutates its own clone.
+ clone.AttrBag = append(clone.AttrBag, attribute.Int("goroutine", id))
+ clone.HeaderID = "modified"
+ }(i)
+ }
+
+ wg.Wait()
+
+ // After all goroutines complete, the original must be untouched.
+ assert.Equal(t, "hdr-concurrent", original.HeaderID)
+ assert.Len(t, original.AttrBag, 1)
+ assert.Equal(t, "value", original.AttrBag[0].Value.AsString())
+}
diff --git a/commons/context_example_test.go b/commons/context_example_test.go
new file mode 100644
index 00000000..ee4e7825
--- /dev/null
+++ b/commons/context_example_test.go
@@ -0,0 +1,29 @@
+//go:build unit
+
+package commons_test
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons"
+)
+
+func ExampleWithTimeoutSafe() {
+ ctx := context.Background()
+
+ timeoutCtx, cancel, err := commons.WithTimeoutSafe(ctx, 100*time.Millisecond)
+ if cancel != nil {
+ defer cancel()
+ }
+
+ _, hasDeadline := timeoutCtx.Deadline()
+
+ fmt.Println(err == nil)
+ fmt.Println(hasDeadline)
+
+ // Output:
+ // true
+ // true
+}
diff --git a/commons/context_test.go b/commons/context_test.go
index b81fe124..9f0fa0bc 100644
--- a/commons/context_test.go
+++ b/commons/context_test.go
@@ -1,6 +1,4 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package commons
@@ -9,91 +7,13 @@ import (
"errors"
"testing"
"time"
-)
-
-func TestWithTimeout_NoParentDeadline(t *testing.T) {
- parent := context.Background()
- timeout := 5 * time.Second
-
- ctx, cancel := WithTimeout(parent, timeout)
- defer cancel()
-
- deadline, ok := ctx.Deadline()
- if !ok {
- t.Fatal("expected context to have a deadline")
- }
-
- expectedDeadline := time.Now().Add(timeout)
- // Allow 200ms variance for test execution time
- timeUntil := time.Until(deadline)
- if timeUntil < 4800*time.Millisecond || timeUntil > 5200*time.Millisecond {
- t.Errorf("deadline not within expected range: got %v (%.2fs remaining), expected ~%v (5s)",
- deadline, timeUntil.Seconds(), expectedDeadline)
- }
-}
-
-func TestWithTimeout_ParentDeadlineShorter(t *testing.T) {
- // Parent has 2s deadline
- parent, parentCancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer parentCancel()
-
- // We request 10s, but parent's 2s should win
- ctx, cancel := WithTimeout(parent, 10*time.Second)
- defer cancel()
-
- deadline, ok := ctx.Deadline()
- if !ok {
- t.Fatal("expected context to have a deadline")
- }
-
- // Should use parent's deadline (2s)
- timeUntil := time.Until(deadline)
- if timeUntil > 2*time.Second || timeUntil < 1*time.Second {
- t.Errorf("expected deadline to be ~2s from now, got %v", timeUntil)
- }
-}
-
-func TestWithTimeout_ParentDeadlineLonger(t *testing.T) {
- // Parent has 10s deadline
- parent, parentCancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer parentCancel()
-
- // We request 2s, our timeout should win
- ctx, cancel := WithTimeout(parent, 2*time.Second)
- defer cancel()
- deadline, ok := ctx.Deadline()
- if !ok {
- t.Fatal("expected context to have a deadline")
- }
-
- // Should use our timeout (2s)
- timeUntil := time.Until(deadline)
- // Allow 200ms variance
- if timeUntil < 1800*time.Millisecond || timeUntil > 2200*time.Millisecond {
- t.Errorf("expected deadline to be ~2s from now, got %v (%.2fs)", timeUntil, timeUntil.Seconds())
- }
-}
-
-func TestWithTimeout_CancelWorks(t *testing.T) {
- parent := context.Background()
- ctx, cancel := WithTimeout(parent, 5*time.Second)
-
- // Cancel immediately
- cancel()
-
- // Context should be cancelled
- select {
- case <-ctx.Done():
- // Expected
- case <-time.After(100 * time.Millisecond):
- t.Error("context was not cancelled")
- }
-
- if ctx.Err() != context.Canceled {
- t.Errorf("expected context.Canceled error, got %v", ctx.Err())
- }
-}
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+)
func TestWithTimeoutSafe_NilParent(t *testing.T) {
ctx, cancel, err := WithTimeoutSafe(nil, 5*time.Second)
@@ -166,7 +86,6 @@ func TestWithTimeoutSafe_ParentDeadlineShorter(t *testing.T) {
func TestWithTimeoutSafe_CancelWorks(t *testing.T) {
parent := context.Background()
ctx, cancel, err := WithTimeoutSafe(parent, 5*time.Second)
-
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -185,16 +104,6 @@ func TestWithTimeoutSafe_CancelWorks(t *testing.T) {
}
}
-func TestWithTimeout_PanicOnNilParent(t *testing.T) {
- defer func() {
- if r := recover(); r == nil {
- t.Error("expected panic for nil parent")
- }
- }()
-
- WithTimeout(nil, 5*time.Second)
-}
-
func TestWithTimeoutSafe_ZeroTimeout(t *testing.T) {
parent := context.Background()
ctx, cancel, err := WithTimeoutSafe(parent, 0)
@@ -238,3 +147,217 @@ func TestWithTimeoutSafe_NegativeTimeout(t *testing.T) {
t.Error("expected context to be done with negative timeout")
}
}
+
+// ---- Logger context helpers ----
+
+func TestNewLoggerFromContext(t *testing.T) {
+ t.Parallel()
+
+ t.Run("without_logger", func(t *testing.T) {
+ t.Parallel()
+
+ logger := NewLoggerFromContext(context.Background())
+ require.NotNil(t, logger)
+ assert.IsType(t, &log.NopLogger{}, logger)
+ })
+
+ t.Run("with_logger", func(t *testing.T) {
+ t.Parallel()
+
+ nop := &log.NopLogger{}
+ ctx := ContextWithLogger(context.Background(), nop)
+ logger := NewLoggerFromContext(ctx)
+ assert.Equal(t, nop, logger)
+ })
+
+ t.Run("nil_ctx_returns_nop_logger", func(t *testing.T) {
+ t.Parallel()
+
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety
+ logger := NewLoggerFromContext(nil)
+ require.NotNil(t, logger, "nil ctx must not panic, must return NopLogger")
+ assert.IsType(t, &log.NopLogger{}, logger)
+ })
+}
+
+func TestContextWithLogger(t *testing.T) {
+ t.Parallel()
+
+ nop := &log.NopLogger{}
+ ctx := ContextWithLogger(context.Background(), nop)
+ v := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
+ assert.Equal(t, nop, v.Logger)
+}
+
+// ---- Tracer context helpers ----
+
+func TestContextWithTracer(t *testing.T) {
+ t.Parallel()
+
+ tracer := otel.Tracer("test")
+ ctx := ContextWithTracer(context.Background(), tracer)
+ v := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
+ assert.Equal(t, tracer, v.Tracer)
+}
+
+// ---- MetricFactory context helpers ----
+
+func TestContextWithMetricFactory(t *testing.T) {
+ t.Parallel()
+
+ ctx := ContextWithMetricFactory(context.Background(), nil)
+ v := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
+ assert.Nil(t, v.MetricFactory)
+}
+
+// ---- HeaderID context helpers ----
+
+func TestContextWithHeaderID(t *testing.T) {
+ t.Parallel()
+
+ ctx := ContextWithHeaderID(context.Background(), "hdr-123")
+ v := ctx.Value(CustomContextKey).(*CustomContextKeyValue)
+ assert.Equal(t, "hdr-123", v.HeaderID)
+}
+
+// ---- Tracking bundle ----
+
+func TestNewTrackingFromContext(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty_context_returns_defaults", func(t *testing.T) {
+ t.Parallel()
+
+ logger, tracer, headerID, mf := NewTrackingFromContext(context.Background())
+ assert.NotNil(t, logger)
+ assert.NotNil(t, tracer)
+ assert.NotEmpty(t, headerID)
+ assert.NotNil(t, mf)
+ })
+
+ t.Run("full_context", func(t *testing.T) {
+ t.Parallel()
+
+ nop := &log.NopLogger{}
+ tracer := otel.Tracer("test-tracer")
+ ctx := ContextWithLogger(context.Background(), nop)
+ ctx = ContextWithTracer(ctx, tracer)
+ ctx = ContextWithHeaderID(ctx, "id-456")
+
+ logger, tr, hid, mf := NewTrackingFromContext(ctx)
+ assert.Equal(t, nop, logger)
+ assert.Equal(t, tracer, tr)
+ assert.Equal(t, "id-456", hid)
+ assert.NotNil(t, mf)
+ })
+
+ t.Run("nil_values_get_defaults", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.WithValue(context.Background(), CustomContextKey, &CustomContextKeyValue{})
+
+ logger, tracer, headerID, mf := NewTrackingFromContext(ctx)
+ assert.IsType(t, &log.NopLogger{}, logger)
+ assert.NotNil(t, tracer)
+ assert.NotEmpty(t, headerID)
+ assert.NotNil(t, mf)
+ })
+
+ t.Run("nil_ctx_returns_defaults", func(t *testing.T) {
+ t.Parallel()
+
+ //nolint:staticcheck // SA1012: intentionally testing nil ctx
+ logger, tracer, headerID, mf := NewTrackingFromContext(nil)
+ assert.NotNil(t, logger)
+ assert.NotNil(t, tracer)
+ assert.NotEmpty(t, headerID)
+ assert.NotNil(t, mf)
+ })
+}
+
+// ---- Attribute Bag ----
+
+func TestContextWithSpanAttributes(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty_kvs_returns_same_ctx", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ ctx2 := ContextWithSpanAttributes(ctx)
+ assert.Equal(t, ctx, ctx2)
+ })
+
+ t.Run("nil_ctx_with_no_attrs_returns_non_nil", func(t *testing.T) {
+ t.Parallel()
+
+ //nolint:staticcheck // SA1012: intentionally testing nil ctx
+ result := ContextWithSpanAttributes(nil)
+ assert.NotNil(t, result, "nil ctx + no attrs must return context.Background(), not nil")
+ })
+
+ t.Run("nil_ctx_with_attrs_returns_non_nil", func(t *testing.T) {
+ t.Parallel()
+
+ //nolint:staticcheck // SA1012: intentionally testing nil ctx
+ result := ContextWithSpanAttributes(nil, attribute.String("k", "v"))
+ assert.NotNil(t, result, "nil ctx + attrs must return valid context, not nil")
+
+ attrs := AttributesFromContext(result)
+ assert.Len(t, attrs, 1)
+ })
+
+ t.Run("appends_attributes", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := ContextWithSpanAttributes(context.Background(),
+ attribute.String("tenant.id", "t1"),
+ )
+ ctx = ContextWithSpanAttributes(ctx,
+ attribute.String("region", "us"),
+ )
+
+ attrs := AttributesFromContext(ctx)
+ assert.Len(t, attrs, 2)
+ })
+}
+
+func TestAttributesFromContext(t *testing.T) {
+ t.Parallel()
+
+ t.Run("no_attributes", func(t *testing.T) {
+ t.Parallel()
+ assert.Nil(t, AttributesFromContext(context.Background()))
+ })
+
+ t.Run("returns_copy", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := ContextWithSpanAttributes(context.Background(),
+ attribute.String("k", "v"),
+ )
+
+ a1 := AttributesFromContext(ctx)
+ a2 := AttributesFromContext(ctx)
+ assert.Equal(t, a1, a2)
+
+ // Mutating the copy should not affect next retrieval.
+ a1[0] = attribute.String("k", "changed")
+ a3 := AttributesFromContext(ctx)
+ assert.Equal(t, "v", a3[0].Value.AsString())
+ })
+}
+
+func TestReplaceAttributes(t *testing.T) {
+ t.Parallel()
+
+ ctx := ContextWithSpanAttributes(context.Background(),
+ attribute.String("old", "val"),
+ )
+
+ ctx = ReplaceAttributes(ctx, attribute.String("new", "val2"))
+
+ attrs := AttributesFromContext(ctx)
+ require.Len(t, attrs, 1)
+ assert.Equal(t, "new", string(attrs[0].Key))
+}
diff --git a/commons/cron/cron.go b/commons/cron/cron.go
new file mode 100644
index 00000000..b0dee270
--- /dev/null
+++ b/commons/cron/cron.go
@@ -0,0 +1,323 @@
+package cron
+
+import (
+ "errors"
+ "fmt"
+ "slices"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// ErrInvalidExpression is returned when a cron expression cannot be parsed
+// due to incorrect field count, out-of-range values, or malformed syntax.
+var ErrInvalidExpression = errors.New("invalid cron expression")
+
+// ErrNoMatch is returned when Next exhausts its iteration limit without
+// finding a time that satisfies all cron fields.
+var ErrNoMatch = errors.New("cron: no matching time found within iteration limit")
+
+// ErrNilSchedule is returned when Next is called on a nil schedule receiver.
+var ErrNilSchedule = errors.New("cron schedule is nil")
+
+// Cron field boundary constants.
+const (
+ cronFieldCount = 5 // number of fields in a standard cron expression
+ maxMinute = 59 // maximum value for minute field
+ maxHour = 23 // maximum value for hour field
+ minDayOfMonth = 1 // minimum value for day-of-month field
+ maxDayOfMonth = 31 // maximum value for day-of-month field
+ minMonth = 1 // minimum value for month field
+ maxMonth = 12 // maximum value for month field
+ maxDayOfWeek = 7 // maximum accepted day-of-week value (7 is normalized to 0 = Sunday)
+ splitParts = 2 // number of parts when splitting step or range expressions
+)
+
+// Schedule represents a parsed cron schedule capable of computing
+// the next execution time after a given reference time.
+type Schedule interface {
+ Next(t time.Time) (time.Time, error)
+}
+
+type schedule struct {
+ minutes []int
+ hours []int
+ doms []int
+ months []int
+ dows []int
+ domIsWild bool // true when the day-of-month field was "*" (unrestricted)
+ dowIsWild bool // true when the day-of-week field was "*" (unrestricted)
+}
+
+// Parse parses a standard 5-field cron expression and returns a Schedule
+// that can compute the next execution time. The expression format is:
+// minute hour day-of-month month day-of-week
+// Returns ErrInvalidExpression if the expression is malformed or contains out-of-range values.
+func Parse(expr string) (Schedule, error) {
+ expr = strings.TrimSpace(expr)
+ if expr == "" {
+ return nil, fmt.Errorf("%w: empty expression", ErrInvalidExpression)
+ }
+
+ fields := strings.Fields(expr)
+ if len(fields) != cronFieldCount {
+ return nil, fmt.Errorf("%w: expected %d fields, got %d", ErrInvalidExpression, cronFieldCount, len(fields))
+ }
+
+ minutes, err := parseField(fields[0], 0, maxMinute)
+ if err != nil {
+ return nil, fmt.Errorf("invalid minute field: %w", err)
+ }
+
+ hours, err := parseField(fields[1], 0, maxHour)
+ if err != nil {
+ return nil, fmt.Errorf("invalid hour field: %w", err)
+ }
+
+ domIsWild := isWildcard(fields[2])
+
+ doms, err := parseField(fields[2], minDayOfMonth, maxDayOfMonth)
+ if err != nil {
+ return nil, fmt.Errorf("invalid day-of-month field: %w", err)
+ }
+
+ months, err := parseField(fields[3], minMonth, maxMonth)
+ if err != nil {
+ return nil, fmt.Errorf("invalid month field: %w", err)
+ }
+
+ dowIsWild := isWildcard(fields[4])
+
+ dows, err := parseField(fields[4], 0, maxDayOfWeek)
+ if err != nil {
+ return nil, fmt.Errorf("invalid day-of-week field: %w", err)
+ }
+
+ // Normalize DOW 7 → 0 (both mean Sunday) per widespread cron convention.
+ dows = normalizeDOW(dows)
+
+ return &schedule{
+ minutes: minutes,
+ hours: hours,
+ doms: doms,
+ months: months,
+ dows: dows,
+ domIsWild: domIsWild,
+ dowIsWild: dowIsWild,
+ }, nil
+}
+
+// Next computes the next execution time after the given reference time.
+// It normalizes the input to UTC, advances by one minute, and iteratively
+// checks each cron field (month, day-of-month, day-of-week, hour, minute)
+// to find the next matching time. Returns the matching time in UTC, or
+// ErrNoMatch if no match is found within maxIterations.
+//
+// DOM/DOW semantics follow the standard cron convention: when BOTH fields are
+// restricted (not wildcards), the day matches if EITHER condition is true (OR).
+// When only one is restricted, that field alone determines the match.
+func (sched *schedule) Next(from time.Time) (time.Time, error) {
+ if sched == nil {
+ return time.Time{}, ErrNilSchedule
+ }
+
+ from = from.UTC()
+ candidate := from.Add(time.Minute)
+ candidate = time.Date(candidate.Year(), candidate.Month(), candidate.Day(), candidate.Hour(), candidate.Minute(), 0, 0, time.UTC)
+
+ // 4 years (1461 days) to accommodate leap-day and other sparse schedules.
+ const maxIterations = 1461 * 24 * 60
+ for range maxIterations {
+ if !slices.Contains(sched.months, int(candidate.Month())) {
+ candidate = time.Date(candidate.Year(), candidate.Month()+1, 1, 0, 0, 0, 0, time.UTC)
+
+ continue
+ }
+
+ if !sched.matchDay(candidate) {
+ candidate = candidate.AddDate(0, 0, 1)
+ candidate = time.Date(candidate.Year(), candidate.Month(), candidate.Day(), 0, 0, 0, 0, time.UTC)
+
+ continue
+ }
+
+ if !slices.Contains(sched.hours, candidate.Hour()) {
+ candidate = candidate.Add(time.Hour)
+ candidate = time.Date(candidate.Year(), candidate.Month(), candidate.Day(), candidate.Hour(), 0, 0, 0, time.UTC)
+
+ continue
+ }
+
+ if !slices.Contains(sched.minutes, candidate.Minute()) {
+ candidate = candidate.Add(time.Minute)
+
+ continue
+ }
+
+ return candidate, nil
+ }
+
+ return time.Time{}, ErrNoMatch
+}
+
+// matchDay implements standard cron DOM/DOW semantics:
+// - Both wildcards: any day matches.
+// - Only DOM restricted: match on DOM alone.
+// - Only DOW restricted: match on DOW alone.
+// - Both restricted: match if EITHER DOM or DOW is satisfied (OR semantics).
+func (sched *schedule) matchDay(t time.Time) bool {
+ domMatch := slices.Contains(sched.doms, t.Day())
+ dowMatch := slices.Contains(sched.dows, int(t.Weekday()))
+
+ switch {
+ case sched.domIsWild && sched.dowIsWild:
+ return true
+ case sched.domIsWild:
+ return dowMatch
+ case sched.dowIsWild:
+ return domMatch
+ default:
+ // Both restricted: standard cron OR semantics.
+ return domMatch || dowMatch
+ }
+}
+
+func parseField(field string, minVal, maxVal int) ([]int, error) {
+ var result []int
+
+ for part := range strings.SplitSeq(field, ",") {
+ vals, err := parsePart(part, minVal, maxVal)
+ if err != nil {
+ return nil, err
+ }
+
+ result = append(result, vals...)
+ }
+
+ return deduplicate(result), nil
+}
+
+func parsePart(part string, minVal, maxVal int) ([]int, error) {
+ var rangeStart, rangeEnd, step int
+
+ stepParts := strings.SplitN(part, "/", splitParts)
+ hasStep := len(stepParts) == splitParts
+
+ if hasStep {
+ s, err := parseStep(stepParts[1])
+ if err != nil {
+ return nil, err
+ }
+
+ step = s
+ }
+
+ rangePart := stepParts[0]
+
+ switch {
+ case rangePart == "*":
+ rangeStart = minVal
+ rangeEnd = maxVal
+ case strings.Contains(rangePart, "-"):
+ lo, hi, err := parseRange(rangePart, minVal, maxVal)
+ if err != nil {
+ return nil, err
+ }
+
+ rangeStart = lo
+ rangeEnd = hi
+ default:
+ val, err := strconv.Atoi(rangePart)
+ if err != nil {
+ return nil, fmt.Errorf("%w: invalid value %q", ErrInvalidExpression, rangePart)
+ }
+
+ if val < minVal || val > maxVal {
+ return nil, fmt.Errorf("%w: value %d out of bounds [%d, %d]", ErrInvalidExpression, val, minVal, maxVal)
+ }
+
+ if hasStep {
+ rangeStart = val
+ rangeEnd = maxVal
+ } else {
+ return []int{val}, nil
+ }
+ }
+
+ if !hasStep {
+ step = 1
+ }
+
+ var vals []int
+ for v := rangeStart; v <= rangeEnd; v += step {
+ vals = append(vals, v)
+ }
+
+ return vals, nil
+}
+
+// parseStep parses and validates a cron step value, ensuring it is a positive integer.
+func parseStep(raw string) (int, error) {
+ s, err := strconv.Atoi(raw)
+ if err != nil || s <= 0 {
+ return 0, fmt.Errorf("%w: invalid step %q", ErrInvalidExpression, raw)
+ }
+
+ return s, nil
+}
+
+// parseRange parses a "lo-hi" range expression, validates bounds against
+// [minVal, maxVal], and returns the low and high values.
+func parseRange(rangePart string, minVal, maxVal int) (int, int, error) {
+ bounds := strings.SplitN(rangePart, "-", splitParts)
+
+ lo, err := strconv.Atoi(bounds[0])
+ if err != nil {
+ return 0, 0, fmt.Errorf("%w: invalid range start %q", ErrInvalidExpression, bounds[0])
+ }
+
+ hi, err := strconv.Atoi(bounds[1])
+ if err != nil {
+ return 0, 0, fmt.Errorf("%w: invalid range end %q", ErrInvalidExpression, bounds[1])
+ }
+
+ if lo < minVal || hi > maxVal || lo > hi {
+ return 0, 0, fmt.Errorf("%w: range %d-%d out of bounds [%d, %d]", ErrInvalidExpression, lo, hi, minVal, maxVal)
+ }
+
+ return lo, hi, nil
+}
+
+func deduplicate(vals []int) []int {
+ seen := make(map[int]bool, len(vals))
+ result := make([]int, 0, len(vals))
+
+ for _, v := range vals {
+ if !seen[v] {
+ seen[v] = true
+ result = append(result, v)
+ }
+ }
+
+ slices.Sort(result)
+
+ return result
+}
+
+// isWildcard reports whether a cron field token is an unrestricted wildcard.
+// Only bare "*" counts; "*/5" is a step expression and is considered restricted.
+func isWildcard(field string) bool {
+ return field == "*"
+}
+
+// normalizeDOW rewrites day-of-week value 7 to 0 (both represent Sunday)
+// and deduplicates the result, matching widespread cron convention.
+func normalizeDOW(dows []int) []int {
+ for i, v := range dows {
+ if v == 7 {
+ dows[i] = 0
+ }
+ }
+
+ return deduplicate(dows)
+}
diff --git a/commons/cron/cron_test.go b/commons/cron/cron_test.go
new file mode 100644
index 00000000..e277f793
--- /dev/null
+++ b/commons/cron/cron_test.go
@@ -0,0 +1,306 @@
+//go:build unit
+
+package cron
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParse_DailyMidnight(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("0 0 * * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 16, 0, 0, 0, 0, time.UTC), next)
+}
+
+func TestParse_EveryFiveMinutes(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("*/5 * * * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 10, 3, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 15, 10, 5, 0, 0, time.UTC), next)
+}
+
+func TestParse_DailySixThirty(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("30 6 * * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 7, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 16, 6, 30, 0, 0, time.UTC), next)
+}
+
+func TestParse_DailyNoon(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("0 12 * * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 15, 12, 0, 0, 0, time.UTC), next)
+}
+
+func TestParse_EveryMonday(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("0 0 * * 1")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Monday, next.Weekday())
+ assert.Equal(t, 0, next.Hour())
+ assert.Equal(t, 0, next.Minute())
+ assert.True(t, next.After(from))
+}
+
+func TestParse_FifteenthOfMonth(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("0 0 15 * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 16, 0, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, 15, next.Day())
+ assert.Equal(t, 0, next.Hour())
+ assert.Equal(t, 0, next.Minute())
+ assert.True(t, next.After(from))
+}
+
+func TestParse_Ranges(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("0 9-17 * * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 18, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, 9, next.Hour())
+ assert.Equal(t, time.Date(2026, 1, 16, 9, 0, 0, 0, time.UTC), next)
+}
+
+func TestParse_Lists(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("0 6,12,18 * * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 7, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 15, 12, 0, 0, 0, time.UTC), next)
+}
+
+func TestParse_RangeWithStep(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse("0 1-10/3 * * *")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 15, 1, 0, 0, 0, time.UTC), next)
+
+ next, err = sched.Next(next)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 15, 4, 0, 0, 0, time.UTC), next)
+}
+
+func TestParse_InvalidExpression(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("not-a-cron")
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidExpression)
+}
+
+func TestParse_EmptyString(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("")
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidExpression)
+}
+
+func TestParse_TooFewFields(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("0 0 *")
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidExpression)
+}
+
+func TestParse_TooManyFields(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("0 0 * * * *")
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidExpression)
+}
+
+func TestParse_OutOfRangeValue(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("60 0 * * *")
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidExpression)
+}
+
+func TestParse_InvalidStep(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("*/0 * * * *")
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidExpression)
+}
+
+func TestParse_WhitespaceHandling(t *testing.T) {
+ t.Parallel()
+
+ sched, err := Parse(" 0 0 * * * ")
+ require.NoError(t, err)
+
+ from := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 16, 0, 0, 0, 0, time.UTC), next)
+}
+
+func TestNext_ExhaustionReturnsError(t *testing.T) {
+ t.Parallel()
+
+ // Schedule for Feb 30 — a date that never exists.
+ // DOW is wildcard so day matching uses DOM alone; February never has day 30.
+ // This forces the iterator to exhaust maxIterations without finding a match.
+ sched := &schedule{
+ minutes: []int{0},
+ hours: []int{0},
+ doms: []int{30},
+ months: []int{2},
+ dows: []int{0, 1, 2, 3, 4, 5, 6},
+ domIsWild: false,
+ dowIsWild: true, // simulate "0 0 30 2 *"
+ }
+
+ from := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNoMatch)
+ assert.True(t, next.IsZero(), "expected zero time on exhaustion")
+}
+
+func TestParse_DOW7NormalizedToSunday(t *testing.T) {
+ t.Parallel()
+
+ // DOW 7 should be accepted and treated as Sunday (0).
+ sched, err := Parse("0 0 * * 7")
+ require.NoError(t, err)
+
+ // 2026-01-18 is a Sunday.
+ from := time.Date(2026, 1, 17, 12, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Sunday, next.Weekday())
+ assert.Equal(t, time.Date(2026, 1, 18, 0, 0, 0, 0, time.UTC), next)
+}
+
+func TestParse_DOMAndDOWBothRestricted_ORSemantics(t *testing.T) {
+ t.Parallel()
+
+ // "0 0 15 * 1" = midnight on the 15th OR on any Monday.
+ // Standard cron: when both DOM and DOW are restricted, match EITHER.
+ sched, err := Parse("0 0 15 * 1")
+ require.NoError(t, err)
+
+ // 2026-01-15 is a Thursday. Should match because DOM=15.
+ from := time.Date(2026, 1, 14, 23, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC), next,
+ "should match DOM=15 even though it's Thursday, not Monday (OR semantics)")
+}
+
+func TestParse_DOMAndDOWBothRestricted_MatchesDOW(t *testing.T) {
+ t.Parallel()
+
+ // "0 0 15 * 1" = midnight on the 15th OR on any Monday.
+ sched, err := Parse("0 0 15 * 1")
+ require.NoError(t, err)
+
+ // 2026-01-19 is a Monday. Should match because DOW=1.
+ from := time.Date(2026, 1, 18, 12, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2026, 1, 19, 0, 0, 0, 0, time.UTC), next,
+ "should match DOW=Monday even though DOM is not 15 (OR semantics)")
+}
+
+func TestParse_LeapDaySparseSchedule(t *testing.T) {
+ t.Parallel()
+
+ // "0 0 29 2 *" = Feb 29 only. Needs 4-year search window.
+ sched, err := Parse("0 0 29 2 *")
+ require.NoError(t, err)
+
+ // Starting from 2025, the next Feb 29 is 2028-02-29.
+ from := time.Date(2025, 3, 1, 0, 0, 0, 0, time.UTC)
+ next, err := sched.Next(from)
+
+ require.NoError(t, err)
+ assert.Equal(t, time.Date(2028, 2, 29, 0, 0, 0, 0, time.UTC), next)
+}
+
+func TestNext_NilScheduleReturnsError(t *testing.T) {
+ t.Parallel()
+
+ var sched *schedule
+
+ next, err := sched.Next(time.Now())
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilSchedule)
+ assert.True(t, next.IsZero())
+}
diff --git a/commons/cron/doc.go b/commons/cron/doc.go
new file mode 100644
index 00000000..67135006
--- /dev/null
+++ b/commons/cron/doc.go
@@ -0,0 +1,5 @@
+// Package cron parses standard 5-field cron expressions and computes next run times.
+//
+// It supports wildcards, ranges, steps, and lists across minute, hour,
+// day-of-month, month, and day-of-week fields.
+package cron
diff --git a/commons/crypto/crypto.go b/commons/crypto/crypto.go
index 021fd88a..d2df72f6 100644
--- a/commons/crypto/crypto.go
+++ b/commons/crypto/crypto.go
@@ -1,10 +1,7 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package crypto
import (
+ "context"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
@@ -13,12 +10,38 @@ import (
"encoding/base64"
"encoding/hex"
"errors"
+ "fmt"
"io"
+ "reflect"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
- libLog "github.com/LerianStudio/lib-commons/v2/commons/log"
- "go.uber.org/zap"
+var (
+ // ErrCipherNotInitialized is returned when encryption/decryption is attempted before InitializeCipher.
+ ErrCipherNotInitialized = errors.New("cipher not initialized")
+ // ErrCiphertextTooShort is returned when the ciphertext is shorter than the nonce size.
+ ErrCiphertextTooShort = errors.New("ciphertext too short")
+ // ErrNilCrypto is returned when a Crypto method is called on a nil receiver.
+ ErrNilCrypto = errors.New("crypto instance is nil")
+ // ErrNilInput is returned when a nil pointer is passed to Encrypt or Decrypt.
+ ErrNilInput = errors.New("nil input")
+ // ErrEmptyKey is returned when an empty hash secret key is provided to GenerateHash.
+ ErrEmptyKey = errors.New("hash secret key must not be empty")
)
+// isNilInterface returns true if the interface value is nil or holds a typed nil.
+func isNilInterface(i any) bool {
+ if i == nil {
+ return true
+ }
+
+ v := reflect.ValueOf(i)
+
+ return v.Kind() == reflect.Ptr && v.IsNil()
+}
+
+// Crypto groups hashing and symmetric encryption helpers.
type Crypto struct {
HashSecretKey string
EncryptSecretKey string
@@ -26,9 +49,45 @@ type Crypto struct {
Cipher cipher.AEAD
}
-// GenerateHash using HMAC-SHA256
+// String implements fmt.Stringer to prevent accidental secret key exposure in logs or spans.
+func (c *Crypto) String() string {
+ if c == nil {
+ return ""
+ }
+
+ return "Crypto{keys:REDACTED}"
+}
+
+// GoString implements fmt.GoStringer to prevent accidental secret key exposure in %#v formatting.
+func (c *Crypto) GoString() string {
+ return c.String()
+}
+
+// logger returns the configured Logger, falling back to a NopLogger if nil.
+// Uses isNilInterface to detect typed nils (e.g. (*MyLogger)(nil)).
+func (c *Crypto) logger() libLog.Logger {
+ if c == nil || isNilInterface(c.Logger) {
+ return libLog.NewNop()
+ }
+
+ return c.Logger
+}
+
+// GenerateHash produces an HMAC-SHA256 hex-encoded hash of the plaintext.
+//
+// Returns "" for nil receiver or nil input as intentional safe degradation:
+// callers that cannot supply a Crypto instance or input get a deterministic
+// empty result rather than an error, which simplifies optional-hash pipelines.
+//
+// Returns "" with a logged error if HashSecretKey is empty, since HMAC with
+// an empty key produces a valid but insecure hash.
func (c *Crypto) GenerateHash(plaintext *string) string {
- if plaintext == nil {
+ if c == nil || plaintext == nil {
+ return ""
+ }
+
+ if c.HashSecretKey == "" {
+ c.logger().Log(context.Background(), libLog.LevelError, "GenerateHash called with empty HashSecretKey")
return ""
}
@@ -40,29 +99,35 @@ func (c *Crypto) GenerateHash(plaintext *string) string {
return hash
}
-// InitializeCipher loads an AES-GCM block cipher for encryption/decryption
+// InitializeCipher loads an AES-GCM block cipher for encryption/decryption.
+// The EncryptSecretKey must be a hex-encoded key of 16, 24, or 32 bytes
+// (corresponding to AES-128, AES-192, or AES-256 respectively).
func (c *Crypto) InitializeCipher() error {
- if c.Cipher != nil {
- c.Logger.Info("Cipher already initialized")
+ if c == nil {
+ return ErrNilCrypto
+ }
+
+ if !isNilInterface(c.Cipher) {
+ c.logger().Log(context.Background(), libLog.LevelInfo, "Cipher already initialized")
return nil
}
decodedKey, err := hex.DecodeString(c.EncryptSecretKey)
if err != nil {
- c.Logger.Error("Failed to decode hex private key", zap.Error(err))
- return err
+ c.logger().Log(context.Background(), libLog.LevelError, "Failed to decode hex private key", libLog.Err(err))
+ return fmt.Errorf("crypto: hex decode key: %w", err)
}
blockCipher, err := aes.NewCipher(decodedKey)
if err != nil {
- c.Logger.Error("Error creating AES block cipher with the private key", zap.Error(err))
- return err
+ c.logger().Log(context.Background(), libLog.LevelError, "Error creating AES block cipher with the private key", libLog.Err(err))
+ return fmt.Errorf("crypto: create AES block cipher: %w", err)
}
aesGcm, err := cipher.NewGCM(blockCipher)
if err != nil {
- c.Logger.Error("Error creating GCM cipher", zap.Error(err))
- return err
+ c.logger().Log(context.Background(), libLog.LevelError, "Error creating GCM cipher", libLog.Err(err))
+ return fmt.Errorf("crypto: create GCM cipher: %w", err)
}
c.Cipher = aesGcm
@@ -70,22 +135,28 @@ func (c *Crypto) InitializeCipher() error {
return nil
}
-// Encrypt a plaintext using AES-GCM, which requires a private 32 bytes key and a random 12 bytes nonce.
-// It generates a base64 string with the encoded ciphertext.
+// Encrypt a plaintext using AES-GCM with a random 12-byte nonce.
+// Requires InitializeCipher to have been called with a valid AES key
+// (16, 24, or 32 bytes for AES-128, AES-192, or AES-256 respectively).
+// Returns a base64-encoded string with the nonce prepended to the ciphertext.
func (c *Crypto) Encrypt(plainText *string) (*string, error) {
+ if c == nil {
+ return nil, ErrNilCrypto
+ }
+
if plainText == nil {
- return nil, nil
+ return nil, ErrNilInput
}
- if c.Cipher == nil {
- return nil, errors.New("cipher not initialized")
+ if isNilInterface(c.Cipher) {
+ return nil, ErrCipherNotInitialized
}
// Generates random nonce with a size of 12 bytes
nonce := make([]byte, c.Cipher.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
- c.Logger.Error("Failed to generate nonce", zap.Error(err))
- return nil, err
+ c.logger().Log(context.Background(), libLog.LevelError, "Failed to generate nonce", libLog.Err(err))
+ return nil, fmt.Errorf("crypto: generate nonce: %w", err)
}
// Cipher Text prefixed with the random Nonce
@@ -99,26 +170,29 @@ func (c *Crypto) Encrypt(plainText *string) (*string, error) {
// Decrypt a base64 encoded encrypted plaintext.
// The encrypted plain text must be prefixed with the random nonce used for encryption.
func (c *Crypto) Decrypt(encryptedText *string) (*string, error) {
+ if c == nil {
+ return nil, ErrNilCrypto
+ }
+
if encryptedText == nil {
- return nil, nil
+ return nil, ErrNilInput
}
- if c.Cipher == nil {
- return nil, errors.New("cipher not initialized")
+ if isNilInterface(c.Cipher) {
+ return nil, ErrCipherNotInitialized
}
decodedEncryptedText, err := base64.StdEncoding.DecodeString(*encryptedText)
if err != nil {
- c.Logger.Error("Failed to decode encrypted text", zap.Error(err))
- return nil, err
+ c.logger().Log(context.Background(), libLog.LevelError, "Failed to decode encrypted text", libLog.Err(err))
+ return nil, fmt.Errorf("crypto: decode base64: %w", err)
}
nonceSize := c.Cipher.NonceSize()
if len(decodedEncryptedText) < nonceSize {
- err := errors.New("ciphertext too short")
- c.Logger.Error("Failed to decrypt ciphertext", zap.Error(err))
+ c.logger().Log(context.Background(), libLog.LevelError, "Failed to decrypt ciphertext", libLog.Err(ErrCiphertextTooShort))
- return nil, err
+ return nil, ErrCiphertextTooShort
}
// Separating nonce from ciphertext before decrypting
@@ -128,8 +202,8 @@ func (c *Crypto) Decrypt(encryptedText *string) (*string, error) {
// False positive described at https://github.com/securego/gosec/issues/1209
plainText, err := c.Cipher.Open(nil, nonce, cipherText, nil)
if err != nil {
- c.Logger.Error("Failed to decrypt ciphertext", zap.Error(err))
- return nil, err
+ c.logger().Log(context.Background(), libLog.LevelError, "Failed to decrypt ciphertext", libLog.Err(err))
+ return nil, fmt.Errorf("crypto: decrypt: %w", err)
}
result := string(plainText)
diff --git a/commons/crypto/crypto_nil_test.go b/commons/crypto/crypto_nil_test.go
new file mode 100644
index 00000000..a9fbab7b
--- /dev/null
+++ b/commons/crypto/crypto_nil_test.go
@@ -0,0 +1,141 @@
+//go:build unit
+
+package crypto
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ t.Run("InitializeCipher returns ErrNilCrypto", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Crypto
+
+ err := c.InitializeCipher()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilCrypto)
+ })
+
+ t.Run("Encrypt returns ErrNilCrypto", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Crypto
+ input := "data"
+
+ result, err := c.Encrypt(&input)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilCrypto)
+ assert.Nil(t, result)
+ })
+
+ t.Run("Decrypt returns ErrNilCrypto", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Crypto
+ input := "data"
+
+ result, err := c.Decrypt(&input)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilCrypto)
+ assert.Nil(t, result)
+ })
+
+ t.Run("GenerateHash returns empty string", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Crypto
+ input := "data"
+
+ result := c.GenerateHash(&input)
+ assert.Empty(t, result)
+ })
+
+ t.Run("String returns nil marker", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Crypto
+ assert.Equal(t, "", c.String())
+ })
+
+ t.Run("GoString returns nil marker", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Crypto
+ assert.Equal(t, "", c.GoString())
+ })
+
+ t.Run("logger returns NopLogger on nil receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Crypto
+ l := c.logger()
+ assert.NotNil(t, l)
+ })
+}
+
+func TestRedaction(t *testing.T) {
+ t.Parallel()
+
+ t.Run("String returns REDACTED text", func(t *testing.T) {
+ t.Parallel()
+
+ c := Crypto{
+ HashSecretKey: "super-secret-hash-key",
+ EncryptSecretKey: "super-secret-encrypt-key",
+ }
+
+ s := c.String()
+ assert.Contains(t, s, "REDACTED")
+ assert.NotContains(t, s, "super-secret-hash-key")
+ assert.NotContains(t, s, "super-secret-encrypt-key")
+ })
+
+ t.Run("GoString returns REDACTED text", func(t *testing.T) {
+ t.Parallel()
+
+ c := Crypto{
+ HashSecretKey: "super-secret-hash-key",
+ EncryptSecretKey: "super-secret-encrypt-key",
+ }
+
+ s := c.GoString()
+ assert.Contains(t, s, "REDACTED")
+ assert.NotContains(t, s, "super-secret-hash-key")
+ assert.NotContains(t, s, "super-secret-encrypt-key")
+ })
+
+ t.Run("fmt Sprintf %v does not leak secrets", func(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{
+ HashSecretKey: "secret-hash-value",
+ EncryptSecretKey: "secret-encrypt-value",
+ }
+
+ output := fmt.Sprintf("%v", c)
+ assert.NotContains(t, output, "secret-hash-value")
+ assert.NotContains(t, output, "secret-encrypt-value")
+ assert.Contains(t, output, "REDACTED")
+ })
+
+ t.Run("fmt Sprintf %#v does not leak secrets", func(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{
+ HashSecretKey: "secret-hash-value",
+ EncryptSecretKey: "secret-encrypt-value",
+ }
+
+ output := fmt.Sprintf("%#v", c)
+ assert.NotContains(t, output, "secret-hash-value")
+ assert.NotContains(t, output, "secret-encrypt-value")
+ assert.Contains(t, output, "REDACTED")
+ })
+}
diff --git a/commons/crypto/crypto_test.go b/commons/crypto/crypto_test.go
new file mode 100644
index 00000000..cb0b7db9
--- /dev/null
+++ b/commons/crypto/crypto_test.go
@@ -0,0 +1,393 @@
+//go:build unit
+
+package crypto
+
+import (
+ "encoding/base64"
+ "testing"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+const validHexKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
+
+func newTestCrypto(t *testing.T) *Crypto {
+ t.Helper()
+
+ c := &Crypto{
+ HashSecretKey: "hash-secret",
+ EncryptSecretKey: validHexKey,
+ Logger: libLog.NewNop(),
+ }
+
+ require.NoError(t, c.InitializeCipher())
+
+ return c
+}
+
+func ptr(s string) *string { return &s }
+
+func TestGenerateHash(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input *string
+ expectLen int
+ }{
+ {
+ name: "nil input returns empty string",
+ input: nil,
+ expectLen: 0,
+ },
+ {
+ name: "non-nil input returns 64-char hex hash",
+ input: ptr("hello"),
+ expectLen: 64,
+ },
+ {
+ name: "empty string input returns 64-char hex hash",
+ input: ptr(""),
+ expectLen: 64,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{HashSecretKey: "test-key", Logger: libLog.NewNop()}
+ result := c.GenerateHash(tt.input)
+
+ if tt.input == nil {
+ assert.Empty(t, result)
+ } else {
+ assert.Len(t, result, tt.expectLen)
+ }
+ })
+ }
+}
+
+func TestGenerateHash_Consistency(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{HashSecretKey: "test-key", Logger: libLog.NewNop()}
+ input := ptr("hello")
+
+ hash1 := c.GenerateHash(input)
+ hash2 := c.GenerateHash(input)
+
+ assert.Equal(t, hash1, hash2)
+}
+
+func TestGenerateHash_DifferentInputs(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{HashSecretKey: "test-key", Logger: libLog.NewNop()}
+
+ hash1 := c.GenerateHash(ptr("hello"))
+ hash2 := c.GenerateHash(ptr("world"))
+
+ assert.NotEqual(t, hash1, hash2)
+}
+
+func TestInitializeCipher(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ key string
+ expectErr bool
+ }{
+ {
+ name: "valid 32-byte hex key succeeds",
+ key: validHexKey,
+ expectErr: false,
+ },
+ {
+ name: "invalid hex characters",
+ key: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz",
+ expectErr: true,
+ },
+ {
+ name: "wrong key length (15 bytes)",
+ key: "0123456789abcdef0123456789abcd",
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{EncryptSecretKey: tt.key, Logger: libLog.NewNop()}
+ err := c.InitializeCipher()
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Nil(t, c.Cipher)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, c.Cipher)
+ }
+ })
+ }
+}
+
+func TestInitializeCipher_AlreadyInitialized(t *testing.T) {
+ t.Parallel()
+
+ c := newTestCrypto(t)
+ originalCipher := c.Cipher
+
+ err := c.InitializeCipher()
+
+ assert.NoError(t, err)
+ assert.Equal(t, originalCipher, c.Cipher)
+}
+
+func TestEncrypt(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ initCipher bool
+ input *string
+ expectNil bool
+ expectErr bool
+ sentinel error
+ }{
+ {
+ name: "nil input returns error",
+ initCipher: true,
+ input: nil,
+ expectNil: true,
+ expectErr: true,
+ sentinel: ErrNilInput,
+ },
+ {
+ name: "uninitialized cipher returns error",
+ initCipher: false,
+ input: ptr("hello"),
+ expectNil: true,
+ expectErr: true,
+ sentinel: ErrCipherNotInitialized,
+ },
+ {
+ name: "successful encryption",
+ initCipher: true,
+ input: ptr("hello world"),
+ expectNil: false,
+ expectErr: false,
+ },
+ {
+ name: "empty string encrypts successfully",
+ initCipher: true,
+ input: ptr(""),
+ expectNil: false,
+ expectErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{
+ EncryptSecretKey: validHexKey,
+ Logger: libLog.NewNop(),
+ }
+ if tt.initCipher {
+ require.NoError(t, c.InitializeCipher())
+ }
+
+ result, err := c.Encrypt(tt.input)
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ if tt.sentinel != nil {
+ assert.ErrorIs(t, err, tt.sentinel)
+ }
+ } else {
+ assert.NoError(t, err)
+ }
+
+ if tt.expectNil {
+ assert.Nil(t, result)
+ } else {
+ require.NotNil(t, result)
+ assert.NotEmpty(t, *result)
+ // Result must be valid base64
+ _, decErr := base64.StdEncoding.DecodeString(*result)
+ assert.NoError(t, decErr)
+ }
+ })
+ }
+}
+
+func TestDecrypt(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ initCipher bool
+ input *string
+ expectNil bool
+ expectErr bool
+ sentinel error
+ }{
+ {
+ name: "nil input returns error",
+ initCipher: true,
+ input: nil,
+ expectNil: true,
+ expectErr: true,
+ sentinel: ErrNilInput,
+ },
+ {
+ name: "uninitialized cipher returns error",
+ initCipher: false,
+ input: ptr("c29tZXRoaW5n"),
+ expectNil: true,
+ expectErr: true,
+ sentinel: ErrCipherNotInitialized,
+ },
+ {
+ name: "invalid base64 input",
+ initCipher: true,
+ input: ptr("!!!not-base64!!!"),
+ expectNil: true,
+ expectErr: true,
+ },
+ {
+ name: "ciphertext too short",
+ initCipher: true,
+ input: ptr(base64.StdEncoding.EncodeToString([]byte("short"))),
+ expectNil: true,
+ expectErr: true,
+ sentinel: ErrCiphertextTooShort,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{
+ EncryptSecretKey: validHexKey,
+ Logger: libLog.NewNop(),
+ }
+ if tt.initCipher {
+ require.NoError(t, c.InitializeCipher())
+ }
+
+ result, err := c.Decrypt(tt.input)
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ if tt.sentinel != nil {
+ assert.ErrorIs(t, err, tt.sentinel)
+ }
+ } else {
+ assert.NoError(t, err)
+ }
+
+ if tt.expectNil {
+ assert.Nil(t, result)
+ }
+ })
+ }
+}
+
+func TestEncryptDecrypt_RoundTrip(t *testing.T) {
+ t.Parallel()
+
+ c := newTestCrypto(t)
+
+ inputs := []string{
+ "hello world",
+ "",
+ "special chars: !@#$%^&*()",
+ "unicode: 日本語テスト 🎉",
+ "a longer string that exercises the AES-GCM cipher with more data to process in blocks",
+ }
+
+ for _, input := range inputs {
+ t.Run(input, func(t *testing.T) {
+ t.Parallel()
+
+ encrypted, err := c.Encrypt(ptr(input))
+ require.NoError(t, err)
+ require.NotNil(t, encrypted)
+
+ decrypted, err := c.Decrypt(encrypted)
+ require.NoError(t, err)
+ require.NotNil(t, decrypted)
+
+ assert.Equal(t, input, *decrypted)
+ })
+ }
+}
+
+func TestEncrypt_ProducesUniqueOutputs(t *testing.T) {
+ t.Parallel()
+
+ c := newTestCrypto(t)
+ input := ptr("same plaintext")
+
+ enc1, err1 := c.Encrypt(input)
+ require.NoError(t, err1)
+
+ enc2, err2 := c.Encrypt(input)
+ require.NoError(t, err2)
+
+ assert.NotEqual(t, *enc1, *enc2, "AES-GCM with random nonce should produce different ciphertexts")
+}
+
+func TestGenerateHash_EmptyKey(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{HashSecretKey: "", Logger: libLog.NewNop()}
+ input := ptr("hello")
+
+ result := c.GenerateHash(input)
+ assert.Empty(t, result, "GenerateHash with empty key should return empty string")
+}
+
+func TestLogger(t *testing.T) {
+ t.Parallel()
+
+ t.Run("returns configured logger", func(t *testing.T) {
+ t.Parallel()
+
+ nop := libLog.NewNop()
+ c := &Crypto{Logger: nop}
+
+ assert.Equal(t, nop, c.logger())
+ })
+
+ t.Run("returns NopLogger when Logger is nil", func(t *testing.T) {
+ t.Parallel()
+
+ c := &Crypto{}
+ l := c.logger()
+
+ assert.NotNil(t, l)
+ assert.IsType(t, &libLog.NopLogger{}, l)
+ })
+
+ t.Run("returns NopLogger for typed-nil Logger", func(t *testing.T) {
+ t.Parallel()
+
+ // Simulate a typed-nil: interface holds (*NopLogger)(nil).
+ // This exercises the isNilInterface reflection path.
+ var nilLogger *libLog.NopLogger
+ c := &Crypto{Logger: nilLogger}
+ l := c.logger()
+
+ assert.NotNil(t, l)
+ assert.IsType(t, &libLog.NopLogger{}, l)
+ })
+}
diff --git a/commons/crypto/doc.go b/commons/crypto/doc.go
new file mode 100644
index 00000000..61ec1003
--- /dev/null
+++ b/commons/crypto/doc.go
@@ -0,0 +1,8 @@
+// Package crypto provides hashing and symmetric encryption helpers.
+//
+// The Crypto type supports:
+// - HMAC-SHA256 hashing for deterministic fingerprints
+// - AES-GCM encryption/decryption for confidential payloads
+//
+// InitializeCipher must be called before Encrypt or Decrypt.
+package crypto
diff --git a/commons/doc.go b/commons/doc.go
new file mode 100644
index 00000000..a558b9c2
--- /dev/null
+++ b/commons/doc.go
@@ -0,0 +1,14 @@
+// Package commons provides shared infrastructure helpers used across Lerian services.
+//
+// The package includes context helpers, validation utilities, error adapters,
+// and cross-cutting primitives used by higher-level subpackages.
+//
+// Typical usage at request ingress:
+//
+// ctx = commons.ContextWithLogger(ctx, logger)
+// ctx = commons.ContextWithTracer(ctx, tracer)
+// ctx = commons.ContextWithHeaderID(ctx, requestID)
+//
+// This package is intentionally dependency-light; specialized integrations live in
+// subpackages such as opentelemetry, mongo, redis, rabbitmq, and server.
+package commons
diff --git a/commons/errgroup/doc.go b/commons/errgroup/doc.go
new file mode 100644
index 00000000..d6235927
--- /dev/null
+++ b/commons/errgroup/doc.go
@@ -0,0 +1,5 @@
+// Package errgroup coordinates goroutines that share a cancellation context.
+//
+// The first goroutine error cancels the group context and is returned by Wait;
+// recovered panics are converted into errors.
+package errgroup
diff --git a/commons/errgroup/errgroup.go b/commons/errgroup/errgroup.go
new file mode 100644
index 00000000..2457e655
--- /dev/null
+++ b/commons/errgroup/errgroup.go
@@ -0,0 +1,124 @@
+package errgroup
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+)
+
+var (
+ // ErrPanicRecovered is returned when a goroutine in the group panics.
+ ErrPanicRecovered = errors.New("errgroup: panic recovered")
+
+ // ErrNilGroup is returned when Go or Wait is called on a nil *Group.
+ ErrNilGroup = errors.New("errgroup: nil group")
+)
+
+// Group manages a set of goroutines that share a cancellation context.
+// The first error returned by any goroutine cancels the group's context
+// and is returned by Wait. Subsequent errors are discarded.
+type Group struct {
+ ctx context.Context
+ cancel context.CancelFunc
+ wg sync.WaitGroup
+ errOnce sync.Once
+ err error
+ loggerMu sync.RWMutex
+ logger libLog.Logger
+}
+
+// SetLogger sets an optional logger for panic recovery observability.
+// When set, panics recovered in goroutines will be logged before the
+// error is propagated via Wait. Safe for concurrent use.
+func (grp *Group) SetLogger(logger libLog.Logger) {
+ if grp == nil {
+ return
+ }
+
+ grp.loggerMu.Lock()
+ grp.logger = logger
+ grp.loggerMu.Unlock()
+}
+
+// getLogger returns the current logger in a concurrency-safe manner.
+func (grp *Group) getLogger() libLog.Logger {
+ grp.loggerMu.RLock()
+ l := grp.logger
+ grp.loggerMu.RUnlock()
+
+ return l
+}
+
+// effectiveCtx returns the group's context, falling back to context.Background()
+// for zero-value Groups not created via WithContext.
+func (grp *Group) effectiveCtx() context.Context {
+ if grp.ctx != nil {
+ return grp.ctx
+ }
+
+ return context.Background()
+}
+
+// WithContext returns a new Group and a derived context.Context.
+// The derived context is canceled when the first goroutine in the Group
+// returns a non-nil error or when Wait returns, whichever occurs first.
+func WithContext(ctx context.Context) (*Group, context.Context) {
+ ctx, cancel := context.WithCancel(ctx)
+ return &Group{ctx: ctx, cancel: cancel}, ctx
+}
+
+// Go starts a new goroutine in the Group. The first non-nil error returned
+// by a goroutine is recorded and triggers cancellation of the group context.
+// Callers must not mutate shared state without synchronization.
+// If called on a nil *Group, Go is a no-op.
+func (grp *Group) Go(fn func() error) {
+ if grp == nil {
+ return
+ }
+
+ grp.wg.Go(func() {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ runtime.HandlePanicValue(grp.effectiveCtx(), grp.getLogger(), recovered, "errgroup", "group.Go")
+
+ grp.errOnce.Do(func() {
+ grp.err = fmt.Errorf("%w: %v", ErrPanicRecovered, recovered)
+ if grp.cancel != nil {
+ grp.cancel()
+ }
+ })
+ }
+ }()
+
+ if err := fn(); err != nil {
+ grp.errOnce.Do(func() {
+ grp.err = err
+ if grp.cancel != nil {
+ grp.cancel()
+ }
+ })
+ }
+ })
+}
+
+// Wait blocks until all goroutines in the Group have completed.
+// It cancels the group context after all goroutines finish and returns
+// the first non-nil error (if any) recorded by Go.
+// Returns ErrNilGroup if called on a nil *Group.
+func (grp *Group) Wait() error {
+ if grp == nil {
+ return ErrNilGroup
+ }
+
+ grp.wg.Wait()
+
+ if grp.cancel != nil {
+ grp.cancel()
+ }
+
+ return grp.err
+}
diff --git a/commons/errgroup/errgroup_nil_test.go b/commons/errgroup/errgroup_nil_test.go
new file mode 100644
index 00000000..f92e9b7a
--- /dev/null
+++ b/commons/errgroup/errgroup_nil_test.go
@@ -0,0 +1,136 @@
+//go:build unit
+
+package errgroup_test
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/errgroup"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNilReceiver_SetLogger(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil pointer SetLogger does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ var g *errgroup.Group
+
+ assert.NotPanics(t, func() {
+ g.SetLogger(log.NewNop())
+ })
+ })
+
+ t.Run("nil pointer SetLogger with nil logger does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ var g *errgroup.Group
+
+ assert.NotPanics(t, func() {
+ g.SetLogger(nil)
+ })
+ })
+}
+
+func TestZeroValueGroup(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Go and Wait work without WithContext", func(t *testing.T) {
+ t.Parallel()
+
+ var g errgroup.Group
+
+ g.Go(func() error {
+ return nil
+ })
+
+ err := g.Wait()
+ assert.NoError(t, err)
+ })
+
+ t.Run("Go returns error through Wait", func(t *testing.T) {
+ t.Parallel()
+
+ var g errgroup.Group
+ expectedErr := errors.New("zero-value error")
+
+ g.Go(func() error {
+ return expectedErr
+ })
+
+ err := g.Wait()
+ require.Error(t, err)
+ assert.Equal(t, expectedErr, err)
+ })
+
+ t.Run("Wait with no goroutines returns nil", func(t *testing.T) {
+ t.Parallel()
+
+ var g errgroup.Group
+
+ err := g.Wait()
+ assert.NoError(t, err)
+ })
+
+ t.Run("panic in Go recovers and returns ErrPanicRecovered", func(t *testing.T) {
+ t.Parallel()
+
+ var g errgroup.Group
+
+ g.Go(func() error {
+ panic("boom from zero-value group")
+ })
+
+ err := g.Wait()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, errgroup.ErrPanicRecovered)
+ assert.Contains(t, err.Error(), "boom from zero-value group")
+ })
+
+ t.Run("panic with nil cancel does not double-panic", func(t *testing.T) {
+ t.Parallel()
+
+ // Zero-value Group has nil cancel. The panic recovery path
+ // checks cancel != nil before calling it. This test ensures
+ // the nil-guard works correctly.
+ var g errgroup.Group
+
+ assert.NotPanics(t, func() {
+ g.Go(func() error {
+ panic("nil cancel test")
+ })
+ _ = g.Wait()
+ })
+ })
+
+ t.Run("multiple goroutines on zero-value group", func(t *testing.T) {
+ t.Parallel()
+
+ var g errgroup.Group
+ firstErr := errors.New("first")
+
+ g.Go(func() error {
+ return firstErr
+ })
+
+ g.Go(func() error {
+ return errors.New("second")
+ })
+
+ g.Go(func() error {
+ return nil
+ })
+
+ err := g.Wait()
+ require.Error(t, err)
+ // errOnce guarantees the first recorded error wins.
+ // Due to goroutine scheduling, either error could be first,
+ // but we'll always get exactly one error back.
+ assert.True(t, err.Error() == "first" || err.Error() == "second",
+ "expected error to be one of the goroutine errors, got: %v", err)
+ })
+}
diff --git a/commons/errgroup/errgroup_test.go b/commons/errgroup/errgroup_test.go
new file mode 100644
index 00000000..2786cffd
--- /dev/null
+++ b/commons/errgroup/errgroup_test.go
@@ -0,0 +1,221 @@
+//go:build unit
+
+package errgroup_test
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/errgroup"
+)
+
+func TestWithContext_AllSucceed(t *testing.T) {
+ t.Parallel()
+
+ group, _ := errgroup.WithContext(context.Background())
+
+ group.Go(func() error { return nil })
+ group.Go(func() error { return nil })
+ group.Go(func() error { return nil })
+
+ err := group.Wait()
+ assert.NoError(t, err)
+}
+
+func TestWithContext_OneError(t *testing.T) {
+ t.Parallel()
+
+ expectedErr := errors.New("something failed")
+ group, groupCtx := errgroup.WithContext(context.Background())
+
+ group.Go(func() error { return expectedErr })
+ group.Go(func() error {
+ <-groupCtx.Done()
+ return nil
+ })
+
+ err := group.Wait()
+ require.Error(t, err)
+ assert.Equal(t, expectedErr, err)
+}
+
+func TestWithContext_MultipleErrors_ReturnsFirst(t *testing.T) {
+ t.Parallel()
+
+ firstErr := errors.New("first error")
+ group, _ := errgroup.WithContext(context.Background())
+
+ started := make(chan struct{})
+ firstDone := make(chan struct{})
+
+ group.Go(func() error {
+ <-started
+ close(firstDone)
+
+ return firstErr
+ })
+
+ group.Go(func() error {
+ <-started
+ <-firstDone // Wait for first goroutine to signal before returning
+
+ return errors.New("second error")
+ })
+
+ close(started)
+
+ err := group.Wait()
+ require.Error(t, err)
+ assert.Equal(t, firstErr, err)
+}
+
+func TestWithContext_ZeroGoroutines(t *testing.T) {
+ t.Parallel()
+
+ group, _ := errgroup.WithContext(context.Background())
+
+ err := group.Wait()
+ assert.NoError(t, err)
+}
+
+func TestWithContext_ContextCancellation(t *testing.T) {
+ t.Parallel()
+
+ var cancelled atomic.Bool
+
+ group, groupCtx := errgroup.WithContext(context.Background())
+
+ group.Go(func() error {
+ return errors.New("trigger cancel")
+ })
+
+ group.Go(func() error {
+ <-groupCtx.Done()
+ cancelled.Store(true)
+ return nil
+ })
+
+ _ = group.Wait()
+ assert.True(t, cancelled.Load())
+}
+
+func TestWithContext_PanicRecovery(t *testing.T) {
+ t.Parallel()
+
+ group, _ := errgroup.WithContext(context.Background())
+
+ group.Go(func() error {
+ panic("something went wrong")
+ })
+
+ err := group.Wait()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, errgroup.ErrPanicRecovered)
+ assert.Contains(t, err.Error(), "something went wrong")
+}
+
+func TestWithContext_PanicAlongsideSuccess(t *testing.T) {
+ t.Parallel()
+
+ var completed atomic.Bool
+
+ group, _ := errgroup.WithContext(context.Background())
+
+ group.Go(func() error {
+ panic("boom")
+ })
+
+ group.Go(func() error {
+ completed.Store(true)
+ return nil
+ })
+
+ err := group.Wait()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, errgroup.ErrPanicRecovered)
+ assert.True(t, completed.Load())
+}
+
+func TestWithContext_PanicAndError_FirstWins(t *testing.T) {
+ t.Parallel()
+
+ regularErr := errors.New("regular error")
+ group, _ := errgroup.WithContext(context.Background())
+
+ started := make(chan struct{})
+
+ // This goroutine returns a regular error first
+ group.Go(func() error {
+ <-started
+ return regularErr
+ })
+
+ // This goroutine panics after a delay
+ group.Go(func() error {
+ <-started
+ time.Sleep(50 * time.Millisecond)
+ panic("delayed panic")
+ })
+
+ close(started)
+
+ err := group.Wait()
+ require.Error(t, err)
+ // The regular error should win because it fires first
+ assert.Equal(t, regularErr, err)
+}
+
+func TestWithContext_PanicWithNonStringValue(t *testing.T) {
+ t.Parallel()
+
+ group, _ := errgroup.WithContext(context.Background())
+
+ group.Go(func() error {
+ panic(42)
+ })
+
+ err := group.Wait()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, errgroup.ErrPanicRecovered)
+}
+
+func TestWithContext_PanicWithNilValue(t *testing.T) {
+ t.Parallel()
+
+ group, _ := errgroup.WithContext(context.Background())
+
+ group.Go(func() error {
+ panic(nil)
+ })
+
+ err := group.Wait()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, errgroup.ErrPanicRecovered)
+}
+
+func TestWithContext_PanicCancelsContext(t *testing.T) {
+ t.Parallel()
+
+ var cancelled atomic.Bool
+
+ group, groupCtx := errgroup.WithContext(context.Background())
+
+ group.Go(func() error {
+ panic("trigger cancel via panic")
+ })
+
+ group.Go(func() error {
+ <-groupCtx.Done()
+ cancelled.Store(true)
+ return nil
+ })
+
+ _ = group.Wait()
+ assert.True(t, cancelled.Load())
+}
diff --git a/commons/errors.go b/commons/errors.go
index a8769a9b..92e06d44 100644
--- a/commons/errors.go
+++ b/commons/errors.go
@@ -1,11 +1,12 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package commons
import (
- constant "github.com/LerianStudio/lib-commons/v2/commons/constants"
+ "errors"
+ "fmt"
+ "strings"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/security"
)
// Response represents a business error with code, title, and message.
@@ -17,6 +18,7 @@ type Response struct {
Err error `json:"err,omitempty"`
}
+// Error returns the business-facing message and satisfies the error interface.
func (e Response) Error() string {
return e.Message
}
@@ -69,9 +71,51 @@ func ValidateBusinessError(err error, entityType string, args ...any) error {
Message: "External accounts cannot be used for pending transactions in source operations. Please check the accounts and try again.",
},
}
- if mappedError, found := errorMap[err]; found {
- return mappedError
+ // Use errors.Is to match wrapped sentinels instead of exact map identity.
+ for sentinel, mappedError := range errorMap {
+ if !errors.Is(err, sentinel) {
+ continue
+ }
+
+ var response Response
+ if !errors.As(mappedError, &response) {
+ return mappedError
+ }
+
+ if len(args) > 0 {
+ parts := make([]string, 0, len(args))
+
+ for _, arg := range args {
+ s := fmt.Sprint(arg)
+ // Redact arguments that look like sensitive fields (credentials, PII)
+ // to prevent leaking them to external API consumers.
+ if looksLikeSensitiveArg(s) {
+ continue
+ }
+
+ parts = append(parts, s)
+ }
+
+ if len(parts) > 0 {
+ response.Message = fmt.Sprintf("%s (%s)", response.Message, strings.Join(parts, ", "))
+ }
+ }
+
+ return response
}
return err
}
+
+// looksLikeSensitiveArg checks whether a stringified argument contains a key=value
+// pair where the key is a known sensitive field name.
+func looksLikeSensitiveArg(s string) bool {
+ if idx := strings.IndexByte(s, '='); idx > 0 {
+ key := s[:idx]
+ if security.IsSensitiveField(key) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/commons/errors_test.go b/commons/errors_test.go
index 8b8d238e..93bbadc4 100644
--- a/commons/errors_test.go
+++ b/commons/errors_test.go
@@ -1,14 +1,13 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package commons
import (
"errors"
+ "fmt"
"testing"
- constant "github.com/LerianStudio/lib-commons/v2/commons/constants"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
"github.com/stretchr/testify/assert"
)
@@ -156,10 +155,33 @@ func TestValidateBusinessError(t *testing.T) {
}
func TestValidateBusinessError_WithArgs(t *testing.T) {
- // Test that ValidateBusinessError accepts variadic args (even if not used currently)
- result := ValidateBusinessError(constant.ErrAccountIneligibility, "account", "arg1", "arg2")
+ result := ValidateBusinessError(constant.ErrAccountIneligibility, "account", "alias=@account1", "balance=default")
response, ok := result.(Response)
assert.True(t, ok)
assert.Equal(t, "account", response.EntityType)
+ assert.Contains(t, response.Message, "alias=@account1")
+ assert.Contains(t, response.Message, "balance=default")
+}
+
+func TestValidateBusinessError_WrappedSentinel(t *testing.T) {
+ // Wrap a known sentinel error — errors.Is should still match.
+ wrapped := fmt.Errorf("context info: %w", constant.ErrInsufficientFunds)
+ result := ValidateBusinessError(wrapped, "transaction")
+
+ response, ok := result.(Response)
+ assert.True(t, ok, "wrapped sentinel should be matched via errors.Is")
+ assert.Equal(t, "transaction", response.EntityType)
+ assert.Equal(t, constant.ErrInsufficientFunds.Error(), response.Code)
+ assert.Contains(t, response.Message, "insufficient funds")
+}
+
+func TestValidateBusinessError_SensitiveArgRedacted(t *testing.T) {
+ // Args with sensitive-looking keys (password=...) should be redacted.
+ result := ValidateBusinessError(constant.ErrAccountIneligibility, "account", "password=secret123", "alias=@acc1")
+
+ response, ok := result.(Response)
+ assert.True(t, ok)
+ assert.NotContains(t, response.Message, "secret123", "sensitive args must not appear in message")
+ assert.Contains(t, response.Message, "alias=@acc1", "non-sensitive args should appear")
}
diff --git a/commons/internal/nilcheck/nilcheck.go b/commons/internal/nilcheck/nilcheck.go
new file mode 100644
index 00000000..cd489f4d
--- /dev/null
+++ b/commons/internal/nilcheck/nilcheck.go
@@ -0,0 +1,19 @@
+package nilcheck
+
+import "reflect"
+
+// Interface reports whether value is nil, including typed-nil interfaces.
+func Interface(value any) bool {
+ if value == nil {
+ return true
+ }
+
+ v := reflect.ValueOf(value)
+
+ switch v.Kind() {
+ case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
+ return v.IsNil()
+ default:
+ return false
+ }
+}
diff --git a/commons/internal/nilcheck/nilcheck_test.go b/commons/internal/nilcheck/nilcheck_test.go
new file mode 100644
index 00000000..5bfa7c7d
--- /dev/null
+++ b/commons/internal/nilcheck/nilcheck_test.go
@@ -0,0 +1,49 @@
+//go:build unit
+
+package nilcheck
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type sampleStruct struct{}
+
+type sampleInterface interface {
+ Do()
+}
+
+type sampleImpl struct{}
+
+func (*sampleImpl) Do() {}
+
+func TestInterface(t *testing.T) {
+ t.Parallel()
+
+ var nilPointer *sampleStruct
+ var nilSlice []string
+ var nilMap map[string]string
+ var nilChan chan int
+ var nilFunc func()
+ var nilIface sampleInterface
+
+ var typedNilIface sampleInterface
+ var typedImpl *sampleImpl
+ typedNilIface = typedImpl
+
+ require.True(t, Interface(nil))
+ require.True(t, Interface(nilPointer))
+ require.True(t, Interface(nilSlice))
+ require.True(t, Interface(nilMap))
+ require.True(t, Interface(nilChan))
+ require.True(t, Interface(nilFunc))
+ require.True(t, Interface(nilIface))
+ require.True(t, Interface(typedNilIface))
+
+ require.False(t, Interface(0))
+ require.False(t, Interface(""))
+ require.False(t, Interface(sampleStruct{}))
+ require.False(t, Interface(&sampleStruct{}))
+ require.False(t, Interface([]string{}))
+}
diff --git a/commons/jwt/doc.go b/commons/jwt/doc.go
new file mode 100644
index 00000000..e1a3b316
--- /dev/null
+++ b/commons/jwt/doc.go
@@ -0,0 +1,5 @@
+// Package jwt provides minimal HMAC-based JWT signing and verification.
+//
+// The package supports HS256, HS384, and HS512, and includes helpers to
+// validate standard time-based claims (exp, nbf, iat).
+package jwt
diff --git a/commons/jwt/jwt.go b/commons/jwt/jwt.go
new file mode 100644
index 00000000..38d1039a
--- /dev/null
+++ b/commons/jwt/jwt.go
@@ -0,0 +1,383 @@
+package jwt
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "hash"
+ "slices"
+ "strings"
+ "time"
+)
+
+const (
+ // AlgHS256 identifies the HMAC-SHA256 signing algorithm.
+ AlgHS256 = "HS256"
+ // AlgHS384 identifies the HMAC-SHA384 signing algorithm.
+ AlgHS384 = "HS384"
+ // AlgHS512 identifies the HMAC-SHA512 signing algorithm.
+ AlgHS512 = "HS512"
+
+ // jwtPartCount is the number of dot-separated parts in a valid JWT (header.payload.signature).
+ jwtPartCount = 3
+)
+
+// MapClaims is a convenience alias for an unstructured JWT payload.
+type MapClaims = map[string]any
+
+// Token represents a parsed JWT with its header, claims, and validation state.
+// SignatureValid is true only when the token's cryptographic signature has been
+// verified successfully. It does NOT indicate that time-based claims (exp, nbf,
+// iat) have been validated. Use ParseAndValidate for full validation, or call
+// Token.ValidateTimeClaims after Parse.
+type Token struct {
+ Claims MapClaims
+ SignatureValid bool
+ Header map[string]any
+}
+
+var (
+ // ErrInvalidToken indicates the token string is malformed or cannot be decoded.
+ ErrInvalidToken = errors.New("invalid token")
+ // ErrUnsupportedAlgorithm indicates the signing algorithm is not supported or not allowed.
+ ErrUnsupportedAlgorithm = errors.New("unsupported signing algorithm")
+ // ErrSignatureInvalid indicates the token signature does not match the expected value.
+ ErrSignatureInvalid = errors.New("signature verification failed")
+ // ErrTokenExpired indicates the token's exp claim is in the past.
+ ErrTokenExpired = errors.New("token has expired")
+ // ErrTokenNotYetValid indicates the token's nbf claim is in the future.
+ ErrTokenNotYetValid = errors.New("token is not yet valid")
+ // ErrTokenIssuedInFuture indicates the token's iat claim is in the future.
+ ErrTokenIssuedInFuture = errors.New("token issued in the future")
+ // ErrEmptySecret indicates an empty secret was provided for signing or verification.
+ ErrEmptySecret = errors.New("secret must not be empty")
+ // ErrNilToken indicates a method was called on a nil *Token.
+ ErrNilToken = errors.New("token is nil")
+ // ErrInvalidTimeClaim indicates a time claim is present but has an unsupported or unparseable type.
+ ErrInvalidTimeClaim = errors.New("invalid time claim type")
+)
+
+// Parse validates and decodes a JWT token string. It expects three dot-separated
+// base64url-encoded parts (header, payload, signature), verifies the algorithm
+// is in the allowedAlgorithms whitelist, and checks the HMAC signature using
+// the provided secret with constant-time comparison. Returns a populated Token
+// on success, or ErrInvalidToken, ErrUnsupportedAlgorithm, or ErrSignatureInvalid
+// on failure.
+//
+// Note: Token.SignatureValid indicates only that the cryptographic signature
+// has been verified successfully. It does NOT validate time-based claims such
+// as exp (expiration) or nbf (not-before). Use ParseAndValidate for a single-
+// step parse-and-validate flow, or call Token.ValidateTimeClaims after Parse.
+func Parse(tokenString string, secret []byte, allowedAlgorithms []string) (*Token, error) {
+ const maxTokenLength = 8192 // 8KB is generous for any legitimate JWT
+
+ if len(secret) == 0 {
+ return nil, ErrEmptySecret
+ }
+
+ if len(tokenString) > maxTokenLength {
+ return nil, fmt.Errorf("token exceeds maximum length of %d bytes: %w", maxTokenLength, ErrInvalidToken)
+ }
+
+ if tokenString == "" {
+ return nil, fmt.Errorf("empty token string: %w", ErrInvalidToken)
+ }
+
+ parts := strings.Split(tokenString, ".")
+ if len(parts) != jwtPartCount {
+ return nil, fmt.Errorf("token must have %d parts: %w", jwtPartCount, ErrInvalidToken)
+ }
+
+ header, alg, err := parseHeader(parts[0], allowedAlgorithms)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := verifySignature(parts, alg, secret); err != nil {
+ return nil, err
+ }
+
+ claims, err := parseClaims(parts[1])
+ if err != nil {
+ return nil, err
+ }
+
+ return &Token{
+ Claims: claims,
+ SignatureValid: true,
+ Header: header,
+ }, nil
+}
+
+// ParseAndValidate parses the JWT, verifies the cryptographic signature,
+// and validates time-based claims (exp, nbf, iat). This is the recommended
+// single-step validation for most use cases. It returns the token only if
+// both the signature and time claims are valid.
+func ParseAndValidate(tokenString string, secret []byte, allowedAlgorithms []string) (*Token, error) {
+ token, err := Parse(tokenString, secret, allowedAlgorithms)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := token.ValidateTimeClaims(); err != nil {
+ return nil, err
+ }
+
+ return token, nil
+}
+
+// parseHeader decodes and validates the JWT header part. It base64url-decodes
+// the header, unmarshals it, extracts the signing algorithm, and verifies it
+// is in the allowed list.
+func parseHeader(headerPart string, allowedAlgorithms []string) (map[string]any, string, error) {
+ headerBytes, err := base64.RawURLEncoding.DecodeString(headerPart)
+ if err != nil {
+ return nil, "", fmt.Errorf("decode header: %w", ErrInvalidToken)
+ }
+
+ var header map[string]any
+ if err := json.Unmarshal(headerBytes, &header); err != nil {
+ return nil, "", fmt.Errorf("unmarshal header: %w", ErrInvalidToken)
+ }
+
+ alg, ok := header["alg"].(string)
+ if !ok || alg == "" {
+ return nil, "", fmt.Errorf("missing alg in header: %w", ErrInvalidToken)
+ }
+
+ if !isAllowed(alg, allowedAlgorithms) {
+ return nil, "", fmt.Errorf("algorithm %q not allowed: %w", alg, ErrUnsupportedAlgorithm)
+ }
+
+ return header, alg, nil
+}
+
+// verifySignature checks the HMAC signature of the JWT. It computes the expected
+// signature from the signing input (header.payload) and compares it against the
+// actual signature using constant-time comparison.
+func verifySignature(parts []string, alg string, secret []byte) error {
+ hashFunc, err := hashForAlgorithm(alg)
+ if err != nil {
+ return err
+ }
+
+ signingInput := parts[0] + "." + parts[1]
+
+ expectedSig, err := computeHMAC([]byte(signingInput), secret, hashFunc)
+ if err != nil {
+ return fmt.Errorf("compute signature: %w", ErrInvalidToken)
+ }
+
+ actualSig, err := base64.RawURLEncoding.DecodeString(parts[2])
+ if err != nil {
+ return fmt.Errorf("decode signature: %w", ErrInvalidToken)
+ }
+
+ if !hmac.Equal(expectedSig, actualSig) {
+ return ErrSignatureInvalid
+ }
+
+ return nil
+}
+
+// parseClaims decodes and unmarshals the JWT payload part into a MapClaims map.
+// Uses json.Decoder with UseNumber() to preserve numeric fidelity for time
+// claims (iat, exp, nbf) instead of converting them to float64.
+func parseClaims(payloadPart string) (MapClaims, error) {
+ payloadBytes, err := base64.RawURLEncoding.DecodeString(payloadPart)
+ if err != nil {
+ return nil, fmt.Errorf("decode payload: %w", ErrInvalidToken)
+ }
+
+ var claims MapClaims
+
+ dec := json.NewDecoder(bytes.NewReader(payloadBytes))
+ dec.UseNumber()
+
+ if err := dec.Decode(&claims); err != nil {
+ return nil, fmt.Errorf("unmarshal payload: %w", ErrInvalidToken)
+ }
+
+ return claims, nil
+}
+
+// Sign produces a compact JWT serialization from the given claims. It encodes
+// the header and payload as base64url, computes an HMAC signature using the
+// specified algorithm and secret, and returns the three-part dot-separated
+// token string. Supported algorithms: HS256, HS384, HS512.
+func Sign(claims MapClaims, algorithm string, secret []byte) (string, error) {
+ if len(secret) == 0 {
+ return "", ErrEmptySecret
+ }
+
+ hashFunc, err := hashForAlgorithm(algorithm)
+ if err != nil {
+ return "", err
+ }
+
+ header := map[string]string{"alg": algorithm, "typ": "JWT"}
+
+ headerJSON, err := json.Marshal(header)
+ if err != nil {
+ return "", fmt.Errorf("marshal header: %w", err)
+ }
+
+ claimsJSON, err := json.Marshal(claims)
+ if err != nil {
+ return "", fmt.Errorf("marshal claims: %w", err)
+ }
+
+ headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON)
+ payloadEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON)
+
+ signingInput := headerEncoded + "." + payloadEncoded
+
+ sig, err := computeHMAC([]byte(signingInput), secret, hashFunc)
+ if err != nil {
+ return "", fmt.Errorf("compute signature: %w", err)
+ }
+
+ sigEncoded := base64.RawURLEncoding.EncodeToString(sig)
+
+ return signingInput + "." + sigEncoded, nil
+}
+
+func isAllowed(alg string, allowed []string) bool {
+ return slices.Contains(allowed, alg)
+}
+
+func hashForAlgorithm(alg string) (func() hash.Hash, error) {
+ switch alg {
+ case AlgHS256:
+ return sha256.New, nil
+ case AlgHS384:
+ return sha512.New384, nil
+ case AlgHS512:
+ return sha512.New, nil
+ default:
+ return nil, fmt.Errorf("algorithm %q: %w", alg, ErrUnsupportedAlgorithm)
+ }
+}
+
+func computeHMAC(data, secret []byte, hashFunc func() hash.Hash) ([]byte, error) {
+ mac := hmac.New(hashFunc, secret)
+
+ if _, err := mac.Write(data); err != nil {
+ return nil, fmt.Errorf("hmac write: %w", err)
+ }
+
+ return mac.Sum(nil), nil
+}
+
+// ValidateTimeClaims checks the standard JWT time-based claims (exp, nbf, iat)
+// on this token against the current UTC time.
+// Returns ErrNilToken if called on a nil *Token.
+func (t *Token) ValidateTimeClaims() error {
+ if t == nil {
+ return ErrNilToken
+ }
+
+ return ValidateTimeClaimsAt(t.Claims, time.Now().UTC())
+}
+
+// ValidateTimeClaimsAt checks the standard JWT time-based claims on this token
+// against the provided time.
+// Returns ErrNilToken if called on a nil *Token.
+func (t *Token) ValidateTimeClaimsAt(now time.Time) error {
+ if t == nil {
+ return ErrNilToken
+ }
+
+ return ValidateTimeClaimsAt(t.Claims, now)
+}
+
+// ValidateTimeClaimsAt checks the standard JWT time-based claims against the provided time.
+// Each claim is optional: if absent from the map, the corresponding check is skipped.
+// Returns ErrTokenExpired if the token has expired (at or past the expiry time, per
+// RFC 7519 §4.1.4), ErrTokenNotYetValid if the token cannot be used yet, or
+// ErrTokenIssuedInFuture if the issued-at time is in the future.
+// Returns ErrInvalidTimeClaim if a time claim is present but has an unsupported type.
+func ValidateTimeClaimsAt(claims MapClaims, now time.Time) error {
+ exp, expOK, err := extractTime(claims, "exp")
+ if err != nil {
+ return err
+ }
+
+ if expOK {
+ // RFC 7519 §4.1.4: the token MUST NOT be accepted on or after the expiration time.
+ // !now.Before(exp) is equivalent to now >= exp.
+ if !now.Before(exp) {
+ return fmt.Errorf("token expired at %s: %w", exp.Format(time.RFC3339), ErrTokenExpired)
+ }
+ }
+
+ nbf, nbfOK, err := extractTime(claims, "nbf")
+ if err != nil {
+ return err
+ }
+
+ if nbfOK {
+ if now.Before(nbf) {
+ return fmt.Errorf("token not valid until %s: %w", nbf.Format(time.RFC3339), ErrTokenNotYetValid)
+ }
+ }
+
+ iat, iatOK, err := extractTime(claims, "iat")
+ if err != nil {
+ return err
+ }
+
+ if iatOK {
+ if now.Before(iat) {
+ return fmt.Errorf("token issued at %s which is in the future: %w", iat.Format(time.RFC3339), ErrTokenIssuedInFuture)
+ }
+ }
+
+ return nil
+}
+
+// ValidateTimeClaims checks the standard JWT time-based claims (exp, nbf, iat)
+// against the current UTC time.
+func ValidateTimeClaims(claims MapClaims) error {
+ return ValidateTimeClaimsAt(claims, time.Now().UTC())
+}
+
+// extractTime retrieves a time value from claims by key. It supports float64
+// (the default from encoding/json), json.Number (when using a decoder with
+// UseNumber), and integer types (int, int32, int64).
+//
+// Returns:
+// - (time, true, nil) if the claim is present and successfully parsed
+// - (zero, false, nil) if the claim is absent
+// - (zero, false, error) if the claim is present but has an unsupported or unparseable type
+func extractTime(claims MapClaims, key string) (time.Time, bool, error) {
+ raw, exists := claims[key]
+ if !exists {
+ return time.Time{}, false, nil
+ }
+
+ switch v := raw.(type) {
+ case float64:
+ return time.Unix(int64(v), 0).UTC(), true, nil
+ case json.Number:
+ f, err := v.Float64()
+ if err != nil {
+ return time.Time{}, false, fmt.Errorf("claim %q: unparseable json.Number %q: %w", key, v.String(), ErrInvalidTimeClaim)
+ }
+
+ return time.Unix(int64(f), 0).UTC(), true, nil
+ case int:
+ return time.Unix(int64(v), 0).UTC(), true, nil
+ case int32:
+ return time.Unix(int64(v), 0).UTC(), true, nil
+ case int64:
+ return time.Unix(v, 0).UTC(), true, nil
+ default:
+ return time.Time{}, false, fmt.Errorf("claim %q: unsupported type %T: %w", key, raw, ErrInvalidTimeClaim)
+ }
+}
diff --git a/commons/jwt/jwt_test.go b/commons/jwt/jwt_test.go
new file mode 100644
index 00000000..a4dc695a
--- /dev/null
+++ b/commons/jwt/jwt_test.go
@@ -0,0 +1,524 @@
+//go:build unit
+
+package jwt
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var allAlgorithms = []string{AlgHS256, AlgHS384, AlgHS512}
+
+func TestSign_Parse_RoundTrip_HS256(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"sub": "user-1", "tenant_id": "abc"}
+ secret := []byte("test-secret-256")
+
+ tokenStr, err := Sign(claims, AlgHS256, secret)
+ require.NoError(t, err)
+ assert.NotEmpty(t, tokenStr)
+
+ token, err := Parse(tokenStr, secret, allAlgorithms)
+ require.NoError(t, err)
+ assert.True(t, token.SignatureValid)
+ assert.Equal(t, "user-1", token.Claims["sub"])
+ assert.Equal(t, "abc", token.Claims["tenant_id"])
+ assert.Equal(t, "HS256", token.Header["alg"])
+}
+
+func TestSign_Parse_RoundTrip_HS384(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"sub": "user-2"}
+ secret := []byte("test-secret-384")
+
+ tokenStr, err := Sign(claims, AlgHS384, secret)
+ require.NoError(t, err)
+
+ token, err := Parse(tokenStr, secret, allAlgorithms)
+ require.NoError(t, err)
+ assert.True(t, token.SignatureValid)
+ assert.Equal(t, "user-2", token.Claims["sub"])
+ assert.Equal(t, "HS384", token.Header["alg"])
+}
+
+func TestSign_Parse_RoundTrip_HS512(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"sub": "user-3"}
+ secret := []byte("test-secret-512")
+
+ tokenStr, err := Sign(claims, AlgHS512, secret)
+ require.NoError(t, err)
+
+ token, err := Parse(tokenStr, secret, allAlgorithms)
+ require.NoError(t, err)
+ assert.True(t, token.SignatureValid)
+ assert.Equal(t, "user-3", token.Claims["sub"])
+ assert.Equal(t, "HS512", token.Header["alg"])
+}
+
+func TestParse_WrongSecret(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"sub": "user-1"}
+ secret := []byte("correct-secret")
+
+ tokenStr, err := Sign(claims, AlgHS256, secret)
+ require.NoError(t, err)
+
+ _, err = Parse(tokenStr, []byte("wrong-secret"), allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrSignatureInvalid)
+}
+
+func TestParse_TamperedPayload(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"sub": "user-1", "role": "user"}
+ secret := []byte("test-secret")
+
+ tokenStr, err := Sign(claims, AlgHS256, secret)
+ require.NoError(t, err)
+
+ parts := strings.Split(tokenStr, ".")
+ tamperedPayload := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"admin","role":"admin"}`))
+ tampered := parts[0] + "." + tamperedPayload + "." + parts[2]
+
+ _, err = Parse(tampered, secret, allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrSignatureInvalid)
+}
+
+func TestParse_AlgorithmNotAllowed(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"sub": "user-1"}
+ secret := []byte("test-secret")
+
+ tokenStr, err := Sign(claims, AlgHS256, secret)
+ require.NoError(t, err)
+
+ _, err = Parse(tokenStr, secret, []string{AlgHS384, AlgHS512})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
+}
+
+func TestParse_NoneAlgorithmRejected(t *testing.T) {
+ t.Parallel()
+
+ header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
+ payload := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"attacker"}`))
+ noneToken := header + "." + payload + "."
+
+ _, err := Parse(noneToken, []byte("secret"), allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
+}
+
+func TestParse_MalformedToken_WrongParts(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("only.two", []byte("secret"), allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidToken)
+
+ _, err = Parse("one", []byte("secret"), allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidToken)
+
+ _, err = Parse("a.b.c.d", []byte("secret"), allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidToken)
+}
+
+func TestParse_EmptyToken(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("", []byte("secret"), allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidToken)
+}
+
+func TestParse_ClaimsCorrectlyParsed(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{
+ "tenant_id": "550e8400-e29b-41d4-a716-446655440000",
+ "sub": "user-42",
+ "exp": float64(9999999999),
+ }
+ secret := []byte("parse-claims-secret")
+
+ tokenStr, err := Sign(claims, AlgHS256, secret)
+ require.NoError(t, err)
+
+ token, err := Parse(tokenStr, secret, allAlgorithms)
+ require.NoError(t, err)
+ assert.True(t, token.SignatureValid)
+ assert.Equal(t, "550e8400-e29b-41d4-a716-446655440000", token.Claims["tenant_id"])
+ assert.Equal(t, "user-42", token.Claims["sub"])
+
+ // With UseNumber(), numeric claims are json.Number, not float64.
+ expNum, ok := token.Claims["exp"].(json.Number)
+ require.True(t, ok, "exp claim should be json.Number after UseNumber() decoding")
+ assert.Equal(t, "9999999999", expNum.String())
+}
+
+func TestParse_OversizedToken_ReturnsError(t *testing.T) {
+ t.Parallel()
+
+ // Build a token string that exceeds the 8192-byte maximum.
+ oversized := strings.Repeat("a", 8193)
+
+ _, err := Parse(oversized, []byte("secret"), allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidToken)
+ assert.Contains(t, err.Error(), "exceeds maximum length")
+}
+
+func TestSign_UnsupportedAlgorithm(t *testing.T) {
+ t.Parallel()
+
+ _, err := Sign(MapClaims{"sub": "x"}, "RS256", []byte("secret"))
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsupportedAlgorithm)
+}
+
+func TestValidateTimeClaims_AllValid(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+ claims := MapClaims{
+ "exp": float64(now.Add(1 * time.Hour).Unix()),
+ "nbf": float64(now.Add(-1 * time.Hour).Unix()),
+ "iat": float64(now.Add(-30 * time.Minute).Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ assert.NoError(t, err)
+}
+
+func TestValidateTimeClaims_ExpiredToken(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+ claims := MapClaims{
+ "exp": float64(now.Add(-1 * time.Hour).Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTokenExpired)
+}
+
+func TestValidateTimeClaims_NotYetValid(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+ claims := MapClaims{
+ "nbf": float64(now.Add(1 * time.Hour).Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTokenNotYetValid)
+}
+
+func TestValidateTimeClaims_IssuedInFuture(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+ claims := MapClaims{
+ "iat": float64(now.Add(1 * time.Hour).Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTokenIssuedInFuture)
+}
+
+func TestValidateTimeClaims_MissingClaims(t *testing.T) {
+ t.Parallel()
+
+ err := ValidateTimeClaims(MapClaims{"sub": "user-1"})
+ assert.NoError(t, err)
+}
+
+func TestValidateTimeClaims_EmptyClaims(t *testing.T) {
+ t.Parallel()
+
+ err := ValidateTimeClaims(MapClaims{})
+ assert.NoError(t, err)
+}
+
+func TestValidateTimeClaims_WrongTypeReturnsError(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ claims MapClaims
+ }{
+ {name: "string exp", claims: MapClaims{"exp": "not-a-number"}},
+ {name: "bool nbf", claims: MapClaims{"nbf": true}},
+ {name: "slice iat", claims: MapClaims{"iat": []string{"invalid"}}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := ValidateTimeClaims(tt.claims)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidTimeClaim)
+ })
+ }
+}
+
+func TestValidateTimeClaims_JsonNumberFormat(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+ future := now.Add(1 * time.Hour).Unix()
+ past := now.Add(-1 * time.Hour).Unix()
+
+ t.Run("valid json.Number claims", func(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{
+ "exp": json.Number(fmt.Sprintf("%d", future)),
+ "nbf": json.Number(fmt.Sprintf("%d", past)),
+ "iat": json.Number(fmt.Sprintf("%d", past)),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ assert.NoError(t, err)
+ })
+
+ t.Run("expired json.Number", func(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{
+ "exp": json.Number(fmt.Sprintf("%d", past)),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTokenExpired)
+ })
+
+ t.Run("invalid json.Number returns error", func(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{
+ "exp": json.Number("not-a-number"),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidTimeClaim)
+ })
+}
+
+func TestValidateTimeClaims_BoundaryExpNow(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+
+ t.Run("expired 1 second ago", func(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{
+ "exp": float64(now.Add(-1 * time.Second).Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTokenExpired)
+ })
+
+ t.Run("exact expiry instant is expired per RFC 7519", func(t *testing.T) {
+ t.Parallel()
+
+ // Token expiry is exactly now. Per RFC 7519 §4.1.4, the token
+ // MUST NOT be accepted on or after the expiration time.
+ claims := MapClaims{
+ "exp": float64(now.Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTokenExpired)
+ })
+}
+
+func TestValidateTimeClaims_BoundaryNbfNow(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+
+ // Token becomes valid 1 second ago — should be valid.
+ claims := MapClaims{
+ "nbf": float64(now.Add(-1 * time.Second).Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ assert.NoError(t, err)
+}
+
+func TestValidateTimeClaims_MultipleErrors_ReturnsFirst(t *testing.T) {
+ t.Parallel()
+
+ now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+
+ // Both exp and nbf are invalid; exp is checked first.
+ claims := MapClaims{
+ "exp": float64(now.Add(-1 * time.Hour).Unix()),
+ "nbf": float64(now.Add(1 * time.Hour).Unix()),
+ }
+
+ err := ValidateTimeClaimsAt(claims, now)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTokenExpired)
+}
+
+func TestExtractTime_Float64(t *testing.T) {
+ t.Parallel()
+
+ ts := float64(1700000000)
+ claims := MapClaims{"exp": ts}
+
+ result, ok, err := extractTime(claims, "exp")
+ require.NoError(t, err)
+ assert.True(t, ok)
+ assert.Equal(t, time.Unix(1700000000, 0).UTC(), result)
+}
+
+func TestExtractTime_JsonNumber(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"exp": json.Number("1700000000")}
+
+ result, ok, err := extractTime(claims, "exp")
+ require.NoError(t, err)
+ assert.True(t, ok)
+ assert.Equal(t, time.Unix(1700000000, 0).UTC(), result)
+}
+
+func TestExtractTime_IntTypes(t *testing.T) {
+ t.Parallel()
+
+ t.Run("int", func(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"exp": int(1700000000)}
+ result, ok, err := extractTime(claims, "exp")
+ require.NoError(t, err)
+ assert.True(t, ok)
+ assert.Equal(t, time.Unix(1700000000, 0).UTC(), result)
+ })
+
+ t.Run("int32", func(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"exp": int32(1700000000)}
+ result, ok, err := extractTime(claims, "exp")
+ require.NoError(t, err)
+ assert.True(t, ok)
+ assert.Equal(t, time.Unix(1700000000, 0).UTC(), result)
+ })
+
+ t.Run("int64", func(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"exp": int64(1700000000)}
+ result, ok, err := extractTime(claims, "exp")
+ require.NoError(t, err)
+ assert.True(t, ok)
+ assert.Equal(t, time.Unix(1700000000, 0).UTC(), result)
+ })
+}
+
+func TestExtractTime_Missing(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"sub": "user-1"}
+
+ _, ok, err := extractTime(claims, "exp")
+ require.NoError(t, err)
+ assert.False(t, ok)
+}
+
+func TestExtractTime_InvalidType(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"exp": "string-value"}
+
+ _, ok, err := extractTime(claims, "exp")
+ require.Error(t, err)
+ assert.False(t, ok)
+ assert.ErrorIs(t, err, ErrInvalidTimeClaim)
+}
+
+func TestExtractTime_InvalidJsonNumber(t *testing.T) {
+ t.Parallel()
+
+ claims := MapClaims{"exp": json.Number("abc")}
+
+ _, ok, err := extractTime(claims, "exp")
+ require.Error(t, err)
+ assert.False(t, ok)
+ assert.ErrorIs(t, err, ErrInvalidTimeClaim)
+}
+
+func TestNilToken_ValidateTimeClaims(t *testing.T) {
+ t.Parallel()
+
+ var token *Token
+
+ err := token.ValidateTimeClaims()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilToken)
+}
+
+func TestNilToken_ValidateTimeClaimsAt(t *testing.T) {
+ t.Parallel()
+
+ var token *Token
+
+ err := token.ValidateTimeClaimsAt(time.Now())
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilToken)
+}
+
+func TestParse_EmptySecret(t *testing.T) {
+ t.Parallel()
+
+ _, err := Parse("a.b.c", nil, allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptySecret)
+
+ _, err = Parse("a.b.c", []byte{}, allAlgorithms)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptySecret)
+}
+
+func TestSign_EmptySecret(t *testing.T) {
+ t.Parallel()
+
+ _, err := Sign(MapClaims{"sub": "x"}, AlgHS256, nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptySecret)
+
+ _, err = Sign(MapClaims{"sub": "x"}, AlgHS256, []byte{})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptySecret)
+}
diff --git a/commons/license/doc.go b/commons/license/doc.go
new file mode 100644
index 00000000..0baf91f5
--- /dev/null
+++ b/commons/license/doc.go
@@ -0,0 +1,5 @@
+// Package license provides helpers for license validation and management.
+//
+// It centralizes license parsing and policy checks so callers can enforce
+// product capabilities consistently at startup and runtime boundaries.
+package license
diff --git a/commons/license/manager.go b/commons/license/manager.go
index 4f785c7f..9fc95579 100644
--- a/commons/license/manager.go
+++ b/commons/license/manager.go
@@ -1,14 +1,13 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package license
import (
+ "context"
"errors"
"fmt"
- "os"
"sync"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
)
var (
@@ -21,13 +20,56 @@ var (
// Handler defines the function signature for termination handlers
type Handler func(reason string)
+// ManagerOption is a functional option for configuring ManagerShutdown.
+type ManagerOption func(*ManagerShutdown)
+
+// WithLogger provides a structured logger for assertion and validation logging.
+func WithLogger(l log.Logger) ManagerOption {
+ return func(m *ManagerShutdown) {
+ if l != nil {
+ m.Logger = l
+ }
+ }
+}
+
+// WithFailClosed configures the manager to record an assertion failure AND
+// log the reason at error level, providing a fail-closed posture where
+// license validation failures produce observable signals (assertion events
+// + error logs) rather than being silently swallowed.
+//
+// Callers that need an actual process exit should combine this with
+// SetHandler to provide their own os.Exit or signal-based shutdown.
+//
+// Contrast with the default fail-open behavior where validation failures are
+// only recorded as assertion events.
+func WithFailClosed() ManagerOption {
+ return func(m *ManagerShutdown) {
+ m.handler = func(reason string) {
+ // Record assertion event (same as DefaultHandler)
+ asserter := assert.New(context.Background(), m.Logger, "license", "FailClosed")
+ _ = asserter.Never(context.Background(), "LICENSE VALIDATION FAILED (fail-closed)", "reason", reason)
+
+ // Also log at error level if logger is available
+ if m.Logger != nil {
+ m.Logger.Log(context.Background(), log.LevelError, "license validation failed (fail-closed mode)",
+ log.String("reason", reason),
+ )
+ }
+ }
+ }
+}
+
// DefaultHandler is the default termination behavior.
-// It logs the failure reason to stderr and terminates the process with exit code 1.
-// This ensures the application cannot continue running with an invalid license,
-// even when a recovery middleware is present that would catch panics.
+// It records an assertion failure without panicking.
+//
+// NOTE: This intentionally implements a fail-open policy: license validation
+// failures are recorded as assertion events but do NOT terminate the process.
+// This design choice avoids unexpected shutdowns in environments where the
+// license server is unreachable. To enforce a fail-closed policy, use
+// WithFailClosed() when constructing the manager.
func DefaultHandler(reason string) {
- fmt.Fprintf(os.Stderr, "LICENSE VALIDATION FAILED: %s\n", reason)
- os.Exit(1)
+ asserter := assert.New(context.Background(), nil, "license", "DefaultHandler")
+ _ = asserter.Never(context.Background(), "LICENSE VALIDATION FAILED", "reason", reason)
}
// DefaultHandlerWithError returns an error instead of panicking.
@@ -39,20 +81,31 @@ func DefaultHandlerWithError(reason string) error {
// ManagerShutdown handles termination behavior
type ManagerShutdown struct {
handler Handler
+ Logger log.Logger
mu sync.RWMutex
}
-// New creates a new termination manager with the default handler
-func New() *ManagerShutdown {
- return &ManagerShutdown{
+// New creates a new termination manager with the default handler.
+// Options can be provided to configure the manager (e.g., WithLogger).
+// Nil options in the variadic list are silently skipped.
+func New(opts ...ManagerOption) *ManagerShutdown {
+ m := &ManagerShutdown{
handler: DefaultHandler,
}
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(m)
+ }
+ }
+
+ return m
}
// SetHandler updates the termination handler
// This should be called during application startup, before any validation occurs
func (m *ManagerShutdown) SetHandler(handler Handler) {
- if handler == nil {
+ if m == nil || handler == nil {
return
}
@@ -65,15 +118,30 @@ func (m *ManagerShutdown) SetHandler(handler Handler) {
// Terminate invokes the termination handler.
// This will trigger the application to gracefully shut down.
//
-// Note: This method panics if the manager was not initialized with New().
-// Use TerminateSafe() if you need to handle the uninitialized case gracefully.
+// Note: This method no longer panics if the manager was not initialized with New().
+// In that case it records an assertion failure and returns.
func (m *ManagerShutdown) Terminate(reason string) {
+ if m == nil {
+ // nil receiver: no logger available, nil is legitimate here.
+ asserter := assert.New(context.Background(), nil, "license", "Terminate")
+ _ = asserter.Never(context.Background(), "license.ManagerShutdown is nil")
+
+ return
+ }
+
m.mu.RLock()
handler := m.handler
+ logger := m.Logger
m.mu.RUnlock()
if handler == nil {
- panic(ErrManagerNotInitialized)
+ asserter := assert.New(context.Background(), logger, "license", "Terminate")
+ _ = asserter.NoError(context.Background(), ErrManagerNotInitialized,
+ "license terminate called without initialization",
+ "reason", reason,
+ )
+
+ return
}
handler(reason)
@@ -83,27 +151,50 @@ func (m *ManagerShutdown) Terminate(reason string) {
// Use this when you want to check license validity without triggering shutdown.
//
// Note: This method intentionally does NOT invoke the custom handler set via SetHandler().
-// It always returns ErrLicenseValidationFailed wrapped with the reason, regardless of
-// manager initialization state. This differs from Terminate() which requires initialization
-// and invokes the configured handler. Use Terminate() for actual shutdown behavior,
+// It always returns ErrLicenseValidationFailed wrapped with the reason when the
+// manager is properly initialized. Use Terminate() for actual shutdown behavior,
// and TerminateWithError() for validation checks that should return errors.
+//
+// Nil receiver: returns ErrManagerNotInitialized (not ErrLicenseValidationFailed)
+// to distinguish between "license failed" and "manager not created".
func (m *ManagerShutdown) TerminateWithError(reason string) error {
+ if m == nil {
+ return ErrManagerNotInitialized
+ }
+
+ if m.Logger != nil {
+ m.Logger.Log(context.Background(), log.LevelWarn, "license validation failed",
+ log.String("reason", reason),
+ )
+ }
+
return fmt.Errorf("%w: %s", ErrLicenseValidationFailed, reason)
}
// TerminateSafe invokes the termination handler and returns an error if the manager
// was not properly initialized. This is the safe alternative to Terminate that
-// returns an error instead of panicking when the handler is nil.
+// returns an explicit error when the handler is nil.
//
// Use this method when you need to handle the uninitialized manager case gracefully.
-// For normal shutdown behavior where panic on uninitialized manager is acceptable,
+// For normal shutdown behavior where assertion-based handling is acceptable,
// use Terminate() instead.
func (m *ManagerShutdown) TerminateSafe(reason string) error {
+ if m == nil {
+ return ErrManagerNotInitialized
+ }
+
m.mu.RLock()
handler := m.handler
+ logger := m.Logger
m.mu.RUnlock()
if handler == nil {
+ if logger != nil {
+ logger.Log(context.Background(), log.LevelWarn, "license terminate called without initialization",
+ log.String("reason", reason),
+ )
+ }
+
return ErrManagerNotInitialized
}
diff --git a/commons/license/manager_nil_test.go b/commons/license/manager_nil_test.go
new file mode 100644
index 00000000..85c0a41d
--- /dev/null
+++ b/commons/license/manager_nil_test.go
@@ -0,0 +1,105 @@
+//go:build unit
+
+package license_test
+
+import (
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/license"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNilReceiver_Terminate(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil pointer Terminate does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ var m *license.ManagerShutdown
+
+ assert.NotPanics(t, func() {
+ m.Terminate("nil receiver test")
+ })
+ })
+
+ t.Run("nil pointer TerminateWithError does not panic and returns error", func(t *testing.T) {
+ t.Parallel()
+
+ var m *license.ManagerShutdown
+
+ assert.NotPanics(t, func() {
+ err := m.TerminateWithError("nil receiver test")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, license.ErrManagerNotInitialized)
+ })
+ })
+
+ t.Run("nil pointer TerminateSafe does not panic and returns error", func(t *testing.T) {
+ t.Parallel()
+
+ var m *license.ManagerShutdown
+
+ assert.NotPanics(t, func() {
+ err := m.TerminateSafe("nil receiver test")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, license.ErrManagerNotInitialized)
+ })
+ })
+
+ t.Run("nil pointer SetHandler does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ var m *license.ManagerShutdown
+
+ assert.NotPanics(t, func() {
+ m.SetHandler(func(_ string) {})
+ })
+ })
+}
+
+func TestNilReceiver_WithLogger(t *testing.T) {
+ t.Parallel()
+
+ t.Run("WithLogger configures logger on new manager", func(t *testing.T) {
+ t.Parallel()
+
+ nop := log.NewNop()
+ m := license.New(license.WithLogger(nop))
+
+ // Verify the manager works — logger is used internally by TerminateWithError
+ // when Logger != nil. We verify it doesn't panic and behaves correctly.
+ err := m.TerminateWithError("test with logger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, license.ErrLicenseValidationFailed)
+ })
+
+ t.Run("WithLogger with nil logger is safe", func(t *testing.T) {
+ t.Parallel()
+
+ // WithLogger(nil) should be a no-op — Logger remains nil.
+ m := license.New(license.WithLogger(nil))
+
+ assert.NotPanics(t, func() {
+ err := m.TerminateWithError("test with nil logger")
+ require.Error(t, err)
+ })
+ })
+
+ t.Run("WithLogger can be combined with SetHandler", func(t *testing.T) {
+ t.Parallel()
+
+ nop := log.NewNop()
+ handlerCalled := false
+
+ m := license.New(license.WithLogger(nop))
+ m.SetHandler(func(reason string) {
+ handlerCalled = true
+ assert.Equal(t, "combo test", reason)
+ })
+
+ m.Terminate("combo test")
+ assert.True(t, handlerCalled)
+ })
+}
diff --git a/commons/license/manager_test.go b/commons/license/manager_test.go
index dccc156b..b6d5e8ff 100644
--- a/commons/license/manager_test.go
+++ b/commons/license/manager_test.go
@@ -1,17 +1,12 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package license_test
import (
- "bytes"
"errors"
- "os"
- "os/exec"
"testing"
- "github.com/LerianStudio/lib-commons/v2/commons/license"
+ "github.com/LerianStudio/lib-commons/v4/commons/license"
"github.com/stretchr/testify/assert"
)
@@ -47,44 +42,12 @@ func TestSetHandlerWithNil(t *testing.T) {
assert.True(t, handlerCalled, "Original handler should still be called when nil is passed")
}
-// runSubprocessTest runs the named test in a subprocess with the given env var set to "1".
-// It asserts the process exits with code 1 and stderr contains "LICENSE VALIDATION FAILED"
-// plus any additional expected messages.
-func runSubprocessTest(t *testing.T, testName, envVar string, expectedMessages ...string) {
- t.Helper()
-
- cmd := exec.Command(os.Args[0], "-test.run="+testName)
- cmd.Env = append(os.Environ(), envVar+"=1")
-
- var stderr bytes.Buffer
- cmd.Stderr = &stderr
-
- err := cmd.Run()
-
- var exitErr *exec.ExitError
- if errors.As(err, &exitErr) {
- assert.Equal(t, 1, exitErr.ExitCode(), "Expected exit code 1")
- } else {
- t.Fatal("Expected process to exit with code 1")
- }
-
- assert.Contains(t, stderr.String(), "LICENSE VALIDATION FAILED")
-
- for _, msg := range expectedMessages {
- assert.Contains(t, stderr.String(), msg)
- }
-}
-
func TestDefaultHandler(t *testing.T) {
- // DefaultHandler calls os.Exit(1), so we test it in a subprocess
- if os.Getenv("TEST_DEFAULT_HANDLER_EXIT") == "1" {
- manager := license.New()
- manager.Terminate("default handler test")
-
- return
- }
+ manager := license.New()
- runSubprocessTest(t, "TestDefaultHandler", "TEST_DEFAULT_HANDLER_EXIT", "default handler test")
+ assert.NotPanics(t, func() {
+ manager.Terminate("default handler test")
+ }, "Default handler should not panic")
}
func TestDefaultHandlerWithError(t *testing.T) {
@@ -128,14 +91,13 @@ func TestTerminateWithError_UninitializedManager(t *testing.T) {
assert.Contains(t, err.Error(), "test reason")
}
-func TestTerminate_UninitializedManagerPanics(t *testing.T) {
- // Terminate requires a handler to be set. On a zero-value manager,
- // the handler is nil, causing a panic with ErrManagerNotInitialized.
+func TestTerminate_UninitializedManagerDoesNotPanic(t *testing.T) {
+ // Terminate on zero-value manager should fail safely without panic.
var manager license.ManagerShutdown
- assert.Panics(t, func() {
+ assert.NotPanics(t, func() {
manager.Terminate("test reason")
- }, "Terminate on uninitialized manager should panic")
+ }, "Terminate on uninitialized manager should not panic")
}
func TestDefaultHandlerWithError_EmptyReason(t *testing.T) {
@@ -178,13 +140,49 @@ func TestTerminateSafe_UninitializedManager(t *testing.T) {
}
func TestTerminateSafe_WithDefaultHandler(t *testing.T) {
- // DefaultHandler calls os.Exit(1), so we test it in a subprocess
- if os.Getenv("TEST_TERMINATE_SAFE_DEFAULT_EXIT") == "1" {
- manager := license.New()
- _ = manager.TerminateSafe("test")
+ manager := license.New()
- return
+ err := manager.TerminateSafe("test")
+ assert.NoError(t, err)
+}
+
+func TestNew_NilOptionSkipped(t *testing.T) {
+ t.Parallel()
+
+ // Nil options in the variadic list should be silently skipped.
+ assert.NotPanics(t, func() {
+ manager := license.New(nil, nil)
+ assert.NotNil(t, manager)
+ })
+}
+
+func TestNew_NilOptionMixedWithValid(t *testing.T) {
+ t.Parallel()
+
+ handlerCalled := false
+ customHandler := func(reason string) {
+ handlerCalled = true
}
- runSubprocessTest(t, "TestTerminateSafe_WithDefaultHandler", "TEST_TERMINATE_SAFE_DEFAULT_EXIT")
+ // Mix nil options with valid options.
+ manager := license.New(nil, license.WithLogger(nil), nil)
+ assert.NotNil(t, manager)
+
+ manager.SetHandler(customHandler)
+ manager.Terminate("test")
+ assert.True(t, handlerCalled)
+}
+
+func TestWithFailClosed(t *testing.T) {
+ t.Parallel()
+
+ // WithFailClosed should set the handler to TerminateSafe behavior.
+ manager := license.New(license.WithFailClosed())
+
+ // TerminateSafe returns nil when handler is non-nil (it invokes handler then returns nil).
+ // The WithFailClosed handler itself calls TerminateSafe internally.
+ // Since the manager IS initialized (New was called), the handler should not return errors.
+ assert.NotPanics(t, func() {
+ manager.Terminate("fail-closed test")
+ })
}
diff --git a/commons/log/doc.go b/commons/log/doc.go
new file mode 100644
index 00000000..c1d2664c
--- /dev/null
+++ b/commons/log/doc.go
@@ -0,0 +1,5 @@
+// Package log defines the v2 logging interface and typed logging fields.
+//
+// Adapters (such as the zap package) implement Logger so applications can keep
+// logging calls consistent across backends.
+package log
diff --git a/commons/log/go_logger.go b/commons/log/go_logger.go
new file mode 100644
index 00000000..e9e1ed17
--- /dev/null
+++ b/commons/log/go_logger.go
@@ -0,0 +1,212 @@
+package log
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "reflect"
+ "strings"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/security"
+)
+
+var logControlCharReplacer = strings.NewReplacer(
+ "\n", `\n`,
+ "\r", `\r`,
+ "\t", `\t`,
+ "\x00", `\0`,
+)
+
+func sanitizeLogString(s string) string {
+ return logControlCharReplacer.Replace(s)
+}
+
+// GoLogger is the stdlib logger implementation for Logger.
+type GoLogger struct {
+ Level Level
+ fields []Field
+ groups []string
+}
+
+// Enabled reports whether the logger emits entries at the given level.
+// On a nil receiver, Enabled returns false silently. Use NopLogger as the
+// documented nil-safe alternative.
+//
+// Unknown level policy: levels outside the defined range (LevelError..LevelDebug)
+// are treated as suppressed by GoLogger (since their numeric value exceeds any
+// configured threshold). The zap adapter maps unknown levels to Info. The net
+// effect is: unknown levels produce Info-level output if a zap backend is used,
+// or are suppressed in the stdlib GoLogger. Callers should use only the defined
+// Level constants.
+func (l *GoLogger) Enabled(level Level) bool {
+ if l == nil {
+ return false
+ }
+
+ return l.Level >= level
+}
+
+// Log writes a single log line if the level is enabled.
+func (l *GoLogger) Log(_ context.Context, level Level, msg string, fields ...Field) {
+ if !l.Enabled(level) {
+ return
+ }
+
+ line := l.hydrateLine(level, msg, fields...)
+ log.Print(line)
+}
+
+// With returns a child logger with additional persistent fields.
+//
+//nolint:ireturn
+func (l *GoLogger) With(fields ...Field) Logger {
+ if l == nil {
+ return &NopLogger{}
+ }
+
+ newFields := make([]Field, 0, len(l.fields)+len(fields))
+ newFields = append(newFields, l.fields...)
+ newFields = append(newFields, fields...)
+
+ newGroups := make([]string, 0, len(l.groups))
+ newGroups = append(newGroups, l.groups...)
+
+ return &GoLogger{
+ Level: l.Level,
+ fields: newFields,
+ groups: newGroups,
+ }
+}
+
+// WithGroup returns a child logger scoped under the provided group name.
+// Empty or whitespace-only names are silently ignored, consistent with
+// the zap adapter. This avoids creating unnecessary allocations.
+//
+//nolint:ireturn
+func (l *GoLogger) WithGroup(name string) Logger {
+ if l == nil {
+ return &NopLogger{}
+ }
+
+ if strings.TrimSpace(name) == "" {
+ return l
+ }
+
+ newGroups := make([]string, 0, len(l.groups)+1)
+ newGroups = append(newGroups, l.groups...)
+ newGroups = append(newGroups, sanitizeLogString(name))
+
+ newFields := make([]Field, 0, len(l.fields))
+ newFields = append(newFields, l.fields...)
+
+ return &GoLogger{
+ Level: l.Level,
+ fields: newFields,
+ groups: newGroups,
+ }
+}
+
+// Sync flushes buffered logs. It is a no-op for the stdlib logger.
+func (l *GoLogger) Sync(_ context.Context) error { return nil }
+
+func (l *GoLogger) hydrateLine(level Level, msg string, fields ...Field) string {
+ parts := make([]string, 0, 4)
+ parts = append(parts, fmt.Sprintf("[%s]", level.String()))
+
+ if l != nil && len(l.groups) > 0 {
+ parts = append(parts, fmt.Sprintf("[group=%s]", strings.Join(l.groups, ".")))
+ }
+
+ allFields := make([]Field, 0, len(fields))
+ if l != nil {
+ allFields = append(allFields, l.fields...)
+ }
+
+ allFields = append(allFields, fields...)
+
+ if rendered := renderFields(allFields); rendered != "" {
+ parts = append(parts, rendered)
+ }
+
+ parts = append(parts, sanitizeLogString(msg))
+
+ return strings.Join(parts, " ")
+}
+
+// redactedValue is the placeholder used for sensitive field values in log output.
+const redactedValue = "[REDACTED]"
+
+func renderFields(fields []Field) string {
+ if len(fields) == 0 {
+ return ""
+ }
+
+ parts := make([]string, 0, len(fields))
+ for _, field := range fields {
+ key := sanitizeLogString(field.Key)
+ if key == "" {
+ continue
+ }
+
+ var rendered any
+ if security.IsSensitiveField(field.Key) {
+ rendered = redactedValue
+ } else {
+ rendered = sanitizeFieldValue(field.Value)
+ }
+
+ parts = append(parts, fmt.Sprintf("%s=%v", key, rendered))
+ }
+
+ if len(parts) == 0 {
+ return ""
+ }
+
+ return fmt.Sprintf("[%s]", strings.Join(parts, ", "))
+}
+
+// isTypedNil reports whether v is a non-nil interface wrapping a nil pointer.
+// This prevents panics when calling methods (Error, String) on typed-nil values.
+func isTypedNil(v any) bool {
+ if v == nil {
+ return false
+ }
+
+ rv := reflect.ValueOf(v)
+
+ switch rv.Kind() {
+ case reflect.Ptr, reflect.Interface, reflect.Func, reflect.Map, reflect.Slice, reflect.Chan:
+ return rv.IsNil()
+ default:
+ return false
+ }
+}
+
+func sanitizeFieldValue(value any) any {
+ if value == nil {
+ return nil
+ }
+
+ // Guard against typed-nil before calling interface methods.
+ if isTypedNil(value) {
+ return ""
+ }
+
+ switch v := value.(type) {
+ case string:
+ return sanitizeLogString(v)
+ case error:
+ return sanitizeLogString(v.Error())
+ case fmt.Stringer:
+ return sanitizeLogString(v.String())
+ case bool, int, int8, int16, int32, int64,
+ uint, uint8, uint16, uint32, uint64,
+ float32, float64:
+ // Primitive types cannot carry newlines; pass through unchanged.
+ return value
+ default:
+ // Composite types (structs, slices, maps, etc.) may carry raw newlines
+ // when rendered with fmt. Pre-serialize and sanitize the result.
+ return sanitizeLogString(fmt.Sprintf("%v", v))
+ }
+}
diff --git a/commons/log/log.go b/commons/log/log.go
index 7cdbfb28..bd3e6c3d 100644
--- a/commons/log/log.go
+++ b/commons/log/log.go
@@ -1,226 +1,116 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package log
import (
+ "context"
"fmt"
- "log"
"strings"
)
-// Logger is the pkg interface for log implementation.
+// Logger is the package interface for v2 logging.
//
//go:generate mockgen --destination=log_mock.go --package=log . Logger
type Logger interface {
- Info(args ...any)
- Infof(format string, args ...any)
- Infoln(args ...any)
-
- Error(args ...any)
- Errorf(format string, args ...any)
- Errorln(args ...any)
-
- Warn(args ...any)
- Warnf(format string, args ...any)
- Warnln(args ...any)
-
- Debug(args ...any)
- Debugf(format string, args ...any)
- Debugln(args ...any)
-
- Fatal(args ...any)
- Fatalf(format string, args ...any)
- Fatalln(args ...any)
-
- WithFields(fields ...any) Logger
-
- WithDefaultMessageTemplate(message string) Logger
-
- Sync() error
+ Log(ctx context.Context, level Level, msg string, fields ...Field)
+ With(fields ...Field) Logger
+ WithGroup(name string) Logger
+ Enabled(level Level) bool
+ Sync(ctx context.Context) error
}
-// LogLevel represents the level of log system (fatal, error, warn, info and debug).
-type LogLevel int8
-
-// These are the different log levels. You can set the logging level to log.
+// Level represents the severity of a log entry.
+//
+// Lower numeric values indicate higher severity (LevelError=0 is most severe,
+// LevelDebug=3 is least). This is inverted from slog/zap conventions where
+// higher numeric values mean higher severity.
+//
+// The GoLogger.Enabled comparison uses l.Level >= level, which works because
+// the logger's Level acts as a verbosity ceiling: a logger at LevelInfo (2)
+// emits Error (0), Warn (1), and Info (2) messages, but suppresses Debug (3).
+type Level uint8
+
+// Level constants define log severity. Lower numeric values indicate higher
+// severity. Setting a logger's Level to a given constant enables that level
+// and all levels with lower numeric values (i.e., higher severity).
+//
+// LevelError (0) -- only errors
+// LevelWarn (1) -- errors + warnings
+// LevelInfo (2) -- errors + warnings + info
+// LevelDebug (3) -- everything
const (
- // PanicLevel level, highest level of severity. Logs and then calls panic with the
- // message passed to Debug, Info, ...
- PanicLevel LogLevel = iota
- // FatalLevel level. Logs and then calls `logger.Exit(1)`. It will exit even if the
- // logging level is set to Panic.
- FatalLevel
- // ErrorLevel level. Logs. Used for errors that should definitely be noted.
- // Commonly used for hooks to send errors to an error tracking service.
- ErrorLevel
- // WarnLevel level. Non-critical entries that deserve eyes.
- WarnLevel
- // InfoLevel level. General operational entries about what's going on inside the
- // application.
- InfoLevel
- // DebugLevel level. Usually only enabled when debugging. Very verbose logging.
- DebugLevel
+ LevelError Level = iota
+ LevelWarn
+ LevelInfo
+ LevelDebug
)
-// ParseLevel takes a string level and returns a LogLevel constant.
-func ParseLevel(lvl string) (LogLevel, error) {
- switch strings.ToLower(lvl) {
- case "fatal":
- return FatalLevel, nil
- case "error":
- return ErrorLevel, nil
- case "warn", "warning":
- return WarnLevel, nil
- case "info":
- return InfoLevel, nil
+// LevelUnknown represents an invalid or unrecognized log level.
+// Returned by ParseLevel on error to distinguish from LevelError (the zero value).
+const LevelUnknown Level = 255
+
+// String returns the string representation of a log level.
+func (level Level) String() string {
+ switch level {
+ case LevelDebug:
+ return "debug"
+ case LevelInfo:
+ return "info"
+ case LevelWarn:
+ return "warn"
+ case LevelError:
+ return "error"
+ default:
+ return "unknown"
+ }
+}
+
+// ParseLevel takes a string level and returns a Level constant.
+// Leading and trailing whitespace is trimmed before matching.
+func ParseLevel(lvl string) (Level, error) {
+ switch strings.ToLower(strings.TrimSpace(lvl)) {
case "debug":
- return DebugLevel, nil
- }
-
- var l LogLevel
-
- return l, fmt.Errorf("not a valid LogLevel: %q", lvl)
-}
-
-// GoLogger is the Go built-in (log) implementation of Logger interface.
-type GoLogger struct {
- fields []any
- Level LogLevel
- defaultMessageTemplate string
-}
-
-// IsLevelEnabled checks if the given level is enabled.
-func (l *GoLogger) IsLevelEnabled(level LogLevel) bool {
- return l.Level >= level
-}
-
-// Info implements Info Logger interface function.
-func (l *GoLogger) Info(args ...any) {
- if l.IsLevelEnabled(InfoLevel) {
- log.Print(args...)
- }
-}
-
-// Infof implements Infof Logger interface function.
-func (l *GoLogger) Infof(format string, args ...any) {
- if l.IsLevelEnabled(InfoLevel) {
- log.Printf(format, args...)
- }
-}
-
-// Infoln implements Infoln Logger interface function.
-func (l *GoLogger) Infoln(args ...any) {
- if l.IsLevelEnabled(InfoLevel) {
- log.Println(args...)
- }
-}
-
-// Error implements Error Logger interface function.
-func (l *GoLogger) Error(args ...any) {
- if l.IsLevelEnabled(ErrorLevel) {
- log.Print(args...)
- }
-}
-
-// Errorf implements Errorf Logger interface function.
-func (l *GoLogger) Errorf(format string, args ...any) {
- if l.IsLevelEnabled(ErrorLevel) {
- log.Printf(format, args...)
- }
-}
-
-// Errorln implements Errorln Logger interface function.
-func (l *GoLogger) Errorln(args ...any) {
- if l.IsLevelEnabled(ErrorLevel) {
- log.Println(args...)
- }
-}
-
-// Warn implements Warn Logger interface function.
-func (l *GoLogger) Warn(args ...any) {
- if l.IsLevelEnabled(WarnLevel) {
- log.Print(args...)
- }
-}
-
-// Warnf implements Warnf Logger interface function.
-func (l *GoLogger) Warnf(format string, args ...any) {
- if l.IsLevelEnabled(WarnLevel) {
- log.Printf(format, args...)
- }
-}
-
-// Warnln implements Warnln Logger interface function.
-func (l *GoLogger) Warnln(args ...any) {
- if l.IsLevelEnabled(WarnLevel) {
- log.Println(args...)
- }
-}
-
-// Debug implements Debug Logger interface function.
-func (l *GoLogger) Debug(args ...any) {
- if l.IsLevelEnabled(DebugLevel) {
- log.Print(args...)
+ return LevelDebug, nil
+ case "info":
+ return LevelInfo, nil
+ case "warn", "warning":
+ return LevelWarn, nil
+ case "error":
+ return LevelError, nil
}
-}
-// Debugf implements Debugf Logger interface function.
-func (l *GoLogger) Debugf(format string, args ...any) {
- if l.IsLevelEnabled(DebugLevel) {
- log.Printf(format, args...)
- }
+ return LevelUnknown, fmt.Errorf("not a valid Level: %q", lvl)
}
-// Debugln implements Debugln Logger interface function.
-func (l *GoLogger) Debugln(args ...any) {
- if l.IsLevelEnabled(DebugLevel) {
- log.Println(args...)
- }
+// Field is a strongly-typed key/value attribute attached to a log event.
+type Field struct {
+ Key string
+ Value any
}
-// Fatal implements Fatal Logger interface function.
-func (l *GoLogger) Fatal(args ...any) {
- if l.IsLevelEnabled(FatalLevel) {
- log.Print(args...)
- }
+// Any creates a field with an arbitrary value.
+//
+// WARNING: prefer typed constructors (String, Int, Bool, Err) to avoid
+// accidentally logging sensitive data (passwords, tokens, PII). If using
+// Any, ensure the value is sanitized or non-sensitive.
+func Any(key string, value any) Field {
+ return Field{Key: key, Value: value}
}
-// Fatalf implements Fatalf Logger interface function.
-func (l *GoLogger) Fatalf(format string, args ...any) {
- if l.IsLevelEnabled(FatalLevel) {
- log.Printf(format, args...)
- }
+// String creates a string field.
+func String(key, value string) Field {
+ return Field{Key: key, Value: value}
}
-// Fatalln implements Fatalln Logger interface function.
-func (l *GoLogger) Fatalln(args ...any) {
- if l.IsLevelEnabled(FatalLevel) {
- log.Println(args...)
- }
+// Int creates an integer field.
+func Int(key string, value int) Field {
+ return Field{Key: key, Value: value}
}
-// WithFields implements WithFields Logger interface function
-//
-//nolint:ireturn
-func (l *GoLogger) WithFields(fields ...any) Logger {
- return &GoLogger{
- Level: l.Level,
- fields: fields,
- defaultMessageTemplate: l.defaultMessageTemplate,
- }
+// Bool creates a boolean field.
+func Bool(key string, value bool) Field {
+ return Field{Key: key, Value: value}
}
-func (l *GoLogger) WithDefaultMessageTemplate(message string) Logger {
- return &GoLogger{
- Level: l.Level,
- fields: l.fields,
- defaultMessageTemplate: message,
- }
+// Err creates the conventional `error` field.
+func Err(err error) Field {
+ return Field{Key: "error", Value: err}
}
-
-// Sync implements Sync Logger interface function.
-//
-//nolint:ireturn
-func (l *GoLogger) Sync() error { return nil }
diff --git a/commons/log/log_example_test.go b/commons/log/log_example_test.go
new file mode 100644
index 00000000..e8d9345f
--- /dev/null
+++ b/commons/log/log_example_test.go
@@ -0,0 +1,20 @@
+//go:build unit
+
+package log_test
+
+import (
+ "fmt"
+
+ ulog "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+func ExampleParseLevel() {
+ level, err := ulog.ParseLevel("warning")
+
+ fmt.Println(err == nil)
+ fmt.Println(level.String())
+
+ // Output:
+ // true
+ // warn
+}
diff --git a/commons/log/log_mock.go b/commons/log/log_mock.go
index c2deb226..bde46c27 100644
--- a/commons/log/log_mock.go
+++ b/commons/log/log_mock.go
@@ -1,19 +1,15 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/LerianStudio/lib-commons/v2/commons/log (interfaces: Logger)
+// Source: github.com/LerianStudio/lib-commons/v4/commons/log (interfaces: Logger)
//
// Generated by this command:
//
// mockgen --destination=log_mock.go --package=log . Logger
//
-// Package log is a generated GoMock package.
package log
import (
+ context "context"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
@@ -43,293 +39,79 @@ func (m *MockLogger) EXPECT() *MockLoggerMockRecorder {
return m.recorder
}
-// Debug mocks base method.
-func (m *MockLogger) Debug(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Debug", varargs...)
-}
-
-// Debug indicates an expected call of Debug.
-func (mr *MockLoggerMockRecorder) Debug(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), args...)
-}
-
-// Debugf mocks base method.
-func (m *MockLogger) Debugf(format string, args ...any) {
+// Enabled mocks base method.
+func (m *MockLogger) Enabled(level Level) bool {
m.ctrl.T.Helper()
- varargs := []any{format}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Debugf", varargs...)
-}
-
-// Debugf indicates an expected call of Debugf.
-func (mr *MockLoggerMockRecorder) Debugf(format any, args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- varargs := append([]any{format}, args...)
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...)
-}
-
-// Debugln mocks base method.
-func (m *MockLogger) Debugln(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Debugln", varargs...)
-}
-
-// Debugln indicates an expected call of Debugln.
-func (mr *MockLoggerMockRecorder) Debugln(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugln", reflect.TypeOf((*MockLogger)(nil).Debugln), args...)
-}
-
-// Error mocks base method.
-func (m *MockLogger) Error(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Error", varargs...)
-}
-
-// Error indicates an expected call of Error.
-func (mr *MockLoggerMockRecorder) Error(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), args...)
-}
-
-// Errorf mocks base method.
-func (m *MockLogger) Errorf(format string, args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{format}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Errorf", varargs...)
-}
-
-// Errorf indicates an expected call of Errorf.
-func (mr *MockLoggerMockRecorder) Errorf(format any, args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- varargs := append([]any{format}, args...)
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...)
-}
-
-// Errorln mocks base method.
-func (m *MockLogger) Errorln(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Errorln", varargs...)
-}
-
-// Errorln indicates an expected call of Errorln.
-func (mr *MockLoggerMockRecorder) Errorln(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorln", reflect.TypeOf((*MockLogger)(nil).Errorln), args...)
-}
-
-// Fatal mocks base method.
-func (m *MockLogger) Fatal(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Fatal", varargs...)
-}
-
-// Fatal indicates an expected call of Fatal.
-func (mr *MockLoggerMockRecorder) Fatal(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fatal", reflect.TypeOf((*MockLogger)(nil).Fatal), args...)
-}
-
-// Fatalf mocks base method.
-func (m *MockLogger) Fatalf(format string, args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{format}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Fatalf", varargs...)
-}
-
-// Fatalf indicates an expected call of Fatalf.
-func (mr *MockLoggerMockRecorder) Fatalf(format any, args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- varargs := append([]any{format}, args...)
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fatalf", reflect.TypeOf((*MockLogger)(nil).Fatalf), varargs...)
-}
-
-// Fatalln mocks base method.
-func (m *MockLogger) Fatalln(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Fatalln", varargs...)
-}
-
-// Fatalln indicates an expected call of Fatalln.
-func (mr *MockLoggerMockRecorder) Fatalln(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fatalln", reflect.TypeOf((*MockLogger)(nil).Fatalln), args...)
-}
-
-// Info mocks base method.
-func (m *MockLogger) Info(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Info", varargs...)
-}
-
-// Info indicates an expected call of Info.
-func (mr *MockLoggerMockRecorder) Info(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), args...)
-}
-
-// Infof mocks base method.
-func (m *MockLogger) Infof(format string, args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{format}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Infof", varargs...)
+ ret := m.ctrl.Call(m, "Enabled", level)
+ ret0, _ := ret[0].(bool)
+ return ret0
}
-// Infof indicates an expected call of Infof.
-func (mr *MockLoggerMockRecorder) Infof(format any, args ...any) *gomock.Call {
+// Enabled indicates an expected call of Enabled.
+func (mr *MockLoggerMockRecorder) Enabled(level any) *gomock.Call {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{format}, args...)
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infof", reflect.TypeOf((*MockLogger)(nil).Infof), varargs...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Enabled", reflect.TypeOf((*MockLogger)(nil).Enabled), level)
}
-// Infoln mocks base method.
-func (m *MockLogger) Infoln(args ...any) {
+// Log mocks base method.
+func (m *MockLogger) Log(ctx context.Context, level Level, msg string, fields ...Field) {
m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
+ varargs := []any{ctx, level, msg}
+ for _, a := range fields {
varargs = append(varargs, a)
}
- m.ctrl.Call(m, "Infoln", varargs...)
+ m.ctrl.Call(m, "Log", varargs...)
}
-// Infoln indicates an expected call of Infoln.
-func (mr *MockLoggerMockRecorder) Infoln(args ...any) *gomock.Call {
+// Log indicates an expected call of Log.
+func (mr *MockLoggerMockRecorder) Log(ctx, level, msg any, fields ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Infoln", reflect.TypeOf((*MockLogger)(nil).Infoln), args...)
+ varargs := append([]any{ctx, level, msg}, fields...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Log", reflect.TypeOf((*MockLogger)(nil).Log), varargs...)
}
// Sync mocks base method.
-func (m *MockLogger) Sync() error {
+func (m *MockLogger) Sync(ctx context.Context) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Sync")
+ ret := m.ctrl.Call(m, "Sync", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Sync indicates an expected call of Sync.
-func (mr *MockLoggerMockRecorder) Sync() *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockLogger)(nil).Sync))
-}
-
-// Warn mocks base method.
-func (m *MockLogger) Warn(args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Warn", varargs...)
-}
-
-// Warn indicates an expected call of Warn.
-func (mr *MockLoggerMockRecorder) Warn(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), args...)
-}
-
-// Warnf mocks base method.
-func (m *MockLogger) Warnf(format string, args ...any) {
- m.ctrl.T.Helper()
- varargs := []any{format}
- for _, a := range args {
- varargs = append(varargs, a)
- }
- m.ctrl.Call(m, "Warnf", varargs...)
-}
-
-// Warnf indicates an expected call of Warnf.
-func (mr *MockLoggerMockRecorder) Warnf(format any, args ...any) *gomock.Call {
+func (mr *MockLoggerMockRecorder) Sync(ctx any) *gomock.Call {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{format}, args...)
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockLogger)(nil).Warnf), varargs...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockLogger)(nil).Sync), ctx)
}
-// Warnln mocks base method.
-func (m *MockLogger) Warnln(args ...any) {
+// With mocks base method.
+func (m *MockLogger) With(fields ...Field) Logger {
m.ctrl.T.Helper()
varargs := []any{}
- for _, a := range args {
+ for _, a := range fields {
varargs = append(varargs, a)
}
- m.ctrl.Call(m, "Warnln", varargs...)
-}
-
-// Warnln indicates an expected call of Warnln.
-func (mr *MockLoggerMockRecorder) Warnln(args ...any) *gomock.Call {
- mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnln", reflect.TypeOf((*MockLogger)(nil).Warnln), args...)
-}
-
-// WithDefaultMessageTemplate mocks base method.
-func (m *MockLogger) WithDefaultMessageTemplate(message string) Logger {
- m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "WithDefaultMessageTemplate", message)
+ ret := m.ctrl.Call(m, "With", varargs...)
ret0, _ := ret[0].(Logger)
return ret0
}
-// WithDefaultMessageTemplate indicates an expected call of WithDefaultMessageTemplate.
-func (mr *MockLoggerMockRecorder) WithDefaultMessageTemplate(message any) *gomock.Call {
+// With indicates an expected call of With.
+func (mr *MockLoggerMockRecorder) With(fields ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithDefaultMessageTemplate", reflect.TypeOf((*MockLogger)(nil).WithDefaultMessageTemplate), message)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "With", reflect.TypeOf((*MockLogger)(nil).With), fields...)
}
-// WithFields mocks base method.
-func (m *MockLogger) WithFields(fields ...any) Logger {
+// WithGroup mocks base method.
+func (m *MockLogger) WithGroup(name string) Logger {
m.ctrl.T.Helper()
- varargs := []any{}
- for _, a := range fields {
- varargs = append(varargs, a)
- }
- ret := m.ctrl.Call(m, "WithFields", varargs...)
+ ret := m.ctrl.Call(m, "WithGroup", name)
ret0, _ := ret[0].(Logger)
return ret0
}
-// WithFields indicates an expected call of WithFields.
-func (mr *MockLoggerMockRecorder) WithFields(fields ...any) *gomock.Call {
+// WithGroup indicates an expected call of WithGroup.
+func (mr *MockLoggerMockRecorder) WithGroup(name any) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithFields", reflect.TypeOf((*MockLogger)(nil).WithFields), fields...)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithGroup", reflect.TypeOf((*MockLogger)(nil).WithGroup), name)
}
diff --git a/commons/log/log_test.go b/commons/log/log_test.go
index 969c53cd..04299b51 100644
--- a/commons/log/log_test.go
+++ b/commons/log/log_test.go
@@ -1,619 +1,761 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package log
import (
"bytes"
- "log"
+ "context"
+ "errors"
+ stdlog "log"
+ "strings"
+ "sync"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
)
+var stdLoggerOutputMu sync.Mutex
+
+func withTestLoggerOutput(t *testing.T, output *bytes.Buffer) {
+ t.Helper()
+
+ stdLoggerOutputMu.Lock()
+ defer t.Cleanup(func() {
+ stdLoggerOutputMu.Unlock()
+ })
+
+ originalOutput := stdlog.Writer()
+ stdlog.SetOutput(output)
+ t.Cleanup(func() { stdlog.SetOutput(originalOutput) })
+}
+
func TestParseLevel(t *testing.T) {
tests := []struct {
- name string
- input string
- expected LogLevel
- expectError bool
+ in string
+ expected Level
+ err bool
}{
- {
- name: "parse fatal level",
- input: "fatal",
- expected: FatalLevel,
- expectError: false,
- },
- {
- name: "parse error level",
- input: "error",
- expected: ErrorLevel,
- expectError: false,
- },
- {
- name: "parse warn level",
- input: "warn",
- expected: WarnLevel,
- expectError: false,
- },
- {
- name: "parse warning level",
- input: "warning",
- expected: WarnLevel,
- expectError: false,
- },
- {
- name: "parse info level",
- input: "info",
- expected: InfoLevel,
- expectError: false,
- },
- {
- name: "parse debug level",
- input: "debug",
- expected: DebugLevel,
- expectError: false,
- },
- {
- name: "parse uppercase level",
- input: "INFO",
- expected: InfoLevel,
- expectError: false,
- },
- {
- name: "parse mixed case level",
- input: "WaRn",
- expected: WarnLevel,
- expectError: false,
- },
- {
- name: "parse invalid level",
- input: "invalid",
- expected: LogLevel(0),
- expectError: true,
- },
- {
- name: "parse empty string",
- input: "",
- expected: LogLevel(0),
- expectError: true,
- },
- {
- name: "parse panic level - not supported",
- input: "panic",
- expected: LogLevel(0),
- expectError: true,
- },
+ {in: "error", expected: LevelError},
+ {in: "warn", expected: LevelWarn},
+ {in: "warning", expected: LevelWarn},
+ {in: "info", expected: LevelInfo},
+ {in: "debug", expected: LevelDebug},
+ {in: "panic", err: true},
+ {in: "fatal", err: true},
+ {in: "INVALID", err: true},
}
for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- level, err := ParseLevel(tt.input)
-
- if tt.expectError {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- assert.Equal(t, tt.expected, level)
- }
- })
+ level, err := ParseLevel(tt.in)
+ if tt.err {
+ assert.Error(t, err)
+ continue
+ }
+
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, level)
}
}
-func TestGoLogger_IsLevelEnabled(t *testing.T) {
- tests := []struct {
- name string
- loggerLevel LogLevel
- checkLevel LogLevel
- expected bool
- }{
- {
- name: "debug logger - check debug",
- loggerLevel: DebugLevel,
- checkLevel: DebugLevel,
- expected: true,
- },
- {
- name: "debug logger - check info",
- loggerLevel: DebugLevel,
- checkLevel: InfoLevel,
- expected: true,
- },
- {
- name: "info logger - check debug",
- loggerLevel: InfoLevel,
- checkLevel: DebugLevel,
- expected: false,
- },
- {
- name: "info logger - check info",
- loggerLevel: InfoLevel,
- checkLevel: InfoLevel,
- expected: true,
- },
- {
- name: "error logger - check warn",
- loggerLevel: ErrorLevel,
- checkLevel: WarnLevel,
- expected: false,
- },
- {
- name: "error logger - check error",
- loggerLevel: ErrorLevel,
- checkLevel: ErrorLevel,
- expected: true,
- },
+func TestGoLogger_Enabled(t *testing.T) {
+ logger := &GoLogger{Level: LevelInfo}
+ assert.True(t, logger.Enabled(LevelError))
+ assert.True(t, logger.Enabled(LevelInfo))
+ assert.False(t, logger.Enabled(LevelDebug))
+}
+
+func TestGoLogger_LogWithFieldsAndGroup(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := (&GoLogger{Level: LevelDebug}).
+ WithGroup("http").
+ With(String("request_id", "r-1"))
+
+ logger.Log(context.Background(), LevelInfo, "request finished", Int("status", 200))
+
+ out := buf.String()
+ assert.Contains(t, out, "[info]")
+ assert.Contains(t, out, "group=http")
+ assert.Contains(t, out, "request_id=r-1")
+ assert.Contains(t, out, "status=200")
+ assert.Contains(t, out, "request finished")
+}
+
+func TestGoLogger_WithIsImmutable(t *testing.T) {
+ base := &GoLogger{Level: LevelDebug}
+ withField := base.With(String("k", "v"))
+
+ assert.NotEqual(t, base, withField)
+ assert.Empty(t, base.fields)
+
+ goLogger, ok := withField.(*GoLogger)
+ require.True(t, ok, "expected *GoLogger from With()")
+ assert.Len(t, goLogger.fields, 1)
+}
+
+func TestNopLogger(t *testing.T) {
+ nop := NewNop()
+ assert.NotPanics(t, func() {
+ nop.Log(context.Background(), LevelInfo, "hello")
+ _ = nop.With(String("k", "v"))
+ _ = nop.WithGroup("x")
+ _ = nop.Sync(context.Background())
+ })
+ assert.False(t, nop.Enabled(LevelError))
+}
+
+func TestLevelLegacyNamesRejected(t *testing.T) {
+ _, panicErr := ParseLevel("panic")
+ _, fatalErr := ParseLevel("fatal")
+ assert.Error(t, panicErr)
+ assert.Error(t, fatalErr)
+}
+
+// TestNoLegacyLevelSymbolsInAPI verifies that ParseLevel rejects legacy level
+// names ("panic", "fatal") that were removed in the v2 API migration.
+// This is a behavior-based assertion — it proves the API contract rather than
+// scanning source text.
+func TestNoLegacyLevelSymbolsInAPI(t *testing.T) {
+ legacyNames := []string{"panic", "fatal", "PANIC", "FATAL", "Panic", "Fatal"}
+ for _, name := range legacyNames {
+ level, err := ParseLevel(name)
+ assert.Error(t, err, "ParseLevel(%q) should reject legacy level name", name)
+ assert.Equal(t, LevelUnknown, level,
+ "ParseLevel(%q) should return LevelUnknown for rejected names", name)
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- logger := &GoLogger{Level: tt.loggerLevel}
- result := logger.IsLevelEnabled(tt.checkLevel)
- assert.Equal(t, tt.expected, result)
- })
+ // Confirm no level constant stringifies to legacy names
+ for _, lvl := range []Level{LevelError, LevelWarn, LevelInfo, LevelDebug} {
+ s := lvl.String()
+ assert.NotEqual(t, "panic", s, "no Level constant should stringify to 'panic'")
+ assert.NotEqual(t, "fatal", s, "no Level constant should stringify to 'fatal'")
}
}
-func TestGoLogger_Info(t *testing.T) {
- var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer()) // Reset to default
-
+// ===========================================================================
+// CWE-117: Log Injection Prevention Tests
+//
+// CWE-117 (Improper Output Neutralization for Logs) attacks rely on injecting
+// newlines or control characters into log messages to forge log entries, corrupt
+// log parsing, or hide malicious activity. In a financial services platform,
+// log integrity is critical for audit trails and regulatory compliance.
+// ===========================================================================
+
+// TestCWE117_MessageNewlineInjection verifies that newline characters embedded
+// in log messages are escaped, preventing an attacker from forging additional
+// log entries. This is the canonical CWE-117 attack vector.
+func TestCWE117_MessageNewlineInjection(t *testing.T) {
tests := []struct {
- name string
- loggerLevel LogLevel
- message string
- expectLogged bool
+ name string
+ input string
}{
{
- name: "info level - log info",
- loggerLevel: InfoLevel,
- message: "test info message",
- expectLogged: true,
+ name: "LF newline injection",
+ input: "legitimate message\n[info] forged log entry",
},
{
- name: "warn level - log info",
- loggerLevel: WarnLevel,
- message: "test info message",
- expectLogged: false,
+ name: "CR injection",
+ input: "legitimate message\r[info] forged log entry",
},
{
- name: "debug level - log info",
- loggerLevel: DebugLevel,
- message: "test info message",
- expectLogged: true,
+ name: "CRLF injection",
+ input: "legitimate message\r\n[info] forged log entry",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- buf.Reset()
- logger := &GoLogger{Level: tt.loggerLevel}
-
- logger.Info(tt.message)
-
- output := buf.String()
- if tt.expectLogged {
- assert.Contains(t, output, tt.message)
- } else {
- assert.Empty(t, output)
- }
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ logger.Log(context.Background(), LevelInfo, tt.input)
+
+ out := buf.String()
+
+ // The output must be a single line (the stdlib logger adds one trailing newline).
+ // Count the actual newlines -- there should be exactly 1 (the trailing one from log.Print).
+ newlineCount := strings.Count(out, "\n")
+ assert.Equal(t, 1, newlineCount,
+ "CWE-117: log output must be a single line, got %d newlines in: %q", newlineCount, out)
+
+ // The forged entry should NOT appear as if it were a real log line
+ assert.NotContains(t, out, "\n[info] forged")
})
}
}
-func TestGoLogger_Infof(t *testing.T) {
+// TestCWE117_FieldValueInjection verifies that field values containing newlines
+// are sanitized. An attacker might inject malicious data via user-controlled
+// field values (e.g., request headers, user IDs).
+func TestCWE117_FieldValueInjection(t *testing.T) {
var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ // Simulate a user-controlled value injected through a field
+ maliciousValue := "normal_user\n[error] ADMIN ACCESS GRANTED user=admin action=delete_all"
+ logger.Log(context.Background(), LevelInfo, "user login", String("user_id", maliciousValue))
- logger := &GoLogger{Level: InfoLevel}
-
- buf.Reset()
- logger.Infof("test %s message %d", "formatted", 123)
-
- output := buf.String()
- assert.Contains(t, output, "test formatted message 123")
+ out := buf.String()
+ newlineCount := strings.Count(out, "\n")
+ assert.Equal(t, 1, newlineCount,
+ "CWE-117: field injection must not create extra log lines, got: %q", out)
}
-func TestGoLogger_Infoln(t *testing.T) {
+// TestCWE117_FieldKeyInjection verifies that field keys with injection
+// characters are sanitized.
+func TestCWE117_FieldKeyInjection(t *testing.T) {
var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
+ withTestLoggerOutput(t, &buf)
- logger := &GoLogger{Level: InfoLevel}
-
- buf.Reset()
- logger.Infoln("test", "info", "line")
-
- output := buf.String()
- assert.Contains(t, output, "test info line")
+ logger := &GoLogger{Level: LevelDebug}
+ // Malicious field key containing newline
+ logger.Log(context.Background(), LevelInfo, "event",
+ String("key\ninjected_key", "value"))
+
+ out := buf.String()
+ newlineCount := strings.Count(out, "\n")
+ assert.Equal(t, 1, newlineCount,
+ "CWE-117: field key injection must not create extra log lines")
}
-func TestGoLogger_Error(t *testing.T) {
+// TestCWE117_GroupNameInjection verifies that group names with injection
+// characters are sanitized when creating logger hierarchies.
+func TestCWE117_GroupNameInjection(t *testing.T) {
var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
+ withTestLoggerOutput(t, &buf)
- tests := []struct {
- name string
- loggerLevel LogLevel
- message string
- expectLogged bool
- }{
- {
- name: "error level - log error",
- loggerLevel: ErrorLevel,
- message: "test error message",
- expectLogged: true,
- },
- {
- name: "fatal level - log error",
- loggerLevel: FatalLevel,
- message: "test error message",
- expectLogged: false,
- },
- {
- name: "debug level - log error",
- loggerLevel: DebugLevel,
- message: "test error message",
- expectLogged: true,
- },
- }
+ logger := (&GoLogger{Level: LevelDebug}).
+ WithGroup("safe\n[error] forged entry")
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- buf.Reset()
- logger := &GoLogger{Level: tt.loggerLevel}
-
- logger.Error(tt.message)
-
- output := buf.String()
- if tt.expectLogged {
- assert.Contains(t, output, tt.message)
- } else {
- assert.Empty(t, output)
- }
- })
+ logger.Log(context.Background(), LevelInfo, "test message")
+
+ out := buf.String()
+ newlineCount := strings.Count(out, "\n")
+ assert.Equal(t, 1, newlineCount,
+ "CWE-117: group name injection must not create extra log lines")
+}
+
+// TestCWE117_NullByteInjection verifies null bytes do not corrupt log output.
+// Null bytes can truncate strings in C-based log processors.
+func TestCWE117_NullByteInjection(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ logger.Log(context.Background(), LevelInfo, "before\x00after")
+
+ out := buf.String()
+ // The null byte should not appear literally in the output
+ assert.NotContains(t, out, "\x00",
+ "CWE-117: null bytes must not appear in log output")
+}
+
+// TestCWE117_ANSIEscapeSequences verifies that ANSI escape codes are handled.
+// Attackers can use ANSI escapes to hide log entries in terminal output or
+// manipulate log viewers that render ANSI colors.
+func TestCWE117_ANSIEscapeSequences(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ // \x1b[31m sets red text, \x1b[0m resets -- attacker could hide text
+ logger.Log(context.Background(), LevelInfo, "normal \x1b[31mRED ALERT\x1b[0m normal")
+
+ out := buf.String()
+ // At minimum, the output should be a single line
+ newlineCount := strings.Count(out, "\n")
+ assert.Equal(t, 1, newlineCount,
+ "ANSI escapes must not break single-line log output")
+ // Verify the message content is present (even if ANSI codes pass through,
+ // the important thing is no line splitting occurs)
+ assert.Contains(t, out, "normal")
+}
+
+// TestCWE117_TabInjection verifies tab characters are escaped.
+// Tab injection can misalign columnar log formats.
+func TestCWE117_TabInjection(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ logger.Log(context.Background(), LevelInfo, "field1\tfield2\tfield3")
+
+ out := buf.String()
+ // Tabs should be escaped to literal \t
+ assert.NotContains(t, out, "\t",
+ "tab characters should be escaped in log output")
+ assert.Contains(t, out, `\t`)
+}
+
+// TestCWE117_MultipleVectorsSimultaneously tests a message that combines
+// multiple injection techniques at once.
+func TestCWE117_MultipleVectorsSimultaneously(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ // Combine multiple attack vectors: newlines, tabs, CR, null bytes
+ attack := "msg\n[error] fake\r[warn] also fake\ttab\x00null"
+ logger.Log(context.Background(), LevelInfo, attack,
+ String("user\nfake_key", "val\nfake_val"))
+
+ out := buf.String()
+ newlineCount := strings.Count(out, "\n")
+ assert.Equal(t, 1, newlineCount,
+ "CWE-117: combined attack must not create multiple log lines")
+}
+
+// TestCWE117_VeryLongMessageDoesNotCrash ensures that extremely long messages
+// with embedded control characters are handled without panicking or truncating
+// in unexpected ways.
+func TestCWE117_VeryLongMessageDoesNotCrash(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+
+ // 100KB message with injection attempts every 1000 chars
+ var sb strings.Builder
+ for i := 0; i < 100; i++ {
+ sb.WriteString(strings.Repeat("A", 1000))
+ sb.WriteString("\n[error] forged entry ")
}
+
+ longMsg := sb.String()
+
+ assert.NotPanics(t, func() {
+ logger.Log(context.Background(), LevelInfo, longMsg)
+ })
+
+ out := buf.String()
+ newlineCount := strings.Count(out, "\n")
+ assert.Equal(t, 1, newlineCount,
+ "CWE-117: very long message with injections must remain single-line")
}
-func TestGoLogger_Warn(t *testing.T) {
+// ===========================================================================
+// GoLogger Behavioral Tests
+// ===========================================================================
+
+// TestGoLogger_OutputFormat verifies the overall format of log output.
+func TestGoLogger_OutputFormat(t *testing.T) {
var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
+ withTestLoggerOutput(t, &buf)
+ logger := &GoLogger{Level: LevelDebug}
+ logger.Log(context.Background(), LevelError, "something broke", String("code", "500"))
+
+ out := buf.String()
+ assert.Contains(t, out, "[error]")
+ assert.Contains(t, out, "code=500")
+ assert.Contains(t, out, "something broke")
+}
+
+// TestGoLogger_LevelFiltering verifies that messages below the configured
+// level are suppressed.
+func TestGoLogger_LevelFiltering(t *testing.T) {
tests := []struct {
- name string
- loggerLevel LogLevel
- message string
- expectLogged bool
+ name string
+ loggerLvl Level
+ msgLvl Level
+ shouldEmit bool
}{
- {
- name: "warn level - log warn",
- loggerLevel: WarnLevel,
- message: "test warn message",
- expectLogged: true,
- },
- {
- name: "error level - log warn",
- loggerLevel: ErrorLevel,
- message: "test warn message",
- expectLogged: false,
- },
- {
- name: "info level - log warn",
- loggerLevel: InfoLevel,
- message: "test warn message",
- expectLogged: true,
- },
+ {"error logger emits error", LevelError, LevelError, true},
+ {"error logger suppresses warn", LevelError, LevelWarn, false},
+ {"error logger suppresses info", LevelError, LevelInfo, false},
+ {"error logger suppresses debug", LevelError, LevelDebug, false},
+ {"warn logger emits error", LevelWarn, LevelError, true},
+ {"warn logger emits warn", LevelWarn, LevelWarn, true},
+ {"warn logger suppresses info", LevelWarn, LevelInfo, false},
+ {"info logger emits info", LevelInfo, LevelInfo, true},
+ {"info logger emits error", LevelInfo, LevelError, true},
+ {"debug logger emits everything", LevelDebug, LevelDebug, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- buf.Reset()
- logger := &GoLogger{Level: tt.loggerLevel}
-
- logger.Warn(tt.message)
-
- output := buf.String()
- if tt.expectLogged {
- assert.Contains(t, output, tt.message)
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: tt.loggerLvl}
+ logger.Log(context.Background(), tt.msgLvl, "test message")
+
+ if tt.shouldEmit {
+ assert.NotEmpty(t, buf.String(), "expected message to be emitted")
} else {
- assert.Empty(t, output)
+ assert.Empty(t, buf.String(), "expected message to be suppressed")
}
})
}
}
-func TestGoLogger_Debug(t *testing.T) {
+// TestGoLogger_WithPreservesFields verifies that With() attaches fields
+// that appear in subsequent log output.
+func TestGoLogger_WithPreservesFields(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := (&GoLogger{Level: LevelDebug}).
+ With(String("service", "payments"), Int("version", 2))
+
+ logger.Log(context.Background(), LevelInfo, "started")
+
+ out := buf.String()
+ assert.Contains(t, out, "service=payments")
+ assert.Contains(t, out, "version=2")
+}
+
+// TestGoLogger_WithGroupNesting verifies nested group naming.
+func TestGoLogger_WithGroupNesting(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := (&GoLogger{Level: LevelDebug}).
+ WithGroup("http").
+ WithGroup("middleware")
+
+ logger.Log(context.Background(), LevelInfo, "applied")
+
+ out := buf.String()
+ assert.Contains(t, out, "group=http.middleware")
+}
+
+// TestGoLogger_WithGroupEmptyNameIgnored verifies that empty group names
+// are silently ignored.
+func TestGoLogger_WithGroupEmptyNameIgnored(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := (&GoLogger{Level: LevelDebug}).
+ WithGroup("").
+ WithGroup(" ")
+
+ logger.Log(context.Background(), LevelInfo, "test")
+
+ out := buf.String()
+ assert.NotContains(t, out, "group=")
+}
+
+// TestGoLogger_SyncReturnsNil verifies Sync is a no-op for stdlib logger.
+func TestGoLogger_SyncReturnsNil(t *testing.T) {
+ logger := &GoLogger{Level: LevelInfo}
+ assert.NoError(t, logger.Sync(context.Background()))
+}
+
+// TestGoLogger_NilReceiverSafety ensures nil GoLogger does not panic.
+func TestGoLogger_NilReceiverSafety(t *testing.T) {
+ var logger *GoLogger
+
+ assert.False(t, logger.Enabled(LevelError))
+
+ assert.NotPanics(t, func() {
+ child := logger.With(String("k", "v"))
+ require.NotNil(t, child)
+ })
+
+ assert.NotPanics(t, func() {
+ child := logger.WithGroup("grp")
+ require.NotNil(t, child)
+ })
+}
+
+// TestGoLogger_EmptyFieldKeySkipped verifies fields with empty keys are dropped.
+func TestGoLogger_EmptyFieldKeySkipped(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ logger.Log(context.Background(), LevelInfo, "msg", String("", "should_be_dropped"))
+
+ out := buf.String()
+ assert.NotContains(t, out, "should_be_dropped")
+}
+
+// TestGoLogger_BoolAndErrFields verifies Bool and Err field constructors.
+func TestGoLogger_BoolAndErrFields(t *testing.T) {
var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ logger.Log(context.Background(), LevelInfo, "event",
+ Bool("active", true),
+ Err(assert.AnError))
+
+ out := buf.String()
+ assert.Contains(t, out, "active=true")
+ assert.Contains(t, out, "error=")
+}
+// TestGoLogger_AnyFieldConstructor verifies the Any field constructor.
+func TestGoLogger_AnyFieldConstructor(t *testing.T) {
+ f := Any("data", map[string]int{"count": 42})
+ assert.Equal(t, "data", f.Key)
+ assert.NotNil(t, f.Value)
+}
+
+// TestGoLogger_SensitiveFieldRedaction verifies that fields whose keys match
+// sensitive field patterns are redacted in log output.
+func TestGoLogger_SensitiveFieldRedaction(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ logger.Log(context.Background(), LevelInfo, "login attempt",
+ String("password", "super_secret"),
+ String("api_key", "key-12345"),
+ String("user_id", "u-42"),
+ )
+
+ out := buf.String()
+ assert.NotContains(t, out, "super_secret", "password value must be redacted")
+ assert.NotContains(t, out, "key-12345", "api_key value must be redacted")
+ assert.Contains(t, out, "[REDACTED]", "redacted fields must show [REDACTED]")
+ assert.Contains(t, out, "user_id=u-42", "non-sensitive fields must pass through")
+}
+
+// TestGoLogger_WithGroupEmptyReturnsReceiver verifies that empty group name
+// returns the same logger without allocation.
+func TestGoLogger_WithGroupEmptyReturnsReceiver(t *testing.T) {
+ logger := &GoLogger{Level: LevelDebug}
+ same := logger.WithGroup("")
+ // Should be the exact same pointer.
+ assert.Same(t, logger, same, "WithGroup(\"\") should return the same logger")
+}
+
+// TestParseLevel_WhitespaceTrimming verifies whitespace is trimmed.
+func TestParseLevel_WhitespaceTrimming(t *testing.T) {
tests := []struct {
- name string
- loggerLevel LogLevel
- message string
- expectLogged bool
+ input string
+ expected Level
}{
- {
- name: "debug level - log debug",
- loggerLevel: DebugLevel,
- message: "test debug message",
- expectLogged: true,
- },
- {
- name: "info level - log debug",
- loggerLevel: InfoLevel,
- message: "test debug message",
- expectLogged: false,
- },
+ {" debug ", LevelDebug},
+ {"\tinfo\n", LevelInfo},
+ {" warn ", LevelWarn},
+ {"\nerror\t", LevelError},
}
for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- buf.Reset()
- logger := &GoLogger{Level: tt.loggerLevel}
-
- logger.Debug(tt.message)
-
- output := buf.String()
- if tt.expectLogged {
- assert.Contains(t, output, tt.message)
- } else {
- assert.Empty(t, output)
+ level, err := ParseLevel(tt.input)
+ require.NoError(t, err, "ParseLevel(%q) should not error", tt.input)
+ assert.Equal(t, tt.expected, level)
+ }
+}
+
+// ===========================================================================
+// NopLogger Comprehensive Tests
+// ===========================================================================
+
+// TestNopLogger_AllMethodsAreNoOps verifies every method on NopLogger
+// completes without panicking and returns sensible zero values.
+func TestNopLogger_AllMethodsAreNoOps(t *testing.T) {
+ nop := NewNop()
+
+ t.Run("Log does not panic at any level", func(t *testing.T) {
+ assert.NotPanics(t, func() {
+ for _, level := range []Level{LevelError, LevelWarn, LevelInfo, LevelDebug} {
+ nop.Log(context.Background(), level, "message",
+ String("k", "v"), Int("n", 1), Bool("b", true))
}
})
- }
+ })
+
+ t.Run("With returns self", func(t *testing.T) {
+ child := nop.With(String("a", "b"), String("c", "d"))
+ // NopLogger.With returns itself
+ assert.Equal(t, nop, child)
+ })
+
+ t.Run("WithGroup returns self", func(t *testing.T) {
+ child := nop.WithGroup("any_group")
+ assert.Equal(t, nop, child)
+ })
+
+ t.Run("Enabled always false", func(t *testing.T) {
+ for _, level := range []Level{LevelError, LevelWarn, LevelInfo, LevelDebug} {
+ assert.False(t, nop.Enabled(level))
+ }
+ })
+
+ t.Run("Sync returns nil", func(t *testing.T) {
+ assert.NoError(t, nop.Sync(context.Background()))
+ })
}
-func TestGoLogger_WithFields(t *testing.T) {
- var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
-
- logger := &GoLogger{Level: InfoLevel}
-
- // Test with fields - Note: current implementation doesn't actually use fields
- buf.Reset()
- loggerWithFields := logger.WithFields("key1", "value1", "key2", 123)
- loggerWithFields.Info("test message")
-
- output := buf.String()
- assert.Contains(t, output, "test message")
- // Current implementation doesn't include fields in output
- // These assertions would fail with current implementation
- // assert.Contains(t, output, "key1")
- // assert.Contains(t, output, "value1")
- // assert.Contains(t, output, "key2")
- // assert.Contains(t, output, "123")
-
- // Verify original logger is not modified
- buf.Reset()
- logger.Info("original logger")
- output = buf.String()
- assert.Contains(t, output, "original logger")
-
- // Verify WithFields returns a new logger instance
- assert.NotEqual(t, logger, loggerWithFields)
-}
-
-func TestGoLogger_WithDefaultMessageTemplate(t *testing.T) {
- var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
+// TestNopLogger_InterfaceCompliance verifies NopLogger satisfies Logger.
+func TestNopLogger_InterfaceCompliance(t *testing.T) {
+ var _ Logger = NewNop()
+ var _ Logger = &NopLogger{}
+}
- logger := &GoLogger{Level: InfoLevel}
+// ===========================================================================
+// MockLogger Verification Tests
+// ===========================================================================
- // Test with default message template - should preserve Level
- buf.Reset()
- loggerWithTemplate := logger.WithDefaultMessageTemplate("Template: ")
- loggerWithTemplate.Info("test message")
+// TestMockLogger_RecordsCalls verifies the mock correctly records method
+// invocations for test assertions.
+func TestMockLogger_RecordsCalls(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ mock := NewMockLogger(ctrl)
- output := buf.String()
- // WithDefaultMessageTemplate preserves Level, so it should log
- assert.Contains(t, output, "test message")
+ ctx := context.Background()
- // Verify original logger is not modified (immutability)
- buf.Reset()
- logger.Info("original message")
- output = buf.String()
- assert.Contains(t, output, "original message")
+ // Set up expectations
+ mock.EXPECT().Enabled(LevelInfo).Return(true)
+ mock.EXPECT().Log(ctx, LevelInfo, "hello", String("k", "v"))
+ mock.EXPECT().Sync(ctx).Return(nil)
+
+ // Exercise
+ assert.True(t, mock.Enabled(LevelInfo))
+ mock.Log(ctx, LevelInfo, "hello", String("k", "v"))
+ assert.NoError(t, mock.Sync(ctx))
}
-func TestGoLogger_Sync(t *testing.T) {
- logger := &GoLogger{Level: InfoLevel}
- err := logger.Sync()
- assert.NoError(t, err)
+// TestMockLogger_WithAndWithGroup verifies With/WithGroup on the mock.
+func TestMockLogger_WithAndWithGroup(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ mock := NewMockLogger(ctrl)
+
+ childMock := NewMockLogger(ctrl)
+
+ mock.EXPECT().With(String("tenant", "t1")).Return(childMock)
+ mock.EXPECT().WithGroup("audit").Return(childMock)
+
+ child1 := mock.With(String("tenant", "t1"))
+ assert.Equal(t, childMock, child1)
+
+ child2 := mock.WithGroup("audit")
+ assert.Equal(t, childMock, child2)
}
-func TestGoLogger_FormattedMethods(t *testing.T) {
- var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
-
- logger := &GoLogger{Level: DebugLevel}
-
- // Test Errorf
- buf.Reset()
- logger.Errorf("error: %s %d", "test", 42)
- assert.Contains(t, buf.String(), "error: test 42")
-
- // Test Warnf
- buf.Reset()
- logger.Warnf("warning: %s %d", "test", 42)
- assert.Contains(t, buf.String(), "warning: test 42")
-
- // Test Debugf
- buf.Reset()
- logger.Debugf("debug: %s %d", "test", 42)
- assert.Contains(t, buf.String(), "debug: test 42")
-}
-
-func TestGoLogger_LineMethods(t *testing.T) {
- var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
-
- logger := &GoLogger{Level: DebugLevel}
-
- // Test Errorln
- buf.Reset()
- logger.Errorln("error", "line", "test")
- assert.Contains(t, buf.String(), "error line test")
-
- // Test Warnln
- buf.Reset()
- logger.Warnln("warn", "line", "test")
- assert.Contains(t, buf.String(), "warn line test")
-
- // Test Debugln
- buf.Reset()
- logger.Debugln("debug", "line", "test")
- assert.Contains(t, buf.String(), "debug line test")
-}
-
-func TestNoneLogger(t *testing.T) {
- // NoneLogger should not panic and should return itself for chaining methods
- logger := &NoneLogger{}
-
- // Test all methods don't panic
- assert.NotPanics(t, func() {
- logger.Info("test")
- logger.Infof("test %s", "format")
- logger.Infoln("test", "line")
-
- logger.Error("test")
- logger.Errorf("test %s", "format")
- logger.Errorln("test", "line")
-
- logger.Warn("test")
- logger.Warnf("test %s", "format")
- logger.Warnln("test", "line")
-
- logger.Debug("test")
- logger.Debugf("test %s", "format")
- logger.Debugln("test", "line")
-
- logger.Fatal("test")
- logger.Fatalf("test %s", "format")
- logger.Fatalln("test", "line")
- })
-
- // Test WithFields returns itself
- result := logger.WithFields("key", "value")
- assert.Equal(t, logger, result)
-
- // Test WithDefaultMessageTemplate returns itself
- result = logger.WithDefaultMessageTemplate("template")
- assert.Equal(t, logger, result)
-
- // Test Sync returns nil
- err := logger.Sync()
- assert.NoError(t, err)
-}
-
-func TestGoLogger_ComplexScenarios(t *testing.T) {
- var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
-
- // Test chaining methods
- logger := &GoLogger{Level: InfoLevel}
-
- // Note: Current implementation has issues with chaining
- // WithDefaultMessageTemplate doesn't preserve Level
- buf.Reset()
- // Create a logger that will actually work
- loggerWithFields := logger.WithFields("request_id", "123", "user_id", "456")
- // Since WithDefaultMessageTemplate doesn't preserve level, we can't chain it
- loggerWithFields.Info("API: request processed")
-
- output := buf.String()
- // Current implementation doesn't use fields or template
- assert.Contains(t, output, "API: request processed")
- // These would fail with current implementation
- // assert.Contains(t, output, "request_id")
- // assert.Contains(t, output, "123")
- // assert.Contains(t, output, "user_id")
- // assert.Contains(t, output, "456")
-
- // Test multiple arguments
- buf.Reset()
- logger.Info("multiple", "arguments", 123, true, 45.67)
- output = buf.String()
- assert.Contains(t, output, "multiple")
- assert.Contains(t, output, "arguments")
- assert.Contains(t, output, "123")
- assert.Contains(t, output, "true")
- assert.Contains(t, output, "45.67")
-}
-
-func TestLogLevel_String(t *testing.T) {
- // Test that log levels have proper string representations
+// TestMockLogger_InterfaceCompliance verifies MockLogger satisfies Logger.
+func TestMockLogger_InterfaceCompliance(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ var _ Logger = NewMockLogger(ctrl)
+}
+
+// ===========================================================================
+// Level String Tests
+// ===========================================================================
+
+// TestLevel_String verifies all level string representations.
+func TestLevel_String(t *testing.T) {
tests := []struct {
- level LogLevel
+ level Level
expected string
}{
- {FatalLevel, "fatal"},
- {ErrorLevel, "error"},
- {WarnLevel, "warn"},
- {InfoLevel, "info"},
- {DebugLevel, "debug"},
+ {LevelError, "error"},
+ {LevelWarn, "warn"},
+ {LevelInfo, "info"},
+ {LevelDebug, "debug"},
+ {Level(255), "unknown"},
}
for _, tt := range tests {
- t.Run(tt.expected, func(t *testing.T) {
- // Parse the string and verify we get the same level back
- parsed, err := ParseLevel(tt.expected)
- assert.NoError(t, err)
- assert.Equal(t, tt.level, parsed)
- })
+ assert.Equal(t, tt.expected, tt.level.String())
}
}
-// TestGoLogger_FatalMethods tests fatal methods without actually calling log.Fatal
-// Since Fatal methods call log.Fatal which exits the program, we can't test them directly
-// We just ensure they exist and are callable
-func TestGoLogger_FatalMethods(t *testing.T) {
- logger := &GoLogger{Level: FatalLevel}
-
- // Just verify the methods exist and are callable
- // We can't actually call them because they would exit the test
- assert.NotNil(t, logger.Fatal)
- assert.NotNil(t, logger.Fatalf)
- assert.NotNil(t, logger.Fatalln)
+// ===========================================================================
+// renderFields Tests
+// ===========================================================================
+
+// TestRenderFields_EmptyReturnsEmpty verifies that no fields produce empty output.
+func TestRenderFields_EmptyReturnsEmpty(t *testing.T) {
+ assert.Equal(t, "", renderFields(nil))
+ assert.Equal(t, "", renderFields([]Field{}))
}
-func TestGoLogger_EdgeCases(t *testing.T) {
- var buf bytes.Buffer
- log.SetOutput(&buf)
- defer log.SetOutput(log.Writer())
-
- logger := &GoLogger{Level: InfoLevel}
-
- // Test with nil arguments
- buf.Reset()
- logger.Info(nil)
- assert.Contains(t, buf.String(), "")
-
- // Test with empty string
- buf.Reset()
- logger.Info("")
- // Empty string still produces output with timestamp
- assert.NotEmpty(t, buf.String())
-
- // Test with special characters
- buf.Reset()
- logger.Info("special chars: \n\t\r")
- output := buf.String()
- assert.Contains(t, output, "special chars:")
-
- // Test format with wrong number of arguments
- buf.Reset()
- logger.Infof("format %s", "only one arg")
- output = buf.String()
- assert.Contains(t, output, "format only one arg")
+// TestRenderFields_SingleField verifies single field rendering.
+func TestRenderFields_SingleField(t *testing.T) {
+ result := renderFields([]Field{String("status", "ok")})
+ assert.Equal(t, "[status=ok]", result)
+}
+
+// TestRenderFields_MultipleFields verifies multiple field rendering.
+func TestRenderFields_MultipleFields(t *testing.T) {
+ result := renderFields([]Field{
+ String("a", "1"),
+ Int("b", 2),
+ Bool("c", true),
+ })
+ assert.Contains(t, result, "a=1")
+ assert.Contains(t, result, "b=2")
+ assert.Contains(t, result, "c=true")
+}
+
+// TestRenderFields_EmptyKeyFieldSkipped verifies empty-key fields are dropped.
+func TestRenderFields_EmptyKeyFieldSkipped(t *testing.T) {
+ result := renderFields([]Field{String("", "val")})
+ assert.Equal(t, "", result)
+}
+
+// TestRenderFields_SanitizesKeysAndValues verifies CWE-117 in field rendering.
+func TestRenderFields_SanitizesKeysAndValues(t *testing.T) {
+ result := renderFields([]Field{
+ String("status\ninjection", "value\ninjection"),
+ })
+ assert.NotContains(t, result, "\n")
+ assert.Contains(t, result, `\n`)
+}
+
+// ===========================================================================
+// sanitizeFieldValue Tests
+// ===========================================================================
+
+// testStringer is a small helper that implements fmt.Stringer for testing.
+type testStringer struct{ s string }
+
+func (ts testStringer) String() string { return ts.s }
+
+// TestSanitizeFieldValue verifies that sanitizeFieldValue handles string,
+// error, and fmt.Stringer types, sanitizing control characters in each case.
+func TestSanitizeFieldValue(t *testing.T) {
+ tests := []struct {
+ name string
+ input any
+ expected any
+ }{
+ {
+ name: "plain string passthrough",
+ input: "hello",
+ expected: "hello",
+ },
+ {
+ name: "string with newline is sanitized",
+ input: "line1\nline2",
+ expected: `line1\nline2`,
+ },
+ {
+ name: "error with newline is sanitized",
+ input: errors.New("bad\ninput"),
+ expected: `bad\ninput`,
+ },
+ {
+ name: "fmt.Stringer with newline is sanitized",
+ input: testStringer{s: "hello\nworld"},
+ expected: `hello\nworld`,
+ },
+ {
+ name: "integer passes through unchanged",
+ input: 42,
+ expected: 42,
+ },
+ {
+ name: "nil passes through unchanged",
+ input: nil,
+ expected: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := sanitizeFieldValue(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
}
diff --git a/commons/log/nil.go b/commons/log/nil.go
index 763f1c7c..c6ab551f 100644
--- a/commons/log/nil.go
+++ b/commons/log/nil.go
@@ -1,72 +1,36 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package log
-// NoneLogger is a wrapper for log nothing.
-type NoneLogger struct{}
-
-// Info implements Info Logger interface function.
-func (l *NoneLogger) Info(args ...any) {}
-
-// Infof implements Infof Logger interface function.
-func (l *NoneLogger) Infof(format string, args ...any) {}
-
-// Infoln implements Infoln Logger interface function.
-func (l *NoneLogger) Infoln(args ...any) {}
-
-// Error implements Error Logger interface function.
-func (l *NoneLogger) Error(args ...any) {}
-
-// Errorf implements Errorf Logger interface function.
-func (l *NoneLogger) Errorf(format string, args ...any) {}
-
-// Errorln implements Errorln Logger interface function.
-func (l *NoneLogger) Errorln(args ...any) {}
-
-// Warn implements Warn Logger interface function.
-func (l *NoneLogger) Warn(args ...any) {}
-
-// Warnf implements Warnf Logger interface function.
-func (l *NoneLogger) Warnf(format string, args ...any) {}
+import "context"
-// Warnln implements Warnln Logger interface function.
-func (l *NoneLogger) Warnln(args ...any) {}
+// NopLogger is a no-op logger implementation.
+type NopLogger struct{}
-// Debug implements Debug Logger interface function.
-func (l *NoneLogger) Debug(args ...any) {}
-
-// Debugf implements Debugf Logger interface function.
-func (l *NoneLogger) Debugf(format string, args ...any) {}
-
-// Debugln implements Debugln Logger interface function.
-func (l *NoneLogger) Debugln(args ...any) {}
-
-// Fatal implements Fatal Logger interface function.
-func (l *NoneLogger) Fatal(args ...any) {}
-
-// Fatalf implements Fatalf Logger interface function.
-func (l *NoneLogger) Fatalf(format string, args ...any) {}
+// NewNop creates a no-op logger implementation.
+func NewNop() Logger {
+ return &NopLogger{}
+}
-// Fatalln implements Fatalln Logger interface function.
-func (l *NoneLogger) Fatalln(args ...any) {}
+// Log drops all log events.
+func (l *NopLogger) Log(_ context.Context, _ Level, _ string, _ ...Field) {}
-// WithFields implements WithFields Logger interface function
+// With returns the same no-op logger.
//
//nolint:ireturn
-func (l *NoneLogger) WithFields(fields ...any) Logger {
+func (l *NopLogger) With(_ ...Field) Logger {
return l
}
-// WithDefaultMessageTemplate sets the default message template for the logger.
+// WithGroup returns the same no-op logger.
//
//nolint:ireturn
-func (l *NoneLogger) WithDefaultMessageTemplate(message string) Logger {
+func (l *NopLogger) WithGroup(_ string) Logger {
return l
}
-// Sync implements Sync Logger interface function.
-//
-//nolint:ireturn
-func (l *NoneLogger) Sync() error { return nil }
+// Enabled always returns false for NopLogger.
+func (l *NopLogger) Enabled(_ Level) bool {
+ return false
+}
+
+// Sync is a no-op and always returns nil.
+func (l *NopLogger) Sync(_ context.Context) error { return nil }
diff --git a/commons/log/sanitizer.go b/commons/log/sanitizer.go
new file mode 100644
index 00000000..0c70250f
--- /dev/null
+++ b/commons/log/sanitizer.go
@@ -0,0 +1,41 @@
+package log
+
+import (
+ "context"
+ "fmt"
+)
+
+// SafeError logs errors with explicit production-aware sanitization.
+// When production is true, only the error type is logged (no message details).
+//
+// Design rationale: the production boolean is caller-supplied rather than
+// derived from a global flag. This keeps the log package free of global state
+// and lets the caller (typically a service boundary) decide the sanitization
+// policy based on its own configuration. Callers in production deployments
+// should pass true to prevent leaking sensitive error details into log output.
+func SafeError(logger Logger, ctx context.Context, msg string, err error, production bool) {
+ if logger == nil {
+ return
+ }
+
+ if err == nil {
+ return
+ }
+
+ if !logger.Enabled(LevelError) {
+ return
+ }
+
+ if production {
+ logger.Log(ctx, LevelError, msg, String("error_type", fmt.Sprintf("%T", err)))
+ return
+ }
+
+ logger.Log(ctx, LevelError, msg, Err(err))
+}
+
+// SanitizeExternalResponse removes potentially sensitive external response data.
+// Returns only status code for error messages.
+func SanitizeExternalResponse(statusCode int) string {
+ return fmt.Sprintf("external system returned status %d", statusCode)
+}
diff --git a/commons/log/sanitizer_test.go b/commons/log/sanitizer_test.go
new file mode 100644
index 00000000..30c1e787
--- /dev/null
+++ b/commons/log/sanitizer_test.go
@@ -0,0 +1,307 @@
+//go:build unit
+
+package log
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSafeError_ProductionAndNonProduction(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ logger := &GoLogger{Level: LevelDebug}
+ err := errors.New("credential_id=abc123")
+
+ SafeError(logger, context.Background(), "request failed", err, false)
+ assert.Contains(t, buf.String(), "request failed")
+ assert.Contains(t, buf.String(), "credential_id=abc123")
+
+ buf.Reset()
+ SafeError(logger, context.Background(), "request failed", err, true)
+ out := buf.String()
+ assert.Contains(t, out, "request failed")
+ assert.Contains(t, out, "error_type=*errors.errorString")
+ assert.NotContains(t, out, "credential_id=abc123")
+}
+
+func TestSafeError_NilGuards(t *testing.T) {
+ t.Run("nil logger produces no output", func(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ assert.NotPanics(t, func() {
+ SafeError(nil, context.Background(), "msg", assert.AnError, true)
+ })
+ assert.Empty(t, buf.String(), "nil logger must produce no output")
+ })
+
+ t.Run("nil error produces no output", func(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ assert.NotPanics(t, func() {
+ SafeError(&GoLogger{Level: LevelInfo}, context.Background(), "msg", nil, true)
+ })
+ assert.Empty(t, buf.String(), "nil error must produce no output")
+ })
+}
+
+func TestSanitizeExternalResponse(t *testing.T) {
+ assert.Equal(t, "external system returned status 400", SanitizeExternalResponse(400))
+}
+
+// ---------------------------------------------------------------------------
+// CWE-117: Comprehensive sanitizeLogString test matrix
+// ---------------------------------------------------------------------------
+
+func TestSanitizeLogString_ControlCharacterMatrix(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ assertFn func(t *testing.T, result string)
+ }{
+ // --- Newline variants (MUST be neutralized for CWE-117) ---
+ {
+ name: "LF newline is escaped",
+ input: "line1\nline2",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.NotContains(t, result, "\n")
+ assert.Contains(t, result, `\n`)
+ },
+ },
+ {
+ name: "CR carriage return is escaped",
+ input: "line1\rline2",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.NotContains(t, result, "\r")
+ assert.Contains(t, result, `\r`)
+ },
+ },
+ {
+ name: "CRLF is escaped",
+ input: "line1\r\nline2",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.NotContains(t, result, "\r")
+ assert.NotContains(t, result, "\n")
+ assert.Contains(t, result, `\r`)
+ assert.Contains(t, result, `\n`)
+ },
+ },
+ {
+ name: "tab character is escaped",
+ input: "field1\tfield2",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.NotContains(t, result, "\t")
+ assert.Contains(t, result, `\t`)
+ },
+ },
+
+ // --- Null bytes ---
+ {
+ name: "null byte is removed or escaped",
+ input: "before\x00after",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ // The sanitizer should at minimum not pass through raw null bytes.
+ // Depending on implementation it may remove or escape them.
+ assert.NotContains(t, result, "\x00")
+ },
+ },
+
+ // --- Normal strings (pass-through) ---
+ {
+ name: "normal ASCII passes through unchanged",
+ input: "hello world 123 !@#$%",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.Equal(t, "hello world 123 !@#$%", result)
+ },
+ },
+ {
+ name: "empty string passes through",
+ input: "",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.Equal(t, "", result)
+ },
+ },
+ {
+ name: "legitimate Unicode text passes through",
+ input: "Hello, \u4e16\u754c! Ol\u00e1! \u00dcber!",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ // Normal Unicode should be preserved
+ assert.Contains(t, result, "\u4e16\u754c")
+ assert.Contains(t, result, "Ol\u00e1")
+ },
+ },
+
+ // --- Multiple embedded control chars ---
+ {
+ name: "multiple newlines in single string",
+ input: "line1\nline2\nline3\nline4",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.NotContains(t, result, "\n")
+ // All 3 newlines should be escaped
+ assert.Equal(t, 3, strings.Count(result, `\n`))
+ },
+ },
+ {
+ name: "mixed control characters",
+ input: "start\nmiddle\rend\ttab",
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.NotContains(t, result, "\n")
+ assert.NotContains(t, result, "\r")
+ assert.NotContains(t, result, "\t")
+ },
+ },
+
+ // --- Very long strings ---
+ {
+ name: "very long string with embedded control chars",
+ input: strings.Repeat("a", 5000) + "\n" + strings.Repeat("b", 5000),
+ assertFn: func(t *testing.T, result string) {
+ t.Helper()
+ assert.NotContains(t, result, "\n")
+ assert.Contains(t, result, `\n`)
+ // Verify content integrity: the 'a's and 'b's should still be there
+ assert.True(t, len(result) > 10000)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := sanitizeLogString(tt.input)
+ tt.assertFn(t, result)
+ })
+ }
+}
+
+// TestSanitizeFieldValue_TypeDispatch verifies the sanitizeFieldValue function
+// correctly handles both string and non-string values.
+func TestSanitizeFieldValue_TypeDispatch(t *testing.T) {
+ t.Run("string values are sanitized", func(t *testing.T) {
+ result := sanitizeFieldValue("value\ninjected")
+ s, ok := result.(string)
+ require.True(t, ok)
+ assert.NotContains(t, s, "\n")
+ assert.Contains(t, s, `\n`)
+ })
+
+ t.Run("integer values pass through", func(t *testing.T) {
+ result := sanitizeFieldValue(42)
+ assert.Equal(t, 42, result)
+ })
+
+ t.Run("boolean values pass through", func(t *testing.T) {
+ result := sanitizeFieldValue(true)
+ assert.Equal(t, true, result)
+ })
+
+ t.Run("nil values pass through", func(t *testing.T) {
+ result := sanitizeFieldValue(nil)
+ assert.Nil(t, result)
+ })
+
+ t.Run("error values are sanitized", func(t *testing.T) {
+ err := errors.New("some error\nwith newline")
+ result := sanitizeFieldValue(err)
+ s, ok := result.(string)
+ require.True(t, ok, "error values should be converted to sanitized strings")
+ assert.NotContains(t, s, "\n")
+ assert.Contains(t, s, `\n`)
+ assert.Equal(t, `some error\nwith newline`, s)
+ })
+
+ t.Run("struct values with newlines are sanitized", func(t *testing.T) {
+ type payload struct {
+ Msg string
+ }
+ result := sanitizeFieldValue(payload{Msg: "line1\nline2"})
+ s, ok := result.(string)
+ require.True(t, ok, "composite types should be serialized to sanitized strings")
+ assert.NotContains(t, s, "\n")
+ })
+
+ t.Run("slice values with newlines are sanitized", func(t *testing.T) {
+ result := sanitizeFieldValue([]string{"a\nb", "c"})
+ s, ok := result.(string)
+ require.True(t, ok, "slices should be serialized to sanitized strings")
+ assert.NotContains(t, s, "\n")
+ })
+
+ t.Run("map values with newlines are sanitized", func(t *testing.T) {
+ result := sanitizeFieldValue(map[string]string{"k": "v\ninjected"})
+ s, ok := result.(string)
+ require.True(t, ok, "maps should be serialized to sanitized strings")
+ assert.NotContains(t, s, "\n")
+ })
+
+ t.Run("typed-nil error returns placeholder", func(t *testing.T) {
+ var err *customError // typed nil
+ result := sanitizeFieldValue(err)
+ assert.Equal(t, "", result,
+ "typed-nil error should return placeholder, not panic")
+ })
+
+ t.Run("typed-nil stringer returns placeholder", func(t *testing.T) {
+ var s *testStringer // typed nil
+ result := sanitizeFieldValue(s)
+ assert.Equal(t, "", result,
+ "typed-nil Stringer should return placeholder, not panic")
+ })
+}
+
+// customError is a typed error for testing typed-nil behavior.
+type customError struct{ msg string }
+
+func (e *customError) Error() string { return e.msg }
+
+// TestSafeError_LevelFiltering verifies SafeError respects level gating.
+func TestSafeError_LevelFiltering(t *testing.T) {
+ var buf bytes.Buffer
+ withTestLoggerOutput(t, &buf)
+
+ // Logger at LevelWarn should NOT emit LevelError if LevelWarn < LevelError numerically.
+ // But in this codebase, LevelError=0 < LevelWarn=1, so LevelWarn logger should emit errors.
+ logger := &GoLogger{Level: LevelWarn}
+
+ SafeError(logger, context.Background(), "should appear", errors.New("err"), false)
+ assert.Contains(t, buf.String(), "should appear")
+}
+
+// TestSanitizeExternalResponse_VariousCodes verifies status code formatting.
+func TestSanitizeExternalResponse_VariousCodes(t *testing.T) {
+ tests := []struct {
+ code int
+ expected string
+ }{
+ {200, "external system returned status 200"},
+ {400, "external system returned status 400"},
+ {401, "external system returned status 401"},
+ {403, "external system returned status 403"},
+ {404, "external system returned status 404"},
+ {500, "external system returned status 500"},
+ {502, "external system returned status 502"},
+ {503, "external system returned status 503"},
+ }
+
+ for _, tt := range tests {
+ assert.Equal(t, tt.expected, SanitizeExternalResponse(tt.code))
+ }
+}
diff --git a/commons/mongo/connection_string.go b/commons/mongo/connection_string.go
index 68ca0722..f8d60c9d 100644
--- a/commons/mongo/connection_string.go
+++ b/commons/mongo/connection_string.go
@@ -1,90 +1,172 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package mongo
import (
- "fmt"
+ "errors"
"net/url"
+ "strconv"
"strings"
+)
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+var (
+ // ErrInvalidScheme is returned when URI scheme is not mongodb or mongodb+srv.
+ ErrInvalidScheme = errors.New("invalid mongo uri scheme")
+ // ErrEmptyHost is returned when URI host is empty.
+ ErrEmptyHost = errors.New("mongo uri host cannot be empty")
+ // ErrInvalidPort is returned when URI port is outside the valid TCP range.
+ ErrInvalidPort = errors.New("mongo uri port is invalid")
+ // ErrPortNotAllowedForSRV is returned when a port is provided for mongodb+srv.
+ ErrPortNotAllowedForSRV = errors.New("port cannot be set for mongodb+srv")
+ // ErrPasswordWithoutUser is returned when password is set without username.
+ ErrPasswordWithoutUser = errors.New("password requires username")
)
-// BuildConnectionString constructs a properly formatted MongoDB connection string.
-//
-// Features:
-// - URL-encodes credentials (handles special characters like @, :, /)
-// - Omits port for mongodb+srv URIs (SRV discovery doesn't use ports)
-// - Handles empty credentials gracefully (connects without auth)
-// - Optionally logs masked connection string for debugging
-//
-// Parameters:
-// - scheme: "mongodb" or "mongodb+srv"
-// - user: username (will be URL-encoded)
-// - password: password (will be URL-encoded)
-// - host: MongoDB host address
-// - port: port number (ignored for mongodb+srv)
-// - parameters: query parameters (e.g., "replicaSet=rs0&authSource=admin")
-// - logger: optional logger for debug output (credentials masked)
-//
-// Returns the complete connection string ready for use with MongoDB drivers.
-func BuildConnectionString(scheme, user, password, host, port, parameters string, logger log.Logger) string {
- var connectionString string
-
- credentialsPart := buildCredentialsPart(user, password)
- hostPart := buildHostPart(scheme, host, port)
-
- if credentialsPart != "" {
- connectionString = fmt.Sprintf("%s://%s@%s/", scheme, credentialsPart, hostPart)
- } else {
- connectionString = fmt.Sprintf("%s://%s/", scheme, hostPart)
- }
-
- if parameters != "" {
- connectionString += "?" + parameters
- }
-
- if logger != nil {
- logMaskedConnectionString(logger, scheme, hostPart, parameters, credentialsPart != "")
- }
-
- return connectionString
+// URIConfig contains the components used to build a MongoDB URI.
+type URIConfig struct {
+ Scheme string
+ Username string
+ Password string // #nosec G117 -- builder struct for one-time URI construction; password encoded via url.UserPassword()
+ Host string
+ Port string
+ Database string
+ Query url.Values
}
-func buildCredentialsPart(user, password string) string {
- if user == "" {
- return ""
+// BuildURI validates the structural fields of URIConfig (scheme, host, port,
+// credential presence) and assembles a MongoDB connection URI. It does NOT
+// perform DNS resolution, full RFC 3986 host validation, or MongoDB
+// connstring-level validation — those checks are deferred to the driver's
+// connstring.Parse when the URI is actually used to connect.
+func BuildURI(cfg URIConfig) (string, error) {
+ scheme := strings.TrimSpace(cfg.Scheme)
+ host := strings.TrimSpace(cfg.Host)
+ port := strings.TrimSpace(cfg.Port)
+ database := strings.TrimSpace(cfg.Database)
+ username := strings.TrimSpace(cfg.Username)
+
+ if err := validateBuildURIInput(scheme, host, port, username, cfg.Password); err != nil {
+ return "", err
}
- return url.UserPassword(user, password).String()
+ // Default authSource to "admin" when a database is specified in the path
+ // but authSource is not explicitly set. Without this, the MongoDB driver
+ // uses the path database as authSource, which breaks backward compatibility
+ // for deployments where the user was created in the "admin" database
+ // (the common default). Callers that need a different authSource can set
+ // it explicitly in cfg.Query.
+ query := cfg.Query
+ if database != "" && username != "" {
+ if query == nil || !query.Has("authSource") {
+ if query == nil {
+ query = url.Values{}
+ }
+
+ query.Set("authSource", "admin")
+ }
+ }
+
+ uri := buildURL(scheme, host, port, username, cfg.Password, database, query)
+
+ return uri.String(), nil
}
-func buildHostPart(scheme, host, port string) string {
- if strings.HasPrefix(scheme, "mongodb+srv") {
- return host
+func validateBuildURIInput(scheme, host, port, username, password string) error {
+ if err := validateScheme(scheme); err != nil {
+ return err
+ }
+
+ if host == "" {
+ return ErrEmptyHost
+ }
+
+ if username == "" && password != "" {
+ return ErrPasswordWithoutUser
}
- if port != "" {
- return fmt.Sprintf("%s:%s", host, port)
+ if scheme == "mongodb+srv" && port != "" {
+ return ErrPortNotAllowedForSRV
}
- return host
+ if scheme == "mongodb" {
+ if err := validateMongoPort(port); err != nil {
+ return err
+ }
+ }
+
+ return nil
}
-func logMaskedConnectionString(logger log.Logger, scheme, hostPart, parameters string, hasCredentials bool) {
- var maskedConnStr string
+func validateScheme(scheme string) error {
+ if scheme != "mongodb" && scheme != "mongodb+srv" {
+ return ErrInvalidScheme
+ }
+
+ return nil
+}
- if hasCredentials {
- maskedConnStr = fmt.Sprintf("%s://@%s/", scheme, hostPart)
- } else {
- maskedConnStr = fmt.Sprintf("%s://%s/", scheme, hostPart)
+func validateMongoPort(port string) error {
+ if port == "" {
+ return nil
}
- if parameters != "" {
- maskedConnStr += "?" + parameters
+ parsedPort, err := strconv.Atoi(port)
+ if err != nil || parsedPort < 1 || parsedPort > 65535 {
+ return ErrInvalidPort
+ }
+
+ return nil
+}
+
+func buildURL(scheme, host, port, username, password, database string, query url.Values) *url.URL {
+ uri := &url.URL{Scheme: scheme}
+ uri.Host = buildHost(host, port)
+ uri.User = buildUser(username, password)
+ uri.Path = buildPath(database)
+
+ if len(query) > 0 {
+ uri.RawQuery = query.Encode()
+ }
+
+ return uri
+}
+
+// buildHost concatenates host and port. IPv6 addresses are bracketed per
+// RFC 3986 to avoid ambiguity with the port separator. The caller is
+// responsible for validating that host contains only legitimate hostname
+// characters. The mongo driver validates the full URI downstream via
+// connstring.Parse.
+func buildHost(host, port string) string {
+ // Detect raw IPv6 literal: must contain at least two colons to distinguish
+ // from a simple "host:port" pair. Already-bracketed addresses are left untouched.
+ if strings.Count(host, ":") >= 2 && !strings.HasPrefix(host, "[") {
+ host = "[" + host + "]"
+ }
+
+ if port == "" {
+ return host
+ }
+
+ return host + ":" + port
+}
+
+func buildUser(username, password string) *url.Userinfo {
+ if username == "" {
+ return nil
+ }
+
+ // When password is empty, use url.User to produce "username@" instead of
+ // "username:@". The trailing colon is technically valid per RFC 3986 but
+ // can confuse some drivers and implies an empty password was intentionally set.
+ if password == "" {
+ return url.User(username)
+ }
+
+ return url.UserPassword(username, password)
+}
+
+func buildPath(database string) string {
+ if database == "" {
+ return "/"
}
- logger.Debugf("MongoDB connection string built: %s", maskedConnStr)
+ return "/" + url.PathEscape(database)
}
diff --git a/commons/mongo/connection_string_example_test.go b/commons/mongo/connection_string_example_test.go
new file mode 100644
index 00000000..1dfadefd
--- /dev/null
+++ b/commons/mongo/connection_string_example_test.go
@@ -0,0 +1,32 @@
+//go:build unit
+
+package mongo_test
+
+import (
+ "fmt"
+ "net/url"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/mongo"
+)
+
+func ExampleBuildURI() {
+ query := url.Values{}
+ query.Set("replicaSet", "rs0")
+
+ uri, err := mongo.BuildURI(mongo.URIConfig{
+ Scheme: "mongodb",
+ Username: "app",
+ Password: "EXAMPLE_DO_NOT_USE",
+ Host: "db.internal",
+ Port: "27017",
+ Database: "ledger",
+ Query: query,
+ })
+
+ fmt.Println(err == nil)
+ fmt.Println(uri)
+
+ // Output:
+ // true
+ // mongodb://app:EXAMPLE_DO_NOT_USE@db.internal:27017/ledger?authSource=admin&replicaSet=rs0
+}
diff --git a/commons/mongo/connection_string_test.go b/commons/mongo/connection_string_test.go
index 84213bf3..c252b13d 100644
--- a/commons/mongo/connection_string_test.go
+++ b/commons/mongo/connection_string_test.go
@@ -1,196 +1,246 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package mongo
import (
- "fmt"
- "strings"
+ "net/url"
"testing"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
-func TestBuildConnectionString(t *testing.T) {
+func TestBuildURI_SuccessCases(t *testing.T) {
t.Parallel()
- tests := []struct {
- name string
- scheme string
- user string
- password string
- host string
- port string
- parameters string
- expected string
- }{
- {
- name: "basic_connection_no_parameters",
- scheme: "mongodb",
- user: "admin",
- password: "secret123",
- host: "localhost",
- port: "27017",
- parameters: "",
- expected: "mongodb://admin:secret123@localhost:27017/",
- },
- {
- name: "connection_with_single_parameter",
- scheme: "mongodb",
- user: "admin",
- password: "secret123",
- host: "localhost",
- port: "27017",
- parameters: "authSource=admin",
- expected: "mongodb://admin:secret123@localhost:27017/?authSource=admin",
- },
- {
- name: "connection_with_multiple_parameters",
- scheme: "mongodb",
- user: "admin",
- password: "secret123",
- host: "mongo.example.com",
- port: "5703",
- parameters: "replicaSet=rs0&authSource=admin&directConnection=true",
- expected: "mongodb://admin:secret123@mongo.example.com:5703/?replicaSet=rs0&authSource=admin&directConnection=true",
- },
- {
- name: "mongodb_srv_scheme_omits_port",
- scheme: "mongodb+srv",
- user: "user",
- password: "pass",
- host: "cluster.mongodb.net",
- port: "27017",
- parameters: "retryWrites=true&w=majority",
- expected: "mongodb+srv://user:pass@cluster.mongodb.net/?retryWrites=true&w=majority",
- },
- {
- name: "mongodb_srv_without_parameters",
- scheme: "mongodb+srv",
- user: "user",
- password: "pass",
- host: "cluster.mongodb.net",
- port: "",
- parameters: "",
- expected: "mongodb+srv://user:pass@cluster.mongodb.net/",
- },
- {
- name: "special_characters_in_password_url_encoded",
- scheme: "mongodb",
- user: "admin",
- password: "p@ss:word/123",
- host: "localhost",
- port: "27017",
- parameters: "",
- expected: "mongodb://admin:p%40ss%3Aword%2F123@localhost:27017/",
- },
- {
- name: "special_characters_in_username_url_encoded",
- scheme: "mongodb",
- user: "user@domain",
- password: "pass",
- host: "localhost",
- port: "27017",
- parameters: "",
- expected: "mongodb://user%40domain:pass@localhost:27017/",
- },
- {
- name: "empty_credentials",
- scheme: "mongodb",
- user: "",
- password: "",
- host: "localhost",
- port: "27017",
- parameters: "",
- expected: "mongodb://localhost:27017/",
- },
- {
- name: "empty_user_with_password",
- scheme: "mongodb",
- user: "",
- password: "secret",
- host: "localhost",
- port: "27017",
- parameters: "",
- expected: "mongodb://localhost:27017/",
- },
- {
- name: "user_without_password",
- scheme: "mongodb",
- user: "admin",
- password: "",
- host: "localhost",
- port: "27017",
- parameters: "",
- expected: "mongodb://admin:@localhost:27017/",
- },
- {
- name: "empty_parameters_no_question_mark",
- scheme: "mongodb",
- user: "user",
- password: "pass",
- host: "db.local",
- port: "27017",
- parameters: "",
- expected: "mongodb://user:pass@db.local:27017/",
- },
- }
-
- for _, tt := range tests {
- tt := tt
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
-
- result := BuildConnectionString(tt.scheme, tt.user, tt.password, tt.host, tt.port, tt.parameters, nil)
- assert.Equal(t, tt.expected, result)
+ t.Run("mongodb with auth, port, database and query", func(t *testing.T) {
+ t.Parallel()
+
+ query := url.Values{}
+ query.Set("authSource", "admin")
+ query.Set("replicaSet", "rs0")
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb",
+ Username: "dbuser",
+ Password: "p@ss:word/123",
+ Host: "localhost",
+ Port: "27017",
+ Database: "ledger",
+ Query: query,
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://dbuser:p%40ss%3Aword%2F123@localhost:27017/ledger?authSource=admin&replicaSet=rs0", uri)
+ })
+
+ t.Run("mongodb+srv omits port", func(t *testing.T) {
+ t.Parallel()
+
+ query := url.Values{}
+ query.Set("retryWrites", "true")
+ query.Set("w", "majority")
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb+srv",
+ Username: "user",
+ Password: "secret",
+ Host: "cluster.mongodb.net",
+ Database: "banking",
+ Query: query,
+ })
+ require.NoError(t, err)
+ assert.Contains(t, uri, "mongodb+srv://user:secret@cluster.mongodb.net/banking?")
+ assert.Contains(t, uri, "authSource=admin")
+ assert.Contains(t, uri, "retryWrites=true")
+ assert.Contains(t, uri, "w=majority")
+ })
+
+ t.Run("without credentials defaults to root path", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb",
+ Host: "127.0.0.1",
+ Port: "27017",
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://127.0.0.1:27017/", uri)
+ })
+
+ t.Run("username without password", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb",
+ Username: "readonly",
+ Host: "localhost",
+ Port: "27017",
+ })
+ require.NoError(t, err)
+ // Uses url.User (not url.UserPassword) so no trailing colon before @.
+ assert.Equal(t, "mongodb://readonly@localhost:27017/", uri)
+ })
+
+ t.Run("default authSource=admin when database and credentials set", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb",
+ Username: "appuser",
+ Password: "secret",
+ Host: "localhost",
+ Port: "27017",
+ Database: "midaz",
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://appuser:secret@localhost:27017/midaz?authSource=admin", uri)
+ })
+
+ t.Run("explicit authSource not overridden", func(t *testing.T) {
+ t.Parallel()
+
+ query := url.Values{}
+ query.Set("authSource", "myauthdb")
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb",
+ Username: "appuser",
+ Password: "secret",
+ Host: "localhost",
+ Port: "27017",
+ Database: "midaz",
+ Query: query,
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://appuser:secret@localhost:27017/midaz?authSource=myauthdb", uri)
+ })
+
+ t.Run("no authSource added without database in path", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb",
+ Username: "appuser",
+ Password: "secret",
+ Host: "localhost",
+ Port: "27017",
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://appuser:secret@localhost:27017/", uri)
+ assert.NotContains(t, uri, "authSource")
+ })
+
+ t.Run("no authSource added without credentials", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{
+ Scheme: "mongodb",
+ Host: "localhost",
+ Port: "27017",
+ Database: "midaz",
})
- }
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://localhost:27017/midaz", uri)
+ assert.NotContains(t, uri, "authSource")
+ })
}
-func TestBuildConnectionString_LoggerMasksCredentials(t *testing.T) {
+func TestBuildURI_Validation(t *testing.T) {
t.Parallel()
- logger := &testLogger{}
+ t.Run("invalid scheme", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{Scheme: "postgres", Host: "localhost"})
+ assert.Empty(t, uri)
+ assert.ErrorIs(t, err, ErrInvalidScheme)
+ })
+
+ t.Run("empty host", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{Scheme: "mongodb", Host: " "})
+ assert.Empty(t, uri)
+ assert.ErrorIs(t, err, ErrEmptyHost)
+ })
+
+ t.Run("invalid port", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Port: "70000"})
+ assert.Empty(t, uri)
+ assert.ErrorIs(t, err, ErrInvalidPort)
+ })
+
+ t.Run("srv port is forbidden", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{Scheme: "mongodb+srv", Host: "cluster.mongodb.net", Port: "27017"})
+ assert.Empty(t, uri)
+ assert.ErrorIs(t, err, ErrPortNotAllowedForSRV)
+ })
- _ = BuildConnectionString("mongodb", "dbuser", "supersecret", "localhost", "27017", "authSource=admin", logger)
+ t.Run("password without username", func(t *testing.T) {
+ t.Parallel()
- assert.Len(t, logger.debugs, 1, "expected exactly one debug log")
- assert.True(t, strings.Contains(logger.debugs[0], ""), "expected credentials to be masked")
- assert.False(t, strings.Contains(logger.debugs[0], "dbuser"), "username should not appear in logs")
- assert.False(t, strings.Contains(logger.debugs[0], "supersecret"), "password should not appear in logs")
+ uri, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Password: "secret"})
+ assert.Empty(t, uri)
+ assert.ErrorIs(t, err, ErrPasswordWithoutUser)
+ })
+
+ t.Run("whitespace_only_username_treated_as_empty", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Username: " ", Password: "secret"})
+ assert.Empty(t, uri)
+ assert.ErrorIs(t, err, ErrPasswordWithoutUser)
+ })
}
-func TestBuildConnectionString_NilLoggerDoesNotPanic(t *testing.T) {
+func TestBuildURI_PortBoundaries(t *testing.T) {
t.Parallel()
- assert.NotPanics(t, func() {
- _ = BuildConnectionString("mongodb", "user", "pass", "localhost", "27017", "", nil)
+ t.Run("port_zero_is_invalid", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Port: "0"})
+ assert.ErrorIs(t, err, ErrInvalidPort)
})
-}
-type testLogger struct {
- debugs []string
-}
+ t.Run("port_one_is_valid", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Port: "1"})
+ require.NoError(t, err)
+ assert.Contains(t, uri, ":1/")
+ })
+
+ t.Run("port_65535_is_valid", func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Port: "65535"})
+ require.NoError(t, err)
+ assert.Contains(t, uri, ":65535/")
+ })
-func (l *testLogger) Debug(args ...any) {}
-func (l *testLogger) Debugf(format string, args ...any) {
- l.debugs = append(l.debugs, fmt.Sprintf(format, args...))
+ t.Run("port_65536_is_invalid", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Port: "65536"})
+ assert.ErrorIs(t, err, ErrInvalidPort)
+ })
+
+ t.Run("non_numeric_port", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Port: "abc"})
+ assert.ErrorIs(t, err, ErrInvalidPort)
+ })
+
+ t.Run("negative_port", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := BuildURI(URIConfig{Scheme: "mongodb", Host: "localhost", Port: "-1"})
+ assert.ErrorIs(t, err, ErrInvalidPort)
+ })
}
-func (l *testLogger) Debugln(args ...any) {}
-func (l *testLogger) Info(args ...any) {}
-func (l *testLogger) Infof(format string, args ...any) {}
-func (l *testLogger) Infoln(args ...any) {}
-func (l *testLogger) Warn(args ...any) {}
-func (l *testLogger) Warnf(format string, args ...any) {}
-func (l *testLogger) Warnln(args ...any) {}
-func (l *testLogger) Error(args ...any) {}
-func (l *testLogger) Errorf(format string, args ...any) {}
-func (l *testLogger) Errorln(args ...any) {}
-func (l *testLogger) Fatal(args ...any) {}
-func (l *testLogger) Fatalf(format string, args ...any) {}
-func (l *testLogger) Fatalln(args ...any) {}
-func (l *testLogger) WithFields(fields ...any) log.Logger { return l }
-func (l *testLogger) WithDefaultMessageTemplate(msg string) log.Logger { return l }
-func (l *testLogger) Sync() error { return nil }
diff --git a/commons/mongo/doc.go b/commons/mongo/doc.go
new file mode 100644
index 00000000..a756d841
--- /dev/null
+++ b/commons/mongo/doc.go
@@ -0,0 +1,6 @@
+// Package mongo provides resilient MongoDB connection and index management helpers.
+//
+// The package wraps connection lifecycle concerns (connect, ping, close), offers
+// EnsureIndexes for idempotent index creation with structured error reporting,
+// and supports TLS configuration for encrypted connections.
+package mongo
diff --git a/commons/mongo/mongo.go b/commons/mongo/mongo.go
index d4516dcf..d10b66f2 100644
--- a/commons/mongo/mongo.go
+++ b/commons/mongo/mongo.go
@@ -1,115 +1,784 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package mongo
import (
"context"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "errors"
"fmt"
+ neturl "net/url"
+ "regexp"
+ "sort"
"strings"
+ "sync"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/LerianStudio/lib-commons/v4/commons/backoff"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+)
+
+const (
+ defaultServerSelectionTimeout = 5 * time.Second
+ defaultHeartbeatInterval = 10 * time.Second
+ maxMaxPoolSize = 1000
+)
+
+var (
+ // ErrNilContext is returned when a required context is nil.
+ ErrNilContext = errors.New("context cannot be nil")
+ // ErrNilClient is returned when a *Client receiver is nil.
+ ErrNilClient = errors.New("mongo client is nil")
+ // ErrClientClosed is returned when the client is not connected.
+ ErrClientClosed = errors.New("mongo client is closed")
+ // ErrNilDependency is returned when an Option sets a required dependency to nil.
+ ErrNilDependency = errors.New("mongo option set a required dependency to nil")
+ // ErrInvalidConfig indicates the provided configuration is invalid.
+ ErrInvalidConfig = errors.New("invalid mongo config")
+ // ErrEmptyURI is returned when Mongo URI is empty.
+ ErrEmptyURI = errors.New("mongo uri cannot be empty")
+ // ErrEmptyDatabaseName is returned when database name is empty.
+ ErrEmptyDatabaseName = errors.New("database name cannot be empty")
+ // ErrEmptyCollectionName is returned when collection name is empty.
+ ErrEmptyCollectionName = errors.New("collection name cannot be empty")
+ // ErrEmptyIndexes is returned when no index model is provided.
+ ErrEmptyIndexes = errors.New("at least one index must be provided")
+ // ErrConnect wraps connection establishment failures.
+ ErrConnect = errors.New("mongo connect failed")
+ // ErrPing wraps connectivity probe failures.
+ ErrPing = errors.New("mongo ping failed")
+ // ErrDisconnect wraps disconnection failures.
+ ErrDisconnect = errors.New("mongo disconnect failed")
+ // ErrCreateIndex wraps index creation failures.
+ ErrCreateIndex = errors.New("mongo create index failed")
+ // ErrNilMongoClient is returned when mongo driver returns a nil client.
+ ErrNilMongoClient = errors.New("mongo driver returned nil client")
)
-// MongoConnection is a hub which deal with mongodb connections.
-type MongoConnection struct {
- ConnectionStringSource string
- DB *mongo.Client
- Connected bool
+// nilClientAssert fires a telemetry assertion for nil-receiver calls and returns ErrNilClient.
+func nilClientAssert(operation string) error {
+ asserter := assert.New(context.Background(), nil, "mongo", operation)
+ _ = asserter.Never(context.Background(), "mongo client receiver is nil")
+
+ return ErrNilClient
+}
+
+// TLSConfig configures TLS validation for MongoDB connections.
+type TLSConfig struct {
+ CACertBase64 string
+ MinVersion uint16
+}
+
+// Config defines MongoDB connection and pool behavior.
+type Config struct {
+ URI string
Database string
- Logger log.Logger
MaxPoolSize uint64
+ ServerSelectionTimeout time.Duration
+ HeartbeatInterval time.Duration
+ TLS *TLSConfig
+ Logger log.Logger
+ MetricsFactory *metrics.MetricsFactory
}
-// Connect keeps a singleton connection with mongodb.
-func (mc *MongoConnection) Connect(ctx context.Context) error {
- mc.Logger.Info("Connecting to mongodb...")
+func (cfg Config) validate() error {
+ if strings.TrimSpace(cfg.URI) == "" {
+ return ErrEmptyURI
+ }
- clientOptions := options.
- Client().
- ApplyURI(mc.ConnectionStringSource).
- SetMaxPoolSize(mc.MaxPoolSize).
- SetServerSelectionTimeout(5 * time.Second).
- SetHeartbeatInterval(10 * time.Second)
+ if strings.TrimSpace(cfg.Database) == "" {
+ return ErrEmptyDatabaseName
+ }
- noSQLDB, err := mongo.Connect(ctx, clientOptions)
- if err != nil {
- mc.Logger.Errorf("failed to open connect to mongodb: %v", err)
- return fmt.Errorf("failed to connect to mongodb: %w", err)
+ return nil
+}
+
+// Option customizes internal client dependencies (primarily for tests).
+type Option func(*clientDeps)
+
+// connectBackoffCap is the maximum delay between lazy-connect retries.
+const connectBackoffCap = 30 * time.Second
+
+// connectionFailuresMetric defines the counter for mongo connection failures.
+var connectionFailuresMetric = metrics.Metric{
+ Name: "mongo_connection_failures_total",
+ Unit: "1",
+ Description: "Total number of mongo connection failures",
+}
+
+// Client wraps a MongoDB client with lifecycle and index helpers.
+type Client struct {
+ mu sync.RWMutex
+ client *mongo.Client
+ closed bool // terminal flag; set by Close(), prevents reconnection
+ databaseName string
+ cfg Config
+ metricsFactory *metrics.MetricsFactory
+ uri string // private copy for reconnection; cfg.URI cleared after connect
+ deps clientDeps
+
+ // Lazy-connect rate-limiting: prevents thundering-herd reconnect storms
+ // when the database is down by enforcing exponential backoff between attempts.
+ lastConnectAttempt time.Time
+ connectAttempts int
+}
+
+type clientDeps struct {
+ connect func(context.Context, *options.ClientOptions) (*mongo.Client, error)
+ ping func(context.Context, *mongo.Client) error
+ disconnect func(context.Context, *mongo.Client) error
+ createIndex func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error
+}
+
+func defaultDeps() clientDeps {
+ return clientDeps{
+ connect: func(ctx context.Context, clientOptions *options.ClientOptions) (*mongo.Client, error) {
+ return mongo.Connect(ctx, clientOptions)
+ },
+ ping: func(ctx context.Context, client *mongo.Client) error {
+ return client.Ping(ctx, nil)
+ },
+ disconnect: func(ctx context.Context, client *mongo.Client) error {
+ return client.Disconnect(ctx)
+ },
+ createIndex: func(ctx context.Context, client *mongo.Client, database, collection string, index mongo.IndexModel) error {
+ _, err := client.Database(database).Collection(collection).Indexes().CreateOne(ctx, index)
+
+ return err
+ },
+ }
+}
+
+// NewClient validates config, connects to MongoDB, and returns a ready client.
+func NewClient(ctx context.Context, cfg Config, opts ...Option) (*Client, error) {
+ if ctx == nil {
+ return nil, ErrNilContext
+ }
+
+ cfg = normalizeConfig(cfg)
+
+ if err := cfg.validate(); err != nil {
+ return nil, err
}
- if err := noSQLDB.Ping(ctx, nil); err != nil {
- mc.Logger.Errorf("MongoDBConnection.Ping failed: %v", err)
+ deps := defaultDeps()
+
+ for _, opt := range opts {
+ if opt == nil {
+ asserter := assert.New(ctx, cfg.Logger, "mongo", "NewClient")
+ _ = asserter.Never(ctx, "nil mongo option received; skipping")
- if disconnectErr := noSQLDB.Disconnect(ctx); disconnectErr != nil {
- mc.Logger.Errorf("failed to disconnect after ping failure: %v", disconnectErr)
+ continue
}
- return fmt.Errorf("failed to ping mongodb: %w", err)
+ opt(&deps)
}
- mc.Logger.Info("Connected to mongodb ✅ \n")
+ if deps.connect == nil || deps.ping == nil || deps.disconnect == nil || deps.createIndex == nil {
+ return nil, ErrNilDependency
+ }
- mc.Connected = true
+ client := &Client{
+ databaseName: cfg.Database,
+ cfg: cfg,
+ metricsFactory: cfg.MetricsFactory,
+ uri: cfg.URI,
+ deps: deps,
+ }
- mc.DB = noSQLDB
+ if err := client.Connect(ctx); err != nil {
+ return nil, err
+ }
+
+ return client, nil
+}
+
+// Connect establishes a MongoDB connection if one is not already open.
+func (c *Client) Connect(ctx context.Context) error {
+ if c == nil {
+ return nilClientAssert("connect")
+ }
+
+ if ctx == nil {
+ return ErrNilContext
+ }
+
+ tracer := otel.Tracer("mongo")
+
+ ctx, span := tracer.Start(ctx, "mongo.connect")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemMongoDB))
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.closed {
+ return ErrClientClosed
+ }
+
+ if c.client != nil {
+ return nil
+ }
+
+ if err := c.connectLocked(ctx); err != nil {
+ c.recordConnectionFailure("connect")
+
+ libOpentelemetry.HandleSpanError(span, "Failed to connect to mongo", err)
+
+ return err
+ }
return nil
}
-// GetDB returns a pointer to the mongodb connection, initializing it if necessary.
-func (mc *MongoConnection) GetDB(ctx context.Context) (*mongo.Client, error) {
- if mc.DB == nil {
- err := mc.Connect(ctx)
+// connectLocked performs the actual connection logic.
+// The caller MUST hold c.mu (write lock) before calling this method.
+func (c *Client) connectLocked(ctx context.Context) error {
+ clientOptions := options.Client().ApplyURI(c.uri)
+
+ serverSelectionTimeout := c.cfg.ServerSelectionTimeout
+ if serverSelectionTimeout <= 0 {
+ serverSelectionTimeout = defaultServerSelectionTimeout
+ }
+
+ heartbeatInterval := c.cfg.HeartbeatInterval
+ if heartbeatInterval <= 0 {
+ heartbeatInterval = defaultHeartbeatInterval
+ }
+
+ clientOptions.SetServerSelectionTimeout(serverSelectionTimeout)
+ clientOptions.SetHeartbeatInterval(heartbeatInterval)
+
+ if c.cfg.MaxPoolSize > 0 {
+ clientOptions.SetMaxPoolSize(c.cfg.MaxPoolSize)
+ }
+
+ if c.cfg.TLS != nil {
+ tlsCfg, err := buildTLSConfig(*c.cfg.TLS)
if err != nil {
- mc.Logger.Infof("ERRCONECT %s", err)
- return nil, err
+ return fmt.Errorf("%w: TLS configuration: %w", ErrInvalidConfig, err)
+ }
+
+ clientOptions.SetTLSConfig(tlsCfg)
+ }
+
+ mongoClient, err := c.deps.connect(ctx, clientOptions)
+ if err != nil {
+ sanitized := sanitizeDriverError(err)
+ c.log(ctx, "mongo connect failed", log.Err(sanitized))
+
+ return fmt.Errorf("%w: %w", ErrConnect, sanitized)
+ }
+
+ if mongoClient == nil {
+ return ErrNilMongoClient
+ }
+
+ if err := c.deps.ping(ctx, mongoClient); err != nil {
+ if disconnectErr := c.deps.disconnect(ctx, mongoClient); disconnectErr != nil {
+ c.log(ctx, "failed to disconnect after ping failure", log.Err(sanitizeDriverError(disconnectErr)))
+ }
+
+ sanitized := sanitizeDriverError(err)
+ c.log(ctx, "mongo ping failed", log.Err(sanitized))
+
+ return fmt.Errorf("%w: %w", ErrPing, sanitized)
+ }
+
+ c.client = mongoClient
+
+ if c.cfg.TLS == nil && !isTLSImplied(c.uri) {
+ c.logAtLevel(ctx, log.LevelWarn, "mongo connection established without TLS; "+
+ "consider configuring TLS for production use")
+ }
+
+ c.cfg.URI = ""
+
+ return nil
+}
+
+// Client returns the underlying mongo client if connected.
+//
+// Note: the returned *mongo.Client may become stale if Close is called
+// concurrently from another goroutine. Callers that need atomicity
+// across multiple operations should coordinate externally.
+func (c *Client) Client(ctx context.Context) (*mongo.Client, error) {
+ if c == nil {
+ return nil, nilClientAssert("client")
+ }
+
+ if ctx == nil {
+ return nil, ErrNilContext
+ }
+
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if c.client == nil {
+ return nil, ErrClientClosed
+ }
+
+ return c.client, nil
+}
+
+// ResolveClient returns a connected mongo client, reconnecting lazily if needed.
+// Unlike Client(), this method attempts to re-establish a dropped connection using
+// double-checked locking with backoff rate-limiting to prevent reconnect storms.
+func (c *Client) ResolveClient(ctx context.Context) (*mongo.Client, error) {
+ if c == nil {
+ return nil, nilClientAssert("resolve_client")
+ }
+
+ if ctx == nil {
+ return nil, ErrNilContext
+ }
+
+ // Fast path: already connected (read-lock only).
+ c.mu.RLock()
+ closed := c.closed
+ client := c.client
+ c.mu.RUnlock()
+
+ if closed {
+ return nil, ErrClientClosed
+ }
+
+ if client != nil {
+ return client, nil
+ }
+
+ // Slow path: acquire write lock and double-check before connecting.
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.closed {
+ return nil, ErrClientClosed
+ }
+
+ if c.client != nil {
+ return c.client, nil
+ }
+
+ // Rate-limit lazy-connect retries: if previous attempts failed recently,
+ // enforce a minimum delay before the next attempt to prevent reconnect storms.
+ if c.connectAttempts > 0 {
+ delay := min(backoff.ExponentialWithJitter(1*time.Second, c.connectAttempts), connectBackoffCap)
+
+ if elapsed := time.Since(c.lastConnectAttempt); elapsed < delay {
+ return nil, fmt.Errorf("mongo resolve_client: rate-limited (next attempt in %s)", delay-elapsed)
}
}
- return mc.DB, nil
+ c.lastConnectAttempt = time.Now()
+
+ tracer := otel.Tracer("mongo")
+
+ ctx, span := tracer.Start(ctx, "mongo.resolve")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemMongoDB))
+
+ if err := c.connectLocked(ctx); err != nil {
+ c.connectAttempts++
+ c.recordConnectionFailure("resolve")
+
+ libOpentelemetry.HandleSpanError(span, "Failed to resolve mongo connection", err)
+
+ return nil, err
+ }
+
+ c.connectAttempts = 0
+
+ if c.client == nil {
+ err := ErrClientClosed
+ libOpentelemetry.HandleSpanError(span, "Mongo client not connected after resolve", err)
+
+ return nil, err
+ }
+
+ return c.client, nil
+}
+
+// DatabaseName returns the configured database name.
+func (c *Client) DatabaseName() (string, error) {
+ if c == nil {
+ return "", nilClientAssert("database_name")
+ }
+
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return c.databaseName, nil
}
-// EnsureIndexes guarantees an index exists for a given collection.
-// Idempotent. Returns error if connection or index creation fails.
-func (mc *MongoConnection) EnsureIndexes(ctx context.Context, collection string, index mongo.IndexModel) error {
- mc.Logger.Debugf("Ensuring indexes for collection: collection=%s", collection)
+// Database returns the configured mongo database handle.
+//
+// Note: the returned *mongo.Database may become stale if Close is called
+// concurrently from another goroutine. Callers that need atomicity
+// across multiple operations should coordinate externally.
+func (c *Client) Database(ctx context.Context) (*mongo.Database, error) {
+ client, err := c.Client(ctx)
+ if err != nil {
+ return nil, err
+ }
- client, err := mc.GetDB(ctx)
+ databaseName, err := c.DatabaseName()
if err != nil {
- mc.Logger.Warnf("Failed to get database connection for index creation: %v", err)
- return fmt.Errorf("failed to get database connection for index creation: %w", err)
+ return nil, err
}
- db := client.Database(mc.Database)
+ return client.Database(databaseName), nil
+}
- coll := db.Collection(collection)
+// Ping checks MongoDB availability using the active connection.
+func (c *Client) Ping(ctx context.Context) error {
+ if c == nil {
+ return nilClientAssert("ping")
+ }
- fields := indexKeysString(index.Keys)
+ if ctx == nil {
+ return ErrNilContext
+ }
- mc.Logger.Debugf("Ensuring index: collection=%s, fields=%s", collection, fields)
+ tracer := otel.Tracer("mongo")
- // Note: createIndexes is idempotent; when indexes already exist with same definition,
- // the server returns ok:1 (no error).
- // Also: if the collection does not exist yet, this operation will create it automatically.
- // Create the collection explicitly only if you need to set collection options
- // (e.g., validation rules, default collation, time-series, capped/clustered).
- _, err = coll.Indexes().CreateOne(ctx, index)
+ ctx, span := tracer.Start(ctx, "mongo.ping")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemMongoDB))
+
+ client, err := c.Client(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to get mongo client for ping", err)
+
+ return err
+ }
+
+ if err := c.deps.ping(ctx, client); err != nil {
+ sanitized := sanitizeDriverError(err)
+ pingErr := fmt.Errorf("%w: %w", ErrPing, sanitized)
+ libOpentelemetry.HandleSpanError(span, "Mongo ping failed", pingErr)
+
+ return pingErr
+ }
+
+ return nil
+}
+
+// Close releases the MongoDB connection.
+// The client is marked as closed regardless of whether disconnect succeeds or fails.
+// This prevents callers from retrying operations on a potentially half-closed client.
+func (c *Client) Close(ctx context.Context) error {
+ if c == nil {
+ return nilClientAssert("close")
+ }
+
+ if ctx == nil {
+ return ErrNilContext
+ }
+
+ tracer := otel.Tracer("mongo")
+
+ ctx, span := tracer.Start(ctx, "mongo.close")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemMongoDB))
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.closed = true
+
+ if c.client == nil {
+ return nil
+ }
+
+ err := c.deps.disconnect(ctx, c.client)
+ c.client = nil
+
+ if err != nil {
+ sanitized := sanitizeDriverError(err)
+ c.log(ctx, "mongo disconnect failed", log.Err(sanitized))
+
+ disconnectErr := fmt.Errorf("%w: %w", ErrDisconnect, sanitized)
+ libOpentelemetry.HandleSpanError(span, "Failed to disconnect from mongo", disconnectErr)
+
+ return disconnectErr
+ }
+
+ return nil
+}
+
+// EnsureIndexes creates indexes for a collection if they do not already exist.
+func (c *Client) EnsureIndexes(ctx context.Context, collection string, indexes ...mongo.IndexModel) error {
+ if c == nil {
+ return nilClientAssert("ensure_indexes")
+ }
+
+ if ctx == nil {
+ return ErrNilContext
+ }
+
+ if strings.TrimSpace(collection) == "" {
+ return ErrEmptyCollectionName
+ }
+
+ if len(indexes) == 0 {
+ return ErrEmptyIndexes
+ }
+
+ tracer := otel.Tracer("mongo")
+
+ ctx, span := tracer.Start(ctx, "mongo.ensure_indexes")
+ defer span.End()
+
+ span.SetAttributes(
+ attribute.String(constant.AttrDBSystem, constant.DBSystemMongoDB),
+ attribute.String(constant.AttrDBMongoDBCollection, collection),
+ )
+
+ client, err := c.Client(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to get mongo client for ensure indexes", err)
+
+ return err
+ }
+
+ databaseName, err := c.DatabaseName()
if err != nil {
- mc.Logger.Warnf("Failed to ensure index: collection=%s, fields=%s, err=%v", collection, fields, err)
- return fmt.Errorf("failed to ensure index on collection %s: %w", collection, err)
+ libOpentelemetry.HandleSpanError(span, "Failed to get database name for ensure indexes", err)
+
+ return err
}
- mc.Logger.Infof("Index successfully ensured: collection=%s, fields=%s \n", collection, fields)
+ var indexErrors []error
+
+ for _, index := range indexes {
+ if err := ctx.Err(); err != nil {
+ indexErrors = append(indexErrors, fmt.Errorf("%w: context cancelled: %w", ErrCreateIndex, err))
+
+ break
+ }
+
+ fields := indexKeysString(index.Keys)
+
+ if fields == "" {
+ c.logAtLevel(ctx, log.LevelWarn, "unrecognized index key type; expected bson.D or bson.M",
+ log.String("collection", collection))
+ }
+
+ c.log(ctx, "ensuring mongo index", log.String("collection", collection), log.String("fields", fields))
+
+ if err := c.deps.createIndex(ctx, client, databaseName, collection, index); err != nil {
+ c.logAtLevel(ctx, log.LevelWarn, "failed to create mongo index",
+ log.String("collection", collection),
+ log.String("fields", fields),
+ log.Err(err),
+ )
+
+ indexErrors = append(indexErrors, fmt.Errorf("%w: collection=%s fields=%s: %w", ErrCreateIndex, collection, fields, err))
+ }
+ }
+
+ if len(indexErrors) > 0 {
+ joinedErr := errors.Join(indexErrors...)
+ libOpentelemetry.HandleSpanError(span, "Failed to ensure some mongo indexes", joinedErr)
+
+ return joinedErr
+ }
return nil
}
+func (c *Client) log(ctx context.Context, message string, fields ...log.Field) {
+ c.logAtLevel(ctx, log.LevelDebug, message, fields...)
+}
+
+func (c *Client) logAtLevel(ctx context.Context, level log.Level, message string, fields ...log.Field) {
+ if c == nil || c.cfg.Logger == nil {
+ return
+ }
+
+ if !c.cfg.Logger.Enabled(level) {
+ return
+ }
+
+ c.cfg.Logger.Log(ctx, level, message, fields...)
+}
+
+// normalizeConfig applies safe defaults, trims whitespace, and clamps to a Config.
+func normalizeConfig(cfg Config) Config {
+ cfg.URI = strings.TrimSpace(cfg.URI)
+ cfg.Database = strings.TrimSpace(cfg.Database)
+
+ if cfg.MaxPoolSize > maxMaxPoolSize {
+ cfg.MaxPoolSize = maxMaxPoolSize
+ }
+
+ if cfg.TLS != nil {
+ tlsCopy := *cfg.TLS
+ tlsCopy.CACertBase64 = strings.TrimSpace(tlsCopy.CACertBase64)
+ cfg.TLS = &tlsCopy
+ }
+
+ normalizeTLSDefaults(cfg.TLS)
+
+ return cfg
+}
+
+// normalizeTLSDefaults sets MinVersion to TLS 1.2 when unspecified (zero).
+// Explicit versions are preserved so downstream validation in buildTLSConfig
+// can reject disallowed values rather than silently overwriting them.
+func normalizeTLSDefaults(tlsCfg *TLSConfig) {
+ if tlsCfg == nil {
+ return
+ }
+
+ if tlsCfg.MinVersion == 0 {
+ tlsCfg.MinVersion = tls.VersionTLS12
+ }
+}
+
+// buildTLSConfig creates a *tls.Config from a TLSConfig.
+// When CACertBase64 is provided, it is decoded and used as the root CA pool.
+// When CACertBase64 is empty, the system root CA pool is used (RootCAs = nil).
+// MinVersion defaults to TLS 1.2. If cfg.MinVersion is set, it must be
+// tls.VersionTLS12 or tls.VersionTLS13; any other value returns ErrInvalidConfig.
+func buildTLSConfig(cfg TLSConfig) (*tls.Config, error) {
+ if cfg.MinVersion != 0 && cfg.MinVersion != tls.VersionTLS12 && cfg.MinVersion != tls.VersionTLS13 {
+ return nil, fmt.Errorf("%w: unsupported TLS MinVersion %#x (must be tls.VersionTLS12 or tls.VersionTLS13)", ErrInvalidConfig, cfg.MinVersion)
+ }
+
+ tlsConfig := &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ }
+
+ if cfg.MinVersion == tls.VersionTLS13 {
+ tlsConfig.MinVersion = tls.VersionTLS13
+ }
+
+ // When CACertBase64 is provided, build a custom root CA pool.
+ // When empty, RootCAs remains nil and Go uses the system root CA pool.
+ if strings.TrimSpace(cfg.CACertBase64) != "" {
+ caCert, err := base64.StdEncoding.DecodeString(cfg.CACertBase64)
+ if err != nil {
+ return nil, configError(fmt.Sprintf("decoding CA cert: %v", err))
+ }
+
+ caCertPool := x509.NewCertPool()
+ if !caCertPool.AppendCertsFromPEM(caCert) {
+ return nil, fmt.Errorf("adding CA cert to pool failed: %w", ErrInvalidConfig)
+ }
+
+ tlsConfig.RootCAs = caCertPool
+ }
+
+ return tlsConfig, nil
+}
+
+// isTLSImplied returns true if the URI scheme or query parameters indicate TLS.
+// Uses proper URI parsing to avoid false positives from substring matching
+// (e.g. credentials or unrelated params containing "tls=true").
+func isTLSImplied(uri string) bool {
+ if strings.HasPrefix(strings.ToLower(uri), "mongodb+srv://") {
+ return true
+ }
+
+ parsed, err := neturl.Parse(uri)
+ if err != nil {
+ return false
+ }
+
+ for key, values := range parsed.Query() {
+ if strings.EqualFold(key, "tls") || strings.EqualFold(key, "ssl") {
+ for _, value := range values {
+ if strings.EqualFold(value, "true") {
+ return true
+ }
+ }
+ }
+ }
+
+ return false
+}
+
+// SanitizedError wraps a driver error with a credential-free message.
+// Error() returns only the sanitized text. This prevents URI/auth details
+// from leaking through error messages into logs or upstream callers.
+// Unwrap() preserves the original error chain so callers can still use
+// errors.Is/As to match context.Canceled, context.DeadlineExceeded, or
+// driver sentinels.
+type SanitizedError struct {
+ // Message is the credential-free error description.
+ Message string
+ // cause is the original unwrapped error for errors.Is/As compatibility.
+ cause error
+}
+
+func (e *SanitizedError) Error() string { return e.Message }
+
+// Unwrap returns the original error, preserving the error chain for
+// errors.Is and errors.As matching.
+func (e *SanitizedError) Unwrap() error { return e.cause }
+
+// sanitizeDriverError wraps a raw MongoDB driver error in a SanitizedError
+// that strips potential URI and authentication details from the message.
+func sanitizeDriverError(err error) error {
+ if err == nil {
+ return nil
+ }
+
+ msg := err.Error()
+ msg = uriCredentialsPattern.ReplaceAllString(msg, "://***@")
+ msg = uriPasswordParamPattern.ReplaceAllString(msg, "${1}***")
+
+ return &SanitizedError{Message: msg, cause: err}
+}
+
+// uriCredentialsPattern matches "://user:pass@" in connection strings.
+var uriCredentialsPattern = regexp.MustCompile(`://[^@\s]+@`)
+
+// uriPasswordParamPattern matches "password=value" query parameters.
+var uriPasswordParamPattern = regexp.MustCompile(`(?i)(password=)(\S+)`)
+
+// configError wraps a configuration validation message with ErrInvalidConfig.
+func configError(msg string) error {
+ return fmt.Errorf("%w: %s", ErrInvalidConfig, msg)
+}
+
+// recordConnectionFailure increments the mongo connection failure counter.
+// No-op when metricsFactory is nil.
+func (c *Client) recordConnectionFailure(operation string) {
+ if c == nil || c.metricsFactory == nil {
+ return
+ }
+
+ counter, err := c.metricsFactory.Counter(connectionFailuresMetric)
+ if err != nil {
+ c.logAtLevel(context.Background(), log.LevelWarn, "failed to create mongo metric counter", log.Err(err))
+ return
+ }
+
+ err = counter.
+ WithLabels(map[string]string{
+ "operation": constant.SanitizeMetricLabel(operation),
+ }).
+ AddOne(context.Background())
+ if err != nil {
+ c.logAtLevel(context.Background(), log.LevelWarn, "failed to record mongo metric", log.Err(err))
+ }
+}
+
// indexKeysString returns a string representation of the index keys.
// It's used to log the index keys in a human-readable format.
func indexKeysString(keys any) string {
@@ -127,6 +796,8 @@ func indexKeysString(keys any) string {
parts = append(parts, key)
}
+ sort.Strings(parts)
+
return strings.Join(parts, ",")
default:
return ""
diff --git a/commons/mongo/mongo_integration_test.go b/commons/mongo/mongo_integration_test.go
new file mode 100644
index 00000000..ad0983a5
--- /dev/null
+++ b/commons/mongo/mongo_integration_test.go
@@ -0,0 +1,261 @@
+//go:build integration
+
+package mongo
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go"
+ tcmongo "github.com/testcontainers/testcontainers-go/modules/mongodb"
+ "github.com/testcontainers/testcontainers-go/wait"
+ "go.mongodb.org/mongo-driver/bson"
+ mongodriver "go.mongodb.org/mongo-driver/mongo"
+)
+
+const (
+ testDatabase = "integration_test_db"
+ testCollection = "integration_test_col"
+)
+
+// setupMongoContainer starts a disposable MongoDB 7 container and returns
+// the connection string plus a cleanup function. The container is terminated
+// when cleanup runs (typically via t.Cleanup).
+func setupMongoContainer(t *testing.T) (string, func()) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+ defer cancel()
+
+ container, err := tcmongo.Run(ctx,
+ "mongo:7",
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("Waiting for connections").
+ WithStartupTimeout(30*time.Second),
+ ),
+ )
+ require.NoError(t, err)
+
+ endpoint, err := container.ConnectionString(ctx)
+ require.NoError(t, err)
+
+ return endpoint, func() {
+ closeCtx, closeCancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer closeCancel()
+
+ require.NoError(t, container.Terminate(closeCtx))
+ }
+}
+
+// newIntegrationClient creates a Client backed by the testcontainer at uri.
+func newIntegrationClient(t *testing.T, uri string) *Client {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ client, err := NewClient(ctx, Config{
+ URI: uri,
+ Database: testDatabase,
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err)
+
+ return client
+}
+
+// ---------------------------------------------------------------------------
+// Integration tests
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Mongo_ConnectAndPing(t *testing.T) {
+ uri, cleanup := setupMongoContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ client := newIntegrationClient(t, uri)
+ defer func() { require.NoError(t, client.Close(ctx)) }()
+
+ // Ping must succeed on a healthy container.
+ err := client.Ping(ctx)
+ require.NoError(t, err)
+}
+
+func TestIntegration_Mongo_DatabaseAccess(t *testing.T) {
+ uri, cleanup := setupMongoContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ client := newIntegrationClient(t, uri)
+ defer func() { require.NoError(t, client.Close(ctx)) }()
+
+ // Obtain a database handle and verify the name.
+ db, err := client.Database(ctx)
+ require.NoError(t, err)
+ assert.Equal(t, testDatabase, db.Name())
+
+ // Insert a document into a fresh collection.
+ type testDoc struct {
+ Name string `bson:"name"`
+ Value int `bson:"value"`
+ }
+
+ col := db.Collection(testCollection)
+ insertDoc := testDoc{Name: "integration", Value: 42}
+
+ _, err = col.InsertOne(ctx, insertDoc)
+ require.NoError(t, err)
+
+ // Read it back and verify contents.
+ var result testDoc
+
+ err = col.FindOne(ctx, bson.M{"name": "integration"}).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "integration", result.Name)
+ assert.Equal(t, 42, result.Value)
+}
+
+func TestIntegration_Mongo_EnsureIndexes(t *testing.T) {
+ uri, cleanup := setupMongoContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ client := newIntegrationClient(t, uri)
+ defer func() { require.NoError(t, client.Close(ctx)) }()
+
+ // Force-create the collection so index listing returns results.
+ db, err := client.Database(ctx)
+ require.NoError(t, err)
+
+ err = db.CreateCollection(ctx, testCollection)
+ require.NoError(t, err)
+
+ // Ensure an index on the "email" field.
+ err = client.EnsureIndexes(ctx, testCollection,
+ mongodriver.IndexModel{
+ Keys: bson.D{{Key: "email", Value: 1}},
+ },
+ )
+ require.NoError(t, err)
+
+ // List indexes and verify ours is present.
+ driverClient, err := client.Client(ctx)
+ require.NoError(t, err)
+
+ cursor, err := driverClient.Database(testDatabase).
+ Collection(testCollection).
+ Indexes().
+ List(ctx)
+ require.NoError(t, err)
+
+ var indexes []bson.M
+
+ err = cursor.All(ctx, &indexes)
+ require.NoError(t, err)
+
+ // MongoDB always creates a default _id index, so we expect at least 2.
+ require.GreaterOrEqual(t, len(indexes), 2, "expected at least the _id index + email index")
+
+ // Find the email index by inspecting the "key" document.
+ // The driver may return bson.M or bson.D depending on version/context.
+ found := false
+
+ for _, idx := range indexes {
+ switch keyDoc := idx["key"].(type) {
+ case bson.M:
+ if _, hasEmail := keyDoc["email"]; hasEmail {
+ found = true
+ }
+ case bson.D:
+ for _, elem := range keyDoc {
+ if elem.Key == "email" {
+ found = true
+
+ break
+ }
+ }
+ }
+
+ if found {
+ break
+ }
+ }
+
+ assert.True(t, found, "expected to find an index on 'email'; indexes: %+v", indexes)
+}
+
+func TestIntegration_Mongo_ResolveClient(t *testing.T) {
+ uri, cleanup := setupMongoContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ client := newIntegrationClient(t, uri)
+ defer func() {
+ require.NoError(t, client.Close(ctx))
+ }()
+
+ // Confirm the client is alive before simulating a dropped connection.
+ err := client.Ping(ctx)
+ require.NoError(t, err)
+
+ // ResolveClient should return the active connected driver client when the
+ // wrapper is healthy. The branch where the cached client is absent is covered
+ // in unit tests because it requires synthetic internal state manipulation.
+ driverClient, err := client.ResolveClient(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, driverClient)
+
+ currentClient, err := client.Client(ctx)
+ require.NoError(t, err)
+ assert.Same(t, currentClient, driverClient)
+
+ // Verify the resolved client is functional against the live container.
+ err = client.Ping(ctx)
+ require.NoError(t, err)
+}
+
+func TestIntegration_Mongo_ConcurrentPing(t *testing.T) {
+ uri, cleanup := setupMongoContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ client := newIntegrationClient(t, uri)
+ defer func() { require.NoError(t, client.Close(ctx)) }()
+
+ const goroutines = 10
+
+ var wg sync.WaitGroup
+
+ errs := make([]error, goroutines)
+
+ for i := range goroutines {
+ wg.Add(1)
+
+ go func(idx int) {
+ defer wg.Done()
+
+ errs[idx] = client.Ping(ctx)
+ }(i)
+ }
+
+ wg.Wait()
+
+ for i, err := range errs {
+ assert.NoErrorf(t, err, "goroutine %d returned an error", i)
+ }
+}
diff --git a/commons/mongo/mongo_test.go b/commons/mongo/mongo_test.go
new file mode 100644
index 00000000..a2fbb5e9
--- /dev/null
+++ b/commons/mongo/mongo_test.go
@@ -0,0 +1,1167 @@
+//go:build unit
+
+package mongo
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/base64"
+ "encoding/pem"
+ "errors"
+ "math/big"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.mongodb.org/mongo-driver/bson"
+ "go.mongodb.org/mongo-driver/mongo"
+ "go.mongodb.org/mongo-driver/mongo/options"
+)
+
+// ---------------------------------------------------------------------------
+// Test helpers
+// ---------------------------------------------------------------------------
+
+func withDeps(deps clientDeps) Option {
+ return func(current *clientDeps) {
+ *current = deps
+ }
+}
+
+func baseConfig() Config {
+ return Config{
+ URI: "mongodb://localhost:27017",
+ Database: "app",
+ }
+}
+
+func successDeps() clientDeps {
+ fakeClient := &mongo.Client{}
+
+ return clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return fakeClient, nil
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+}
+
+func newTestClient(t *testing.T, overrides *clientDeps) *Client {
+ t.Helper()
+
+ deps := successDeps()
+ if overrides != nil {
+ if overrides.connect != nil {
+ deps.connect = overrides.connect
+ }
+
+ if overrides.ping != nil {
+ deps.ping = overrides.ping
+ }
+
+ if overrides.disconnect != nil {
+ deps.disconnect = overrides.disconnect
+ }
+
+ if overrides.createIndex != nil {
+ deps.createIndex = overrides.createIndex
+ }
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ require.NoError(t, err)
+
+ return client
+}
+
+// spyLogger implements log.Logger and records messages for verification.
+type spyLogger struct {
+ mu sync.Mutex
+ messages []string
+ levels []log.Level
+}
+
+func (s *spyLogger) Log(_ context.Context, level log.Level, msg string, _ ...log.Field) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.messages = append(s.messages, msg)
+ s.levels = append(s.levels, level)
+}
+
+func (s *spyLogger) With(_ ...log.Field) log.Logger { return s }
+func (s *spyLogger) WithGroup(_ string) log.Logger { return s }
+func (s *spyLogger) Enabled(_ log.Level) bool { return true }
+func (s *spyLogger) Sync(_ context.Context) error { return nil }
+
+func generateTestCertificatePEM(t *testing.T) []byte {
+ t.Helper()
+
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ tmpl := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{CommonName: "mongo-test-ca"},
+ NotBefore: time.Now().Add(-time.Hour),
+ NotAfter: time.Now().Add(time.Hour),
+ IsCA: true,
+ BasicConstraintsValid: true,
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
+ require.NoError(t, err)
+
+ return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+}
+
+// ---------------------------------------------------------------------------
+// NewClient tests
+// ---------------------------------------------------------------------------
+
+func TestNewClient_ValidatesInput(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_context", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := NewClient(nil, baseConfig())
+ assert.Nil(t, client)
+ assert.ErrorIs(t, err, ErrNilContext)
+ })
+
+ t.Run("empty_uri", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := baseConfig()
+ cfg.URI = ""
+
+ client, err := NewClient(context.Background(), cfg)
+ assert.Nil(t, client)
+ assert.ErrorIs(t, err, ErrEmptyURI)
+ })
+
+ t.Run("empty_database", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := baseConfig()
+ cfg.Database = " "
+
+ client, err := NewClient(context.Background(), cfg)
+ assert.Nil(t, client)
+ assert.ErrorIs(t, err, ErrEmptyDatabaseName)
+ })
+}
+
+func TestNewClient_ConnectAndPingFailures(t *testing.T) {
+ t.Parallel()
+
+ t.Run("connect_failure", func(t *testing.T) {
+ t.Parallel()
+
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return nil, errors.New("dial failed")
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ assert.Nil(t, client)
+ assert.ErrorIs(t, err, ErrConnect)
+ })
+
+ t.Run("nil_client_returned", func(t *testing.T) {
+ t.Parallel()
+
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return nil, nil
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ assert.Nil(t, client)
+ assert.ErrorIs(t, err, ErrNilMongoClient)
+ })
+
+ t.Run("ping_failure_disconnects", func(t *testing.T) {
+ t.Parallel()
+
+ fakeClient := &mongo.Client{}
+ var disconnectCalls atomic.Int32
+
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return fakeClient, nil
+ },
+ ping: func(context.Context, *mongo.Client) error {
+ return errors.New("ping failed")
+ },
+ disconnect: func(context.Context, *mongo.Client) error {
+ disconnectCalls.Add(1)
+ return nil
+ },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ assert.Nil(t, client)
+ assert.ErrorIs(t, err, ErrPing)
+ assert.EqualValues(t, 1, disconnectCalls.Load())
+ })
+}
+
+func TestNewClient_NilOptionIsSkipped(t *testing.T) {
+ t.Parallel()
+
+ deps := successDeps()
+ client, err := NewClient(context.Background(), baseConfig(), nil, withDeps(deps))
+ require.NoError(t, err)
+ assert.NotNil(t, client)
+}
+
+func TestNewClient_NilDependencyRejected(t *testing.T) {
+ t.Parallel()
+
+ nilConnect := func(d *clientDeps) { d.connect = nil }
+ _, err := NewClient(context.Background(), baseConfig(), nilConnect)
+ assert.ErrorIs(t, err, ErrNilDependency)
+}
+
+func TestNewClient_ClearsURIAfterConnect(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ assert.Empty(t, client.cfg.URI, "URI should be cleared from cfg after connect")
+ assert.NotEmpty(t, client.uri, "private uri should be preserved")
+}
+
+// ---------------------------------------------------------------------------
+// Connect tests
+// ---------------------------------------------------------------------------
+
+func TestClient_ConnectIsIdempotent(t *testing.T) {
+ t.Parallel()
+
+ fakeClient := &mongo.Client{}
+ var connectCalls atomic.Int32
+
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ connectCalls.Add(1)
+ return fakeClient, nil
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ require.NoError(t, err)
+
+ assert.NoError(t, client.Connect(context.Background()))
+ assert.EqualValues(t, 1, connectCalls.Load())
+}
+
+func TestClient_Connect_Guards(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ assert.ErrorIs(t, c.Connect(context.Background()), ErrNilClient)
+ })
+
+ t.Run("nil_context_on_closed_client", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ require.NoError(t, client.Close(context.Background()))
+ assert.ErrorIs(t, client.Connect(nil), ErrNilContext)
+ })
+}
+
+func TestClient_Connect_ConfigPropagation(t *testing.T) {
+ t.Parallel()
+
+ fakeClient := &mongo.Client{}
+ var capturedOpts *options.ClientOptions
+
+ cfg := baseConfig()
+ cfg.MaxPoolSize = 42
+ cfg.ServerSelectionTimeout = 3 * time.Second
+ cfg.HeartbeatInterval = 7 * time.Second
+
+ deps := clientDeps{
+ connect: func(_ context.Context, opts *options.ClientOptions) (*mongo.Client, error) {
+ capturedOpts = opts
+ return fakeClient, nil
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ _, err := NewClient(context.Background(), cfg, withDeps(deps))
+ require.NoError(t, err)
+ assert.NotNil(t, capturedOpts)
+}
+
+// ---------------------------------------------------------------------------
+// Client and Database tests
+// ---------------------------------------------------------------------------
+
+func TestClient_ClientAndDatabase(t *testing.T) {
+ t.Parallel()
+
+ fakeClient := &mongo.Client{}
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return fakeClient, nil
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ require.NoError(t, err)
+
+ t.Run("nil_context", func(t *testing.T) {
+ t.Parallel()
+
+ mongoClient, callErr := client.Client(nil)
+ assert.Nil(t, mongoClient)
+ assert.ErrorIs(t, callErr, ErrNilContext)
+ })
+
+ t.Run("database_name", func(t *testing.T) {
+ t.Parallel()
+
+ databaseName, err := client.DatabaseName()
+ require.NoError(t, err)
+ assert.Equal(t, "app", databaseName)
+ })
+
+ t.Run("database_returns_handle", func(t *testing.T) {
+ t.Parallel()
+
+ db, callErr := client.Database(context.Background())
+ require.NoError(t, callErr)
+ assert.Equal(t, "app", db.Name())
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Ping tests
+// ---------------------------------------------------------------------------
+
+func TestClient_Ping(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ assert.ErrorIs(t, c.Ping(context.Background()), ErrNilClient)
+ })
+
+ t.Run("nil_context", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ assert.ErrorIs(t, client.Ping(nil), ErrNilContext)
+ })
+
+ t.Run("success", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ assert.NoError(t, client.Ping(context.Background()))
+ })
+
+ t.Run("wraps_ping_error", func(t *testing.T) {
+ t.Parallel()
+
+ var pingCount atomic.Int32
+ deps := successDeps()
+ deps.ping = func(context.Context, *mongo.Client) error {
+ if pingCount.Add(1) == 1 {
+ return nil // first ping (from Connect) succeeds
+ }
+
+ return errors.New("network timeout")
+ }
+
+ client := newTestClient(t, &deps)
+
+ err := client.Ping(context.Background())
+ assert.ErrorIs(t, err, ErrPing)
+ })
+
+ t.Run("closed_client", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ require.NoError(t, client.Close(context.Background()))
+ assert.ErrorIs(t, client.Ping(context.Background()), ErrClientClosed)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Close tests
+// ---------------------------------------------------------------------------
+
+func TestClient_Close(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var client *Client
+ assert.ErrorIs(t, client.Close(context.Background()), ErrNilClient)
+ })
+
+ t.Run("nil_context", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ assert.ErrorIs(t, client.Close(nil), ErrNilContext)
+ })
+
+ t.Run("disconnect_failure_clears_client", func(t *testing.T) {
+ t.Parallel()
+
+ deps := successDeps()
+ deps.disconnect = func(context.Context, *mongo.Client) error {
+ return errors.New("disconnect failed")
+ }
+
+ client := newTestClient(t, &deps)
+
+ err := client.Close(context.Background())
+ assert.ErrorIs(t, err, ErrDisconnect)
+
+ mongoClient, callErr := client.Client(context.Background())
+ assert.Nil(t, mongoClient)
+ assert.ErrorIs(t, callErr, ErrClientClosed)
+ })
+
+ t.Run("close_is_idempotent", func(t *testing.T) {
+ t.Parallel()
+
+ var disconnectCalls atomic.Int32
+ deps := successDeps()
+ deps.disconnect = func(context.Context, *mongo.Client) error {
+ disconnectCalls.Add(1)
+ return nil
+ }
+
+ client := newTestClient(t, &deps)
+
+ require.NoError(t, client.Close(context.Background()))
+ require.NoError(t, client.Close(context.Background()))
+ assert.EqualValues(t, 1, disconnectCalls.Load())
+ })
+
+ t.Run("connect_after_close_returns_error", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ require.NoError(t, client.Close(context.Background()))
+
+ err := client.Connect(context.Background())
+ assert.ErrorIs(t, err, ErrClientClosed, "Connect after Close must return ErrClientClosed")
+ })
+
+ t.Run("resolve_client_after_close_returns_error", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ require.NoError(t, client.Close(context.Background()))
+
+ mc, err := client.ResolveClient(context.Background())
+ assert.Nil(t, mc)
+ assert.ErrorIs(t, err, ErrClientClosed, "ResolveClient after Close must return ErrClientClosed")
+ })
+
+ t.Run("resolve_client_reconnects_when_cached_client_is_absent", func(t *testing.T) {
+ t.Parallel()
+
+ initialClient := &mongo.Client{}
+ reconnectedClient := &mongo.Client{}
+ var connectCalls atomic.Int32
+
+ deps := successDeps()
+ deps.connect = func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ if connectCalls.Add(1) == 1 {
+ return initialClient, nil
+ }
+
+ return reconnectedClient, nil
+ }
+
+ client := newTestClient(t, &deps)
+ assert.EqualValues(t, 1, connectCalls.Load())
+
+ client.mu.Lock()
+ client.client = nil
+ client.mu.Unlock()
+
+ resolved, err := client.ResolveClient(context.Background())
+ require.NoError(t, err)
+ assert.Same(t, reconnectedClient, resolved)
+ assert.EqualValues(t, 2, connectCalls.Load())
+ })
+
+ t.Run("close_prevents_reconnection_via_resolve", func(t *testing.T) {
+ t.Parallel()
+
+ var connectCalls atomic.Int32
+ deps := successDeps()
+ deps.connect = func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ connectCalls.Add(1)
+ return &mongo.Client{}, nil
+ }
+
+ client := newTestClient(t, &deps)
+ initialConnects := connectCalls.Load()
+
+ require.NoError(t, client.Close(context.Background()))
+
+ _, err := client.ResolveClient(context.Background())
+ assert.ErrorIs(t, err, ErrClientClosed)
+ assert.EqualValues(t, initialConnects, connectCalls.Load(), "no reconnection attempt after Close")
+ })
+}
+
+// ---------------------------------------------------------------------------
+// EnsureIndexes tests
+// ---------------------------------------------------------------------------
+
+func TestClient_EnsureIndexes(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ err := c.EnsureIndexes(context.Background(), "users", mongo.IndexModel{Keys: bson.D{{Key: "a", Value: 1}}})
+ assert.ErrorIs(t, err, ErrNilClient)
+ })
+
+ t.Run("nil_context", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ err := client.EnsureIndexes(nil, "users", mongo.IndexModel{Keys: bson.D{{Key: "tenant_id", Value: 1}}})
+ assert.ErrorIs(t, err, ErrNilContext)
+ })
+
+ t.Run("empty_collection", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ err := client.EnsureIndexes(context.Background(), " ", mongo.IndexModel{Keys: bson.D{{Key: "tenant_id", Value: 1}}})
+ assert.ErrorIs(t, err, ErrEmptyCollectionName)
+ })
+
+ t.Run("empty_indexes", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ err := client.EnsureIndexes(context.Background(), "users")
+ assert.ErrorIs(t, err, ErrEmptyIndexes)
+ })
+
+ t.Run("creates_all_indexes", func(t *testing.T) {
+ t.Parallel()
+
+ fakeClient := &mongo.Client{}
+ var createCalls atomic.Int32
+
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return fakeClient, nil
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(_ context.Context, client *mongo.Client, database, collection string, index mongo.IndexModel) error {
+ createCalls.Add(1)
+ assert.Same(t, fakeClient, client)
+ assert.Equal(t, "app", database)
+ assert.Equal(t, "users", collection)
+ assert.NotNil(t, index.Keys)
+
+ return nil
+ },
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ require.NoError(t, err)
+
+ err = client.EnsureIndexes(
+ context.Background(),
+ "users",
+ mongo.IndexModel{Keys: bson.D{{Key: "tenant_id", Value: 1}}},
+ mongo.IndexModel{Keys: bson.D{{Key: "created_at", Value: -1}}},
+ )
+ require.NoError(t, err)
+ assert.EqualValues(t, 2, createCalls.Load())
+ })
+
+ t.Run("wraps_create_index_error", func(t *testing.T) {
+ t.Parallel()
+
+ deps := successDeps()
+ deps.createIndex = func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return errors.New("duplicate options")
+ }
+
+ client := newTestClient(t, &deps)
+
+ err := client.EnsureIndexes(context.Background(), "users", mongo.IndexModel{Keys: bson.D{{Key: "tenant_id", Value: 1}}})
+ assert.ErrorIs(t, err, ErrCreateIndex)
+ })
+
+ t.Run("batches_multiple_errors", func(t *testing.T) {
+ t.Parallel()
+
+ var createCalls atomic.Int32
+ deps := successDeps()
+ deps.createIndex = func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ createCalls.Add(1)
+ return errors.New("failed")
+ }
+
+ client := newTestClient(t, &deps)
+
+ err := client.EnsureIndexes(context.Background(), "users",
+ mongo.IndexModel{Keys: bson.D{{Key: "a", Value: 1}}},
+ mongo.IndexModel{Keys: bson.D{{Key: "b", Value: 1}}},
+ mongo.IndexModel{Keys: bson.D{{Key: "c", Value: 1}}},
+ )
+ assert.Error(t, err)
+ assert.EqualValues(t, 3, createCalls.Load()) // all 3 attempted, not short-circuited
+ assert.ErrorIs(t, err, ErrCreateIndex)
+ })
+
+ t.Run("partial_failure_continues", func(t *testing.T) {
+ t.Parallel()
+
+ var successCalls, failCalls atomic.Int32
+ deps := successDeps()
+ deps.createIndex = func(_ context.Context, _ *mongo.Client, _, _ string, idx mongo.IndexModel) error {
+ keys := idx.Keys.(bson.D)
+ if keys[0].Key == "b" {
+ failCalls.Add(1)
+ return errors.New("duplicate")
+ }
+
+ successCalls.Add(1)
+
+ return nil
+ }
+
+ client := newTestClient(t, &deps)
+
+ err := client.EnsureIndexes(context.Background(), "users",
+ mongo.IndexModel{Keys: bson.D{{Key: "a", Value: 1}}},
+ mongo.IndexModel{Keys: bson.D{{Key: "b", Value: 1}}},
+ mongo.IndexModel{Keys: bson.D{{Key: "c", Value: 1}}},
+ )
+ assert.Error(t, err)
+ assert.EqualValues(t, 2, successCalls.Load())
+ assert.EqualValues(t, 1, failCalls.Load())
+ })
+
+ t.Run("context_cancellation_stops_loop", func(t *testing.T) {
+ t.Parallel()
+
+ var calls atomic.Int32
+ deps := successDeps()
+ deps.createIndex = func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ calls.Add(1)
+ return nil
+ }
+
+ client := newTestClient(t, &deps)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // cancel immediately
+
+ err := client.EnsureIndexes(ctx, "users",
+ mongo.IndexModel{Keys: bson.D{{Key: "a", Value: 1}}},
+ mongo.IndexModel{Keys: bson.D{{Key: "b", Value: 1}}},
+ )
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrCreateIndex)
+ assert.EqualValues(t, 0, calls.Load()) // no indexes attempted
+ })
+
+ t.Run("closed_client", func(t *testing.T) {
+ t.Parallel()
+
+ client := newTestClient(t, nil)
+ require.NoError(t, client.Close(context.Background()))
+
+ err := client.EnsureIndexes(context.Background(), "users", mongo.IndexModel{Keys: bson.D{{Key: "a", Value: 1}}})
+ assert.ErrorIs(t, err, ErrClientClosed)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Concurrency tests
+// ---------------------------------------------------------------------------
+
+func TestClient_ConcurrentClientReads(t *testing.T) {
+ t.Parallel()
+
+ fakeClient := &mongo.Client{}
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return fakeClient, nil
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ client, err := NewClient(context.Background(), baseConfig(), withDeps(deps))
+ require.NoError(t, err)
+
+ const workers = 50
+ results := make([]*mongo.Client, workers)
+ errs := make([]error, workers)
+ var wg sync.WaitGroup
+
+ for i := 0; i < workers; i++ {
+ wg.Add(1)
+
+ go func(idx int) {
+ defer wg.Done()
+ results[idx], errs[idx] = client.Client(context.Background())
+ }(i)
+ }
+
+ wg.Wait()
+
+ for i := 0; i < workers; i++ {
+ assert.NoError(t, errs[i])
+ assert.Same(t, fakeClient, results[i])
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Logging tests
+// ---------------------------------------------------------------------------
+
+func TestClient_LogsOnConnectFailure(t *testing.T) {
+ t.Parallel()
+
+ spy := &spyLogger{}
+ cfg := baseConfig()
+ cfg.Logger = spy
+
+ deps := clientDeps{
+ connect: func(context.Context, *options.ClientOptions) (*mongo.Client, error) {
+ return nil, errors.New("dial failed")
+ },
+ ping: func(context.Context, *mongo.Client) error { return nil },
+ disconnect: func(context.Context, *mongo.Client) error { return nil },
+ createIndex: func(context.Context, *mongo.Client, string, string, mongo.IndexModel) error {
+ return nil
+ },
+ }
+
+ _, _ = NewClient(context.Background(), cfg, withDeps(deps))
+
+ spy.mu.Lock()
+ defer spy.mu.Unlock()
+
+ require.NotEmpty(t, spy.messages)
+ assert.Equal(t, "mongo connect failed", spy.messages[0])
+}
+
+func TestClient_LogsNonTLSWarning(t *testing.T) {
+ t.Parallel()
+
+ spy := &spyLogger{}
+ cfg := baseConfig()
+ cfg.Logger = spy
+
+ client := newTestClientWithLogger(t, nil, spy)
+ _ = client // verify no panic, warning was logged during construction
+
+ spy.mu.Lock()
+ defer spy.mu.Unlock()
+
+ found := false
+
+ for _, msg := range spy.messages {
+ if msg == "mongo connection established without TLS; consider configuring TLS for production use" {
+ found = true
+ break
+ }
+ }
+
+ assert.True(t, found, "expected non-TLS warning in log messages, got: %v", spy.messages)
+}
+
+func newTestClientWithLogger(t *testing.T, overrides *clientDeps, logger log.Logger) *Client {
+ t.Helper()
+
+ deps := successDeps()
+ if overrides != nil {
+ if overrides.connect != nil {
+ deps.connect = overrides.connect
+ }
+
+ if overrides.ping != nil {
+ deps.ping = overrides.ping
+ }
+
+ if overrides.disconnect != nil {
+ deps.disconnect = overrides.disconnect
+ }
+
+ if overrides.createIndex != nil {
+ deps.createIndex = overrides.createIndex
+ }
+ }
+
+ cfg := baseConfig()
+ cfg.Logger = logger
+
+ client, err := NewClient(context.Background(), cfg, withDeps(deps))
+ require.NoError(t, err)
+
+ return client
+}
+
+// ---------------------------------------------------------------------------
+// indexKeysString tests
+// ---------------------------------------------------------------------------
+
+func TestIndexKeysString(t *testing.T) {
+ t.Parallel()
+
+ t.Run("bson_d_preserves_order", func(t *testing.T) {
+ t.Parallel()
+
+ keys := bson.D{{Key: "tenant_id", Value: 1}, {Key: "created_at", Value: -1}}
+ assert.Equal(t, "tenant_id,created_at", indexKeysString(keys))
+ })
+
+ t.Run("bson_m_is_sorted", func(t *testing.T) {
+ t.Parallel()
+
+ keys := bson.M{"zeta": 1, "alpha": 1, "middle": -1}
+ assert.Equal(t, "alpha,middle,zeta", indexKeysString(keys))
+ })
+
+ t.Run("unknown_type", func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "", indexKeysString(42))
+ })
+
+ t.Run("nil_keys", func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "", indexKeysString(nil))
+ })
+
+ t.Run("empty_bson_d", func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "", indexKeysString(bson.D{}))
+ })
+
+ t.Run("empty_bson_m", func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "", indexKeysString(bson.M{}))
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Normalization tests
+// ---------------------------------------------------------------------------
+
+func TestNormalizeConfig(t *testing.T) {
+ t.Parallel()
+
+ t.Run("clamps_pool_size", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := normalizeConfig(Config{MaxPoolSize: 9999})
+ assert.EqualValues(t, maxMaxPoolSize, cfg.MaxPoolSize)
+ })
+
+ t.Run("preserves_valid_pool_size", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := normalizeConfig(Config{MaxPoolSize: 50})
+ assert.EqualValues(t, 50, cfg.MaxPoolSize)
+ })
+
+ t.Run("pool_size_at_cap", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := normalizeConfig(Config{MaxPoolSize: maxMaxPoolSize})
+ assert.EqualValues(t, maxMaxPoolSize, cfg.MaxPoolSize)
+ })
+
+ t.Run("trims_whitespace_from_URI_and_Database", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := normalizeConfig(Config{
+ URI: " mongodb://localhost:27017 ",
+ Database: " mydb ",
+ })
+ assert.Equal(t, "mongodb://localhost:27017", cfg.URI)
+ assert.Equal(t, "mydb", cfg.Database)
+ })
+
+ t.Run("trims_whitespace_from_TLS_CACertBase64", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := normalizeConfig(Config{
+ TLS: &TLSConfig{CACertBase64: " dGVzdA== "},
+ })
+ assert.Equal(t, "dGVzdA==", cfg.TLS.CACertBase64)
+ })
+}
+
+func TestNormalizeTLSDefaults(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_config", func(t *testing.T) {
+ t.Parallel()
+
+ normalizeTLSDefaults(nil) // should not panic
+ })
+
+ t.Run("sets_default_min_version", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := &TLSConfig{}
+ normalizeTLSDefaults(cfg)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+ })
+
+ t.Run("preserves_tls13", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := &TLSConfig{MinVersion: tls.VersionTLS13}
+ normalizeTLSDefaults(cfg)
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
+ })
+
+ t.Run("preserves_explicit_insecure_version", func(t *testing.T) {
+ t.Parallel()
+
+ // normalizeTLSDefaults only sets defaults for unspecified (zero) values.
+ // Explicit versions are preserved for downstream validation in buildTLSConfig.
+ cfg := &TLSConfig{MinVersion: tls.VersionTLS10}
+ normalizeTLSDefaults(cfg)
+ assert.Equal(t, uint16(tls.VersionTLS10), cfg.MinVersion)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// TLS tests
+// ---------------------------------------------------------------------------
+
+func TestBuildTLSConfig(t *testing.T) {
+ t.Parallel()
+
+ t.Run("invalid_base64", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := buildTLSConfig(TLSConfig{CACertBase64: "not-valid-base64!!!"})
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ assert.Contains(t, err.Error(), "decoding CA cert")
+ })
+
+ t.Run("valid_base64_invalid_pem", func(t *testing.T) {
+ t.Parallel()
+
+ invalidPEM := base64.StdEncoding.EncodeToString([]byte("not a PEM certificate"))
+ _, err := buildTLSConfig(TLSConfig{CACertBase64: invalidPEM})
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "adding CA cert to pool failed")
+ })
+
+ t.Run("valid_cert_tls12", func(t *testing.T) {
+ t.Parallel()
+
+ certPEM := generateTestCertificatePEM(t)
+ encoded := base64.StdEncoding.EncodeToString(certPEM)
+
+ cfg, err := buildTLSConfig(TLSConfig{
+ CACertBase64: encoded,
+ MinVersion: tls.VersionTLS12,
+ })
+ require.NoError(t, err)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+ assert.NotNil(t, cfg.RootCAs)
+ })
+
+ t.Run("valid_cert_tls13", func(t *testing.T) {
+ t.Parallel()
+
+ certPEM := generateTestCertificatePEM(t)
+ encoded := base64.StdEncoding.EncodeToString(certPEM)
+
+ cfg, err := buildTLSConfig(TLSConfig{
+ CACertBase64: encoded,
+ MinVersion: tls.VersionTLS13,
+ })
+ require.NoError(t, err)
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
+ })
+
+ t.Run("unsupported_version_returns_error", func(t *testing.T) {
+ t.Parallel()
+
+ certPEM := generateTestCertificatePEM(t)
+ encoded := base64.StdEncoding.EncodeToString(certPEM)
+
+ _, err := buildTLSConfig(TLSConfig{
+ CACertBase64: encoded,
+ MinVersion: tls.VersionTLS10,
+ })
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+
+ t.Run("zero_version_defaults_to_tls12", func(t *testing.T) {
+ t.Parallel()
+
+ certPEM := generateTestCertificatePEM(t)
+ encoded := base64.StdEncoding.EncodeToString(certPEM)
+
+ cfg, err := buildTLSConfig(TLSConfig{
+ CACertBase64: encoded,
+ MinVersion: 0,
+ })
+ require.NoError(t, err)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+ })
+
+ t.Run("empty_ca_cert_uses_system_roots", func(t *testing.T) {
+ t.Parallel()
+
+ cfg, err := buildTLSConfig(TLSConfig{
+ CACertBase64: "",
+ MinVersion: tls.VersionTLS12,
+ })
+ require.NoError(t, err)
+ assert.Nil(t, cfg.RootCAs, "empty CACertBase64 should leave RootCAs nil (system roots)")
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+ })
+
+ t.Run("empty_ca_cert_with_tls13", func(t *testing.T) {
+ t.Parallel()
+
+ cfg, err := buildTLSConfig(TLSConfig{
+ CACertBase64: "",
+ MinVersion: tls.VersionTLS13,
+ })
+ require.NoError(t, err)
+ assert.Nil(t, cfg.RootCAs, "empty CACertBase64 should leave RootCAs nil (system roots)")
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
+ })
+
+ t.Run("whitespace_only_ca_cert_uses_system_roots", func(t *testing.T) {
+ t.Parallel()
+
+ cfg, err := buildTLSConfig(TLSConfig{
+ CACertBase64: " ",
+ MinVersion: tls.VersionTLS12,
+ })
+ require.NoError(t, err)
+ assert.Nil(t, cfg.RootCAs, "whitespace-only CACertBase64 should use system roots")
+ })
+}
+
+func TestConfig_Validate_TLS(t *testing.T) {
+ t.Parallel()
+
+ t.Run("tls_without_ca_cert_passes_validation", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := Config{URI: "mongodb://localhost", Database: "db", TLS: &TLSConfig{}}
+ err := cfg.validate()
+ assert.NoError(t, err, "TLS without CACertBase64 should pass validation (uses system roots)")
+ })
+
+ t.Run("tls_with_min_version_only_passes", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := Config{URI: "mongodb://localhost", Database: "db", TLS: &TLSConfig{MinVersion: tls.VersionTLS13}}
+ err := cfg.validate()
+ assert.NoError(t, err, "TLS with only MinVersion should pass validation")
+ })
+
+ t.Run("tls_with_valid_cert_passes", func(t *testing.T) {
+ t.Parallel()
+
+ certPEM := generateTestCertificatePEM(t)
+ encoded := base64.StdEncoding.EncodeToString(certPEM)
+
+ cfg := Config{URI: "mongodb://localhost", Database: "db", TLS: &TLSConfig{CACertBase64: encoded}}
+ err := cfg.validate()
+ assert.NoError(t, err)
+ })
+}
+
+func TestIsTLSImplied(t *testing.T) {
+ t.Parallel()
+
+ assert.True(t, isTLSImplied("mongodb+srv://cluster.mongodb.net"))
+ assert.True(t, isTLSImplied("mongodb://host:27017/?tls=true"))
+ assert.True(t, isTLSImplied("mongodb://host:27017/?ssl=true"))
+ assert.True(t, isTLSImplied("mongodb://host:27017/?tls=true&appName=myapp"))
+ assert.True(t, isTLSImplied("mongodb://host:27017/?TLS=True"))
+ assert.False(t, isTLSImplied("mongodb://localhost:27017"))
+ assert.False(t, isTLSImplied("mongodb://localhost:27017/?tls=false"))
+ assert.False(t, isTLSImplied("mongodb://localhost:27017/?appName=notls%3Dtrue"))
+}
diff --git a/commons/net/http/context.go b/commons/net/http/context.go
new file mode 100644
index 00000000..8c62ff3a
--- /dev/null
+++ b/commons/net/http/context.go
@@ -0,0 +1,136 @@
+package http
+
+import (
+ "errors"
+ "sync"
+
+ "context"
+
+ "github.com/gofiber/fiber/v2"
+)
+
+// TenantExtractor extracts tenant ID string from a request context.
+type TenantExtractor func(ctx context.Context) string
+
+// IDLocation defines where a resource ID should be extracted from.
+type IDLocation string
+
+const (
+ // IDLocationParam extracts the ID from a URL path parameter.
+ IDLocationParam IDLocation = "param"
+ // IDLocationQuery extracts the ID from a query string parameter.
+ IDLocationQuery IDLocation = "query"
+)
+
+// ErrInvalidIDLocation indicates an unsupported ID source location.
+var ErrInvalidIDLocation = errors.New("invalid id location")
+
+// Sentinel errors for context ownership verification.
+var (
+ ErrMissingContextID = errors.New("context ID is required")
+ ErrInvalidContextID = errors.New("context ID must be a valid UUID")
+ ErrMissingResourceID = errors.New("resource ID is required")
+ ErrInvalidResourceID = errors.New("resource ID must be a valid UUID")
+ ErrTenantIDNotFound = errors.New("tenant ID not found in request context")
+ ErrTenantExtractorNil = errors.New("tenant extractor is not configured")
+ ErrInvalidTenantID = errors.New("invalid tenant ID format")
+ ErrContextNotFound = errors.New("context not found")
+ ErrContextNotOwned = errors.New("context does not belong to the requesting tenant")
+ ErrContextAccessDenied = errors.New("access to context denied")
+ ErrContextNotActive = errors.New("context is not active")
+ ErrContextLookupFailed = errors.New("context lookup failed")
+)
+
+// ErrVerifierNotConfigured indicates that no ownership verifier was provided.
+var ErrVerifierNotConfigured = errors.New("ownership verifier is not configured")
+
+// ErrLookupFailed indicates an ownership lookup failed unexpectedly.
+var ErrLookupFailed = errors.New("resource lookup failed")
+
+// Sentinel errors for exception ownership verification.
+//
+// Deprecated: Domain-specific errors should be defined in consuming services.
+// Use RegisterResourceErrors to register custom resource error mappings instead.
+var (
+ ErrMissingExceptionID = errors.New("exception ID is required")
+ ErrInvalidExceptionID = errors.New("exception ID must be a valid UUID")
+ ErrExceptionNotFound = errors.New("exception not found")
+ ErrExceptionAccessDenied = errors.New("access to exception denied")
+)
+
+// Sentinel errors for dispute ownership verification.
+//
+// Deprecated: Domain-specific errors should be defined in consuming services.
+// Use RegisterResourceErrors to register custom resource error mappings instead.
+var (
+ ErrMissingDisputeID = errors.New("dispute ID is required")
+ ErrInvalidDisputeID = errors.New("dispute ID must be a valid UUID")
+ ErrDisputeNotFound = errors.New("dispute not found")
+ ErrDisputeAccessDenied = errors.New("access to dispute denied")
+)
+
+// ResourceErrorMapping defines how a resource type's ownership errors should be classified.
+// Register mappings via RegisterResourceErrors to extend classifyResourceOwnershipError
+// without modifying this library.
+type ResourceErrorMapping struct {
+ // NotFoundErr is matched via errors.Is to detect "not found" responses from verifiers.
+ NotFoundErr error
+ // AccessDeniedErr is matched via errors.Is to detect "access denied" responses.
+ AccessDeniedErr error
+}
+
+// resourceErrorRegistry holds registered resource-specific error mappings.
+// Protected by registryMu for concurrent safety.
+var (
+ resourceErrorRegistry []ResourceErrorMapping
+ registryMu sync.RWMutex
+)
+
+func init() {
+ // Register legacy exception/dispute errors for backward compatibility.
+ resourceErrorRegistry = []ResourceErrorMapping{
+ {NotFoundErr: ErrExceptionNotFound, AccessDeniedErr: ErrExceptionAccessDenied},
+ {NotFoundErr: ErrDisputeNotFound, AccessDeniedErr: ErrDisputeAccessDenied},
+ }
+}
+
+// RegisterResourceErrors adds a resource error mapping to the global registry.
+// Safe for concurrent use. Call at service initialization to register domain-specific
+// error pairs so that classifyResourceOwnershipError can recognize them.
+//
+// Example:
+//
+// func init() {
+// http.RegisterResourceErrors(http.ResourceErrorMapping{
+// NotFoundErr: ErrInvoiceNotFound,
+// AccessDeniedErr: ErrInvoiceAccessDenied,
+// })
+// }
+func RegisterResourceErrors(mapping ResourceErrorMapping) {
+ registryMu.Lock()
+ defer registryMu.Unlock()
+
+ // Detect duplicate registrations by comparing error sentinel pointers.
+ for _, existing := range resourceErrorRegistry {
+ if errors.Is(existing.NotFoundErr, mapping.NotFoundErr) && errors.Is(existing.AccessDeniedErr, mapping.AccessDeniedErr) {
+ return
+ }
+ }
+
+ resourceErrorRegistry = append(resourceErrorRegistry, mapping)
+}
+
+func getIDValue(fiberCtx *fiber.Ctx, idName string, location IDLocation) (string, error) {
+ if fiberCtx == nil {
+ return "", ErrContextNotFound
+ }
+
+ switch location {
+ case IDLocationParam:
+ return fiberCtx.Params(idName), nil
+ case IDLocationQuery:
+ return fiberCtx.Query(idName), nil
+ default:
+ return "", ErrInvalidIDLocation
+ }
+}
diff --git a/commons/net/http/context_nil_error_test.go b/commons/net/http/context_nil_error_test.go
new file mode 100644
index 00000000..7ba79f3c
--- /dev/null
+++ b/commons/net/http/context_nil_error_test.go
@@ -0,0 +1,192 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseAndVerifyTenantScopedID_NilValidationErrorsFallbackToGenericSentinels(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.NewString()
+
+ t.Run("missing id", func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var gotErr error
+
+ app.Get("/contexts", func(c *fiber.Ctx) error {
+ _, _, gotErr = ParseAndVerifyTenantScopedID(
+ c,
+ "context_id",
+ IDLocationQuery,
+ func(ctx context.Context, tenantID, resourceID uuid.UUID) error { return nil },
+ func(_ context.Context) string { return tenantID },
+ nil,
+ nil,
+ nil,
+ )
+ return nil
+ })
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/contexts", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Error(t, gotErr)
+ assert.ErrorIs(t, gotErr, ErrMissingContextID)
+ })
+
+ t.Run("invalid id", func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var gotErr error
+
+ app.Get("/contexts", func(c *fiber.Ctx) error {
+ _, _, gotErr = ParseAndVerifyTenantScopedID(
+ c,
+ "context_id",
+ IDLocationQuery,
+ func(ctx context.Context, tenantID, resourceID uuid.UUID) error { return nil },
+ func(_ context.Context) string { return tenantID },
+ nil,
+ nil,
+ nil,
+ )
+ return nil
+ })
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/contexts?context_id=not-a-uuid", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Error(t, gotErr)
+ assert.ErrorIs(t, gotErr, ErrInvalidContextID)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_NilValidationErrorsFallbackToGenericSentinels(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.NewString()
+
+ t.Run("missing id", func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var gotErr error
+
+ app.Get("/resources", func(c *fiber.Ctx) error {
+ _, _, gotErr = ParseAndVerifyResourceScopedID(
+ c,
+ "resource_id",
+ IDLocationQuery,
+ func(ctx context.Context, resourceID uuid.UUID) error { return nil },
+ func(_ context.Context) string { return tenantID },
+ nil,
+ nil,
+ nil,
+ "resource",
+ )
+ return nil
+ })
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/resources", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Error(t, gotErr)
+ assert.ErrorIs(t, gotErr, ErrMissingResourceID)
+ })
+
+ t.Run("invalid id", func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var gotErr error
+
+ app.Get("/resources", func(c *fiber.Ctx) error {
+ _, _, gotErr = ParseAndVerifyResourceScopedID(
+ c,
+ "resource_id",
+ IDLocationQuery,
+ func(ctx context.Context, resourceID uuid.UUID) error { return nil },
+ func(_ context.Context) string { return tenantID },
+ nil,
+ nil,
+ nil,
+ "resource",
+ )
+ return nil
+ })
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/resources?resource_id=not-a-uuid", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Error(t, gotErr)
+ assert.ErrorIs(t, gotErr, ErrInvalidResourceID)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_DefaultFiberUserContextDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ assert.NotPanics(t, func() {
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTenantIDNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_DefaultFiberUserContextDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ assert.NotPanics(t, func() {
+ runInFiber(t, "/resources/:resourceId", "/resources/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "resourceId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingResourceID,
+ ErrInvalidResourceID,
+ nil,
+ "resource",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTenantIDNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+ })
+}
diff --git a/commons/net/http/context_ownership.go b/commons/net/http/context_ownership.go
new file mode 100644
index 00000000..957593da
--- /dev/null
+++ b/commons/net/http/context_ownership.go
@@ -0,0 +1,204 @@
+package http
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+)
+
+// TenantOwnershipVerifier validates ownership using tenant and resource IDs.
+type TenantOwnershipVerifier func(ctx context.Context, tenantID, resourceID uuid.UUID) error
+
+// ResourceOwnershipVerifier validates ownership using resource ID only.
+type ResourceOwnershipVerifier func(ctx context.Context, resourceID uuid.UUID) error
+
+// ParseAndVerifyTenantScopedID extracts and validates tenant + resource IDs.
+func ParseAndVerifyTenantScopedID(
+ fiberCtx *fiber.Ctx,
+ idName string,
+ location IDLocation,
+ verifier TenantOwnershipVerifier,
+ tenantExtractor TenantExtractor,
+ missingErr error,
+ invalidErr error,
+ accessErr error,
+) (uuid.UUID, uuid.UUID, error) {
+ if fiberCtx == nil {
+ return uuid.Nil, uuid.Nil, ErrContextNotFound
+ }
+
+ if verifier == nil {
+ return uuid.Nil, uuid.Nil, ErrVerifierNotConfigured
+ }
+
+ missingErr = normalizeIDValidationError(missingErr, ErrMissingContextID)
+ invalidErr = normalizeIDValidationError(invalidErr, ErrInvalidContextID)
+
+ resourceID, ctx, tenantID, err := parseTenantAndResourceID(
+ fiberCtx,
+ idName,
+ location,
+ tenantExtractor,
+ missingErr,
+ invalidErr,
+ )
+ if err != nil {
+ return uuid.Nil, uuid.Nil, err
+ }
+
+ if err := verifier(ctx, tenantID, resourceID); err != nil {
+ return uuid.Nil, uuid.Nil, classifyOwnershipError(err, accessErr)
+ }
+
+ return resourceID, tenantID, nil
+}
+
+// ParseAndVerifyResourceScopedID extracts and validates tenant + resource IDs,
+// then verifies resource ownership where tenant is implicit in the verifier.
+func ParseAndVerifyResourceScopedID(
+ fiberCtx *fiber.Ctx,
+ idName string,
+ location IDLocation,
+ verifier ResourceOwnershipVerifier,
+ tenantExtractor TenantExtractor,
+ missingErr error,
+ invalidErr error,
+ accessErr error,
+ verificationLabel string,
+) (uuid.UUID, uuid.UUID, error) {
+ if fiberCtx == nil {
+ return uuid.Nil, uuid.Nil, ErrContextNotFound
+ }
+
+ if verifier == nil {
+ return uuid.Nil, uuid.Nil, ErrVerifierNotConfigured
+ }
+
+ missingErr = normalizeIDValidationError(missingErr, ErrMissingResourceID)
+ invalidErr = normalizeIDValidationError(invalidErr, ErrInvalidResourceID)
+
+ resourceID, ctx, tenantID, err := parseTenantAndResourceID(
+ fiberCtx,
+ idName,
+ location,
+ tenantExtractor,
+ missingErr,
+ invalidErr,
+ )
+ if err != nil {
+ return uuid.Nil, uuid.Nil, err
+ }
+
+ if err := verifier(ctx, resourceID); err != nil {
+ return uuid.Nil, uuid.Nil, classifyResourceOwnershipError(verificationLabel, err, accessErr)
+ }
+
+ return resourceID, tenantID, nil
+}
+
+// parseTenantAndResourceID extracts and validates both tenant and resource UUIDs
+// from the Fiber request context, returning them along with the Go context.
+func parseTenantAndResourceID(
+ fiberCtx *fiber.Ctx,
+ idName string,
+ location IDLocation,
+ tenantExtractor TenantExtractor,
+ missingErr error,
+ invalidErr error,
+) (uuid.UUID, context.Context, uuid.UUID, error) {
+ ctx := fiberCtx.UserContext()
+
+ if tenantExtractor == nil {
+ return uuid.Nil, ctx, uuid.Nil, ErrTenantExtractorNil
+ }
+
+ resourceIDStr, err := getIDValue(fiberCtx, idName, location)
+ if err != nil {
+ return uuid.Nil, ctx, uuid.Nil, err
+ }
+
+ if resourceIDStr == "" {
+ return uuid.Nil, ctx, uuid.Nil, missingErr
+ }
+
+ resourceID, err := uuid.Parse(resourceIDStr)
+ if err != nil {
+ return uuid.Nil, ctx, uuid.Nil, fmt.Errorf("%w: %s", invalidErr, resourceIDStr)
+ }
+
+ tenantIDStr := tenantExtractor(ctx)
+ if tenantIDStr == "" {
+ return uuid.Nil, ctx, uuid.Nil, ErrTenantIDNotFound
+ }
+
+ tenantID, err := uuid.Parse(tenantIDStr)
+ if err != nil {
+ return uuid.Nil, ctx, uuid.Nil, fmt.Errorf("%w: %w", ErrInvalidTenantID, err)
+ }
+
+ return resourceID, ctx, tenantID, nil
+}
+
+func normalizeIDValidationError(err, fallback error) error {
+ if err != nil {
+ return err
+ }
+
+ return fallback
+}
+
+// classifyOwnershipError maps a verifier error to the appropriate sentinel,
+// substituting accessErr when a custom access-denied error is provided.
+func classifyOwnershipError(err, accessErr error) error {
+ switch {
+ case errors.Is(err, ErrContextNotFound):
+ return ErrContextNotFound
+ case errors.Is(err, ErrContextNotOwned):
+ if accessErr != nil {
+ return accessErr
+ }
+
+ return ErrContextNotOwned
+ case errors.Is(err, ErrContextNotActive):
+ return ErrContextNotActive
+ case errors.Is(err, ErrContextAccessDenied):
+ if accessErr != nil {
+ return accessErr
+ }
+
+ return ErrContextAccessDenied
+ default:
+ return fmt.Errorf("%w: %w", ErrContextLookupFailed, err)
+ }
+}
+
+// classifyResourceOwnershipError maps a resource-scoped verifier error to the
+// appropriate sentinel using the global resource error registry.
+// This allows consuming services to register their own domain-specific error
+// mappings without modifying the shared library.
+func classifyResourceOwnershipError(label string, err, accessErr error) error {
+ registryMu.RLock()
+
+ registry := make([]ResourceErrorMapping, len(resourceErrorRegistry))
+ copy(registry, resourceErrorRegistry)
+ registryMu.RUnlock()
+
+ for _, mapping := range registry {
+ if mapping.NotFoundErr != nil && errors.Is(err, mapping.NotFoundErr) {
+ return err
+ }
+
+ if mapping.AccessDeniedErr != nil && errors.Is(err, mapping.AccessDeniedErr) {
+ if accessErr != nil {
+ return accessErr
+ }
+
+ return err
+ }
+ }
+
+ return fmt.Errorf("%s %w: %w", label, ErrLookupFailed, err)
+}
diff --git a/commons/net/http/context_span.go b/commons/net/http/context_span.go
new file mode 100644
index 00000000..49c446dd
--- /dev/null
+++ b/commons/net/http/context_span.go
@@ -0,0 +1,57 @@
+package http
+
+import (
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ "github.com/google/uuid"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+)
+
+// isNilSpan reports whether span is nil, including typed-nil interface values
+// where a concrete nil pointer is stored in a trace.Span interface.
+// This prevents panics when calling methods on a typed-nil span.
+func isNilSpan(span trace.Span) bool {
+ return nilcheck.Interface(span)
+}
+
+// SetHandlerSpanAttributes adds tenant_id and context_id attributes to a trace span.
+func SetHandlerSpanAttributes(span trace.Span, tenantID, contextID uuid.UUID) {
+ if isNilSpan(span) {
+ return
+ }
+
+ span.SetAttributes(attribute.String("tenant.id", tenantID.String()))
+
+ if contextID != uuid.Nil {
+ span.SetAttributes(attribute.String("context.id", contextID.String()))
+ }
+}
+
+// SetTenantSpanAttribute adds tenant_id attribute to a trace span.
+func SetTenantSpanAttribute(span trace.Span, tenantID uuid.UUID) {
+ if isNilSpan(span) {
+ return
+ }
+
+ span.SetAttributes(attribute.String("tenant.id", tenantID.String()))
+}
+
+// SetExceptionSpanAttributes adds tenant_id and exception_id attributes to a trace span.
+func SetExceptionSpanAttributes(span trace.Span, tenantID, exceptionID uuid.UUID) {
+ if isNilSpan(span) {
+ return
+ }
+
+ span.SetAttributes(attribute.String("tenant.id", tenantID.String()))
+ span.SetAttributes(attribute.String("exception.id", exceptionID.String()))
+}
+
+// SetDisputeSpanAttributes adds tenant_id and dispute_id attributes to a trace span.
+func SetDisputeSpanAttributes(span trace.Span, tenantID, disputeID uuid.UUID) {
+ if isNilSpan(span) {
+ return
+ }
+
+ span.SetAttributes(attribute.String("tenant.id", tenantID.String()))
+ span.SetAttributes(attribute.String("dispute.id", disputeID.String()))
+}
diff --git a/commons/net/http/context_test.go b/commons/net/http/context_test.go
new file mode 100644
index 00000000..17bb3992
--- /dev/null
+++ b/commons/net/http/context_test.go
@@ -0,0 +1,1519 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+)
+
+type tenantKey struct{}
+
+func testTenantExtractor(ctx context.Context) string {
+ v, _ := ctx.Value(tenantKey{}).(string)
+ return v
+}
+
+func setupApp(t *testing.T, path string, h fiber.Handler) *fiber.App {
+ t.Helper()
+ app := fiber.New()
+ app.Get(path, h)
+ return app
+}
+
+// runInFiber runs a handler inside a real Fiber context so assertions
+// that depend on Fiber's *fiber.Ctx work correctly.
+func runInFiber(t *testing.T, path, url string, handler fiber.Handler) {
+ t.Helper()
+
+ app := setupApp(t, path, handler)
+ req := httptest.NewRequest(http.MethodGet, url, nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+// ---------------------------------------------------------------------------
+// ParseAndVerifyTenantScopedID
+// ---------------------------------------------------------------------------
+
+func TestParseAndVerifyTenantScopedID_HappyPath_Param(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ contextID := uuid.New()
+
+ app := setupApp(t, "/contexts/:contextId", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ gotContextID, gotTenantID, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(ctx context.Context, tID, resourceID uuid.UUID) error {
+ if tID != tenantID || resourceID != contextID {
+ return errors.New("bad verifier input")
+ }
+ return nil
+ },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.NoError(t, err)
+ assert.Equal(t, contextID, gotContextID)
+ assert.Equal(t, tenantID, gotTenantID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/contexts/"+contextID.String(), nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestParseAndVerifyTenantScopedID_HappyPath_Query(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ contextID := uuid.New()
+
+ app := setupApp(t, "/search", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ gotContextID, gotTenantID, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationQuery,
+ func(ctx context.Context, tID, resourceID uuid.UUID) error {
+ if tID != tenantID || resourceID != contextID {
+ return errors.New("bad verifier input")
+ }
+ return nil
+ },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.NoError(t, err)
+ assert.Equal(t, contextID, gotContextID)
+ assert.Equal(t, tenantID, gotTenantID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/search?contextId="+contextID.String(), nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestParseAndVerifyTenantScopedID_NilFiberContext(t *testing.T) {
+ t.Parallel()
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ nil,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestParseAndVerifyTenantScopedID_NilVerifier(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ nil, // verifier is nil
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrVerifierNotConfigured)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_NilTenantExtractor(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ nil, // tenant extractor is nil
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTenantExtractorNil)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_MissingResourceID_Param(t *testing.T) {
+ t.Parallel()
+
+ // When route param is not defined in the path, Params returns "".
+ runInFiber(t, "/contexts", "/contexts", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrMissingContextID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_MissingResourceID_Query(t *testing.T) {
+ t.Parallel()
+
+ runInFiber(t, "/search", "/search", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationQuery,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrMissingContextID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_InvalidResourceID(t *testing.T) {
+ t.Parallel()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/not-a-uuid", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidContextID)
+ assert.Contains(t, err.Error(), "not-a-uuid")
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_InvalidResourceID_Query(t *testing.T) {
+ t.Parallel()
+
+ runInFiber(t, "/search", "/search?contextId=garbage-value", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationQuery,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidContextID)
+ assert.Contains(t, err.Error(), "garbage-value")
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_EmptyTenantFromExtractor(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+ emptyTenantExtractor := func(ctx context.Context) string { return "" }
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ emptyTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTenantIDNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_InvalidTenantID(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+ badTenantExtractor := func(ctx context.Context) string { return "not-a-valid-uuid" }
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ badTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidTenantID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_InvalidIDLocation(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocation("body"), // invalid location
+ func(context.Context, uuid.UUID, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidIDLocation)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_VerifierReturnsContextNotFound(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return ErrContextNotFound },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_VerifierReturnsContextNotOwned(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return ErrContextNotOwned },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextAccessDenied)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_VerifierReturnsContextNotActive(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return ErrContextNotActive },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotActive)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_VerifierReturnsContextAccessDenied(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return ErrContextAccessDenied },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextAccessDenied)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_VerifierReturnsUnknownError(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return errors.New("db connection lost") },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextLookupFailed)
+ assert.Contains(t, err.Error(), "db connection lost")
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyTenantScopedID_WrappedVerifierError(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ // Wrap ErrContextNotFound in another error to verify errors.Is traversal works.
+ wrappedErr := fmt.Errorf("database issue: %w", ErrContextNotFound)
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID, uuid.UUID) error { return wrappedErr },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// ParseAndVerifyResourceScopedID
+// ---------------------------------------------------------------------------
+
+func TestParseAndVerifyResourceScopedID_HappyPath(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ exceptionID := uuid.New()
+
+ app := setupApp(t, "/exceptions/:exceptionId", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ gotID, gotTenantID, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(ctx context.Context, resourceID uuid.UUID) error {
+ if resourceID != exceptionID {
+ return ErrExceptionNotFound
+ }
+ return nil
+ },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.NoError(t, err)
+ assert.Equal(t, exceptionID, gotID)
+ assert.Equal(t, tenantID, gotTenantID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/exceptions/"+exceptionID.String(), nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestParseAndVerifyResourceScopedID_NilFiberContext(t *testing.T) {
+ t.Parallel()
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ nil,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestParseAndVerifyResourceScopedID_NilVerifier(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ nil,
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrVerifierNotConfigured)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_NilTenantExtractor(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return nil },
+ nil,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTenantExtractorNil)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_MissingResourceID(t *testing.T) {
+ t.Parallel()
+
+ runInFiber(t, "/exceptions", "/exceptions", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrMissingExceptionID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_InvalidResourceID(t *testing.T) {
+ t.Parallel()
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/not-valid", func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidExceptionID)
+ assert.Contains(t, err.Error(), "not-valid")
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_EmptyTenantFromExtractor(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+ emptyTenantExtractor := func(ctx context.Context) string { return "" }
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return nil },
+ emptyTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrTenantIDNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_InvalidTenantID(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+ badExtractor := func(ctx context.Context) string { return "zzz-invalid" }
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return nil },
+ badExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidTenantID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_VerifierReturnsExceptionNotFound(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return ErrExceptionNotFound },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrExceptionNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_VerifierReturnsExceptionAccessDenied(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return ErrExceptionAccessDenied },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrExceptionAccessDenied)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_VerifierReturnsDisputeNotFound(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/disputes/:disputeId", "/disputes/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "disputeId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return ErrDisputeNotFound },
+ testTenantExtractor,
+ ErrMissingDisputeID,
+ ErrInvalidDisputeID,
+ ErrDisputeAccessDenied,
+ "dispute",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrDisputeNotFound)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_VerifierReturnsDisputeAccessDenied(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/disputes/:disputeId", "/disputes/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "disputeId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return ErrDisputeAccessDenied },
+ testTenantExtractor,
+ ErrMissingDisputeID,
+ ErrInvalidDisputeID,
+ ErrDisputeAccessDenied,
+ "dispute",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrDisputeAccessDenied)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_VerifierReturnsUnknownError(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocationParam,
+ func(context.Context, uuid.UUID) error { return errors.New("db exploded") },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrLookupFailed)
+ assert.Contains(t, err.Error(), "exception")
+ assert.Contains(t, err.Error(), "db exploded")
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestParseAndVerifyResourceScopedID_InvalidIDLocation(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.New()
+
+ runInFiber(t, "/exceptions/:exceptionId", "/exceptions/"+resourceID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, uuid.NewString()))
+
+ _, _, err := ParseAndVerifyResourceScopedID(
+ c,
+ "exceptionId",
+ IDLocation("cookie"),
+ func(context.Context, uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionAccessDenied,
+ "exception",
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidIDLocation)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// getIDValue
+// ---------------------------------------------------------------------------
+
+func TestGetIDValue_NilFiberContext(t *testing.T) {
+ t.Parallel()
+
+ _, err := getIDValue(nil, "id", IDLocationParam)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestGetIDValue_Param(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.NewString()
+
+ runInFiber(t, "/items/:id", "/items/"+resourceID, func(c *fiber.Ctx) error {
+ val, err := getIDValue(c, "id", IDLocationParam)
+ require.NoError(t, err)
+ assert.Equal(t, resourceID, val)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestGetIDValue_Query(t *testing.T) {
+ t.Parallel()
+
+ resourceID := uuid.NewString()
+
+ runInFiber(t, "/items", "/items?id="+resourceID, func(c *fiber.Ctx) error {
+ val, err := getIDValue(c, "id", IDLocationQuery)
+ require.NoError(t, err)
+ assert.Equal(t, resourceID, val)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestGetIDValue_InvalidLocation(t *testing.T) {
+ t.Parallel()
+
+ runInFiber(t, "/items", "/items", func(c *fiber.Ctx) error {
+ _, err := getIDValue(c, "id", IDLocation("header"))
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidIDLocation)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestGetIDValue_EmptyLocationString(t *testing.T) {
+ t.Parallel()
+
+ runInFiber(t, "/items", "/items", func(c *fiber.Ctx) error {
+ _, err := getIDValue(c, "id", IDLocation(""))
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidIDLocation)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+func TestGetIDValue_SpecialCharactersInQuery(t *testing.T) {
+ t.Parallel()
+
+ // Fiber URL-decodes query params, so %20 becomes a space.
+ runInFiber(t, "/items", "/items?id=hello%20world", func(c *fiber.Ctx) error {
+ val, err := getIDValue(c, "id", IDLocationQuery)
+ require.NoError(t, err)
+ assert.Equal(t, "hello world", val)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// classifyOwnershipError
+// ---------------------------------------------------------------------------
+
+func TestClassifyOwnershipError_AllSentinels(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input error
+ expected error
+ }{
+ {"context not found", ErrContextNotFound, ErrContextNotFound},
+ {"context not owned", ErrContextNotOwned, ErrContextNotOwned},
+ {"context not active", ErrContextNotActive, ErrContextNotActive},
+ {"context access denied", ErrContextAccessDenied, ErrContextAccessDenied},
+ {"wrapped context not found", fmt.Errorf("db: %w", ErrContextNotFound), ErrContextNotFound},
+ {"wrapped context not owned", fmt.Errorf("db: %w", ErrContextNotOwned), ErrContextNotOwned},
+ {"wrapped context not active", fmt.Errorf("db: %w", ErrContextNotActive), ErrContextNotActive},
+ {"wrapped context access denied", fmt.Errorf("db: %w", ErrContextAccessDenied), ErrContextAccessDenied},
+ {"unknown error", errors.New("something else"), ErrContextLookupFailed},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := classifyOwnershipError(tc.input, nil)
+ assert.ErrorIs(t, err, tc.expected)
+ })
+ }
+}
+
+func TestClassifyOwnershipError_UnknownErrorPreservesOriginal(t *testing.T) {
+ t.Parallel()
+
+ originalErr := errors.New("network timeout")
+ err := classifyOwnershipError(originalErr, nil)
+ assert.ErrorIs(t, err, ErrContextLookupFailed)
+ assert.ErrorIs(t, err, originalErr)
+}
+
+// ---------------------------------------------------------------------------
+// classifyOwnershipError with non-nil accessErr
+// ---------------------------------------------------------------------------
+
+func TestClassifyOwnershipError_WithAccessErr_NotOwned(t *testing.T) {
+ t.Parallel()
+
+ customErr := errors.New("custom access denied")
+ err := classifyOwnershipError(ErrContextNotOwned, customErr)
+ assert.Equal(t, customErr, err)
+}
+
+func TestClassifyOwnershipError_WithAccessErr_AccessDenied(t *testing.T) {
+ t.Parallel()
+
+ customErr := errors.New("custom forbidden")
+ err := classifyOwnershipError(ErrContextAccessDenied, customErr)
+ assert.Equal(t, customErr, err)
+}
+
+func TestClassifyOwnershipError_WithAccessErr_NotFound(t *testing.T) {
+ t.Parallel()
+
+ // For not-found, accessErr is irrelevant -- returns ErrContextNotFound
+ customErr := errors.New("custom err")
+ err := classifyOwnershipError(ErrContextNotFound, customErr)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestClassifyOwnershipError_WithAccessErr_NotActive(t *testing.T) {
+ t.Parallel()
+
+ // For not-active, accessErr is irrelevant
+ customErr := errors.New("custom err")
+ err := classifyOwnershipError(ErrContextNotActive, customErr)
+ assert.ErrorIs(t, err, ErrContextNotActive)
+}
+
+func TestClassifyOwnershipError_WithAccessErr_Unknown(t *testing.T) {
+ t.Parallel()
+
+ // For unknown errors, accessErr is irrelevant
+ customErr := errors.New("custom err")
+ err := classifyOwnershipError(errors.New("db timeout"), customErr)
+ assert.ErrorIs(t, err, ErrContextLookupFailed)
+}
+
+// ---------------------------------------------------------------------------
+// classifyResourceOwnershipError
+// ---------------------------------------------------------------------------
+
+func TestClassifyResourceOwnershipError_AllSentinels(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ label string
+ input error
+ expected error
+ }{
+ {"exception not found", "exception", ErrExceptionNotFound, ErrExceptionNotFound},
+ {"exception access denied", "exception", ErrExceptionAccessDenied, ErrExceptionAccessDenied},
+ {"dispute not found", "dispute", ErrDisputeNotFound, ErrDisputeNotFound},
+ {"dispute access denied", "dispute", ErrDisputeAccessDenied, ErrDisputeAccessDenied},
+ {"unknown error", "exception", errors.New("oops"), ErrLookupFailed},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := classifyResourceOwnershipError(tc.label, tc.input, nil)
+ assert.ErrorIs(t, err, tc.expected)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// ResourceErrorMapping registry
+// ---------------------------------------------------------------------------
+
+func snapshotResourceRegistry() []ResourceErrorMapping {
+ registryMu.RLock()
+ defer registryMu.RUnlock()
+
+ snapshot := make([]ResourceErrorMapping, len(resourceErrorRegistry))
+ copy(snapshot, resourceErrorRegistry)
+
+ return snapshot
+}
+
+func restoreResourceRegistry(snapshot []ResourceErrorMapping) {
+ registryMu.Lock()
+ defer registryMu.Unlock()
+
+ resourceErrorRegistry = make([]ResourceErrorMapping, len(snapshot))
+ copy(resourceErrorRegistry, snapshot)
+}
+
+func TestRegisterResourceErrors_CustomMapping(t *testing.T) {
+ original := snapshotResourceRegistry()
+ t.Cleanup(func() {
+ restoreResourceRegistry(original)
+ })
+
+ // Define custom resource errors
+ errInvoiceNotFound := errors.New("invoice not found")
+ errInvoiceAccessDenied := errors.New("invoice access denied")
+
+ // Register custom mappings for this test.
+ RegisterResourceErrors(ResourceErrorMapping{
+ NotFoundErr: errInvoiceNotFound,
+ AccessDeniedErr: errInvoiceAccessDenied,
+ })
+
+ // classifyResourceOwnershipError should recognize the new mapping
+ err := classifyResourceOwnershipError("invoice", errInvoiceNotFound, nil)
+ assert.ErrorIs(t, err, errInvoiceNotFound)
+
+ err = classifyResourceOwnershipError("invoice", errInvoiceAccessDenied, nil)
+ assert.ErrorIs(t, err, errInvoiceAccessDenied)
+}
+
+func TestClassifyResourceOwnershipError_WithAccessErr_ReturnsAccessErr(t *testing.T) {
+ t.Parallel()
+
+ customAccessErr := errors.New("custom forbidden for exception")
+
+ // When verifier returns ErrExceptionAccessDenied and we provide a custom accessErr,
+ // classifyResourceOwnershipError should return the custom accessErr.
+ err := classifyResourceOwnershipError("exception", ErrExceptionAccessDenied, customAccessErr)
+ assert.Equal(t, customAccessErr, err)
+}
+
+func TestClassifyResourceOwnershipError_WithAccessErr_NotFoundIgnoresAccessErr(t *testing.T) {
+ t.Parallel()
+
+ customAccessErr := errors.New("custom denied")
+
+ // For not-found errors, accessErr is ignored -- the original error is returned.
+ err := classifyResourceOwnershipError("exception", ErrExceptionNotFound, customAccessErr)
+ assert.ErrorIs(t, err, ErrExceptionNotFound)
+}
+
+func TestClassifyResourceOwnershipError_LabelInMessage(t *testing.T) {
+ t.Parallel()
+
+ err := classifyResourceOwnershipError("my_resource", errors.New("db failure"), nil)
+ assert.ErrorIs(t, err, ErrLookupFailed)
+ assert.Contains(t, err.Error(), "my_resource")
+ assert.Contains(t, err.Error(), "db failure")
+}
+
+// ---------------------------------------------------------------------------
+// SetHandlerSpanAttributes
+// ---------------------------------------------------------------------------
+
+type mockSpan struct {
+ trace.Span
+ attrs []attribute.KeyValue
+}
+
+func (m *mockSpan) SetAttributes(kv ...attribute.KeyValue) {
+ m.attrs = append(m.attrs, kv...)
+}
+
+func (m *mockSpan) findAttr(key string) (attribute.KeyValue, bool) {
+ for _, a := range m.attrs {
+ if string(a.Key) == key {
+ return a, true
+ }
+ }
+ return attribute.KeyValue{}, false
+}
+
+func TestSetHandlerSpanAttributes_AllFields(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ contextID := uuid.New()
+
+ span := &mockSpan{}
+ SetHandlerSpanAttributes(span, tenantID, contextID)
+
+ tenantAttr, ok := span.findAttr("tenant.id")
+ require.True(t, ok)
+ assert.Equal(t, tenantID.String(), tenantAttr.Value.AsString())
+
+ ctxAttr, ok := span.findAttr("context.id")
+ require.True(t, ok)
+ assert.Equal(t, contextID.String(), ctxAttr.Value.AsString())
+}
+
+func TestSetHandlerSpanAttributes_NilContextID(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+
+ span := &mockSpan{}
+ SetHandlerSpanAttributes(span, tenantID, uuid.Nil)
+
+ tenantAttr, ok := span.findAttr("tenant.id")
+ require.True(t, ok)
+ assert.Equal(t, tenantID.String(), tenantAttr.Value.AsString())
+
+ _, ok = span.findAttr("context.id")
+ assert.False(t, ok, "context.id should not be set when contextID is uuid.Nil")
+}
+
+func TestSetHandlerSpanAttributes_NilSpan(t *testing.T) {
+ t.Parallel()
+
+ // Should not panic.
+ SetHandlerSpanAttributes(nil, uuid.New(), uuid.New())
+}
+
+func TestSetHandlerSpanAttributes_TypedNilSpan(t *testing.T) {
+ t.Parallel()
+
+ var typedNil *mockSpan
+ var span trace.Span = typedNil
+
+ assert.NotPanics(t, func() {
+ SetHandlerSpanAttributes(span, uuid.New(), uuid.New())
+ })
+}
+
+// ---------------------------------------------------------------------------
+// SetTenantSpanAttribute
+// ---------------------------------------------------------------------------
+
+func TestSetTenantSpanAttribute_HappyPath(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ span := &mockSpan{}
+
+ SetTenantSpanAttribute(span, tenantID)
+
+ tenantAttr, ok := span.findAttr("tenant.id")
+ require.True(t, ok)
+ assert.Equal(t, tenantID.String(), tenantAttr.Value.AsString())
+}
+
+func TestSetTenantSpanAttribute_NilSpan(t *testing.T) {
+ t.Parallel()
+
+ // Should not panic.
+ SetTenantSpanAttribute(nil, uuid.New())
+}
+
+func TestSetTenantSpanAttribute_TypedNilSpan(t *testing.T) {
+ t.Parallel()
+
+ var typedNil *mockSpan
+ var span trace.Span = typedNil
+
+ assert.NotPanics(t, func() {
+ SetTenantSpanAttribute(span, uuid.New())
+ })
+}
+
+// ---------------------------------------------------------------------------
+// SetExceptionSpanAttributes
+// ---------------------------------------------------------------------------
+
+func TestSetExceptionSpanAttributes_HappyPath(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ exceptionID := uuid.New()
+ span := &mockSpan{}
+
+ SetExceptionSpanAttributes(span, tenantID, exceptionID)
+
+ tenantAttr, ok := span.findAttr("tenant.id")
+ require.True(t, ok)
+ assert.Equal(t, tenantID.String(), tenantAttr.Value.AsString())
+
+ exAttr, ok := span.findAttr("exception.id")
+ require.True(t, ok)
+ assert.Equal(t, exceptionID.String(), exAttr.Value.AsString())
+}
+
+func TestSetExceptionSpanAttributes_NilSpan(t *testing.T) {
+ t.Parallel()
+
+ // Should not panic.
+ SetExceptionSpanAttributes(nil, uuid.New(), uuid.New())
+}
+
+func TestSetExceptionSpanAttributes_TypedNilSpan(t *testing.T) {
+ t.Parallel()
+
+ var typedNil *mockSpan
+ var span trace.Span = typedNil
+
+ assert.NotPanics(t, func() {
+ SetExceptionSpanAttributes(span, uuid.New(), uuid.New())
+ })
+}
+
+// ---------------------------------------------------------------------------
+// SetDisputeSpanAttributes
+// ---------------------------------------------------------------------------
+
+func TestSetDisputeSpanAttributes_HappyPath(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ disputeID := uuid.New()
+ span := &mockSpan{}
+
+ SetDisputeSpanAttributes(span, tenantID, disputeID)
+
+ tenantAttr, ok := span.findAttr("tenant.id")
+ require.True(t, ok)
+ assert.Equal(t, tenantID.String(), tenantAttr.Value.AsString())
+
+ dAttr, ok := span.findAttr("dispute.id")
+ require.True(t, ok)
+ assert.Equal(t, disputeID.String(), dAttr.Value.AsString())
+}
+
+func TestSetDisputeSpanAttributes_NilSpan(t *testing.T) {
+ t.Parallel()
+
+ // Should not panic.
+ SetDisputeSpanAttributes(nil, uuid.New(), uuid.New())
+}
+
+func TestSetDisputeSpanAttributes_TypedNilSpan(t *testing.T) {
+ t.Parallel()
+
+ var typedNil *mockSpan
+ var span trace.Span = typedNil
+
+ assert.NotPanics(t, func() {
+ SetDisputeSpanAttributes(span, uuid.New(), uuid.New())
+ })
+}
+
+// ---------------------------------------------------------------------------
+// IDLocation constants
+// ---------------------------------------------------------------------------
+
+func TestIDLocationConstants(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, IDLocation("param"), IDLocationParam)
+ assert.Equal(t, IDLocation("query"), IDLocationQuery)
+}
+
+// ---------------------------------------------------------------------------
+// Sentinel error identity
+// ---------------------------------------------------------------------------
+
+func TestSentinelErrorIdentity(t *testing.T) {
+ t.Parallel()
+
+ // Ensure all sentinel errors are distinct.
+ sentinels := []error{
+ ErrInvalidIDLocation,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrMissingResourceID,
+ ErrInvalidResourceID,
+ ErrTenantIDNotFound,
+ ErrTenantExtractorNil,
+ ErrInvalidTenantID,
+ ErrContextNotFound,
+ ErrContextNotOwned,
+ ErrContextAccessDenied,
+ ErrContextNotActive,
+ ErrContextLookupFailed,
+ ErrLookupFailed,
+ ErrMissingExceptionID,
+ ErrInvalidExceptionID,
+ ErrExceptionNotFound,
+ ErrExceptionAccessDenied,
+ ErrMissingDisputeID,
+ ErrInvalidDisputeID,
+ ErrDisputeNotFound,
+ ErrDisputeAccessDenied,
+ }
+
+ for i, a := range sentinels {
+ for j, b := range sentinels {
+ if i != j {
+ assert.NotEqual(t, a.Error(), b.Error(),
+ "sentinel errors %d and %d have identical messages: %q", i, j, a.Error())
+ }
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Edge: special characters in param values
+// ---------------------------------------------------------------------------
+
+func TestParseAndVerifyTenantScopedID_UUIDWithUpperCase(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ contextID := uuid.New()
+ // UUID strings are case-insensitive; pass an uppercase version.
+ upperContextID := strings.ToUpper(contextID.String())
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+upperContextID, func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ gotContextID, gotTenantID, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(ctx context.Context, tID, resourceID uuid.UUID) error { return nil },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.NoError(t, err)
+ assert.Equal(t, contextID, gotContextID)
+ assert.Equal(t, tenantID, gotTenantID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// parseTenantAndResourceID returns correct context.Context to verifier
+// ---------------------------------------------------------------------------
+
+func TestParseTenantAndResourceID_ContextPassedToVerifier(t *testing.T) {
+ t.Parallel()
+
+ type ctxValueKey struct{}
+ tenantID := uuid.New()
+ resourceID := uuid.New()
+ contextValue := "custom-value-12345"
+
+ runInFiber(t, "/contexts/:contextId", "/contexts/"+resourceID.String(), func(c *fiber.Ctx) error {
+ userCtx := context.WithValue(context.Background(), tenantKey{}, tenantID.String())
+ userCtx = context.WithValue(userCtx, ctxValueKey{}, contextValue)
+ c.SetUserContext(userCtx)
+
+ _, _, err := ParseAndVerifyTenantScopedID(
+ c,
+ "contextId",
+ IDLocationParam,
+ func(ctx context.Context, tID, resID uuid.UUID) error {
+ // Verify the context passed to verifier contains our custom value.
+ val, ok := ctx.Value(ctxValueKey{}).(string)
+ require.True(t, ok)
+ assert.Equal(t, contextValue, val)
+ return nil
+ },
+ testTenantExtractor,
+ ErrMissingContextID,
+ ErrInvalidContextID,
+ ErrContextAccessDenied,
+ )
+ require.NoError(t, err)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Dispute-specific scoped ID tests
+// ---------------------------------------------------------------------------
+
+func TestParseAndVerifyResourceScopedID_DisputeHappyPath(t *testing.T) {
+ t.Parallel()
+
+ tenantID := uuid.New()
+ disputeID := uuid.New()
+
+ runInFiber(t, "/disputes/:disputeId", "/disputes/"+disputeID.String(), func(c *fiber.Ctx) error {
+ c.SetUserContext(context.WithValue(context.Background(), tenantKey{}, tenantID.String()))
+
+ gotID, gotTenantID, err := ParseAndVerifyResourceScopedID(
+ c,
+ "disputeId",
+ IDLocationParam,
+ func(ctx context.Context, resourceID uuid.UUID) error {
+ if resourceID != disputeID {
+ return ErrDisputeNotFound
+ }
+ return nil
+ },
+ testTenantExtractor,
+ ErrMissingDisputeID,
+ ErrInvalidDisputeID,
+ ErrDisputeAccessDenied,
+ "dispute",
+ )
+ require.NoError(t, err)
+ assert.Equal(t, disputeID, gotID)
+ assert.Equal(t, tenantID, gotTenantID)
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+}
diff --git a/commons/net/http/cursor.go b/commons/net/http/cursor.go
index 69652a2a..81d2862a 100644
--- a/commons/net/http/cursor.go
+++ b/commons/net/http/cursor.go
@@ -1,170 +1,161 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
"encoding/base64"
"encoding/json"
- "strings"
+ "errors"
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+)
- "github.com/LerianStudio/lib-commons/v2/commons"
- "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/Masterminds/squirrel"
+const (
+ // CursorDirectionNext is the cursor direction for forward navigation.
+ CursorDirectionNext = "next"
+ // CursorDirectionPrev is the cursor direction for backward navigation.
+ CursorDirectionPrev = "prev"
)
+// ErrInvalidCursorDirection indicates an invalid next/prev cursor direction.
+var ErrInvalidCursorDirection = errors.New("invalid cursor direction")
+
+// Cursor is the cursor contract for keyset navigation.
type Cursor struct {
- ID string `json:"id"`
- PointsNext bool `json:"points_next"`
+ ID string `json:"id"`
+ Direction string `json:"direction"`
}
+// CursorPagination carries encoded next and previous cursors.
type CursorPagination struct {
Next string `json:"next"`
Prev string `json:"prev"`
}
-func CreateCursor(id string, pointsNext bool) Cursor {
- return Cursor{
- ID: id,
- PointsNext: pointsNext,
+// EncodeCursor encodes a Cursor as a base64 JSON token.
+func EncodeCursor(cursor Cursor) (string, error) {
+ if cursor.ID == "" {
+ return "", ErrInvalidCursor
+ }
+
+ if cursor.Direction != CursorDirectionNext && cursor.Direction != CursorDirectionPrev {
+ return "", ErrInvalidCursorDirection
+ }
+
+ cursorBytes, err := json.Marshal(cursor)
+ if err != nil {
+ return "", err
}
+
+ return base64.StdEncoding.EncodeToString(cursorBytes), nil
}
+// DecodeCursor decodes a base64 JSON cursor token and validates it.
func DecodeCursor(cursor string) (Cursor, error) {
decodedCursor, err := base64.StdEncoding.DecodeString(cursor)
if err != nil {
- return Cursor{}, err
+ return Cursor{}, fmt.Errorf("%w: decode failed: %w", ErrInvalidCursor, err)
}
var cur Cursor
if err := json.Unmarshal(decodedCursor, &cur); err != nil {
- return Cursor{}, err
+ return Cursor{}, fmt.Errorf("%w: unmarshal failed: %w", ErrInvalidCursor, err)
}
- return cur, nil
-}
-
-func ApplyCursorPagination(
- findAll squirrel.SelectBuilder,
- decodedCursor Cursor,
- orderDirection string,
- limit int,
- tableAlias ...string,
-) (squirrel.SelectBuilder, string) {
- var operator string
+ if cur.ID == "" {
+ return Cursor{}, fmt.Errorf("%w: missing id", ErrInvalidCursor)
+ }
- var actualOrder string
+ if cur.Direction != CursorDirectionNext && cur.Direction != CursorDirectionPrev {
+ return Cursor{}, ErrInvalidCursorDirection
+ }
- ascOrder := strings.ToUpper(string(constant.Asc))
- descOrder := strings.ToUpper(string(constant.Desc))
+ return cur, nil
+}
- ID := "id"
- if len(tableAlias) > 0 {
- ID = tableAlias[0] + "." + ID
- }
+// CursorDirectionRules returns the comparison operator and effective order.
+func CursorDirectionRules(requestedSortDirection, cursorDirection string) (operator, effectiveOrder string, err error) {
+ order := ValidateSortDirection(requestedSortDirection)
- if decodedCursor.ID != "" {
- if decodedCursor.PointsNext {
- if orderDirection == ascOrder {
- operator = ">"
- actualOrder = ascOrder
- } else {
- operator = "<"
- actualOrder = descOrder
- }
- } else {
- if orderDirection == ascOrder {
- operator = "<"
- actualOrder = descOrder
- } else {
- operator = ">"
- actualOrder = ascOrder
- }
+ switch cursorDirection {
+ case CursorDirectionNext:
+ if order == cn.SortDirASC {
+ return ">", cn.SortDirASC, nil
}
- whereClause := squirrel.Expr(ID+" "+operator+" ?", decodedCursor.ID)
- findAll = findAll.Where(whereClause).OrderBy(ID + " " + actualOrder)
+ return "<", cn.SortDirDESC, nil
+ case CursorDirectionPrev:
+ if order == cn.SortDirASC {
+ return "<", cn.SortDirDESC, nil
+ }
- return findAll.Limit(commons.SafeIntToUint64(limit + 1)), actualOrder
+ return ">", cn.SortDirASC, nil
+ default:
+ return "", "", ErrInvalidCursorDirection
}
-
- findAll = findAll.OrderBy(ID + " " + orderDirection)
-
- return findAll.Limit(commons.SafeIntToUint64(limit + 1)), orderDirection
}
+// PaginateRecords slices records to the requested page and normalizes prev direction order.
func PaginateRecords[T any](
isFirstPage bool,
hasPagination bool,
- pointsNext bool,
+ cursorDirection string,
items []T,
limit int,
- orderUsed string,
) []T {
if !hasPagination {
return items
}
- paginated := items[:limit]
+ if limit < 0 {
+ limit = 0
+ }
+
+ if limit > len(items) {
+ limit = len(items)
+ }
+
+ paginated := make([]T, limit)
+ copy(paginated, items[:limit])
- if !pointsNext {
+ if !isFirstPage && cursorDirection == CursorDirectionPrev {
return commons.Reverse(paginated)
}
return paginated
}
+// CalculateCursor builds next/prev cursor tokens for a paged record set.
func CalculateCursor(
- isFirstPage, hasPagination, pointsNext bool,
+ isFirstPage, hasPagination bool,
+ cursorDirection string,
firstItemID, lastItemID string,
) (CursorPagination, error) {
var pagination CursorPagination
- if pointsNext {
- if hasPagination {
- next := CreateCursor(lastItemID, true)
-
- cursorBytes, err := json.Marshal(next)
- if err != nil {
- return CursorPagination{}, err
- }
-
- pagination.Next = base64.StdEncoding.EncodeToString(cursorBytes)
- }
-
- if !isFirstPage {
- prev := CreateCursor(firstItemID, false)
+ if cursorDirection != CursorDirectionNext && cursorDirection != CursorDirectionPrev {
+ return CursorPagination{}, ErrInvalidCursorDirection
+ }
- cursorBytes, err := json.Marshal(prev)
- if err != nil {
- return CursorPagination{}, err
- }
+ hasNext := (cursorDirection == CursorDirectionNext && hasPagination) ||
+ (cursorDirection == CursorDirectionPrev && (hasPagination || isFirstPage))
- pagination.Prev = base64.StdEncoding.EncodeToString(cursorBytes)
+ if hasNext {
+ next, err := EncodeCursor(Cursor{ID: lastItemID, Direction: CursorDirectionNext})
+ if err != nil {
+ return CursorPagination{}, err
}
- } else {
- if hasPagination || isFirstPage {
- next := CreateCursor(lastItemID, true)
- cursorBytesNext, err := json.Marshal(next)
- if err != nil {
- return CursorPagination{}, err
- }
+ pagination.Next = next
+ }
- pagination.Next = base64.StdEncoding.EncodeToString(cursorBytesNext)
+ if !isFirstPage {
+ prev, err := EncodeCursor(Cursor{ID: firstItemID, Direction: CursorDirectionPrev})
+ if err != nil {
+ return CursorPagination{}, err
}
- if !isFirstPage {
- prev := CreateCursor(firstItemID, false)
-
- cursorBytesPrev, err := json.Marshal(prev)
- if err != nil {
- return CursorPagination{}, err
- }
-
- pagination.Prev = base64.StdEncoding.EncodeToString(cursorBytesPrev)
- }
+ pagination.Prev = prev
}
return pagination, nil
diff --git a/commons/net/http/cursor_example_test.go b/commons/net/http/cursor_example_test.go
new file mode 100644
index 00000000..1a0359a0
--- /dev/null
+++ b/commons/net/http/cursor_example_test.go
@@ -0,0 +1,35 @@
+//go:build unit
+
+package http_test
+
+import (
+ "fmt"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ uhttp "github.com/LerianStudio/lib-commons/v4/commons/net/http"
+)
+
+func ExampleEncodeCursor() {
+ encoded, err := uhttp.EncodeCursor(uhttp.Cursor{ID: "acc_01", Direction: uhttp.CursorDirectionNext})
+ if err != nil {
+ fmt.Println("encode error")
+ return
+ }
+
+ decoded, err := uhttp.DecodeCursor(encoded)
+ if err != nil {
+ fmt.Println("decode error")
+ return
+ }
+
+ op, order, err := uhttp.CursorDirectionRules(cn.SortDirASC, decoded.Direction)
+
+ fmt.Println(err == nil)
+ fmt.Println(decoded.ID, decoded.Direction)
+ fmt.Println(op, order)
+
+ // Output:
+ // true
+ // acc_01 next
+ // > ASC
+}
diff --git a/commons/net/http/cursor_test.go b/commons/net/http/cursor_test.go
index 01c38f6f..97eb111d 100644
--- a/commons/net/http/cursor_test.go
+++ b/commons/net/http/cursor_test.go
@@ -1,6 +1,4 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package http
@@ -8,756 +6,650 @@ import (
"encoding/base64"
"encoding/json"
"strings"
+ "sync"
"testing"
- "time"
- "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/Masterminds/squirrel"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-func TestDecodeCursor(t *testing.T) {
- cursor := CreateCursor("test_id", true)
- encodedCursor := base64.StdEncoding.EncodeToString([]byte(`{"id":"test_id","points_next":true}`))
+// ---------------------------------------------------------------------------
+// EncodeCursor
+// ---------------------------------------------------------------------------
- decodedCursor, err := DecodeCursor(encodedCursor)
- assert.NoError(t, err)
- assert.Equal(t, cursor, decodedCursor)
-}
-
-func TestApplyCursorPaginationDesc(t *testing.T) {
- query := squirrel.Select("*").From("test_table")
- decodedCursor := CreateCursor("test_id", true)
- orderDirection := strings.ToUpper(string(constant.Desc))
- limit := 10
+func TestEncodeCursor_HappyPath_Next(t *testing.T) {
+ t.Parallel()
- resultQuery, resultOrder := ApplyCursorPagination(query, decodedCursor, orderDirection, limit)
- sqlResult, _, _ := resultQuery.ToSql()
+ id := uuid.NewString()
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionNext})
+ require.NoError(t, err)
+ assert.NotEmpty(t, encoded)
- expectedQuery := query.Where(squirrel.Expr("id < ?", "test_id")).OrderBy("id DESC").Limit(uint64(limit + 1))
- sqlExpected, _, _ := expectedQuery.ToSql()
+ // Verify it is valid base64.
+ raw, err := base64.StdEncoding.DecodeString(encoded)
+ require.NoError(t, err)
- assert.Equal(t, sqlExpected, sqlResult)
- assert.Equal(t, "DESC", resultOrder)
+ var cur Cursor
+ require.NoError(t, json.Unmarshal(raw, &cur))
+ assert.Equal(t, id, cur.ID)
+ assert.Equal(t, CursorDirectionNext, cur.Direction)
}
-func TestApplyCursorPaginationNoCursor(t *testing.T) {
- query := squirrel.Select("*").From("test_table")
- decodedCursor := CreateCursor("", true)
- orderDirection := strings.ToUpper(string(constant.Asc))
- limit := 10
+func TestEncodeCursor_HappyPath_Prev(t *testing.T) {
+ t.Parallel()
- resultQuery, resultOrder := ApplyCursorPagination(query, decodedCursor, orderDirection, limit)
- sqlResult, _, _ := resultQuery.ToSql()
+ id := uuid.NewString()
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionPrev})
+ require.NoError(t, err)
+ assert.NotEmpty(t, encoded)
- expectedQuery := query.OrderBy("id ASC").Limit(uint64(limit + 1))
- sqlExpected, _, _ := expectedQuery.ToSql()
+ decoded, err := DecodeCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, id, decoded.ID)
+ assert.Equal(t, CursorDirectionPrev, decoded.Direction)
+}
+
+func TestEncodeCursor_EmptyID(t *testing.T) {
+ t.Parallel()
- assert.Equal(t, sqlExpected, sqlResult)
- assert.Equal(t, "ASC", resultOrder)
+ _, err := EncodeCursor(Cursor{ID: "", Direction: CursorDirectionNext})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
}
-func TestApplyCursorPaginationPrevPage(t *testing.T) {
- query := squirrel.Select("*").From("test_table")
- decodedCursor := CreateCursor("test_id", false)
- orderDirection := strings.ToUpper(string(constant.Asc))
- limit := 10
+func TestEncodeCursor_InvalidDirection(t *testing.T) {
+ t.Parallel()
- resultQuery, resultOrder := ApplyCursorPagination(query, decodedCursor, orderDirection, limit)
- sqlResult, _, _ := resultQuery.ToSql()
+ _, err := EncodeCursor(Cursor{ID: "some-id", Direction: "sideways"})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
+}
- expectedQuery := query.Where(squirrel.Expr("id < ?", "test_id")).OrderBy("id DESC").Limit(uint64(limit + 1))
- sqlExpected, _, _ := expectedQuery.ToSql()
+func TestEncodeCursor_EmptyDirection(t *testing.T) {
+ t.Parallel()
- assert.Equal(t, sqlExpected, sqlResult)
- assert.Equal(t, "DESC", resultOrder)
+ _, err := EncodeCursor(Cursor{ID: "some-id", Direction: ""})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
}
-func TestApplyCursorPaginationPrevPageDesc(t *testing.T) {
- query := squirrel.Select("*").From("test_table")
- decodedCursor := CreateCursor("test_id", false)
- orderDirection := strings.ToUpper(string(constant.Desc))
- limit := 10
+// ---------------------------------------------------------------------------
+// DecodeCursor
+// ---------------------------------------------------------------------------
- resultQuery, resultOrder := ApplyCursorPagination(query, decodedCursor, orderDirection, limit)
- sqlResult, _, _ := resultQuery.ToSql()
+func TestDecodeCursor_HappyPath_RoundTrip(t *testing.T) {
+ t.Parallel()
+
+ id := uuid.NewString()
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionNext})
+ require.NoError(t, err)
- expectedQuery := query.Where(squirrel.Expr("id > ?", "test_id")).OrderBy("id ASC").Limit(uint64(limit + 1))
- sqlExpected, _, _ := expectedQuery.ToSql()
+ decoded, err := DecodeCursor(encoded)
+ require.NoError(t, err)
- assert.Equal(t, sqlExpected, sqlResult)
- assert.Equal(t, "ASC", resultOrder)
+ assert.Equal(t, id, decoded.ID)
+ assert.Equal(t, CursorDirectionNext, decoded.Direction)
}
-func TestPaginateRecords(t *testing.T) {
- limit := 3
+func TestDecodeCursor_InvalidBase64(t *testing.T) {
+ t.Parallel()
- items1 := []int{1, 2, 3, 4, 5}
- result := PaginateRecords(true, true, true, items1, limit, "ASC")
- assert.Equal(t, []int{1, 2, 3}, result)
+ _, err := DecodeCursor("not-valid-base64!!!")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Contains(t, err.Error(), "decode failed")
+}
- items2 := []int{1, 2, 3, 4, 5}
- result = PaginateRecords(false, true, true, items2, limit, "ASC")
- assert.Equal(t, []int{1, 2, 3}, result)
+func TestDecodeCursor_ValidBase64InvalidJSON(t *testing.T) {
+ t.Parallel()
- items3 := []int{1, 2, 3, 4, 5}
- result = PaginateRecords(false, true, false, items3, limit, "ASC")
- assert.Equal(t, []int{3, 2, 1}, result)
+ encoded := base64.StdEncoding.EncodeToString([]byte("not json at all"))
+ _, err := DecodeCursor(encoded)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Contains(t, err.Error(), "unmarshal failed")
+}
- items4 := []int{1, 2, 3, 4, 5}
- result = PaginateRecords(true, true, true, items4, limit, "DESC")
- assert.Equal(t, []int{1, 2, 3}, result)
+func TestDecodeCursor_MissingID(t *testing.T) {
+ t.Parallel()
- items5 := []int{1, 2, 3, 4, 5}
- result = PaginateRecords(false, true, true, items5, limit, "DESC")
- assert.Equal(t, []int{1, 2, 3}, result)
+ encoded := base64.StdEncoding.EncodeToString([]byte(`{"direction":"next"}`))
+ _, err := DecodeCursor(encoded)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Contains(t, err.Error(), "missing id")
+}
- items6 := []int{1, 2, 3, 4, 5}
- result = PaginateRecords(false, true, false, items6, limit, "DESC")
- assert.Equal(t, []int{3, 2, 1}, result)
+func TestDecodeCursor_EmptyID(t *testing.T) {
+ t.Parallel()
+
+ encoded := base64.StdEncoding.EncodeToString([]byte(`{"id":"","direction":"next"}`))
+ _, err := DecodeCursor(encoded)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Contains(t, err.Error(), "missing id")
}
-func TestCalculateCursor(t *testing.T) {
- firstItemID := "first_id"
- lastItemID := "last_id"
+func TestDecodeCursor_InvalidDirection(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(true, true, true, firstItemID, lastItemID)
- assert.NoError(t, err)
- assert.NotEmpty(t, pagination.Next)
- assert.Empty(t, pagination.Prev)
+ encoded := base64.StdEncoding.EncodeToString([]byte(`{"id":"test-id","direction":"weird"}`))
+ _, err := DecodeCursor(encoded)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
+}
- pagination, err = CalculateCursor(false, true, true, firstItemID, lastItemID)
- assert.NoError(t, err)
- assert.NotEmpty(t, pagination.Next)
- assert.NotEmpty(t, pagination.Prev)
+func TestDecodeCursor_EmptyDirection(t *testing.T) {
+ t.Parallel()
- pagination, err = CalculateCursor(false, true, false, firstItemID, lastItemID)
- assert.NoError(t, err)
- assert.NotEmpty(t, pagination.Next)
- assert.NotEmpty(t, pagination.Prev)
+ encoded := base64.StdEncoding.EncodeToString([]byte(`{"id":"test-id","direction":""}`))
+ _, err := DecodeCursor(encoded)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
+}
- pagination, err = CalculateCursor(true, false, true, firstItemID, lastItemID)
- assert.NoError(t, err)
- assert.Empty(t, pagination.Next)
- assert.Empty(t, pagination.Prev)
+func TestDecodeCursor_MissingDirection(t *testing.T) {
+ t.Parallel()
- pagination, err = CalculateCursor(false, false, true, firstItemID, lastItemID)
- assert.NoError(t, err)
- assert.Empty(t, pagination.Next)
- assert.NotEmpty(t, pagination.Prev)
+ encoded := base64.StdEncoding.EncodeToString([]byte(`{"id":"test-id"}`))
+ _, err := DecodeCursor(encoded)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
+}
- pagination, err = CalculateCursor(false, false, false, firstItemID, lastItemID)
- assert.NoError(t, err)
- assert.Empty(t, pagination.Next)
- assert.NotEmpty(t, pagination.Prev)
+func TestDecodeCursor_EmptyString(t *testing.T) {
+ t.Parallel()
+
+ _, err := DecodeCursor("")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
}
-func TestCursorWithUUIDv7(t *testing.T) {
- uuid2, err := uuid.NewV7()
- require.NoError(t, err)
+func TestDecodeCursor_ExtraFields(t *testing.T) {
+ t.Parallel()
- cursor := CreateCursor(uuid2.String(), true)
- cursorBytes, err := json.Marshal(cursor)
+ // Extra JSON fields should be ignored.
+ encoded := base64.StdEncoding.EncodeToString([]byte(`{"id":"test-id","direction":"next","extra":"ignored"}`))
+ decoded, err := DecodeCursor(encoded)
require.NoError(t, err)
- encodedCursor := base64.StdEncoding.EncodeToString(cursorBytes)
-
- decodedCursor, err := DecodeCursor(encodedCursor)
- assert.NoError(t, err)
- assert.Equal(t, uuid2.String(), decodedCursor.ID)
- assert.True(t, decodedCursor.PointsNext)
+ assert.Equal(t, "test-id", decoded.ID)
+ assert.Equal(t, CursorDirectionNext, decoded.Direction)
}
-func TestApplyCursorPaginationWithUUIDv7(t *testing.T) {
- uuid2, err := uuid.NewV7()
- require.NoError(t, err)
+// ---------------------------------------------------------------------------
+// CursorDirectionRules (4 combos + invalid)
+// ---------------------------------------------------------------------------
+
+func TestCursorDirectionRules_AllCombinations(t *testing.T) {
+ t.Parallel()
tests := []struct {
- name string
- cursorID string
- pointsNext bool
- orderDirection string
- expectedOp string
- expectedOrder string
+ name string
+ requestedSort string
+ cursorDir string
+ expectedOperator string
+ expectedOrder string
+ expectErr bool
}{
{
- name: "next page with UUID v7 - ASC",
- cursorID: uuid2.String(),
- pointsNext: true,
- orderDirection: "ASC",
- expectedOp: ">",
- expectedOrder: "ASC",
+ name: "ASC + next",
+ requestedSort: cn.SortDirASC,
+ cursorDir: CursorDirectionNext,
+ expectedOperator: ">",
+ expectedOrder: cn.SortDirASC,
},
{
- name: "next page with UUID v7 - DESC",
- cursorID: uuid2.String(),
- pointsNext: true,
- orderDirection: "DESC",
- expectedOp: "<",
- expectedOrder: "DESC",
+ name: "ASC + prev",
+ requestedSort: cn.SortDirASC,
+ cursorDir: CursorDirectionPrev,
+ expectedOperator: "<",
+ expectedOrder: cn.SortDirDESC,
},
{
- name: "prev page with UUID v7 - ASC",
- cursorID: uuid2.String(),
- pointsNext: false,
- orderDirection: "ASC",
- expectedOp: "<",
- expectedOrder: "DESC",
+ name: "DESC + next",
+ requestedSort: cn.SortDirDESC,
+ cursorDir: CursorDirectionNext,
+ expectedOperator: "<",
+ expectedOrder: cn.SortDirDESC,
},
{
- name: "prev page with UUID v7 - DESC",
- cursorID: uuid2.String(),
- pointsNext: false,
- orderDirection: "DESC",
- expectedOp: ">",
- expectedOrder: "ASC",
+ name: "DESC + prev",
+ requestedSort: cn.SortDirDESC,
+ cursorDir: CursorDirectionPrev,
+ expectedOperator: ">",
+ expectedOrder: cn.SortDirASC,
+ },
+ {
+ name: "invalid cursor direction",
+ requestedSort: cn.SortDirASC,
+ cursorDir: "invalid",
+ expectErr: true,
},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- query := squirrel.Select("*").From("test_table")
- decodedCursor := CreateCursor(tt.cursorID, tt.pointsNext)
- limit := 10
-
- resultQuery, resultOrder := ApplyCursorPagination(query, decodedCursor, tt.orderDirection, limit)
- sqlResult, args, err := resultQuery.ToSql()
- require.NoError(t, err)
-
- expectedQuery := query.Where(squirrel.Expr("id "+tt.expectedOp+" ?", tt.cursorID)).
- OrderBy("id " + tt.expectedOrder).
- Limit(uint64(limit + 1))
- sqlExpected, expectedArgs, err := expectedQuery.ToSql()
- require.NoError(t, err)
-
- assert.Equal(t, sqlExpected, sqlResult)
- assert.Equal(t, expectedArgs, args)
- assert.Equal(t, tt.expectedOrder, resultOrder)
- })
- }
-}
-
-func TestPaginateRecordsWithUUIDv7(t *testing.T) {
- uuids := make([]uuid.UUID, 5)
- for i := 0; i < 5; i++ {
- var err error
- uuids[i], err = uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- }
-
- items := make([]string, len(uuids))
- for i, u := range uuids {
- items[i] = u.String()
- }
-
- limit := 3
-
- result1 := PaginateRecords(true, true, true, append([]string{}, items...), limit, "ASC")
- assert.Equal(t, items[:3], result1)
-
- result2 := PaginateRecords(false, true, false, append([]string{}, items...), limit, "ASC")
- expected := []string{items[2], items[1], items[0]}
- assert.Equal(t, expected, result2)
-}
-
-func TestCalculateCursorWithUUIDv7(t *testing.T) {
- firstUUID, err := uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- lastUUID, err := uuid.NewV7()
- require.NoError(t, err)
-
- firstItemID := firstUUID.String()
- lastItemID := lastUUID.String()
-
- tests := []struct {
- name string
- isFirstPage bool
- hasPagination bool
- pointsNext bool
- expectNext bool
- expectPrev bool
- }{
{
- name: "first page with pagination - points next",
- isFirstPage: true,
- hasPagination: true,
- pointsNext: true,
- expectNext: true,
- expectPrev: false,
+ name: "empty cursor direction",
+ requestedSort: cn.SortDirASC,
+ cursorDir: "",
+ expectErr: true,
},
{
- name: "middle page with pagination - points next",
- isFirstPage: false,
- hasPagination: true,
- pointsNext: true,
- expectNext: true,
- expectPrev: true,
+ name: "lowercase sort direction defaults to ASC + next",
+ requestedSort: "asc",
+ cursorDir: CursorDirectionNext,
+ expectedOperator: ">",
+ expectedOrder: cn.SortDirASC,
},
{
- name: "page with pagination - points prev",
- isFirstPage: false,
- hasPagination: true,
- pointsNext: false,
- expectNext: true,
- expectPrev: true,
+ name: "lowercase desc + next",
+ requestedSort: "desc",
+ cursorDir: CursorDirectionNext,
+ expectedOperator: "<",
+ expectedOrder: cn.SortDirDESC,
+ },
+ {
+ name: "garbage sort direction defaults to ASC + next",
+ requestedSort: "GARBAGE",
+ cursorDir: CursorDirectionNext,
+ expectedOperator: ">",
+ expectedOrder: cn.SortDirASC,
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- pagination, err := CalculateCursor(tt.isFirstPage, tt.hasPagination, tt.pointsNext, firstItemID, lastItemID)
- require.NoError(t, err)
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
- if tt.expectNext {
- assert.NotEmpty(t, pagination.Next)
+ operator, order, err := CursorDirectionRules(tc.requestedSort, tc.cursorDir)
- decodedNext, err := DecodeCursor(pagination.Next)
- require.NoError(t, err)
- assert.Equal(t, lastItemID, decodedNext.ID)
- assert.True(t, decodedNext.PointsNext)
- } else {
- assert.Empty(t, pagination.Next)
+ if tc.expectErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
+ return
}
- if tt.expectPrev {
- assert.NotEmpty(t, pagination.Prev)
-
- decodedPrev, err := DecodeCursor(pagination.Prev)
- require.NoError(t, err)
- assert.Equal(t, firstItemID, decodedPrev.ID)
- assert.False(t, decodedPrev.PointsNext)
- } else {
- assert.Empty(t, pagination.Prev)
- }
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedOperator, operator)
+ assert.Equal(t, tc.expectedOrder, order)
})
}
}
-func TestUUIDv7TimestampOrdering(t *testing.T) {
- uuids := make([]uuid.UUID, 10)
- timestamps := make([]time.Time, 10)
-
- for i := 0; i < 10; i++ {
- timestamps[i] = time.Now()
- var err error
- uuids[i], err = uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- }
+// ---------------------------------------------------------------------------
+// PaginateRecords
+// ---------------------------------------------------------------------------
- for i := 0; i < 9; i++ {
- uuid1Str := uuids[i].String()
- uuid2Str := uuids[i+1].String()
+func TestPaginateRecords_NoPagination(t *testing.T) {
+ t.Parallel()
- assert.True(t, uuid1Str < uuid2Str,
- "UUID v7 at index %d (%s) should be lexicographically smaller than UUID at index %d (%s)",
- i, uuid1Str, i+1, uuid2Str)
-
- assert.True(t, timestamps[i].Before(timestamps[i+1]) || timestamps[i].Equal(timestamps[i+1]),
- "Timestamp at index %d should be before or equal to timestamp at index %d", i, i+1)
- }
+ items := []int{1, 2, 3, 4, 5}
+ result := PaginateRecords(true, false, CursorDirectionNext, items, 3)
+ assert.Equal(t, []int{1, 2, 3, 4, 5}, result)
}
-func TestCursorPaginationRealWorldScenario(t *testing.T) {
- type Item struct {
- ID string
- Name string
- CreatedAt time.Time
- }
+func TestPaginateRecords_NextDirection(t *testing.T) {
+ t.Parallel()
- items := make([]Item, 20)
- for i := 0; i < 20; i++ {
- itemUUID, err := uuid.NewV7()
- require.NoError(t, err)
- items[i] = Item{
- ID: itemUUID.String(),
- Name: "Item " + itemUUID.String()[:8],
- CreatedAt: time.Now(),
- }
- time.Sleep(1 * time.Millisecond)
- }
+ items := []int{1, 2, 3, 4, 5}
+ result := PaginateRecords(false, true, CursorDirectionNext, items, 3)
+ assert.Equal(t, []int{1, 2, 3}, result)
+}
- limit := 5
+func TestPaginateRecords_PrevDirection_NotFirstPage(t *testing.T) {
+ t.Parallel()
- page1Items := items[:limit]
+ items := []int{1, 2, 3, 4, 5}
+ result := PaginateRecords(false, true, CursorDirectionPrev, items, 3)
+ assert.Equal(t, []int{3, 2, 1}, result)
- pagination, err := CalculateCursor(true, true, true, page1Items[0].ID, page1Items[len(page1Items)-1].ID)
- require.NoError(t, err)
- assert.NotEmpty(t, pagination.Next)
- assert.Empty(t, pagination.Prev)
+ // Original slice should not be mutated.
+ assert.Equal(t, []int{1, 2, 3, 4, 5}, items)
+}
- nextCursor, err := DecodeCursor(pagination.Next)
- require.NoError(t, err)
- assert.Equal(t, page1Items[len(page1Items)-1].ID, nextCursor.ID)
- assert.True(t, nextCursor.PointsNext)
+func TestPaginateRecords_PrevDirection_FirstPage(t *testing.T) {
+ t.Parallel()
- query := squirrel.Select("id", "name", "created_at").From("items")
- paginatedQuery, order := ApplyCursorPagination(query, nextCursor, "ASC", limit)
+ items := []int{1, 2, 3, 4, 5}
+ // When isFirstPage=true, prev direction should NOT reverse.
+ result := PaginateRecords(true, true, CursorDirectionPrev, items, 3)
+ assert.Equal(t, []int{1, 2, 3}, result)
+}
- sql, args, err := paginatedQuery.ToSql()
- require.NoError(t, err)
+func TestPaginateRecords_EmptySlice(t *testing.T) {
+ t.Parallel()
- expectedSQL := "SELECT id, name, created_at FROM items WHERE id > ? ORDER BY id ASC LIMIT 6"
- assert.Equal(t, expectedSQL, sql)
- assert.Equal(t, []interface{}{page1Items[len(page1Items)-1].ID}, args)
- assert.Equal(t, "ASC", order)
+ result := PaginateRecords(true, true, CursorDirectionNext, []int{}, 10)
+ assert.Empty(t, result)
}
-func TestLastPageScenario(t *testing.T) {
- uuids := make([]uuid.UUID, 5)
- for i := 0; i < 5; i++ {
- var err error
- uuids[i], err = uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- }
-
- items := make([]string, len(uuids))
- for i, u := range uuids {
- items[i] = u.String()
- }
+func TestPaginateRecords_SingleItem(t *testing.T) {
+ t.Parallel()
- limit := 3
- lastPageItems := items[limit-1:]
+ result := PaginateRecords(false, true, CursorDirectionNext, []int{42}, 10)
+ assert.Equal(t, []int{42}, result)
+}
- isFirstPage := false
- hasPagination := false
- pointsNext := true
+func TestPaginateRecords_ExactlyLimit(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, lastPageItems[0], lastPageItems[len(lastPageItems)-1])
- require.NoError(t, err)
+ items := []int{1, 2, 3}
+ result := PaginateRecords(false, true, CursorDirectionNext, items, 3)
+ assert.Equal(t, []int{1, 2, 3}, result)
+}
- assert.Empty(t, pagination.Next, "Last page should not have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "Last page should have prev_cursor")
+func TestPaginateRecords_MoreThanLimit(t *testing.T) {
+ t.Parallel()
- decodedPrev, err := DecodeCursor(pagination.Prev)
- require.NoError(t, err)
- assert.Equal(t, lastPageItems[0], decodedPrev.ID)
- assert.False(t, decodedPrev.PointsNext)
+ items := []int{1, 2, 3, 4, 5}
+ result := PaginateRecords(false, true, CursorDirectionNext, items, 2)
+ assert.Equal(t, []int{1, 2}, result)
}
-func TestNavigationFromSecondPageBackToFirst(t *testing.T) {
- uuids := make([]uuid.UUID, 5)
- for i := 0; i < 5; i++ {
- var err error
- uuids[i], err = uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- }
+func TestPaginateRecords_LimitZero(t *testing.T) {
+ t.Parallel()
- items := make([]string, len(uuids))
- for i, u := range uuids {
- items[i] = u.String()
- }
+ // Limit 0 with hasPagination=true should return empty.
+ items := []int{1, 2, 3}
+ result := PaginateRecords(false, true, CursorDirectionNext, items, 0)
+ assert.Empty(t, result)
+}
- limit := 3
+func TestPaginateRecords_NegativeLimit(t *testing.T) {
+ t.Parallel()
- t.Run("simulate second page", func(t *testing.T) {
- secondPageItems := items[1 : limit+1]
+ // Negative limit is clamped to 0.
+ items := []int{1, 2, 3}
+ result := PaginateRecords(false, true, CursorDirectionNext, items, -5)
+ assert.Empty(t, result)
+}
- isFirstPage := false
- hasPagination := len(items) > limit
- pointsNext := true
+func TestPaginateRecords_LimitOne(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, secondPageItems[0], secondPageItems[len(secondPageItems)-1])
- require.NoError(t, err)
+ items := []int{10, 20, 30}
+ result := PaginateRecords(false, true, CursorDirectionNext, items, 1)
+ assert.Equal(t, []int{10}, result)
+}
- assert.NotEmpty(t, pagination.Next, "Second page should have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "Second page should have prev_cursor")
- })
+func TestPaginateRecords_LimitLargerThanSlice(t *testing.T) {
+ t.Parallel()
- t.Run("navigate back to first page using prev_cursor", func(t *testing.T) {
- firstPageItemsFromPrev := items[:limit]
+ items := []int{1, 2}
+ result := PaginateRecords(false, true, CursorDirectionNext, items, 100)
+ assert.Equal(t, []int{1, 2}, result)
+}
- isFirstPage := true
- hasPagination := len(items) > limit
- pointsNext := false
+func TestPaginateRecords_PrevSingleItemNotFirstPage(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, firstPageItemsFromPrev[0], firstPageItemsFromPrev[len(firstPageItemsFromPrev)-1])
- require.NoError(t, err)
+ items := []int{42}
+ result := PaginateRecords(false, true, CursorDirectionPrev, items, 5)
+ assert.Equal(t, []int{42}, result, "single item reversed is still that item")
+}
- assert.NotEmpty(t, pagination.Next, "When returning to first page via prev, should have next_cursor")
- assert.Empty(t, pagination.Prev, "When returning to first page via prev, should NOT have prev_cursor - first page never has prev")
+func TestPaginateRecords_StringType(t *testing.T) {
+ t.Parallel()
- decodedNext, err := DecodeCursor(pagination.Next)
- require.NoError(t, err)
- assert.Equal(t, firstPageItemsFromPrev[len(firstPageItemsFromPrev)-1], decodedNext.ID)
- assert.True(t, decodedNext.PointsNext)
- })
+ items := []string{"a", "b", "c", "d"}
+ result := PaginateRecords(false, true, CursorDirectionPrev, items, 3)
+ assert.Equal(t, []string{"c", "b", "a"}, result)
}
-func TestCompleteNavigationFlow(t *testing.T) {
- uuids := make([]uuid.UUID, 7)
- for i := 0; i < 7; i++ {
- var err error
- uuids[i], err = uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- }
+func TestPaginateRecords_ConcurrentUsage(t *testing.T) {
+ t.Parallel()
- items := make([]string, len(uuids))
- for i, u := range uuids {
- items[i] = u.String()
- }
+ items := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ var wg sync.WaitGroup
- limit := 3
+ for i := 0; i < 20; i++ {
+ wg.Add(1)
- t.Run("first page - initial load", func(t *testing.T) {
- firstPageItems := items[:limit]
+ go func(limit int) {
+ defer wg.Done()
- isFirstPage := true
- hasPagination := len(items) > limit
- pointsNext := true
+ // Each goroutine gets its own copy.
+ localItems := make([]int, len(items))
+ copy(localItems, items)
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, firstPageItems[0], firstPageItems[len(firstPageItems)-1])
- require.NoError(t, err)
+ result := PaginateRecords(false, true, CursorDirectionPrev, localItems, limit)
+ assert.LessOrEqual(t, len(result), limit)
+ }(i%5 + 1)
+ }
- assert.NotEmpty(t, pagination.Next, "First page should have next_cursor")
- assert.Empty(t, pagination.Prev, "First page should NOT have prev_cursor")
- })
+ wg.Wait()
+}
- t.Run("second page - using next_cursor", func(t *testing.T) {
- secondPageItems := items[limit : limit*2]
+// ---------------------------------------------------------------------------
+// CalculateCursor
+// ---------------------------------------------------------------------------
- isFirstPage := false
- hasPagination := len(items) > limit*2
- pointsNext := true
+func TestCalculateCursor_FirstPageWithMore(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, secondPageItems[0], secondPageItems[len(secondPageItems)-1])
- require.NoError(t, err)
+ firstID := uuid.NewString()
+ lastID := uuid.NewString()
- assert.NotEmpty(t, pagination.Next, "Second page should have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "Second page should have prev_cursor")
- })
+ pagination, err := CalculateCursor(true, true, CursorDirectionNext, firstID, lastID)
+ require.NoError(t, err)
+ assert.NotEmpty(t, pagination.Next)
+ assert.Empty(t, pagination.Prev, "first page should not have prev cursor")
- t.Run("last page - using next_cursor", func(t *testing.T) {
- lastPageItems := items[limit*2:]
+ next, err := DecodeCursor(pagination.Next)
+ require.NoError(t, err)
+ assert.Equal(t, lastID, next.ID)
+ assert.Equal(t, CursorDirectionNext, next.Direction)
+}
- isFirstPage := false
- hasPagination := false
- pointsNext := true
+func TestCalculateCursor_MiddlePage(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, lastPageItems[0], lastPageItems[len(lastPageItems)-1])
- require.NoError(t, err)
+ firstID := uuid.NewString()
+ lastID := uuid.NewString()
- assert.Empty(t, pagination.Next, "Last page should NOT have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "Last page should have prev_cursor")
- })
+ pagination, err := CalculateCursor(false, true, CursorDirectionNext, firstID, lastID)
+ require.NoError(t, err)
+ assert.NotEmpty(t, pagination.Next)
+ assert.NotEmpty(t, pagination.Prev)
- t.Run("back to second page - using prev_cursor", func(t *testing.T) {
- secondPageItems := items[limit : limit*2]
+ next, err := DecodeCursor(pagination.Next)
+ require.NoError(t, err)
+ assert.Equal(t, lastID, next.ID)
+ assert.Equal(t, CursorDirectionNext, next.Direction)
- isFirstPage := false
- hasPagination := len(items) > limit
- pointsNext := false
+ prev, err := DecodeCursor(pagination.Prev)
+ require.NoError(t, err)
+ assert.Equal(t, firstID, prev.ID)
+ assert.Equal(t, CursorDirectionPrev, prev.Direction)
+}
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, secondPageItems[0], secondPageItems[len(secondPageItems)-1])
- require.NoError(t, err)
+func TestCalculateCursor_LastPage(t *testing.T) {
+ t.Parallel()
- assert.NotEmpty(t, pagination.Next, "Second page (via prev) should have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "Second page (via prev) should have prev_cursor")
- })
+ firstID := uuid.NewString()
+ lastID := uuid.NewString()
- t.Run("back to first page - using prev_cursor", func(t *testing.T) {
- firstPageItems := items[:limit]
+ pagination, err := CalculateCursor(false, false, CursorDirectionNext, firstID, lastID)
+ require.NoError(t, err)
+ assert.Empty(t, pagination.Next, "last page should not have next cursor")
+ assert.NotEmpty(t, pagination.Prev)
+}
- isFirstPage := true
- hasPagination := len(items) > limit
- pointsNext := false
+func TestCalculateCursor_SinglePage(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, firstPageItems[0], firstPageItems[len(firstPageItems)-1])
- require.NoError(t, err)
+ firstID := uuid.NewString()
+ lastID := uuid.NewString()
- assert.NotEmpty(t, pagination.Next, "First page (via prev) should have next_cursor")
- assert.Empty(t, pagination.Prev, "First page (via prev) should NOT have prev_cursor - first page never has prev")
- })
+ pagination, err := CalculateCursor(true, false, CursorDirectionNext, firstID, lastID)
+ require.NoError(t, err)
+ assert.Empty(t, pagination.Next)
+ assert.Empty(t, pagination.Prev)
}
-func TestPaginationEdgeCases(t *testing.T) {
- t.Run("single page - no pagination needed", func(t *testing.T) {
- uuid1, err := uuid.NewV7()
- require.NoError(t, err)
-
- items := []string{uuid1.String()}
+func TestCalculateCursor_PrevDirection_NotFirstPage_WithPagination(t *testing.T) {
+ t.Parallel()
- isFirstPage := true
- hasPagination := false
- pointsNext := true
+ firstID := uuid.NewString()
+ lastID := uuid.NewString()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, items[0], items[0])
- require.NoError(t, err)
-
- assert.Empty(t, pagination.Next, "Single page should not have next_cursor")
- assert.Empty(t, pagination.Prev, "Single page should not have prev_cursor")
- })
+ pagination, err := CalculateCursor(false, true, CursorDirectionPrev, firstID, lastID)
+ require.NoError(t, err)
+ // For prev direction: (cursorDirection == CursorDirectionPrev && (hasPagination || isFirstPage))
+ assert.NotEmpty(t, pagination.Next)
+ assert.NotEmpty(t, pagination.Prev)
+}
- t.Run("exactly two pages", func(t *testing.T) {
- uuids := make([]uuid.UUID, 4)
- for i := 0; i < 4; i++ {
- var err error
- uuids[i], err = uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- }
+func TestCalculateCursor_PrevDirection_FirstPage_NoPagination(t *testing.T) {
+ t.Parallel()
- items := make([]string, len(uuids))
- for i, u := range uuids {
- items[i] = u.String()
- }
+ firstID := uuid.NewString()
+ lastID := uuid.NewString()
- limit := 2
+ // isFirstPage=true, hasPagination=false, direction=prev
+ // hasNext = (prev && (false || true)) = true
+ pagination, err := CalculateCursor(true, false, CursorDirectionPrev, firstID, lastID)
+ require.NoError(t, err)
+ assert.NotEmpty(t, pagination.Next)
+ assert.Empty(t, pagination.Prev, "first page should not have prev")
+}
- firstPageItems := items[:limit]
- isFirstPage := true
- hasPagination := len(items) > limit
- pointsNext := true
+func TestCalculateCursor_InvalidDirection(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, firstPageItems[0], firstPageItems[len(firstPageItems)-1])
- require.NoError(t, err)
+ _, err := CalculateCursor(true, true, "invalid", "id1", "id2")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
+}
- assert.NotEmpty(t, pagination.Next, "First page of two should have next_cursor")
- assert.Empty(t, pagination.Prev, "First page of two should not have prev_cursor")
+func TestCalculateCursor_EmptyDirection(t *testing.T) {
+ t.Parallel()
- lastPageItems := items[limit:]
- isFirstPage = false
- hasPagination = false
- pointsNext = true
+ _, err := CalculateCursor(true, true, "", "id1", "id2")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursorDirection)
+}
- pagination, err = CalculateCursor(isFirstPage, hasPagination, pointsNext, lastPageItems[0], lastPageItems[len(lastPageItems)-1])
- require.NoError(t, err)
+func TestCalculateCursor_EmptyLastItemID(t *testing.T) {
+ t.Parallel()
- assert.Empty(t, pagination.Next, "Last page of two should not have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "Last page of two should have prev_cursor")
- })
+ // EncodeCursor will fail because ID is empty.
+ _, err := CalculateCursor(true, true, CursorDirectionNext, "first", "")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
}
-func TestBugReproduction(t *testing.T) {
- t.Run("REAL bug reproduction: repository implementation", func(t *testing.T) {
- uuids := make([]uuid.UUID, 3)
- for i := 0; i < 3; i++ {
- var err error
- uuids[i], err = uuid.NewV7()
- require.NoError(t, err)
- time.Sleep(1 * time.Millisecond)
- }
+func TestCalculateCursor_EmptyFirstItemID_NotFirstPage(t *testing.T) {
+ t.Parallel()
- items := make([]string, len(uuids))
- for i, u := range uuids {
- items[i] = u.String()
- }
+ // When not first page, prev cursor is built with firstItemID; empty will fail.
+ _, err := CalculateCursor(false, false, CursorDirectionNext, "", "last")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+}
- limit := 2
+// ---------------------------------------------------------------------------
+// Cursor encode/decode round-trip with various ID formats
+// ---------------------------------------------------------------------------
- t.Run("step 1: first page initial load (cursor empty)", func(t *testing.T) {
- cursor := ""
- allResults := append(items[:limit], "dummy_item")
+func TestCursor_RoundTrip_UUIDId(t *testing.T) {
+ t.Parallel()
- isFirstPage := cursor == ""
- hasPagination := len(allResults) > limit
- pointsNext := true
- actualResults := allResults[:limit]
+ id := uuid.NewString()
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionNext})
+ require.NoError(t, err)
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, actualResults[0], actualResults[len(actualResults)-1])
- require.NoError(t, err)
+ decoded, err := DecodeCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, id, decoded.ID)
+ assert.Equal(t, CursorDirectionNext, decoded.Direction)
+}
- t.Logf("First page initial: next=%s, prev=%s", pagination.Next, pagination.Prev)
- assert.NotEmpty(t, pagination.Next, "Initial first page should have next_cursor")
- assert.Empty(t, pagination.Prev, "Initial first page should NOT have prev_cursor")
- })
+func TestCursor_RoundTrip_ArbitraryStringID(t *testing.T) {
+ t.Parallel()
- t.Run("step 2: second page using next_cursor", func(t *testing.T) {
- firstPageCursor := CreateCursor(items[1], true)
- cursorBytes, _ := json.Marshal(firstPageCursor)
- cursor := base64.StdEncoding.EncodeToString(cursorBytes)
+ id := "custom-resource-id-12345"
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionPrev})
+ require.NoError(t, err)
- decodedCursor, _ := DecodeCursor(cursor)
- allResults := items[1:]
+ decoded, err := DecodeCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, id, decoded.ID)
+ assert.Equal(t, CursorDirectionPrev, decoded.Direction)
+}
- isFirstPage := false
- hasPagination := false
- pointsNext := decodedCursor.PointsNext
- actualResults := allResults
+func TestCursor_RoundTrip_SpecialCharacters(t *testing.T) {
+ t.Parallel()
- if len(actualResults) > limit {
- actualResults = actualResults[:limit]
- }
+ id := "id/with+special=characters&more"
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionNext})
+ require.NoError(t, err)
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, actualResults[0], actualResults[len(actualResults)-1])
- require.NoError(t, err)
+ decoded, err := DecodeCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, id, decoded.ID)
+}
- t.Logf("Second page: next=%s, prev=%s", pagination.Next, pagination.Prev)
- assert.Empty(t, pagination.Next, "Second page (last) should not have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "Second page should have prev_cursor")
- })
+func TestCursor_RoundTrip_VeryLongID(t *testing.T) {
+ t.Parallel()
- t.Run("step 3: back to first page using prev_cursor - CORRECT", func(t *testing.T) {
- prevPageCursor := CreateCursor(items[1], false)
- cursorBytes, _ := json.Marshal(prevPageCursor)
- cursor := base64.StdEncoding.EncodeToString(cursorBytes)
+ id := strings.Repeat("x", 1000)
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionNext})
+ require.NoError(t, err)
- decodedCursor, _ := DecodeCursor(cursor)
- firstPageItems := items[:limit]
+ decoded, err := DecodeCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, id, decoded.ID)
+}
- isFirstPage := len(firstPageItems) < limit || firstPageItems[0] == items[0]
- hasPagination := len(items) > limit
- pointsNext := decodedCursor.PointsNext
+func TestCursor_RoundTrip_UnicodeID(t *testing.T) {
+ t.Parallel()
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, firstPageItems[0], firstPageItems[len(firstPageItems)-1])
- require.NoError(t, err)
+ id := "id-with-unicode-\u00e9\u00e8\u00ea"
+ encoded, err := EncodeCursor(Cursor{ID: id, Direction: CursorDirectionNext})
+ require.NoError(t, err)
- t.Logf("Back to first page: isFirstPage=%v, hasPagination=%v, pointsNext=%v", isFirstPage, hasPagination, pointsNext)
- t.Logf("Back to first page result: next=%s, prev=%s", pagination.Next, pagination.Prev)
-
- assert.NotEmpty(t, pagination.Next, "First page (back from prev) should have next_cursor")
- assert.Empty(t, pagination.Prev, "First page (back from prev) should NOT have prev_cursor - CORRECT")
- })
+ decoded, err := DecodeCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, id, decoded.ID)
+}
- t.Run("step 3: back to first page - WRONG IMPLEMENTATION (YOUR BUG)", func(t *testing.T) {
- prevPageCursor := CreateCursor(items[1], false)
- cursorBytes, _ := json.Marshal(prevPageCursor)
- cursor := base64.StdEncoding.EncodeToString(cursorBytes)
+// ---------------------------------------------------------------------------
+// Cursor constants
+// ---------------------------------------------------------------------------
- decodedCursor, _ := DecodeCursor(cursor)
- firstPageItems := items[:limit]
+func TestCursorDirectionConstants(t *testing.T) {
+ t.Parallel()
- isFirstPage := false
- hasPagination := len(items) > limit
- pointsNext := decodedCursor.PointsNext
+ assert.Equal(t, "next", CursorDirectionNext)
+ assert.Equal(t, "prev", CursorDirectionPrev)
+}
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, firstPageItems[0], firstPageItems[len(firstPageItems)-1])
- require.NoError(t, err)
+// ---------------------------------------------------------------------------
+// CursorPagination struct
+// ---------------------------------------------------------------------------
- t.Logf("WRONG: Back to first page: isFirstPage=%v, hasPagination=%v, pointsNext=%v", isFirstPage, hasPagination, pointsNext)
- t.Logf("WRONG: Back to first page result: next=%s, prev=%s", pagination.Next, pagination.Prev)
-
- assert.NotEmpty(t, pagination.Next, "First page should have next_cursor")
- assert.NotEmpty(t, pagination.Prev, "BUG: First page incorrectly has prev_cursor because isFirstPage=false")
- })
- })
+func TestCursorPagination_JSON(t *testing.T) {
+ t.Parallel()
- t.Run("bug: infinite loop with same cursor values", func(t *testing.T) {
- firstItemID := "0198c376-87de-7234-a8da-8e6ec327889d"
- lastItemID := "0198c376-2a4b-74e5-a25a-2777b1a87ab9"
+ cp := CursorPagination{Next: "abc", Prev: "def"}
+ data, err := json.Marshal(cp)
+ require.NoError(t, err)
- isFirstPage := false
- hasPagination := true
- pointsNext := false
+ var decoded CursorPagination
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ assert.Equal(t, "abc", decoded.Next)
+ assert.Equal(t, "def", decoded.Prev)
+}
- pagination, err := CalculateCursor(isFirstPage, hasPagination, pointsNext, firstItemID, lastItemID)
- require.NoError(t, err)
+func TestCursorPagination_EmptyJSON(t *testing.T) {
+ t.Parallel()
- if pagination.Next != "" && pagination.Prev != "" {
- nextCursor, err := DecodeCursor(pagination.Next)
- require.NoError(t, err)
- prevCursor, err := DecodeCursor(pagination.Prev)
- require.NoError(t, err)
+ cp := CursorPagination{}
+ data, err := json.Marshal(cp)
+ require.NoError(t, err)
- assert.NotEqual(t, nextCursor.ID, prevCursor.ID, "Next and Prev cursors should point to different IDs to avoid infinite loops")
- assert.True(t, nextCursor.PointsNext, "Next cursor should have PointsNext=true")
- assert.False(t, prevCursor.PointsNext, "Prev cursor should have PointsNext=false")
- }
- })
+ var decoded CursorPagination
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ assert.Empty(t, decoded.Next)
+ assert.Empty(t, decoded.Prev)
}
diff --git a/commons/net/http/doc.go b/commons/net/http/doc.go
new file mode 100644
index 00000000..7a1f43b4
--- /dev/null
+++ b/commons/net/http/doc.go
@@ -0,0 +1,5 @@
+// Package http provides Fiber-oriented HTTP helpers, middleware, and error handling.
+//
+// Core entry points include response helpers (Respond, RespondError, RenderError),
+// middleware builders, and FiberErrorHandler for consistent request failure handling.
+package http
diff --git a/commons/net/http/error.go b/commons/net/http/error.go
new file mode 100644
index 00000000..20c0da6d
--- /dev/null
+++ b/commons/net/http/error.go
@@ -0,0 +1,14 @@
+package http
+
+import (
+ "github.com/gofiber/fiber/v2"
+)
+
+// RespondError writes a structured error response using the ErrorResponse schema.
+func RespondError(c *fiber.Ctx, status int, title, message string) error {
+ return Respond(c, status, ErrorResponse{
+ Code: status,
+ Title: title,
+ Message: message,
+ })
+}
diff --git a/commons/net/http/error_test.go b/commons/net/http/error_test.go
new file mode 100644
index 00000000..ad5ee5e4
--- /dev/null
+++ b/commons/net/http/error_test.go
@@ -0,0 +1,1007 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// RespondError -- comprehensive status code and structure coverage
+// ---------------------------------------------------------------------------
+
+func TestRespondError_HappyPath(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondError(c, fiber.StatusBadRequest, "test_error", "test message")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal(body, &errResp))
+
+ assert.Equal(t, 400, errResp.Code)
+ assert.Equal(t, "test_error", errResp.Title)
+ assert.Equal(t, "test message", errResp.Message)
+}
+
+func TestRespondError_AllStatusCodes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ status int
+ title string
+ message string
+ }{
+ {"400 Bad Request", 400, "bad_request", "Invalid input"},
+ {"401 Unauthorized", 401, "unauthorized", "Missing token"},
+ {"403 Forbidden", 403, "forbidden", "Access denied"},
+ {"404 Not Found", 404, "not_found", "Resource not found"},
+ {"405 Method Not Allowed", 405, "method_not_allowed", "POST not supported"},
+ {"409 Conflict", 409, "conflict", "Resource already exists"},
+ {"412 Precondition Failed", 412, "precondition_failed", "ETag mismatch"},
+ {"422 Unprocessable Entity", 422, "unprocessable_entity", "Validation failed"},
+ {"429 Too Many Requests", 429, "rate_limited", "Rate limit exceeded"},
+ {"500 Internal Server Error", 500, "internal_error", "Something went wrong"},
+ {"502 Bad Gateway", 502, "bad_gateway", "Upstream unavailable"},
+ {"503 Service Unavailable", 503, "service_unavailable", "Service temporarily down"},
+ {"504 Gateway Timeout", 504, "gateway_timeout", "Upstream timeout"},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondError(c, tc.status, tc.title, tc.message)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, tc.status, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal(body, &errResp))
+
+ assert.Equal(t, tc.status, errResp.Code)
+ assert.Equal(t, tc.title, errResp.Title)
+ assert.Equal(t, tc.message, errResp.Message)
+ })
+ }
+}
+
+func TestRespondError_NoLegacyField(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondError(c, fiber.StatusUnauthorized, "invalid_credentials", "invalid")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+ _, exists := parsed["error"]
+ assert.False(t, exists, "response should not contain legacy 'error' field")
+}
+
+func TestRespondError_JSONStructureExactlyThreeFields(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondError(c, fiber.StatusUnprocessableEntity, "validation_error", "field 'name' required")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(body, &parsed))
+
+ assert.Len(t, parsed, 3, "response should have exactly 3 fields: code, title, message")
+ assert.Contains(t, parsed, "code")
+ assert.Contains(t, parsed, "title")
+ assert.Contains(t, parsed, "message")
+
+ assert.Equal(t, float64(422), parsed["code"])
+ assert.Equal(t, "validation_error", parsed["title"])
+ assert.Equal(t, "field 'name' required", parsed["message"])
+}
+
+func TestRespondError_EmptyTitleAndMessage(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondError(c, fiber.StatusBadRequest, "", "")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal(body, &errResp))
+ assert.Equal(t, 400, errResp.Code)
+ assert.Empty(t, errResp.Title)
+ assert.Empty(t, errResp.Message)
+}
+
+func TestRespondError_LongMessage(t *testing.T) {
+ t.Parallel()
+
+ longMsg := "The request could not be processed because the 'transaction_amount' field exceeds " +
+ "the maximum allowed value of 999999999.99 for the specified currency code (USD). " +
+ "Please verify the amount and retry the request."
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondError(c, fiber.StatusBadRequest, "amount_exceeded", longMsg)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal(body, &errResp))
+ assert.Equal(t, longMsg, errResp.Message)
+}
+
+func TestRespondError_ContentTypeIsJSON(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondError(c, fiber.StatusBadRequest, "bad", "bad request")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Contains(t, resp.Header.Get("Content-Type"), "application/json")
+}
+
+// ---------------------------------------------------------------------------
+// ErrorResponse interface and marshaling
+// ---------------------------------------------------------------------------
+
+func TestErrorResponse_ImplementsError(t *testing.T) {
+ t.Parallel()
+
+ errResp := ErrorResponse{
+ Code: 400,
+ Title: "bad_request",
+ Message: "invalid input",
+ }
+
+ var err error = errResp
+ assert.Equal(t, "invalid input", err.Error())
+}
+
+func TestErrorResponse_MarshalUnmarshalRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ errResp := ErrorResponse{
+ Code: 404,
+ Title: "not_found",
+ Message: "resource does not exist",
+ }
+
+ data, err := json.Marshal(errResp)
+ require.NoError(t, err)
+
+ var decoded ErrorResponse
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ assert.Equal(t, errResp, decoded)
+}
+
+// ---------------------------------------------------------------------------
+// RenderError -- extended edge cases (not covered in matcher_response_test.go)
+// ---------------------------------------------------------------------------
+
+func TestRenderError_ErrorResponseWithValidCodes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err ErrorResponse
+ wantCode int
+ wantTitle string
+ wantMessage string
+ }{
+ {
+ name: "503 Service Unavailable",
+ err: ErrorResponse{
+ Code: 503,
+ Title: "service_unavailable",
+ Message: "Maintenance mode",
+ },
+ wantCode: 503,
+ wantTitle: "service_unavailable",
+ wantMessage: "Maintenance mode",
+ },
+ {
+ name: "429 Too Many Requests",
+ err: ErrorResponse{
+ Code: 429,
+ Title: "rate_limited",
+ Message: "Slow down",
+ },
+ wantCode: 429,
+ wantTitle: "rate_limited",
+ wantMessage: "Slow down",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, tc.err)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, tc.wantCode, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+
+ assert.Equal(t, float64(tc.wantCode), result["code"])
+ assert.Equal(t, tc.wantTitle, result["title"])
+ assert.Equal(t, tc.wantMessage, result["message"])
+ })
+ }
+}
+
+func TestRenderError_MultipleGenericErrorsSanitized(t *testing.T) {
+ t.Parallel()
+
+ genericErrors := []error{
+ errors.New("password=secret123"),
+ fmt.Errorf("wrapped: %w", errors.New("nested internal")),
+ errors.New("sql: connection refused at 10.0.0.1:5432"),
+ }
+
+ for _, genericErr := range genericErrors {
+ t.Run(genericErr.Error(), func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, genericErr)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+
+ assert.Equal(t, "request_failed", result["title"])
+ assert.Equal(t, "An internal error occurred", result["message"])
+ assert.NotContains(t, string(body), genericErr.Error(),
+ "internal error message should not leak through to the client")
+ })
+ }
+}
+
+func TestRenderError_WrappedErrorResponseConflict(t *testing.T) {
+ t.Parallel()
+
+ original := ErrorResponse{
+ Code: 409,
+ Title: "conflict",
+ Message: "duplicate resource",
+ }
+ wrappedErr := fmt.Errorf("layer: %w", original)
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, wrappedErr)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, 409, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "conflict", result["title"])
+ assert.Equal(t, "duplicate resource", result["message"])
+}
+
+func TestRenderError_WrappedFiberErrorForbidden(t *testing.T) {
+ t.Parallel()
+
+ wrappedErr := fmt.Errorf("context: %w", fiber.NewError(403, "forbidden resource"))
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, wrappedErr)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, 403, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "forbidden resource", result["message"])
+}
+
+// ---------------------------------------------------------------------------
+// FiberErrorHandler
+// ---------------------------------------------------------------------------
+
+func TestFiberErrorHandler_FiberErrorNotFound(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{
+ ErrorHandler: FiberErrorHandler,
+ })
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return fiber.NewError(fiber.StatusNotFound, "route not found")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusNotFound, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, float64(404), result["code"])
+ assert.Equal(t, "request_failed", result["title"])
+ assert.Equal(t, "route not found", result["message"])
+}
+
+func TestFiberErrorHandler_GenericError(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{
+ ErrorHandler: FiberErrorHandler,
+ })
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return errors.New("database connection refused")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+
+ assert.Equal(t, "request_failed", result["title"])
+ assert.Equal(t, "An internal error occurred", result["message"])
+}
+
+func TestFiberErrorHandler_FiberErrorWithVariousStatusCodes(t *testing.T) {
+ t.Parallel()
+
+ codes := []int{400, 401, 403, 404, 405, 409, 422, 429, 500, 502, 503}
+
+ for _, code := range codes {
+ t.Run(fmt.Sprintf("status_%d", code), func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{
+ ErrorHandler: FiberErrorHandler,
+ })
+ msg := fmt.Sprintf("error with code %d", code)
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return fiber.NewError(code, msg)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, code, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, float64(code), result["code"])
+ assert.Equal(t, msg, result["message"])
+ })
+ }
+}
+
+func TestFiberErrorHandler_ErrorResponseType(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{
+ ErrorHandler: FiberErrorHandler,
+ })
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return ErrorResponse{
+ Code: 422,
+ Title: "validation_error",
+ Message: "field required",
+ }
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "validation_error", result["title"])
+ assert.Equal(t, "field required", result["message"])
+}
+
+func TestFiberErrorHandler_RouteNotFound(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{
+ ErrorHandler: FiberErrorHandler,
+ })
+ app.Get("/exists", func(c *fiber.Ctx) error {
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/does-not-exist", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusNotFound, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, float64(404), result["code"])
+ assert.Equal(t, "request_failed", result["title"])
+}
+
+func TestFiberErrorHandler_MethodNotAllowed(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{
+ ErrorHandler: FiberErrorHandler,
+ })
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // Fiber sends 404 by default unless MethodNotAllowed is enabled.
+ assert.True(t, resp.StatusCode == 404 || resp.StatusCode == 405)
+}
+
+// ---------------------------------------------------------------------------
+// Respond and RespondStatus helpers
+// ---------------------------------------------------------------------------
+
+func TestRespond_ValidPayload(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return Respond(c, fiber.StatusOK, fiber.Map{"result": "ok"})
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusOK, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "ok", result["result"])
+}
+
+func TestRespond_InvalidStatusDefaultsTo500(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ status int
+ }{
+ {"negative status", -1},
+ {"zero status", 0},
+ {"status below 100", 99},
+ {"status above 599", 600},
+ {"very large status", 9999},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return Respond(c, tc.status, fiber.Map{"msg": "test"})
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
+ })
+ }
+}
+
+func TestRespond_BoundaryStatusCodes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ status int
+ wantStatus int
+ }{
+ {"100 Continue", 100, 100},
+ {"599 custom", 599, 599},
+ {"200 OK", 200, 200},
+ {"204 No Content", 204, 204},
+ {"301 Moved", 301, 301},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return Respond(c, tc.status, fiber.Map{"msg": "test"})
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, tc.wantStatus, resp.StatusCode)
+ })
+ }
+}
+
+func TestRespondStatus_ValidStatus(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondStatus(c, fiber.StatusNoContent)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusNoContent, resp.StatusCode)
+}
+
+func TestRespondStatus_InvalidStatusDefaultsTo500(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RespondStatus(c, -1)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
+}
+
+func TestRespond_NilPayload(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return Respond(c, fiber.StatusOK, nil)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusOK, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, "null", string(body))
+}
+
+// ---------------------------------------------------------------------------
+// ExtractTokenFromHeader
+// ---------------------------------------------------------------------------
+
+func TestExtractTokenFromHeader_BearerToken(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer my-jwt-token-123")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, "my-jwt-token-123", token)
+}
+
+func TestExtractTokenFromHeader_BearerCaseInsensitive(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "BEARER my-jwt-token-123")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, "my-jwt-token-123", token)
+}
+
+func TestExtractTokenFromHeader_RawTokenPreserved(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "raw-token-no-bearer-prefix")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, "raw-token-no-bearer-prefix", token)
+}
+
+func TestExtractTokenFromHeader_BearerWithoutTokenRejected(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Empty(t, token)
+}
+
+func TestExtractTokenFromHeader_BearerWithExtraFieldsRejected(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer token extra")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Empty(t, token)
+}
+
+func TestExtractTokenFromHeader_EmptyHeader(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Empty(t, token)
+}
+
+func TestExtractTokenFromHeader_BearerWithExtraSpaces(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer my-token ")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // strings.Fields collapses whitespace, so "Bearer my-token " => ["Bearer", "my-token"].
+ // The token is correctly extracted regardless of extra whitespace.
+ assert.Equal(t, "my-token", token, "extra spaces between Bearer and token should be handled correctly")
+}
+
+func TestExtractTokenFromHeader_BearerLowercase(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "bearer my-token-lowercase")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, "my-token-lowercase", token)
+}
+
+func TestExtractTokenFromHeader_NonBearerMultiPartReturnsEmpty(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var token string
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ token = ExtractTokenFromHeader(c)
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Basic abc123")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Empty(t, token)
+}
+
+// ---------------------------------------------------------------------------
+// Ping, Version, NotImplemented, Welcome handlers
+// ---------------------------------------------------------------------------
+
+func TestPing(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/ping", Ping)
+
+ req := httptest.NewRequest(http.MethodGet, "/ping", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusOK, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, "healthy", string(body))
+}
+
+func TestVersion(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/version", Version)
+
+ req := httptest.NewRequest(http.MethodGet, "/version", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusOK, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Contains(t, result, "version")
+ assert.Contains(t, result, "requestDate")
+}
+
+func TestNotImplementedEndpoint(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", NotImplementedEndpoint)
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusNotImplemented, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "not_implemented", result["title"])
+ assert.Equal(t, "Not implemented yet", result["message"])
+}
+
+func TestWelcome(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", Welcome("my-service", "A financial service"))
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusOK, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "my-service", result["service"])
+ assert.Equal(t, "A financial service", result["description"])
+}
diff --git a/commons/net/http/handler.go b/commons/net/http/handler.go
index 52abbba0..40fe5c2a 100644
--- a/commons/net/http/handler.go
+++ b/commons/net/http/handler.go
@@ -1,33 +1,36 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
+ "context"
"errors"
- "log"
- "net/http"
"strings"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
"github.com/gofiber/fiber/v2"
"go.opentelemetry.io/otel/trace"
)
-// Ping returns HTTP Status 200 with response "pong".
+// Ping returns HTTP Status 200 with response "healthy".
func Ping(c *fiber.Ctx) error {
- if err := c.SendString("healthy"); err != nil {
- log.Print(err.Error())
+ if c == nil {
+ return ErrContextNotFound
}
- return nil
+ return c.SendString("healthy")
}
-// Version returns HTTP Status 200 with given version.
+// Version returns HTTP Status 200 with the service version from the VERSION
+// environment variable (defaults to "0.0.0").
+//
+// NOTE: This endpoint intentionally exposes the build version. Callers that
+// need to restrict visibility should gate this route behind authentication
+// or omit it from public-facing routers.
func Version(c *fiber.Ctx) error {
- return OK(c, fiber.Map{
+ return Respond(c, fiber.StatusOK, fiber.Map{
"version": commons.GetenvOrDefault("VERSION", "0.0.0"),
"requestDate": time.Now().UTC(),
})
@@ -36,6 +39,10 @@ func Version(c *fiber.Ctx) error {
// Welcome returns HTTP Status 200 with service info.
func Welcome(service string, description string) fiber.Handler {
return func(c *fiber.Ctx) error {
+ if c == nil {
+ return ErrContextNotFound
+ }
+
return c.JSON(fiber.Map{
"service": service,
"description": description,
@@ -45,65 +52,96 @@ func Welcome(service string, description string) fiber.Handler {
// NotImplementedEndpoint returns HTTP 501 with not implemented message.
func NotImplementedEndpoint(c *fiber.Ctx) error {
- return c.Status(fiber.StatusNotImplemented).JSON(fiber.Map{"error": "Not implemented yet"})
+ return RespondError(c, fiber.StatusNotImplemented, "not_implemented", "Not implemented yet")
}
-// File servers a specific file.
+// File serves a specific file.
func File(filePath string) fiber.Handler {
return func(c *fiber.Ctx) error {
+ if c == nil {
+ return ErrContextNotFound
+ }
+
return c.SendFile(filePath)
}
}
-// ExtractTokenFromHeader extracts the authentication token from the Authorization header.
-// It handles both "Bearer TOKEN" format and raw token format.
+// ExtractTokenFromHeader extracts a token from the Authorization header.
+// It accepts `Bearer ` case-insensitively and also preserves the
+// legacy raw-token form when the header contains a single token with no scheme.
+// Malformed Bearer values and non-Bearer multi-part values return an empty string.
func ExtractTokenFromHeader(c *fiber.Ctx) string {
- authHeader := c.Get(fiber.HeaderAuthorization)
+ if c == nil {
+ return ""
+ }
+ authHeader := strings.TrimSpace(c.Get(fiber.HeaderAuthorization))
if authHeader == "" {
return ""
}
- splitToken := strings.Split(authHeader, " ")
+ fields := strings.Fields(authHeader)
- if len(splitToken) > 1 && strings.EqualFold(splitToken[0], "bearer") {
- return strings.TrimSpace(splitToken[1])
+ if len(fields) == 2 && strings.EqualFold(fields[0], cn.Bearer) {
+ return fields[1]
}
- if len(splitToken) > 0 {
- return strings.TrimSpace(splitToken[0])
+ if len(fields) > 2 && strings.EqualFold(fields[0], cn.Bearer) {
+ return ""
+ }
+
+ if len(fields) == 1 {
+ if strings.EqualFold(fields[0], cn.Bearer) {
+ return ""
+ }
+
+ return fields[0]
}
return ""
}
-// HandleFiberError handles errors for Fiber, properly unwrapping errors to check for fiber.Error
-func HandleFiberError(c *fiber.Ctx, err error) error {
+// FiberErrorHandler is the canonical Fiber error handler.
+// It uses the structured logger from the request context so that error
+// details pass through the sanitization pipeline instead of going to
+// plain stdlib log.Printf.
+func FiberErrorHandler(c *fiber.Ctx, err error) error {
+ if c == nil {
+ if err != nil {
+ return err
+ }
+
+ return ErrContextNotFound
+ }
+
// Safely end spans if user context exists
ctx := c.UserContext()
if ctx != nil {
- // End the span immediately instead of in a goroutine to ensure prompt completion
- trace.SpanFromContext(ctx).End()
+ span := trace.SpanFromContext(ctx)
+ libOpentelemetry.HandleSpanError(span, "handler error", err)
+ span.End()
}
- // Default error handling
- code := fiber.StatusInternalServerError
-
- var e *fiber.Error
- if errors.As(err, &e) {
- code = e.Code
+ var fe *fiber.Error
+ if errors.As(err, &fe) {
+ return RenderError(c, ErrorResponse{
+ Code: fe.Code,
+ Title: cn.DefaultErrorTitle,
+ Message: fe.Message,
+ })
}
- if code == fiber.StatusInternalServerError {
- // Log the actual error for debugging purposes.
- log.Printf("handler error on %s %s: %v", c.Method(), c.Path(), err)
-
- return c.Status(code).JSON(fiber.Map{
- "error": http.StatusText(code),
- })
+ if ctx == nil {
+ ctx = context.Background()
}
- return c.Status(code).JSON(fiber.Map{
- "error": err.Error(),
- })
+ logger := commons.NewLoggerFromContext(ctx)
+ logger.Log(ctx, libLog.LevelError,
+ "handler error",
+ libLog.String("method", c.Method()),
+ libLog.String("path", c.Path()),
+ libLog.Err(err),
+ )
+
+ return RenderError(c, err)
}
diff --git a/commons/net/http/handler_nil_test.go b/commons/net/http/handler_nil_test.go
new file mode 100644
index 00000000..654e6393
--- /dev/null
+++ b/commons/net/http/handler_nil_test.go
@@ -0,0 +1,75 @@
+//go:build unit
+
+package http
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPing_NilContext(t *testing.T) {
+ t.Parallel()
+
+ err := Ping(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestExtractTokenFromHeader_NilContext(t *testing.T) {
+ t.Parallel()
+
+ assert.Empty(t, ExtractTokenFromHeader(nil))
+}
+
+func TestFiberErrorHandler_NilContext(t *testing.T) {
+ t.Parallel()
+
+ handlerErr := errors.New("boom")
+ err := FiberErrorHandler(nil, handlerErr)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, handlerErr)
+}
+
+func TestFiberErrorHandler_NilContextAndNilError(t *testing.T) {
+ t.Parallel()
+
+ err := FiberErrorHandler(nil, nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestWelcome_NilContext(t *testing.T) {
+ t.Parallel()
+
+ err := Welcome("svc", "desc")(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestFile_NilContext(t *testing.T) {
+ t.Parallel()
+
+ err := File("/tmp/ignored")(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestHealthWithDependencies_NilContext(t *testing.T) {
+ t.Parallel()
+
+ err := HealthWithDependencies()(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestEndTracingSpans_NilContext(t *testing.T) {
+ t.Parallel()
+
+ middleware := &TelemetryMiddleware{}
+ err := middleware.EndTracingSpans(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
diff --git a/commons/net/http/handler_test.go b/commons/net/http/handler_test.go
new file mode 100644
index 00000000..0f3b0b88
--- /dev/null
+++ b/commons/net/http/handler_test.go
@@ -0,0 +1,27 @@
+//go:build unit
+
+package http
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestFileHandler(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/file", File("../../../go.mod"))
+
+ req := httptest.NewRequest(http.MethodGet, "/file", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { assert.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
diff --git a/commons/net/http/health.go b/commons/net/http/health.go
index fd1a5987..6f898925 100644
--- a/commons/net/http/health.go
+++ b/commons/net/http/health.go
@@ -1,15 +1,23 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
- "github.com/LerianStudio/lib-commons/v2/commons/circuitbreaker"
- "github.com/LerianStudio/lib-commons/v2/commons/constants"
+ "errors"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/circuitbreaker"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
"github.com/gofiber/fiber/v2"
)
+var (
+ // ErrEmptyDependencyName indicates a DependencyCheck was registered with an empty Name.
+ ErrEmptyDependencyName = errors.New("dependency name must not be empty")
+ // ErrDuplicateDependencyName indicates two DependencyChecks share the same Name.
+ ErrDuplicateDependencyName = errors.New("duplicate dependency name")
+ // ErrCBWithoutServiceName indicates a CircuitBreaker was provided without a ServiceName.
+ ErrCBWithoutServiceName = errors.New("CircuitBreaker provided without ServiceName")
+)
+
// DependencyCheck represents a health check configuration for a single dependency.
//
// At minimum, provide a Name. For circuit breaker integration, provide both
@@ -65,6 +73,10 @@ type DependencyStatus struct {
// Returns HTTP 200 (status: "available") when all dependencies are healthy,
// or HTTP 503 (status: "degraded") when any dependency fails.
//
+// Security note: this response includes dependency names and health metadata.
+// Prefer `Ping` for public liveness probes, and keep
+// `HealthWithDependencies` on internal or authenticated routes.
+//
// Example:
//
// f.Get("/health", commonsHttp.HealthWithDependencies(
@@ -81,7 +93,48 @@ type DependencyStatus struct {
// },
// ))
func HealthWithDependencies(dependencies ...DependencyCheck) fiber.Handler {
+ // Validate dependency names at registration time (2.21).
+ // Errors here are configuration bugs, so we capture them and return
+ // 503 on every request to make misconfiguration visible immediately.
+ seen := make(map[string]struct{}, len(dependencies))
+
+ var configErr error
+
+ for _, dep := range dependencies {
+ if dep.Name == "" {
+ configErr = ErrEmptyDependencyName
+
+ break
+ }
+
+ if _, exists := seen[dep.Name]; exists {
+ configErr = ErrDuplicateDependencyName
+
+ break
+ }
+
+ seen[dep.Name] = struct{}{}
+
+ // 2.6/2.8: CircuitBreaker provided without ServiceName is a misconfiguration.
+ if !nilcheck.Interface(dep.CircuitBreaker) && dep.ServiceName == "" {
+ configErr = ErrCBWithoutServiceName
+
+ break
+ }
+ }
+
return func(c *fiber.Ctx) error {
+ if c == nil {
+ return ErrContextNotFound
+ }
+
+ if configErr != nil {
+ return Respond(c, fiber.StatusServiceUnavailable, fiber.Map{
+ "status": constant.DataSourceStatusDegraded,
+ "error": configErr.Error(),
+ })
+ }
+
overallStatus := constant.DataSourceStatusAvailable
httpStatus := fiber.StatusOK
@@ -92,8 +145,10 @@ func HealthWithDependencies(dependencies ...DependencyCheck) fiber.Handler {
Healthy: true, // Default to healthy unless proven otherwise
}
- // Check circuit breaker state if provided
- if dep.CircuitBreaker != nil && dep.ServiceName != "" {
+ // Check circuit breaker state if provided.
+ // Uses typed-nil-safe check (2.13) so a concrete nil manager
+ // does not sneak past the interface-nil gate.
+ if !nilcheck.Interface(dep.CircuitBreaker) && dep.ServiceName != "" {
cbState := dep.CircuitBreaker.GetState(dep.ServiceName)
cbCounts := dep.CircuitBreaker.GetCounts(dep.ServiceName)
@@ -107,11 +162,14 @@ func HealthWithDependencies(dependencies ...DependencyCheck) fiber.Handler {
status.Healthy = dep.CircuitBreaker.IsHealthy(dep.ServiceName)
}
- // Run custom health check if provided
- // This overrides the circuit breaker health status if both are provided
+ // Run custom health check if provided.
+ // When both CircuitBreaker and HealthCheck are configured, both must
+ // report healthy (AND semantics) to prevent silently bypassing
+ // circuit breaker protection.
if dep.HealthCheck != nil {
- healthy := dep.HealthCheck()
- status.Healthy = healthy
+ if !dep.HealthCheck() {
+ status.Healthy = false
+ }
}
// Update overall status based on final dependency health
@@ -131,14 +189,3 @@ func HealthWithDependencies(dependencies ...DependencyCheck) fiber.Handler {
})
}
}
-
-// HealthSimple is an alias for the existing Ping function for backward compatibility.
-// Use this when you don't need detailed dependency health checks.
-//
-// Returns:
-// - HTTP 200 OK with "healthy" text response
-//
-// Example usage:
-//
-// f.Get("/health", commonsHttp.HealthSimple)
-var HealthSimple = Ping
diff --git a/commons/net/http/health_config_test.go b/commons/net/http/health_config_test.go
new file mode 100644
index 00000000..97205189
--- /dev/null
+++ b/commons/net/http/health_config_test.go
@@ -0,0 +1,101 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestHealthWithDependencies_EmptyDependencyName(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(DependencyCheck{Name: ""}))
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/health", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ body := readHealthConfigBody(t, resp)
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "degraded", result["status"])
+ assert.Equal(t, ErrEmptyDependencyName.Error(), result["error"])
+ _, hasMessage := result["message"]
+ _, hasDependencies := result["dependencies"]
+ assert.False(t, hasMessage)
+ assert.False(t, hasDependencies)
+ assert.Len(t, result, 2)
+}
+
+func TestHealthWithDependencies_DuplicateDependencyName(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{Name: "database"},
+ DependencyCheck{Name: "database"},
+ ))
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/health", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ body := readHealthConfigBody(t, resp)
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "degraded", result["status"])
+ assert.Equal(t, ErrDuplicateDependencyName.Error(), result["error"])
+ _, hasMessage := result["message"]
+ _, hasDependencies := result["dependencies"]
+ assert.False(t, hasMessage)
+ assert.False(t, hasDependencies)
+ assert.Len(t, result, 2)
+}
+
+func TestHealthWithDependencies_CBWithoutServiceName_ConfigPayload(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{Name: "database", CircuitBreaker: &mockCBManager{}},
+ ))
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/health", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ body := readHealthConfigBody(t, resp)
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "degraded", result["status"])
+ assert.Equal(t, ErrCBWithoutServiceName.Error(), result["error"])
+ _, hasMessage := result["message"]
+ _, hasDependencies := result["dependencies"]
+ assert.False(t, hasMessage)
+ assert.False(t, hasDependencies)
+ assert.Len(t, result, 2)
+}
+
+func readHealthConfigBody(t *testing.T, resp *http.Response) []byte {
+ t.Helper()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ return body
+}
diff --git a/commons/net/http/health_integration_test.go b/commons/net/http/health_integration_test.go
new file mode 100644
index 00000000..1be748e4
--- /dev/null
+++ b/commons/net/http/health_integration_test.go
@@ -0,0 +1,458 @@
+//go:build integration
+
+package http
+
+import (
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/circuitbreaker"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// errSimulated is a sentinel error used to drive circuit breaker failures in tests.
+var errSimulated = errors.New("simulated service failure")
+
+// testConfig returns a circuit breaker Config with very short timeouts and
+// low thresholds suitable for integration tests that need the breaker to trip
+// quickly and recover within a bounded wall-clock time.
+func testConfig() circuitbreaker.Config {
+ return circuitbreaker.Config{
+ MaxRequests: 1,
+ Interval: 1 * time.Second,
+ Timeout: 1 * time.Second,
+ ConsecutiveFailures: 2,
+ FailureRatio: 0.5,
+ MinRequests: 2,
+ }
+}
+
+// parseHealthResponse reads and decodes the JSON body from a health endpoint
+// response. It returns the top-level map and the nested dependencies map.
+func parseHealthResponse(t *testing.T, resp *http.Response) (result map[string]any, deps map[string]any) {
+ t.Helper()
+
+ err := json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err, "failed to decode health response body")
+
+ raw, ok := result["dependencies"]
+ require.True(t, ok, "expected 'dependencies' key in response")
+
+ deps, ok = raw.(map[string]any)
+ require.True(t, ok, "expected 'dependencies' to be a JSON object")
+
+ return result, deps
+}
+
+// depStatus extracts a single dependency's status object from the dependencies map.
+func depStatus(t *testing.T, deps map[string]any, name string) map[string]any {
+ t.Helper()
+
+ raw, ok := deps[name]
+ require.True(t, ok, "expected dependency %q in response", name)
+
+ status, ok := raw.(map[string]any)
+ require.True(t, ok, "expected dependency %q to be a JSON object", name)
+
+ return status
+}
+
+// tripCircuitBreaker drives enough failures through the manager's Execute path
+// to move the circuit breaker for serviceName into the open state.
+func tripCircuitBreaker(t *testing.T, mgr circuitbreaker.Manager, serviceName string, failures int) {
+ t.Helper()
+
+ for i := range failures {
+ _, err := mgr.Execute(serviceName, func() (any, error) {
+ return nil, errSimulated
+ })
+ // Early executions return the simulated error; once the breaker
+ // trips the manager wraps the gobreaker open-state error.
+ require.Error(t, err, "expected error on failure iteration %d", i)
+ }
+
+ state := mgr.GetState(serviceName)
+ require.Equal(t, circuitbreaker.StateOpen, state,
+ "circuit breaker should be open after %d consecutive failures", failures)
+}
+
+// ---------------------------------------------------------------------------
+// Test 1: All dependencies healthy — circuit breaker in closed state.
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Health_AllDependenciesHealthy(t *testing.T) {
+ logger := log.NewNop()
+
+ mgr, err := circuitbreaker.NewManager(logger)
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("postgres", circuitbreaker.DefaultConfig())
+ require.NoError(t, err)
+
+ // Drive one successful execution so the breaker has activity.
+ _, err = mgr.Execute("postgres", func() (any, error) {
+ return "ok", nil
+ })
+ require.NoError(t, err)
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "postgres",
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ result, deps := parseHealthResponse(t, resp)
+
+ assert.Equal(t, "available", result["status"])
+
+ dbStatus := depStatus(t, deps, "database")
+ assert.Equal(t, true, dbStatus["healthy"])
+ assert.Equal(t, "closed", dbStatus["circuit_breaker_state"])
+
+ // Verify the counts reflect the successful execution.
+ assert.GreaterOrEqual(t, dbStatus["total_successes"], float64(1),
+ "should report at least 1 success")
+}
+
+// ---------------------------------------------------------------------------
+// Test 2: Dependency unhealthy — circuit breaker tripped to open state.
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Health_DependencyUnhealthy_CircuitOpen(t *testing.T) {
+ logger := log.NewNop()
+ cfg := testConfig()
+
+ mgr, err := circuitbreaker.NewManager(logger)
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("redis", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker: cfg.ConsecutiveFailures == 2, so 2 failures suffice.
+ tripCircuitBreaker(t, mgr, "redis", int(cfg.ConsecutiveFailures))
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "cache",
+ CircuitBreaker: mgr,
+ ServiceName: "redis",
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ result, deps := parseHealthResponse(t, resp)
+
+ assert.Equal(t, "degraded", result["status"])
+
+ cacheStatus := depStatus(t, deps, "cache")
+ assert.Equal(t, false, cacheStatus["healthy"])
+ assert.Equal(t, "open", cacheStatus["circuit_breaker_state"])
+
+ // NOTE: gobreaker resets internal counters to zero when transitioning to
+ // the open state. Because DependencyStatus uses `omitempty` on uint32
+ // counter fields, zero-valued counters are omitted from the JSON.
+ // We verify the breaker tripped by confirming state == "open" above.
+}
+
+// ---------------------------------------------------------------------------
+// Test 3: Custom HealthCheck function (no circuit breaker).
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Health_CustomHealthCheck(t *testing.T) {
+ // Sub-test: healthy custom check → 200.
+ t.Run("healthy", func(t *testing.T) {
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "external-api",
+ HealthCheck: func() bool { return true },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ result, deps := parseHealthResponse(t, resp)
+ assert.Equal(t, "available", result["status"])
+
+ apiStatus := depStatus(t, deps, "external-api")
+ assert.Equal(t, true, apiStatus["healthy"])
+
+ // No circuit breaker configured — state field must be absent.
+ _, hasCBState := apiStatus["circuit_breaker_state"]
+ assert.False(t, hasCBState, "circuit_breaker_state should be omitted when no CB is configured")
+ })
+
+ // Sub-test: unhealthy custom check → 503.
+ t.Run("unhealthy", func(t *testing.T) {
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "external-api",
+ HealthCheck: func() bool { return false },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ result, deps := parseHealthResponse(t, resp)
+ assert.Equal(t, "degraded", result["status"])
+
+ apiStatus := depStatus(t, deps, "external-api")
+ assert.Equal(t, false, apiStatus["healthy"])
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Test 4: AND semantics — both circuit breaker AND health check must pass.
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Health_BothChecks_ANDSemantics(t *testing.T) {
+ logger := log.NewNop()
+
+ mgr, err := circuitbreaker.NewManager(logger)
+ require.NoError(t, err)
+
+ // Create a breaker and leave it in the closed (healthy) state.
+ _, err = mgr.GetOrCreate("postgres", circuitbreaker.DefaultConfig())
+ require.NoError(t, err)
+
+ // One successful execution to confirm the breaker is alive and closed.
+ _, err = mgr.Execute("postgres", func() (any, error) {
+ return "ok", nil
+ })
+ require.NoError(t, err)
+
+ assert.Equal(t, circuitbreaker.StateClosed, mgr.GetState("postgres"),
+ "precondition: circuit breaker should be closed")
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "postgres",
+ HealthCheck: func() bool { return false }, // custom check fails
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // Circuit breaker is healthy (closed), but HealthCheck returns false.
+ // AND semantics → overall degraded.
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ result, deps := parseHealthResponse(t, resp)
+
+ assert.Equal(t, "degraded", result["status"])
+
+ dbStatus := depStatus(t, deps, "database")
+ assert.Equal(t, false, dbStatus["healthy"])
+ // The circuit breaker itself is still closed.
+ assert.Equal(t, "closed", dbStatus["circuit_breaker_state"])
+}
+
+// ---------------------------------------------------------------------------
+// Test 5: Multiple dependencies — 2 healthy, 1 unhealthy → 503.
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Health_MultipleDependencies(t *testing.T) {
+ logger := log.NewNop()
+ cfg := testConfig()
+
+ mgr, err := circuitbreaker.NewManager(logger)
+ require.NoError(t, err)
+
+ // Service 1: postgres — healthy.
+ _, err = mgr.GetOrCreate("postgres", circuitbreaker.DefaultConfig())
+ require.NoError(t, err)
+
+ _, err = mgr.Execute("postgres", func() (any, error) {
+ return "ok", nil
+ })
+ require.NoError(t, err)
+
+ // Service 2: redis — tripped to open.
+ _, err = mgr.GetOrCreate("redis", cfg)
+ require.NoError(t, err)
+
+ tripCircuitBreaker(t, mgr, "redis", int(cfg.ConsecutiveFailures))
+
+ // Service 3: external-api — healthy via custom check (no circuit breaker).
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "postgres",
+ },
+ DependencyCheck{
+ Name: "cache",
+ CircuitBreaker: mgr,
+ ServiceName: "redis",
+ },
+ DependencyCheck{
+ Name: "external-api",
+ HealthCheck: func() bool { return true },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // One unhealthy dependency makes the overall status degraded.
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ result, deps := parseHealthResponse(t, resp)
+
+ assert.Equal(t, "degraded", result["status"])
+ require.Len(t, deps, 3, "should report all 3 dependencies")
+
+ // Database: healthy, closed.
+ dbStatus := depStatus(t, deps, "database")
+ assert.Equal(t, true, dbStatus["healthy"])
+ assert.Equal(t, "closed", dbStatus["circuit_breaker_state"])
+
+ // Cache: unhealthy, open.
+ cacheStatus := depStatus(t, deps, "cache")
+ assert.Equal(t, false, cacheStatus["healthy"])
+ assert.Equal(t, "open", cacheStatus["circuit_breaker_state"])
+
+ // External API: healthy, no circuit breaker state.
+ apiStatus := depStatus(t, deps, "external-api")
+ assert.Equal(t, true, apiStatus["healthy"])
+
+ _, hasCBState := apiStatus["circuit_breaker_state"]
+ assert.False(t, hasCBState, "external-api should not have circuit_breaker_state")
+}
+
+// ---------------------------------------------------------------------------
+// Test 6: Circuit recovery — open → half-open → closed after timeout.
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Health_CircuitRecovery(t *testing.T) {
+ logger := log.NewNop()
+ cfg := testConfig() // Timeout is 1s — the breaker moves to half-open after this.
+
+ mgr, err := circuitbreaker.NewManager(logger)
+ require.NoError(t, err)
+
+ _, err = mgr.GetOrCreate("postgres", cfg)
+ require.NoError(t, err)
+
+ // Trip the breaker to open.
+ tripCircuitBreaker(t, mgr, "postgres", int(cfg.ConsecutiveFailures))
+
+ // ---- Phase 1: health should report degraded while circuit is open ----
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "postgres",
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ result, deps := parseHealthResponse(t, resp)
+ assert.Equal(t, "degraded", result["status"])
+
+ dbStatus := depStatus(t, deps, "database")
+ assert.Equal(t, false, dbStatus["healthy"])
+ assert.Equal(t, "open", dbStatus["circuit_breaker_state"])
+
+ // ---- Phase 2: wait for timeout so breaker moves to half-open ----
+
+ // Sleep slightly beyond the configured Timeout (1s) to allow the
+ // gobreaker state machine to transition from open → half-open.
+ time.Sleep(cfg.Timeout + 200*time.Millisecond)
+
+ // In half-open state, gobreaker allows MaxRequests probe requests.
+ // A successful probe moves the breaker back to closed.
+ _, err = mgr.Execute("postgres", func() (any, error) {
+ return "recovered", nil
+ })
+ require.NoError(t, err)
+
+ // Verify the breaker is now closed.
+ assert.Equal(t, circuitbreaker.StateClosed, mgr.GetState("postgres"),
+ "circuit breaker should transition back to closed after successful probe")
+
+ // ---- Phase 3: health should report available after recovery ----
+
+ req = httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp2, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, resp2.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp2.StatusCode)
+
+ result2, deps2 := parseHealthResponse(t, resp2)
+ assert.Equal(t, "available", result2["status"])
+
+ dbStatus2 := depStatus(t, deps2, "database")
+ assert.Equal(t, true, dbStatus2["healthy"])
+ assert.Equal(t, "closed", dbStatus2["circuit_breaker_state"])
+}
diff --git a/commons/net/http/health_test.go b/commons/net/http/health_test.go
new file mode 100644
index 00000000..11b921d1
--- /dev/null
+++ b/commons/net/http/health_test.go
@@ -0,0 +1,313 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/circuitbreaker"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// mockCBManager implements circuitbreaker.Manager for testing.
+type mockCBManager struct {
+ state circuitbreaker.State
+ counts circuitbreaker.Counts
+ healthy bool
+}
+
+func (m *mockCBManager) GetOrCreate(string, circuitbreaker.Config) (circuitbreaker.CircuitBreaker, error) {
+ return nil, nil
+}
+
+func (m *mockCBManager) Execute(string, func() (any, error)) (any, error) { return nil, nil }
+func (m *mockCBManager) GetState(string) circuitbreaker.State { return m.state }
+func (m *mockCBManager) GetCounts(string) circuitbreaker.Counts { return m.counts }
+func (m *mockCBManager) IsHealthy(string) bool { return m.healthy }
+func (m *mockCBManager) Reset(string) {}
+func (m *mockCBManager) RegisterStateChangeListener(circuitbreaker.StateChangeListener) {
+}
+
+func TestHealthWithDependencies_NoDeps(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies())
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "available", result["status"])
+}
+
+func TestHealthWithDependencies_AllHealthy(t *testing.T) {
+ t.Parallel()
+
+ mgr := &mockCBManager{
+ state: circuitbreaker.StateClosed,
+ counts: circuitbreaker.Counts{Requests: 10, TotalSuccesses: 10},
+ healthy: true,
+ }
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{Name: "database", CircuitBreaker: mgr, ServiceName: "pg"},
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "available", result["status"])
+}
+
+func TestHealthWithDependencies_MixedHealthy(t *testing.T) {
+ t.Parallel()
+
+ healthyMgr := &mockCBManager{state: circuitbreaker.StateClosed, healthy: true}
+ unhealthyMgr := &mockCBManager{
+ state: circuitbreaker.StateOpen, healthy: false,
+ counts: circuitbreaker.Counts{TotalFailures: 5, ConsecutiveFailures: 3},
+ }
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{Name: "database", CircuitBreaker: healthyMgr, ServiceName: "pg"},
+ DependencyCheck{Name: "cache", CircuitBreaker: unhealthyMgr, ServiceName: "redis"},
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "degraded", result["status"])
+}
+
+func TestHealthWithDependencies_CustomHealthCheck(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "external-api",
+ HealthCheck: func() bool { return true },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "available", result["status"])
+
+ deps, ok := result["dependencies"].(map[string]any)
+ require.True(t, ok, "expected dependencies map")
+ require.Len(t, deps, 1)
+
+ dep, ok := deps["external-api"].(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, true, dep["healthy"])
+}
+
+func TestHealthWithDependencies_CustomHealthCheckUnhealthy(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "external-api",
+ HealthCheck: func() bool { return false },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "degraded", result["status"])
+
+ deps, ok := result["dependencies"].(map[string]any)
+ require.True(t, ok, "expected dependencies map")
+ require.Len(t, deps, 1)
+
+ dep, ok := deps["external-api"].(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, false, dep["healthy"])
+}
+
+func TestHealthWithDependencies_HealthCheckOverridesCB(t *testing.T) {
+ t.Parallel()
+
+ mgr := &mockCBManager{state: circuitbreaker.StateClosed, healthy: true}
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "pg",
+ HealthCheck: func() bool { return false },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+}
+
+// ---------------------------------------------------------------------------
+// AND semantics: Both CB and HealthCheck must pass
+// ---------------------------------------------------------------------------
+
+func TestHealthWithDependencies_ANDSemantics_CBHealthyButHealthCheckFails(t *testing.T) {
+ t.Parallel()
+
+ mgr := &mockCBManager{state: circuitbreaker.StateClosed, healthy: true}
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "pg",
+ HealthCheck: func() bool { return false }, // HealthCheck fails
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // CB is healthy but HealthCheck returns false -> overall degraded
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "degraded", result["status"])
+
+ deps, ok := result["dependencies"].(map[string]any)
+ require.True(t, ok)
+
+ dep, ok := deps["database"].(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, false, dep["healthy"])
+}
+
+func TestHealthWithDependencies_ANDSemantics_CBUnhealthyButHealthCheckPasses(t *testing.T) {
+ t.Parallel()
+
+ mgr := &mockCBManager{state: circuitbreaker.StateOpen, healthy: false}
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "pg",
+ HealthCheck: func() bool { return true }, // HealthCheck passes
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // CB is unhealthy -> overall degraded (HealthCheck can't override CB's false)
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "degraded", result["status"])
+}
+
+func TestHealthWithDependencies_ANDSemantics_BothHealthy(t *testing.T) {
+ t.Parallel()
+
+ mgr := &mockCBManager{state: circuitbreaker.StateClosed, healthy: true}
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{
+ Name: "database",
+ CircuitBreaker: mgr,
+ ServiceName: "pg",
+ HealthCheck: func() bool { return true },
+ },
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // Both healthy -> available
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Equal(t, "available", result["status"])
+}
+
+func TestHealthWithDependencies_CBWithoutServiceName(t *testing.T) {
+ t.Parallel()
+
+ mgr := &mockCBManager{state: circuitbreaker.StateOpen, healthy: false}
+
+ app := fiber.New()
+ app.Get("/health", HealthWithDependencies(
+ DependencyCheck{Name: "orphan-cb", CircuitBreaker: mgr, ServiceName: ""},
+ ))
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // CircuitBreaker provided without ServiceName is a misconfiguration.
+ // The handler returns 503 with an error to make it immediately visible.
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+}
diff --git a/commons/net/http/matcher_response.go b/commons/net/http/matcher_response.go
new file mode 100644
index 00000000..044a173a
--- /dev/null
+++ b/commons/net/http/matcher_response.go
@@ -0,0 +1,71 @@
+package http
+
+import (
+ "errors"
+ "net/http"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/gofiber/fiber/v2"
+)
+
+// ErrorResponse provides a consistent error structure for API responses.
+// @Description Standard error response returned by all API endpoints
+type ErrorResponse struct {
+ // HTTP status code
+ Code int `json:"code" example:"400"`
+ // Error type identifier
+ Title string `json:"title" example:"invalid_request"`
+ // Human-readable error message
+ Message string `json:"message" example:"context name is required"`
+}
+
+// Error allows ErrorResponse to satisfy the error interface.
+func (e ErrorResponse) Error() string {
+ return e.Message
+}
+
+// RenderError writes all transport errors through a single, stable contract.
+func RenderError(ctx *fiber.Ctx, err error) error {
+ if ctx == nil {
+ return ErrContextNotFound
+ }
+
+ if err == nil {
+ return nil
+ }
+
+ // errors.As with a value target matches both ErrorResponse and *ErrorResponse,
+ // since ErrorResponse implements error via a value receiver.
+ var responseErr ErrorResponse
+ if errors.As(err, &responseErr) {
+ return renderErrorResponse(ctx, responseErr)
+ }
+
+ var fiberErr *fiber.Error
+ if errors.As(err, &fiberErr) {
+ return RespondError(ctx, fiberErr.Code, cn.DefaultErrorTitle, fiberErr.Message)
+ }
+
+ return RespondError(ctx, fiber.StatusInternalServerError, cn.DefaultErrorTitle, cn.DefaultInternalErrorMessage)
+}
+
+// renderErrorResponse normalizes and sends an ErrorResponse with safe defaults.
+func renderErrorResponse(ctx *fiber.Ctx, resp ErrorResponse) error {
+ status := fiber.StatusInternalServerError
+
+ if resp.Code >= http.StatusContinue && resp.Code <= 599 {
+ status = resp.Code
+ }
+
+ title := resp.Title
+ if title == "" {
+ title = cn.DefaultErrorTitle
+ }
+
+ message := resp.Message
+ if message == "" {
+ message = http.StatusText(status)
+ }
+
+ return RespondError(ctx, status, title, message)
+}
diff --git a/commons/net/http/matcher_response_test.go b/commons/net/http/matcher_response_test.go
new file mode 100644
index 00000000..700c181a
--- /dev/null
+++ b/commons/net/http/matcher_response_test.go
@@ -0,0 +1,431 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// ErrorResponse edge cases
+// ---------------------------------------------------------------------------
+
+func TestErrorResponse_EmptyMessageReturnsEmpty(t *testing.T) {
+ t.Parallel()
+
+ errResp := ErrorResponse{
+ Code: 400,
+ Title: "bad_request",
+ Message: "",
+ }
+
+ assert.Equal(t, "", errResp.Error())
+}
+
+func TestErrorResponse_JSONDeserializationFromString(t *testing.T) {
+ t.Parallel()
+
+ jsonData := `{"code":503,"title":"service_unavailable","message":"try again later"}`
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal([]byte(jsonData), &errResp))
+
+ assert.Equal(t, 503, errResp.Code)
+ assert.Equal(t, "service_unavailable", errResp.Title)
+ assert.Equal(t, "try again later", errResp.Message)
+}
+
+func TestErrorResponse_PartialJSONDeserializationOnlyCode(t *testing.T) {
+ t.Parallel()
+
+ // Only code field present
+ jsonData := `{"code":400}`
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal([]byte(jsonData), &errResp))
+
+ assert.Equal(t, 400, errResp.Code)
+ assert.Equal(t, "", errResp.Title)
+ assert.Equal(t, "", errResp.Message)
+}
+
+func TestErrorResponse_PartialJSONDeserializationOnlyMessage(t *testing.T) {
+ t.Parallel()
+
+ jsonData := `{"message":"something went wrong"}`
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal([]byte(jsonData), &errResp))
+
+ assert.Equal(t, 0, errResp.Code)
+ assert.Equal(t, "", errResp.Title)
+ assert.Equal(t, "something went wrong", errResp.Message)
+}
+
+func TestErrorResponse_JSONRoundTripWithSpecialChars(t *testing.T) {
+ t.Parallel()
+
+ original := ErrorResponse{
+ Code: 418,
+ Title: "im_a_teapot",
+ Message: "I'm a teapot with \"quotes\" and ",
+ }
+
+ data, err := json.Marshal(original)
+ require.NoError(t, err)
+
+ var decoded ErrorResponse
+ require.NoError(t, json.Unmarshal(data, &decoded))
+
+ assert.Equal(t, original, decoded)
+}
+
+func TestErrorResponse_EmptyJSON(t *testing.T) {
+ t.Parallel()
+
+ jsonData := `{}`
+
+ var errResp ErrorResponse
+ require.NoError(t, json.Unmarshal([]byte(jsonData), &errResp))
+
+ assert.Equal(t, 0, errResp.Code)
+ assert.Equal(t, "", errResp.Title)
+ assert.Equal(t, "", errResp.Message)
+}
+
+// ---------------------------------------------------------------------------
+// Nil guard tests
+// ---------------------------------------------------------------------------
+
+func TestRenderError_NilContext(t *testing.T) {
+ t.Parallel()
+
+ err := RenderError(nil, ErrorResponse{Code: 400, Title: "bad", Message: "nil ctx"})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestRenderError_NilError(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, nil)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // RenderError(c, nil) returns nil, so no response body is written -> Fiber defaults to 200
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+// ---------------------------------------------------------------------------
+// RenderError code boundary tests
+// ---------------------------------------------------------------------------
+
+func TestRenderError_CodeBoundaryAt100(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 100,
+ Title: "continue",
+ Message: "boundary test at 100",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, 100, resp.StatusCode)
+}
+
+func TestRenderError_CodeBoundaryAt599(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 599,
+ Title: "custom_error",
+ Message: "boundary test at 599",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, 599, resp.StatusCode)
+}
+
+func TestRenderError_CodeAt99FallsBackTo500(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 99,
+ Title: "test_error",
+ Message: "code 99 should fall back",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
+}
+
+func TestRenderError_CodeAt600FallsBackTo500(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 600,
+ Title: "test_error",
+ Message: "code 600 should fall back",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
+}
+
+// ---------------------------------------------------------------------------
+// RenderError with both empty title and message
+// ---------------------------------------------------------------------------
+
+func TestRenderError_EmptyTitleAndMessageDefaultsBoth(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 500,
+ Title: "",
+ Message: "",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusInternalServerError, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+
+ // Both should be filled with defaults
+ assert.Equal(t, "request_failed", result["title"])
+ assert.Equal(t, "Internal Server Error", result["message"])
+}
+
+// ---------------------------------------------------------------------------
+// RenderError response structure validation
+// ---------------------------------------------------------------------------
+
+func TestRenderError_ResponseHasExactlyThreeFields(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 409,
+ Title: "conflict",
+ Message: "resource already exists",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+
+ assert.Len(t, result, 3, "response should have exactly code, title, and message")
+ assert.Contains(t, result, "code")
+ assert.Contains(t, result, "title")
+ assert.Contains(t, result, "message")
+}
+
+// ---------------------------------------------------------------------------
+// RenderError across HTTP methods
+// ---------------------------------------------------------------------------
+
+func TestRenderError_WorksForAllHTTPMethods(t *testing.T) {
+ t.Parallel()
+
+ methods := []string{
+ http.MethodGet,
+ http.MethodPost,
+ http.MethodPut,
+ http.MethodPatch,
+ http.MethodDelete,
+ }
+
+ for _, method := range methods {
+ t.Run(method, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ handler := func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 400,
+ Title: "bad_request",
+ Message: "test",
+ })
+ }
+
+ switch method {
+ case http.MethodGet:
+ app.Get("/test", handler)
+ case http.MethodPost:
+ app.Post("/test", handler)
+ case http.MethodPut:
+ app.Put("/test", handler)
+ case http.MethodPatch:
+ app.Patch("/test", handler)
+ case http.MethodDelete:
+ app.Delete("/test", handler)
+ }
+
+ req := httptest.NewRequest(method, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// RenderError with fiber.Error with default message
+// ---------------------------------------------------------------------------
+
+func TestRenderError_FiberErrorDefaultMessage(t *testing.T) {
+ t.Parallel()
+
+ // fiber.NewError with just a code uses the default HTTP status text
+ fiberErr := fiber.NewError(fiber.StatusGatewayTimeout)
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, fiberErr)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, fiber.StatusGatewayTimeout, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var result map[string]any
+ require.NoError(t, json.Unmarshal(body, &result))
+ assert.Equal(t, "request_failed", result["title"])
+}
+
+// ---------------------------------------------------------------------------
+// RenderError content type
+// ---------------------------------------------------------------------------
+
+func TestRenderError_ReturnsJSON(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: 400,
+ Title: "bad_request",
+ Message: "test",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ contentType := resp.Header.Get("Content-Type")
+ assert.Contains(t, contentType, "application/json")
+}
+
+// ---------------------------------------------------------------------------
+// RenderError with various 2xx/3xx codes (unusual but valid)
+// ---------------------------------------------------------------------------
+
+func TestRenderError_UnusualValidCodes(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ code int
+ }{
+ {"200 OK (unusual for error)", 200},
+ {"201 Created (unusual for error)", 201},
+ {"301 Moved Permanently", 301},
+ {"302 Found", 302},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return RenderError(c, ErrorResponse{
+ Code: tt.code,
+ Title: "test",
+ Message: "unusual code",
+ })
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // Valid HTTP codes between 100-599 should be used as-is
+ assert.Equal(t, tt.code, resp.StatusCode)
+ })
+ }
+}
diff --git a/commons/net/http/middleware_example_test.go b/commons/net/http/middleware_example_test.go
new file mode 100644
index 00000000..9d461d68
--- /dev/null
+++ b/commons/net/http/middleware_example_test.go
@@ -0,0 +1,35 @@
+//go:build unit
+
+package http_test
+
+import (
+ "encoding/base64"
+ "fmt"
+ "net/http/httptest"
+
+ uhttp "github.com/LerianStudio/lib-commons/v4/commons/net/http"
+ "github.com/gofiber/fiber/v2"
+)
+
+func ExampleWithBasicAuth() {
+ app := fiber.New()
+ app.Use(uhttp.WithBasicAuth(uhttp.FixedBasicAuthFunc("fred", "secret"), "admin"))
+ app.Get("/private", func(c *fiber.Ctx) error {
+ return c.SendStatus(fiber.StatusNoContent)
+ })
+
+ unauthorizedReq := httptest.NewRequest("GET", "/private", nil)
+ unauthorizedResp, _ := app.Test(unauthorizedReq)
+
+ authHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte("fred:secret"))
+ authorizedReq := httptest.NewRequest("GET", "/private", nil)
+ authorizedReq.Header.Set("Authorization", authHeader)
+ authorizedResp, _ := app.Test(authorizedReq)
+
+ fmt.Println(unauthorizedResp.StatusCode)
+ fmt.Println(authorizedResp.StatusCode)
+
+ // Output:
+ // 401
+ // 204
+}
diff --git a/commons/net/http/pagination.go b/commons/net/http/pagination.go
new file mode 100644
index 00000000..6c55279e
--- /dev/null
+++ b/commons/net/http/pagination.go
@@ -0,0 +1,125 @@
+package http
+
+import (
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "strconv"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+)
+
+// ErrLimitMustBePositive is returned when limit is below 1.
+var ErrLimitMustBePositive = errors.New("limit must be greater than zero")
+
+// ErrInvalidCursor is returned when the cursor cannot be decoded.
+var ErrInvalidCursor = errors.New("invalid cursor format")
+
+// ValidateLimitStrict validates a pagination limit without silently coercing
+// non-positive values. It returns ErrLimitMustBePositive when limit < 1 and
+// caps values above maxLimit.
+func ValidateLimitStrict(limit, maxLimit int) (int, error) {
+ if limit <= 0 {
+ return 0, ErrLimitMustBePositive
+ }
+
+ if limit > maxLimit {
+ return maxLimit, nil
+ }
+
+ return limit, nil
+}
+
+// ParsePagination parses limit/offset query params with defaults.
+// Non-numeric values return an error. Negative or zero limits are coerced to
+// DefaultLimit; negative offsets are coerced to DefaultOffset; limits above
+// MaxLimit are capped.
+func ParsePagination(fiberCtx *fiber.Ctx) (int, int, error) {
+ if fiberCtx == nil {
+ return 0, 0, ErrContextNotFound
+ }
+
+ limit := cn.DefaultLimit
+ offset := cn.DefaultOffset
+
+ if limitValue := fiberCtx.Query("limit"); limitValue != "" {
+ parsed, err := strconv.Atoi(limitValue)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid limit value: %w", err)
+ }
+
+ limit = parsed
+ }
+
+ if offsetValue := fiberCtx.Query("offset"); offsetValue != "" {
+ parsed, err := strconv.Atoi(offsetValue)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid offset value: %w", err)
+ }
+
+ offset = parsed
+ }
+
+ if limit <= 0 {
+ limit = cn.DefaultLimit
+ }
+
+ if limit > cn.MaxLimit {
+ limit = cn.MaxLimit
+ }
+
+ if offset < 0 {
+ offset = cn.DefaultOffset
+ }
+
+ return limit, offset, nil
+}
+
+// ParseOpaqueCursorPagination parses cursor/limit query params for opaque cursor pagination.
+// It validates limit but does not attempt to decode the cursor string.
+// Returns the raw cursor string (empty for first page), limit, and any error.
+func ParseOpaqueCursorPagination(fiberCtx *fiber.Ctx) (string, int, error) {
+ if fiberCtx == nil {
+ return "", 0, ErrContextNotFound
+ }
+
+ limit := cn.DefaultLimit
+
+ if limitValue := fiberCtx.Query("limit"); limitValue != "" {
+ parsed, err := strconv.Atoi(limitValue)
+ if err != nil {
+ return "", 0, fmt.Errorf("invalid limit value: %w", err)
+ }
+
+ limit = ValidateLimit(parsed, cn.DefaultLimit, cn.MaxLimit)
+ }
+
+ cursorParam := fiberCtx.Query("cursor")
+ if cursorParam == "" {
+ return "", limit, nil
+ }
+
+ return cursorParam, limit, nil
+}
+
+// EncodeUUIDCursor encodes a UUID into a base64 cursor string.
+func EncodeUUIDCursor(id uuid.UUID) string {
+ return base64.StdEncoding.EncodeToString([]byte(id.String()))
+}
+
+// DecodeUUIDCursor decodes a base64 cursor string into a UUID.
+func DecodeUUIDCursor(cursor string) (uuid.UUID, error) {
+ decoded, err := base64.StdEncoding.DecodeString(cursor)
+ if err != nil {
+ return uuid.Nil, fmt.Errorf("%w: decode failed: %w", ErrInvalidCursor, err)
+ }
+
+ id, err := uuid.Parse(string(decoded))
+ if err != nil {
+ return uuid.Nil, fmt.Errorf("%w: parse failed: %w", ErrInvalidCursor, err)
+ }
+
+ return id, nil
+}
diff --git a/commons/net/http/pagination_cursor_timestamp_test.go b/commons/net/http/pagination_cursor_timestamp_test.go
new file mode 100644
index 00000000..04f4d837
--- /dev/null
+++ b/commons/net/http/pagination_cursor_timestamp_test.go
@@ -0,0 +1,289 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/base64"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestEncodeTimestampCursor(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ timestamp time.Time
+ id uuid.UUID
+ }{
+ {
+ name: "valid timestamp and UUID",
+ timestamp: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
+ id: uuid.MustParse("550e8400-e29b-41d4-a716-446655440000"),
+ },
+ {
+ name: "zero timestamp",
+ timestamp: time.Time{},
+ id: uuid.MustParse("550e8400-e29b-41d4-a716-446655440000"),
+ },
+ {
+ name: "non-UTC timestamp gets converted to UTC",
+ timestamp: time.Date(2025, 1, 15, 10, 30, 0, 0, time.FixedZone("EST", -5*60*60)),
+ id: uuid.MustParse("550e8400-e29b-41d4-a716-446655440000"),
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ encoded, err := EncodeTimestampCursor(tc.timestamp, tc.id)
+ require.NoError(t, err)
+ assert.NotEmpty(t, encoded)
+
+ decoded, err := DecodeTimestampCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, tc.id, decoded.ID)
+ assert.Equal(t, tc.timestamp.UTC(), decoded.Timestamp)
+ })
+ }
+}
+
+func TestDecodeTimestampCursor(t *testing.T) {
+ t.Parallel()
+
+ validTimestamp := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
+ validID := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000")
+ validCursor, encErr := EncodeTimestampCursor(validTimestamp, validID)
+ require.NoError(t, encErr)
+
+ tests := []struct {
+ name string
+ cursor string
+ expectedTimestamp time.Time
+ expectedID uuid.UUID
+ errContains string
+ }{
+ {
+ name: "valid cursor",
+ cursor: validCursor,
+ expectedTimestamp: validTimestamp,
+ expectedID: validID,
+ },
+ {
+ name: "empty string",
+ cursor: "",
+ errContains: "unmarshal failed",
+ },
+ {
+ name: "whitespace only",
+ cursor: " ",
+ errContains: "decode failed",
+ },
+ {
+ name: "invalid base64",
+ cursor: "not-valid-base64!!!",
+ errContains: "decode failed",
+ },
+ {
+ name: "valid base64 but invalid JSON",
+ cursor: base64.StdEncoding.EncodeToString([]byte("not-json")),
+ errContains: "unmarshal failed",
+ },
+ {
+ name: "valid JSON but missing ID",
+ cursor: base64.StdEncoding.EncodeToString([]byte(`{"t":"2025-01-15T10:30:00Z"}`)),
+ errContains: "missing id",
+ },
+ {
+ name: "valid JSON with nil UUID",
+ cursor: base64.StdEncoding.EncodeToString(
+ []byte(`{"t":"2025-01-15T10:30:00Z","i":"00000000-0000-0000-0000-000000000000"}`),
+ ),
+ errContains: "missing id",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ decoded, err := DecodeTimestampCursor(tc.cursor)
+
+ if tc.errContains != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tc.errContains)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Nil(t, decoded)
+
+ return
+ }
+
+ require.NoError(t, err)
+ require.NotNil(t, decoded)
+ assert.Equal(t, tc.expectedTimestamp, decoded.Timestamp)
+ assert.Equal(t, tc.expectedID, decoded.ID)
+ })
+ }
+}
+
+func TestParseTimestampCursorPagination(t *testing.T) {
+ t.Parallel()
+
+ validTimestamp := time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC)
+ validID := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000")
+ validCursor, encErr := EncodeTimestampCursor(validTimestamp, validID)
+ require.NoError(t, encErr)
+
+ tests := []struct {
+ name string
+ queryString string
+ expectedLimit int
+ expectedTimestamp *time.Time
+ expectedID *uuid.UUID
+ errContains string
+ errIs error
+ }{
+ {
+ name: "default values when no query params",
+ queryString: "",
+ expectedLimit: 20,
+ },
+ {
+ name: "valid limit only",
+ queryString: "limit=50",
+ expectedLimit: 50,
+ },
+ {
+ name: "valid cursor and limit",
+ queryString: "cursor=" + validCursor + "&limit=30",
+ expectedLimit: 30,
+ expectedTimestamp: &validTimestamp,
+ expectedID: &validID,
+ },
+ {
+ name: "cursor only uses default limit",
+ queryString: "cursor=" + validCursor,
+ expectedLimit: 20,
+ expectedTimestamp: &validTimestamp,
+ expectedID: &validID,
+ },
+ {
+ name: "limit capped at maxLimit",
+ queryString: "limit=500",
+ expectedLimit: 200,
+ },
+ {
+ name: "invalid limit non-numeric",
+ queryString: "limit=abc",
+ errContains: "invalid limit value",
+ },
+ {
+ name: "limit zero uses default limit",
+ queryString: "limit=0",
+ expectedLimit: 20,
+ },
+ {
+ name: "negative limit uses default limit",
+ queryString: "limit=-5",
+ expectedLimit: 20,
+ },
+ {
+ name: "invalid cursor",
+ queryString: "cursor=invalid",
+ errContains: "invalid cursor format",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var cursor *TimestampCursor
+ var limit int
+ var err error
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ cursor, limit, err = ParseTimestampCursorPagination(c)
+ return nil
+ })
+
+ req := httptest.NewRequest("GET", "/test?"+tc.queryString, nil)
+ resp, testErr := app.Test(req)
+ require.NoError(t, testErr)
+ require.NoError(t, resp.Body.Close())
+
+ if tc.errContains != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tc.errContains)
+ if tc.errIs != nil {
+ assert.ErrorIs(t, err, tc.errIs)
+ }
+
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedLimit, limit)
+
+ if tc.expectedTimestamp == nil {
+ assert.Nil(t, cursor)
+ } else {
+ require.NotNil(t, cursor)
+ assert.Equal(t, *tc.expectedTimestamp, cursor.Timestamp)
+ assert.Equal(t, *tc.expectedID, cursor.ID)
+ }
+ })
+ }
+}
+
+func TestTimestampCursor_RoundTrip(t *testing.T) {
+ t.Parallel()
+
+ // Use fixed deterministic values for reproducible tests
+ timestamp := time.Date(2025, 6, 15, 14, 30, 45, 0, time.UTC)
+ id := uuid.MustParse("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
+
+ encoded, encErr := EncodeTimestampCursor(timestamp, id)
+ require.NoError(t, encErr)
+ decoded, err := DecodeTimestampCursor(encoded)
+
+ require.NoError(t, err)
+ require.NotNil(t, decoded)
+ assert.Equal(t, timestamp, decoded.Timestamp)
+ assert.Equal(t, id, decoded.ID)
+}
+
+func TestParseTimestampCursorPagination_NilContext(t *testing.T) {
+ t.Parallel()
+
+ cursor, limit, err := ParseTimestampCursorPagination(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+ assert.Nil(t, cursor)
+ assert.Zero(t, limit)
+}
+
+func TestEncodeTimestampCursor_Success(t *testing.T) {
+ t.Parallel()
+
+ ts := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
+ id := uuid.MustParse("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
+
+ encoded, err := EncodeTimestampCursor(ts, id)
+ require.NoError(t, err)
+ assert.NotEmpty(t, encoded)
+
+ decoded, err := DecodeTimestampCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, ts, decoded.Timestamp)
+ assert.Equal(t, id, decoded.ID)
+}
diff --git a/commons/net/http/pagination_cursor_uuid_test.go b/commons/net/http/pagination_cursor_uuid_test.go
new file mode 100644
index 00000000..2f710949
--- /dev/null
+++ b/commons/net/http/pagination_cursor_uuid_test.go
@@ -0,0 +1,222 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/base64"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseOpaqueCursorPagination(t *testing.T) {
+ t.Parallel()
+
+ opaqueCursor := "opaque-cursor-value"
+
+ tests := []struct {
+ name string
+ queryString string
+ expectedLimit int
+ expectedCursor string
+ errContains string
+ errIs error
+ }{
+ {
+ name: "default values when no query params",
+ queryString: "",
+ expectedLimit: 20,
+ expectedCursor: "",
+ },
+ {
+ name: "valid limit only",
+ queryString: "limit=50",
+ expectedLimit: 50,
+ expectedCursor: "",
+ },
+ {
+ name: "valid cursor and limit",
+ queryString: "cursor=" + opaqueCursor + "&limit=30",
+ expectedLimit: 30,
+ expectedCursor: opaqueCursor,
+ },
+ {
+ name: "cursor only uses default limit",
+ queryString: "cursor=" + opaqueCursor,
+ expectedLimit: 20,
+ expectedCursor: opaqueCursor,
+ },
+ {
+ name: "limit capped at maxLimit",
+ queryString: "limit=500",
+ expectedLimit: 200,
+ expectedCursor: "",
+ },
+ {
+ name: "invalid limit non-numeric",
+ queryString: "limit=abc",
+ errContains: "invalid limit value",
+ },
+ {
+ name: "limit zero uses default limit",
+ queryString: "limit=0",
+ expectedLimit: 20,
+ expectedCursor: "",
+ },
+ {
+ name: "negative limit uses default limit",
+ queryString: "limit=-5",
+ expectedLimit: 20,
+ expectedCursor: "",
+ },
+ {
+ name: "opaque cursor is accepted without validation",
+ queryString: "cursor=not-base64-$$$",
+ expectedLimit: 20,
+ expectedCursor: "not-base64-$$$",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var cursor string
+ var limit int
+ var err error
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ cursor, limit, err = ParseOpaqueCursorPagination(c)
+ return nil
+ })
+
+ req := httptest.NewRequest("GET", "/test?"+tc.queryString, nil)
+ resp, testErr := app.Test(req)
+ require.NoError(t, testErr)
+ require.NoError(t, resp.Body.Close())
+
+ if tc.errContains != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tc.errContains)
+ if tc.errIs != nil {
+ assert.ErrorIs(t, err, tc.errIs)
+ }
+
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedLimit, limit)
+ assert.Equal(t, tc.expectedCursor, cursor)
+ })
+ }
+}
+
+func TestEncodeUUIDCursor(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ id uuid.UUID
+ }{
+ {
+ name: "valid UUID",
+ id: uuid.MustParse("550e8400-e29b-41d4-a716-446655440000"),
+ },
+ {
+ name: "nil UUID",
+ id: uuid.Nil,
+ },
+ {
+ name: "random UUID",
+ id: uuid.New(),
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ encoded := EncodeUUIDCursor(tc.id)
+ assert.NotEmpty(t, encoded)
+
+ decoded, err := DecodeUUIDCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, tc.id, decoded)
+ })
+ }
+}
+
+func TestDecodeUUIDCursor(t *testing.T) {
+ t.Parallel()
+
+ validUUID := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000")
+ validCursor := EncodeUUIDCursor(validUUID)
+
+ tests := []struct {
+ name string
+ cursor string
+ expected uuid.UUID
+ errContains string
+ }{
+ {
+ name: "valid cursor",
+ cursor: validCursor,
+ expected: validUUID,
+ },
+ {
+ name: "invalid base64",
+ cursor: "not-valid-base64!!!",
+ expected: uuid.Nil,
+ errContains: "decode failed",
+ },
+ {
+ name: "valid base64 but invalid UUID",
+ cursor: base64.StdEncoding.EncodeToString([]byte("not-a-uuid")),
+ expected: uuid.Nil,
+ errContains: "parse failed",
+ },
+ {
+ name: "empty string",
+ cursor: "",
+ expected: uuid.Nil,
+ errContains: "parse failed",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ decoded, err := DecodeUUIDCursor(tc.cursor)
+
+ if tc.errContains != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tc.errContains)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Equal(t, uuid.Nil, decoded)
+
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tc.expected, decoded)
+ })
+ }
+}
+
+func TestParseOpaqueCursorPagination_NilContext(t *testing.T) {
+ t.Parallel()
+
+ cursor, limit, err := ParseOpaqueCursorPagination(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+ assert.Empty(t, cursor)
+ assert.Zero(t, limit)
+}
diff --git a/commons/net/http/pagination_sort.go b/commons/net/http/pagination_sort.go
new file mode 100644
index 00000000..1dfd87de
--- /dev/null
+++ b/commons/net/http/pagination_sort.go
@@ -0,0 +1,136 @@
+package http
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "regexp"
+ "strings"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+)
+
+// sortColumnPattern validates sort column names as simple SQL identifiers.
+// Callers must still enforce endpoint-specific allowlists with ValidateSortColumn.
+var sortColumnPattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
+
+// SortCursor encodes a position in a sorted result set for composite keyset pagination.
+// It stores the sort column name, sort value, and record ID, enabling stable cursor
+// pagination when ordering by columns other than id.
+type SortCursor struct {
+ SortColumn string `json:"sc"`
+ SortValue string `json:"sv"`
+ ID string `json:"i"`
+ PointsNext bool `json:"pn"`
+}
+
+// EncodeSortCursor encodes sort cursor data into a base64 string.
+// Returns an error if id is empty or sortColumn is empty, matching the
+// decoder's validation contract.
+func EncodeSortCursor(sortColumn, sortValue, id string, pointsNext bool) (string, error) {
+ if id == "" {
+ return "", fmt.Errorf("%w: id must not be empty", ErrInvalidCursor)
+ }
+
+ if sortColumn == "" {
+ return "", fmt.Errorf("%w: sort column must not be empty", ErrInvalidCursor)
+ }
+
+ cursor := SortCursor{
+ SortColumn: sortColumn,
+ SortValue: sortValue,
+ ID: id,
+ PointsNext: pointsNext,
+ }
+
+ data, err := json.Marshal(cursor)
+ if err != nil {
+ return "", fmt.Errorf("encode sort cursor: %w", err)
+ }
+
+ return base64.StdEncoding.EncodeToString(data), nil
+}
+
+// DecodeSortCursor decodes a base64 cursor string into a SortCursor.
+// It validates identifier syntax only; callers must still validate SortColumn
+// against their endpoint-specific allowlist before building queries.
+func DecodeSortCursor(cursor string) (*SortCursor, error) {
+ decoded, err := base64.StdEncoding.DecodeString(cursor)
+ if err != nil {
+ return nil, fmt.Errorf("%w: decode failed: %w", ErrInvalidCursor, err)
+ }
+
+ var sc SortCursor
+ if err := json.Unmarshal(decoded, &sc); err != nil {
+ return nil, fmt.Errorf("%w: unmarshal failed: %w", ErrInvalidCursor, err)
+ }
+
+ if sc.ID == "" {
+ return nil, fmt.Errorf("%w: missing id", ErrInvalidCursor)
+ }
+
+ if sc.SortColumn == "" || !sortColumnPattern.MatchString(sc.SortColumn) {
+ return nil, fmt.Errorf("%w: invalid sort column", ErrInvalidCursor)
+ }
+
+ return &sc, nil
+}
+
+// SortCursorDirection computes the actual SQL ORDER BY direction and comparison
+// operator for composite keyset pagination based on the requested direction and
+// whether the cursor points forward or backward.
+func SortCursorDirection(requestedDir string, pointsNext bool) (actualDir, operator string) {
+ isAsc := strings.EqualFold(requestedDir, cn.SortDirASC)
+
+ if pointsNext {
+ if isAsc {
+ return cn.SortDirASC, ">"
+ }
+
+ return cn.SortDirDESC, "<"
+ }
+
+ if isAsc {
+ return cn.SortDirDESC, "<"
+ }
+
+ return cn.SortDirASC, ">"
+}
+
+// CalculateSortCursorPagination computes Next/Prev cursor strings for composite keyset pagination.
+func CalculateSortCursorPagination(
+ isFirstPage, hasPagination, pointsNext bool,
+ sortColumn string,
+ firstSortValue, firstID string,
+ lastSortValue, lastID string,
+) (next, prev string, err error) {
+ hasNext := (pointsNext && hasPagination) || (!pointsNext && (hasPagination || isFirstPage))
+
+ if hasNext {
+ next, err = EncodeSortCursor(sortColumn, lastSortValue, lastID, true)
+ if err != nil {
+ return "", "", err
+ }
+ }
+
+ if !isFirstPage {
+ prev, err = EncodeSortCursor(sortColumn, firstSortValue, firstID, false)
+ if err != nil {
+ return "", "", err
+ }
+ }
+
+ return next, prev, nil
+}
+
+// ValidateSortColumn checks whether column is in the allowed list (case-insensitive)
+// and returns the matched allowed value. If no match is found, it returns defaultColumn.
+func ValidateSortColumn(column string, allowed []string, defaultColumn string) string {
+ for _, a := range allowed {
+ if strings.EqualFold(column, a) {
+ return a
+ }
+ }
+
+ return defaultColumn
+}
diff --git a/commons/net/http/pagination_sort_test.go b/commons/net/http/pagination_sort_test.go
new file mode 100644
index 00000000..5d682be9
--- /dev/null
+++ b/commons/net/http/pagination_sort_test.go
@@ -0,0 +1,383 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestEncodeSortCursor_RoundTrip(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ sortColumn string
+ sortValue string
+ id string
+ pointsNext bool
+ }{
+ {
+ name: "timestamp column forward",
+ sortColumn: "created_at",
+ sortValue: "2025-06-15T14:30:45Z",
+ id: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
+ pointsNext: true,
+ },
+ {
+ name: "status column backward",
+ sortColumn: "status",
+ sortValue: "COMPLETED",
+ id: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
+ pointsNext: false,
+ },
+ {
+ name: "empty sort value",
+ sortColumn: "completed_at",
+ sortValue: "",
+ id: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
+ pointsNext: true,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ encoded, err := EncodeSortCursor(tc.sortColumn, tc.sortValue, tc.id, tc.pointsNext)
+ require.NoError(t, err)
+ assert.NotEmpty(t, encoded)
+
+ decoded, err := DecodeSortCursor(encoded)
+ require.NoError(t, err)
+ require.NotNil(t, decoded)
+ assert.Equal(t, tc.sortColumn, decoded.SortColumn)
+ assert.Equal(t, tc.sortValue, decoded.SortValue)
+ assert.Equal(t, tc.id, decoded.ID)
+ assert.Equal(t, tc.pointsNext, decoded.PointsNext)
+ })
+ }
+}
+
+func TestDecodeSortCursor_Errors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cursor string
+ errContains string
+ }{
+ {
+ name: "empty string",
+ cursor: "",
+ errContains: "unmarshal failed",
+ },
+ {
+ name: "whitespace only",
+ cursor: " ",
+ errContains: "decode failed",
+ },
+ {
+ name: "invalid base64",
+ cursor: "not-valid-base64!!!",
+ errContains: "decode failed",
+ },
+ {
+ name: "valid base64 but invalid JSON",
+ cursor: base64.StdEncoding.EncodeToString([]byte("not-json")),
+ errContains: "unmarshal failed",
+ },
+ {
+ name: "valid JSON but missing ID",
+ cursor: base64.StdEncoding.EncodeToString([]byte(`{"sc":"created_at","sv":"2025-01-01","pn":true}`)),
+ errContains: "missing id",
+ },
+ {
+ name: "invalid sort column",
+ cursor: base64.StdEncoding.EncodeToString([]byte(`{"sc":"created_at;DROP TABLE users","sv":"2025-01-01","i":"abc","pn":true}`)),
+ errContains: "invalid sort column",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ decoded, err := DecodeSortCursor(tc.cursor)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Contains(t, err.Error(), tc.errContains)
+ assert.Nil(t, decoded)
+ })
+ }
+}
+
+func TestSortCursorDirection(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ requestedDir string
+ pointsNext bool
+ expectedDir string
+ expectedOp string
+ }{
+ {
+ name: "ASC forward",
+ requestedDir: "ASC",
+ pointsNext: true,
+ expectedDir: "ASC",
+ expectedOp: ">",
+ },
+ {
+ name: "DESC forward",
+ requestedDir: "DESC",
+ pointsNext: true,
+ expectedDir: "DESC",
+ expectedOp: "<",
+ },
+ {
+ name: "ASC backward",
+ requestedDir: "ASC",
+ pointsNext: false,
+ expectedDir: "DESC",
+ expectedOp: "<",
+ },
+ {
+ name: "DESC backward",
+ requestedDir: "DESC",
+ pointsNext: false,
+ expectedDir: "ASC",
+ expectedOp: ">",
+ },
+ {
+ name: "lowercase asc forward",
+ requestedDir: "asc",
+ pointsNext: true,
+ expectedDir: "ASC",
+ expectedOp: ">",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ actualDir, operator := SortCursorDirection(tc.requestedDir, tc.pointsNext)
+ assert.Equal(t, tc.expectedDir, actualDir)
+ assert.Equal(t, tc.expectedOp, operator)
+ })
+ }
+}
+
+func TestCalculateSortCursorPagination(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ isFirstPage bool
+ hasPagination bool
+ pointsNext bool
+ expectNext bool
+ expectPrev bool
+ }{
+ {
+ name: "first page with more results",
+ isFirstPage: true,
+ hasPagination: true,
+ pointsNext: true,
+ expectNext: true,
+ expectPrev: false,
+ },
+ {
+ name: "middle page forward",
+ isFirstPage: false,
+ hasPagination: true,
+ pointsNext: true,
+ expectNext: true,
+ expectPrev: true,
+ },
+ {
+ name: "last page forward",
+ isFirstPage: false,
+ hasPagination: false,
+ pointsNext: true,
+ expectNext: false,
+ expectPrev: true,
+ },
+ {
+ name: "first page no more results",
+ isFirstPage: true,
+ hasPagination: false,
+ pointsNext: true,
+ expectNext: false,
+ expectPrev: false,
+ },
+ {
+ name: "backward navigation with more",
+ isFirstPage: false,
+ hasPagination: true,
+ pointsNext: false,
+ expectNext: true,
+ expectPrev: true,
+ },
+ {
+ name: "backward navigation at start",
+ isFirstPage: true,
+ hasPagination: false,
+ pointsNext: false,
+ expectNext: true,
+ expectPrev: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ next, prev, calcErr := CalculateSortCursorPagination(
+ tc.isFirstPage, tc.hasPagination, tc.pointsNext,
+ "created_at",
+ "2025-01-01T00:00:00Z", "id-first",
+ "2025-01-02T00:00:00Z", "id-last",
+ )
+ require.NoError(t, calcErr)
+
+ if tc.expectNext {
+ assert.NotEmpty(t, next, "expected next cursor")
+
+ decoded, err := DecodeSortCursor(next)
+ require.NoError(t, err)
+ assert.Equal(t, "created_at", decoded.SortColumn)
+ assert.True(t, decoded.PointsNext)
+ } else {
+ assert.Empty(t, next, "expected no next cursor")
+ }
+
+ if tc.expectPrev {
+ assert.NotEmpty(t, prev, "expected prev cursor")
+
+ decoded, err := DecodeSortCursor(prev)
+ require.NoError(t, err)
+ assert.Equal(t, "created_at", decoded.SortColumn)
+ assert.False(t, decoded.PointsNext)
+ } else {
+ assert.Empty(t, prev, "expected no prev cursor")
+ }
+ })
+ }
+}
+
+func TestValidateSortColumn(t *testing.T) {
+ t.Parallel()
+
+ allowed := []string{"id", "created_at", "status"}
+
+ tests := []struct {
+ name string
+ column string
+ expected string
+ }{
+ {
+ name: "exact match returns allowed value",
+ column: "created_at",
+ expected: "created_at",
+ },
+ {
+ name: "case insensitive match uppercase",
+ column: "CREATED_AT",
+ expected: "created_at",
+ },
+ {
+ name: "case insensitive match mixed case",
+ column: "Status",
+ expected: "status",
+ },
+ {
+ name: "empty column returns default",
+ column: "",
+ expected: "id",
+ },
+ {
+ name: "unknown column returns default",
+ column: "nonexistent",
+ expected: "id",
+ },
+ {
+ name: "id returns id",
+ column: "id",
+ expected: "id",
+ },
+ {
+ name: "sql injection attempt returns default",
+ column: "id; DROP TABLE--",
+ expected: "id",
+ },
+ {
+ name: "whitespace only returns default",
+ column: " ",
+ expected: "id",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := ValidateSortColumn(tc.column, allowed, "id")
+ assert.Equal(t, tc.expected, result)
+ })
+ }
+}
+
+func TestValidateSortColumn_EmptyAllowed(t *testing.T) {
+ t.Parallel()
+
+ result := ValidateSortColumn("anything", nil, "fallback")
+ assert.Equal(t, "fallback", result)
+}
+
+func TestValidateSortColumn_CustomDefault(t *testing.T) {
+ t.Parallel()
+
+ result := ValidateSortColumn("unknown", []string{"name"}, "created_at")
+ assert.Equal(t, "created_at", result)
+}
+
+func TestEncodeSortCursor_Success(t *testing.T) {
+ t.Parallel()
+
+ encoded, err := EncodeSortCursor("created_at", "2025-01-01", "some-id", true)
+ require.NoError(t, err)
+ assert.NotEmpty(t, encoded)
+
+ decoded, err := DecodeSortCursor(encoded)
+ require.NoError(t, err)
+ assert.Equal(t, "created_at", decoded.SortColumn)
+ assert.Equal(t, "2025-01-01", decoded.SortValue)
+ assert.Equal(t, "some-id", decoded.ID)
+ assert.True(t, decoded.PointsNext)
+}
+
+func TestEncodeSortCursor_EmptySortColumn_RejectsAtEncodeTime(t *testing.T) {
+ t.Parallel()
+
+ encoded, err := EncodeSortCursor("", "value", "id-1", true)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Contains(t, err.Error(), "sort column must not be empty")
+ assert.Empty(t, encoded)
+}
+
+func TestEncodeSortCursor_EmptyID_RejectsAtEncodeTime(t *testing.T) {
+ t.Parallel()
+
+ encoded, err := EncodeSortCursor("created_at", "value", "", true)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidCursor)
+ assert.Contains(t, err.Error(), "id must not be empty")
+ assert.Empty(t, encoded)
+}
diff --git a/commons/net/http/pagination_strict_test.go b/commons/net/http/pagination_strict_test.go
new file mode 100644
index 00000000..5959c972
--- /dev/null
+++ b/commons/net/http/pagination_strict_test.go
@@ -0,0 +1,44 @@
+//go:build unit
+
+package http
+
+import (
+ "testing"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateLimitStrict(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ limit int
+ want int
+ wantErr error
+ }{
+ {name: "valid limit", limit: 10, want: 10},
+ {name: "limit capped", limit: cn.MaxLimit + 10, want: cn.MaxLimit},
+ {name: "zero rejected", limit: 0, wantErr: ErrLimitMustBePositive},
+ {name: "negative rejected", limit: -5, wantErr: ErrLimitMustBePositive},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := ValidateLimitStrict(tc.limit, cn.MaxLimit)
+ if tc.wantErr != nil {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, tc.wantErr)
+ assert.Zero(t, got)
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tc.want, got)
+ })
+ }
+}
diff --git a/commons/net/http/pagination_test.go b/commons/net/http/pagination_test.go
new file mode 100644
index 00000000..b9d7b3e2
--- /dev/null
+++ b/commons/net/http/pagination_test.go
@@ -0,0 +1,264 @@
+//go:build unit
+
+package http
+
+import (
+ "net/http/httptest"
+ "testing"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParsePagination(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ queryString string
+ expectedLimit int
+ expectedOffset int
+ expectedErr error
+ errContains string
+ }{
+ {
+ name: "default values when no query params",
+ queryString: "",
+ expectedLimit: 20,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "valid limit and offset",
+ queryString: "limit=10&offset=5",
+ expectedLimit: 10,
+ expectedOffset: 5,
+ expectedErr: nil,
+ },
+ {
+ name: "limit capped at maxLimit",
+ queryString: "limit=500",
+ expectedLimit: 200,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "limit exactly at maxLimit",
+ queryString: "limit=200",
+ expectedLimit: 200,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "limit just below maxLimit",
+ queryString: "limit=199",
+ expectedLimit: 199,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "limit just above maxLimit gets capped",
+ queryString: "limit=201",
+ expectedLimit: 200,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "invalid limit non-numeric",
+ queryString: "limit=abc",
+ expectedErr: nil,
+ errContains: "invalid limit value",
+ },
+ {
+ name: "invalid offset non-numeric",
+ queryString: "offset=xyz",
+ expectedErr: nil,
+ errContains: "invalid offset value",
+ },
+ {
+ name: "limit zero uses default",
+ queryString: "limit=0",
+ expectedLimit: 20,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "negative limit uses default",
+ queryString: "limit=-5",
+ expectedLimit: 20,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "negative offset coerces to default",
+ queryString: "limit=10&offset=-1",
+ expectedLimit: 10,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "very large limit gets capped",
+ queryString: "limit=999999999",
+ expectedLimit: 200,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "very large offset is valid",
+ queryString: "limit=10&offset=999999999",
+ expectedLimit: 10,
+ expectedOffset: 999999999,
+ expectedErr: nil,
+ },
+ {
+ name: "empty limit param uses default",
+ queryString: "limit=&offset=10",
+ expectedLimit: 20,
+ expectedOffset: 10,
+ expectedErr: nil,
+ },
+ {
+ name: "empty offset param uses default",
+ queryString: "limit=25&offset=",
+ expectedLimit: 25,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "only limit provided",
+ queryString: "limit=75",
+ expectedLimit: 75,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "only offset provided",
+ queryString: "offset=100",
+ expectedLimit: 20,
+ expectedOffset: 100,
+ expectedErr: nil,
+ },
+ {
+ name: "offset zero is valid",
+ queryString: "offset=0",
+ expectedLimit: 20,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "limit one is valid minimum",
+ queryString: "limit=1",
+ expectedLimit: 1,
+ expectedOffset: 0,
+ expectedErr: nil,
+ },
+ {
+ name: "limit with decimal is invalid",
+ queryString: "limit=10.5",
+ errContains: "invalid limit value",
+ },
+ {
+ name: "offset with decimal is invalid",
+ queryString: "offset=5.5",
+ errContains: "invalid offset value",
+ },
+ {
+ name: "limit with special characters",
+ queryString: "limit=10@#",
+ errContains: "invalid limit value",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var limit, offset int
+ var err error
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ limit, offset, err = ParsePagination(c)
+ return nil
+ })
+
+ req := httptest.NewRequest("GET", "/test?"+tc.queryString, nil)
+ resp, testErr := app.Test(req)
+ require.NoError(t, testErr)
+ require.NoError(t, resp.Body.Close())
+
+ if tc.expectedErr != nil {
+ require.ErrorIs(t, err, tc.expectedErr)
+ assert.Zero(t, limit)
+ assert.Zero(t, offset)
+
+ return
+ }
+
+ if tc.errContains != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tc.errContains)
+ assert.Zero(t, limit)
+ assert.Zero(t, offset)
+
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tc.expectedLimit, limit)
+ assert.Equal(t, tc.expectedOffset, offset)
+ })
+ }
+}
+
+func TestPaginationConstants(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, 20, cn.DefaultLimit)
+ assert.Equal(t, 0, cn.DefaultOffset)
+ assert.Equal(t, 200, cn.MaxLimit)
+}
+
+// ---------------------------------------------------------------------------
+// Nil guard tests
+// ---------------------------------------------------------------------------
+
+func TestParsePagination_NilContext(t *testing.T) {
+ t.Parallel()
+
+ limit, offset, err := ParsePagination(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+ assert.Zero(t, limit)
+ assert.Zero(t, offset)
+}
+
+// ---------------------------------------------------------------------------
+// Lenient negative offset coercion
+// ---------------------------------------------------------------------------
+
+func TestParsePagination_NegativeOffsetCoercesToZero(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+
+ var limit, offset int
+ var err error
+
+ app.Get("/test", func(c *fiber.Ctx) error {
+ limit, offset, err = ParsePagination(c)
+ return nil
+ })
+
+ req := httptest.NewRequest("GET", "/test?limit=10&offset=-100", nil)
+ resp, testErr := app.Test(req)
+ require.NoError(t, testErr)
+ require.NoError(t, resp.Body.Close())
+
+ require.NoError(t, err)
+ assert.Equal(t, 10, limit)
+ assert.Equal(t, 0, offset, "negative offset should be coerced to 0 (DefaultOffset)")
+}
diff --git a/commons/net/http/pagination_timestamp.go b/commons/net/http/pagination_timestamp.go
new file mode 100644
index 00000000..8d64dfab
--- /dev/null
+++ b/commons/net/http/pagination_timestamp.go
@@ -0,0 +1,90 @@
+package http
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "time"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+)
+
+// TimestampCursor represents a cursor for keyset pagination with timestamp + ID ordering.
+// This ensures correct pagination when records are ordered by (timestamp DESC, id DESC).
+type TimestampCursor struct {
+ Timestamp time.Time `json:"t"`
+ ID uuid.UUID `json:"i"`
+}
+
+// EncodeTimestampCursor encodes a timestamp and UUID into a base64 cursor string.
+// Returns an error if id is uuid.Nil, matching the decoder's validation contract.
+func EncodeTimestampCursor(timestamp time.Time, id uuid.UUID) (string, error) {
+ if id == uuid.Nil {
+ return "", fmt.Errorf("%w: id must not be nil UUID", ErrInvalidCursor)
+ }
+
+ cursor := TimestampCursor{
+ Timestamp: timestamp.UTC(),
+ ID: id,
+ }
+
+ data, err := json.Marshal(cursor)
+ if err != nil {
+ return "", fmt.Errorf("encode timestamp cursor: %w", err)
+ }
+
+ return base64.StdEncoding.EncodeToString(data), nil
+}
+
+// DecodeTimestampCursor decodes a base64 cursor string into a TimestampCursor.
+func DecodeTimestampCursor(cursor string) (*TimestampCursor, error) {
+ decoded, err := base64.StdEncoding.DecodeString(cursor)
+ if err != nil {
+ return nil, fmt.Errorf("%w: decode failed: %w", ErrInvalidCursor, err)
+ }
+
+ var tc TimestampCursor
+ if err := json.Unmarshal(decoded, &tc); err != nil {
+ return nil, fmt.Errorf("%w: unmarshal failed: %w", ErrInvalidCursor, err)
+ }
+
+ if tc.ID == uuid.Nil {
+ return nil, fmt.Errorf("%w: missing id", ErrInvalidCursor)
+ }
+
+ return &tc, nil
+}
+
+// ParseTimestampCursorPagination parses cursor/limit query params for timestamp-based cursor pagination.
+// Returns the decoded TimestampCursor (nil for first page), limit, and any error.
+func ParseTimestampCursorPagination(fiberCtx *fiber.Ctx) (*TimestampCursor, int, error) {
+ if fiberCtx == nil {
+ return nil, 0, ErrContextNotFound
+ }
+
+ limit := cn.DefaultLimit
+
+ if limitValue := fiberCtx.Query("limit"); limitValue != "" {
+ parsed, err := strconv.Atoi(limitValue)
+ if err != nil {
+ return nil, 0, fmt.Errorf("invalid limit value: %w", err)
+ }
+
+ limit = ValidateLimit(parsed, cn.DefaultLimit, cn.MaxLimit)
+ }
+
+ cursorParam := fiberCtx.Query("cursor")
+ if cursorParam == "" {
+ return nil, limit, nil
+ }
+
+ tc, err := DecodeTimestampCursor(cursorParam)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ return tc, limit, nil
+}
diff --git a/commons/net/http/proxy.go b/commons/net/http/proxy.go
index 983f4c1a..0f14a08b 100644
--- a/commons/net/http/proxy.go
+++ b/commons/net/http/proxy.go
@@ -1,31 +1,130 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
- constant "github.com/LerianStudio/lib-commons/v2/commons/constants"
+ "errors"
"net/http"
"net/http/httputil"
"net/url"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+)
+
+var (
+ // ErrInvalidProxyTarget indicates the proxy target URL is malformed or empty.
+ ErrInvalidProxyTarget = errors.New("invalid proxy target")
+ // ErrUntrustedProxyScheme indicates the proxy target uses a disallowed URL scheme.
+ ErrUntrustedProxyScheme = errors.New("untrusted proxy scheme")
+ // ErrUntrustedProxyHost indicates the proxy target hostname is not in the allowed list.
+ ErrUntrustedProxyHost = errors.New("untrusted proxy host")
+ // ErrUnsafeProxyDestination indicates the proxy target resolves to a private or loopback address.
+ ErrUnsafeProxyDestination = errors.New("unsafe proxy destination")
+ // ErrNilProxyRequest indicates a nil HTTP request was passed to the reverse proxy.
+ ErrNilProxyRequest = errors.New("proxy request cannot be nil")
+ // ErrNilProxyResponse indicates a nil HTTP response writer was passed to the reverse proxy.
+ ErrNilProxyResponse = errors.New("proxy response writer cannot be nil")
+ // ErrNilProxyRequestURL indicates the HTTP request has a nil URL.
+ ErrNilProxyRequestURL = errors.New("proxy request URL cannot be nil")
+ // ErrDNSResolutionFailed indicates the proxy target hostname could not be resolved.
+ ErrDNSResolutionFailed = errors.New("DNS resolution failed for proxy target")
+ // ErrNoResolvedIPs indicates DNS resolution returned zero IP addresses for the proxy target.
+ ErrNoResolvedIPs = errors.New("no resolved IPs for proxy target")
)
-// ServeReverseProxy serves a reverse proxy for a given url.
-func ServeReverseProxy(target string, res http.ResponseWriter, req *http.Request) {
+// ReverseProxyPolicy defines strict trust boundaries for reverse proxy targets.
+type ReverseProxyPolicy struct {
+ AllowedSchemes []string
+ // AllowedHosts restricts proxy targets to the listed hostnames (case-insensitive).
+ // An empty or nil slice rejects all hosts (secure-by-default), matching AllowedSchemes behavior.
+ // This allowlist is hostname-based only and does not restrict destination ports.
+ // Callers must explicitly populate this to permit proxy targets.
+ // See isAllowedHost and ErrUntrustedProxyHost for enforcement details.
+ AllowedHosts []string
+ AllowUnsafeDestinations bool
+ // Logger is an optional structured logger for security-relevant events.
+ // When nil, no logging is performed.
+ Logger log.Logger
+}
+
+// DefaultReverseProxyPolicy returns a strict-by-default reverse proxy policy.
+func DefaultReverseProxyPolicy() ReverseProxyPolicy {
+ return ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: nil,
+ AllowUnsafeDestinations: false,
+ }
+}
+
+// ServeReverseProxy serves a reverse proxy for a given URL,
+// enforcing explicit policy checks.
+//
+// Security: Uses a custom transport that validates resolved IPs at connection time
+// to prevent DNS rebinding attacks.
+func ServeReverseProxy(target string, policy ReverseProxyPolicy, res http.ResponseWriter, req *http.Request) error {
+ if req == nil {
+ return ErrNilProxyRequest
+ }
+
+ if req.URL == nil {
+ return ErrNilProxyRequestURL
+ }
+
+ if nilcheck.Interface(res) {
+ return ErrNilProxyResponse
+ }
+
targetURL, err := url.Parse(target)
if err != nil {
- http.Error(res, err.Error(), http.StatusInternalServerError)
- return
+ return ErrInvalidProxyTarget
+ }
+
+ if err := validateProxyTarget(targetURL, policy); err != nil {
+ if !nilcheck.Interface(policy.Logger) {
+ policy.Logger.Log(req.Context(), log.LevelWarn, "reverse proxy target rejected",
+ log.String("target_host", targetURL.Host),
+ log.String("target_scheme", targetURL.Scheme),
+ log.Err(err),
+ )
+ }
+
+ return err
+ }
+
+ ctx, span := otel.Tracer("http.proxy").Start(
+ req.Context(),
+ "http.reverse_proxy",
+ trace.WithSpanKind(trace.SpanKindClient),
+ )
+ defer span.End()
+
+ span.SetAttributes(
+ attribute.String("http.url", targetURL.Host),
+ attribute.String("http.method", req.Method),
+ )
+
+ req = req.WithContext(ctx)
+ if req.Header == nil {
+ req.Header = make(http.Header)
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
+ proxy.Transport = newSSRFSafeTransport(policy)
- // Update the headers to allow for SSL redirection
+ // Preserve current v4 forwarding semantics for existing consumers.
+ // This retains caller headers as-is, including auth/session headers, per user decision.
+ opentelemetry.InjectHTTPContext(req.Context(), req.Header)
req.URL.Host = targetURL.Host
req.URL.Scheme = targetURL.Scheme
- req.Header.Set(constant.HeaderForwardedHost, req.Header.Get(constant.HeaderHost))
+ req.Header.Set(constant.HeaderForwardedHost, req.Host)
req.Host = targetURL.Host
- proxy.ServeHTTP(res, req) //#nosec G704 -- target URL is application-configured, not user input
+ // #nosec G704 -- target validated via validateProxyTarget with scheme/host allowlists and IP safety; ssrfSafeTransport re-validates resolved IPs at connection time
+ proxy.ServeHTTP(res, req)
+
+ return nil
}
diff --git a/commons/net/http/proxy_defensive_test.go b/commons/net/http/proxy_defensive_test.go
new file mode 100644
index 00000000..cc151d88
--- /dev/null
+++ b/commons/net/http/proxy_defensive_test.go
@@ -0,0 +1,196 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ liblog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type typedNilProxyResponseWriter struct{}
+
+func (*typedNilProxyResponseWriter) Header() http.Header {
+ panic("typedNilProxyResponseWriter should not be used")
+}
+
+func (*typedNilProxyResponseWriter) Write([]byte) (int, error) {
+ panic("typedNilProxyResponseWriter should not be used")
+}
+
+func (*typedNilProxyResponseWriter) WriteHeader(int) {
+ panic("typedNilProxyResponseWriter should not be used")
+}
+
+type typedNilProxyLogger struct{}
+
+func (*typedNilProxyLogger) Log(context.Context, liblog.Level, string, ...liblog.Field) {
+ panic("typedNilProxyLogger should not be used")
+}
+
+func (*typedNilProxyLogger) With(...liblog.Field) liblog.Logger {
+ panic("typedNilProxyLogger should not be used")
+}
+
+func (*typedNilProxyLogger) WithGroup(string) liblog.Logger {
+ panic("typedNilProxyLogger should not be used")
+}
+
+func (*typedNilProxyLogger) Enabled(liblog.Level) bool {
+ panic("typedNilProxyLogger should not be used")
+}
+
+func (*typedNilProxyLogger) Sync(context.Context) error {
+ panic("typedNilProxyLogger should not be used")
+}
+
+func TestServeReverseProxy_NilRequestURL(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ req.URL = nil
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://example.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"example.com"},
+ }, rr, req)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilProxyRequestURL)
+}
+
+func TestServeReverseProxy_NilHeaderMap(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("ok"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ req.Header = nil
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{requestHostFromURL(t, target.URL)},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(t, readErr)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "ok", string(body))
+}
+
+func TestServeReverseProxy_TypedNilResponseWriter(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ var res *typedNilProxyResponseWriter
+
+ err := ServeReverseProxy("https://example.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"example.com"},
+ }, res, req)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilProxyResponse)
+}
+
+func TestServeReverseProxy_TypedNilLoggerOnValidationError(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+ var logger *typedNilProxyLogger
+
+ assert.NotPanics(t, func() {
+ err := ServeReverseProxy("http://example.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"example.com"},
+ Logger: logger,
+ }, rr, req)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUntrustedProxyScheme)
+ })
+}
+
+func TestValidateResolvedIPs_NoIPs(t *testing.T) {
+ t.Parallel()
+
+ ip, err := validateResolvedIPs(context.Background(), nil, "example.com", nil)
+ require.Error(t, err)
+ assert.Nil(t, ip)
+ assert.ErrorIs(t, err, ErrNoResolvedIPs)
+}
+
+func TestSSRFSafeTransport_DNSResolutionFailure(t *testing.T) {
+ t.Parallel()
+
+ transport := newSSRFSafeTransportWithDeps(ReverseProxyPolicy{}, func(context.Context, string) ([]net.IPAddr, error) {
+ return nil, errors.New("lookup failed")
+ })
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ _, err := transport.base.DialContext(ctx, "tcp", "example.com:443")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrDNSResolutionFailed)
+}
+
+func TestValidateResolvedIPs_UnsafeAddressRejected(t *testing.T) {
+ t.Parallel()
+
+ ip, err := validateResolvedIPs(context.Background(), []net.IPAddr{{IP: net.ParseIP("127.0.0.1")}}, "example.com", nil)
+ require.Error(t, err)
+ assert.Nil(t, ip)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+}
+
+func TestValidateResolvedIPs_MixedAddressesRejected(t *testing.T) {
+ t.Parallel()
+
+ ip, err := validateResolvedIPs(context.Background(), []net.IPAddr{
+ {IP: net.ParseIP("8.8.8.8")},
+ {IP: net.ParseIP("127.0.0.1")},
+ }, "example.com", nil)
+ require.Error(t, err)
+ assert.Nil(t, ip)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+}
+
+func TestValidateResolvedIPs_AllSafeReturnsFirst(t *testing.T) {
+ t.Parallel()
+
+ ip, err := validateResolvedIPs(context.Background(), []net.IPAddr{
+ {IP: net.ParseIP("8.8.8.8")},
+ {IP: net.ParseIP("1.1.1.1")},
+ }, "example.com", nil)
+ require.NoError(t, err)
+ assert.Equal(t, net.ParseIP("8.8.8.8"), ip)
+}
+
+func TestValidateResolvedIPs_TypedNilLogger(t *testing.T) {
+ t.Parallel()
+
+ var logger *typedNilProxyLogger
+
+ assert.NotPanics(t, func() {
+ ip, err := validateResolvedIPs(context.Background(), nil, "example.com", logger)
+ require.Error(t, err)
+ assert.Nil(t, ip)
+ assert.ErrorIs(t, err, ErrNoResolvedIPs)
+ })
+}
diff --git a/commons/net/http/proxy_forwarding_test.go b/commons/net/http/proxy_forwarding_test.go
new file mode 100644
index 00000000..41362d6a
--- /dev/null
+++ b/commons/net/http/proxy_forwarding_test.go
@@ -0,0 +1,165 @@
+//go:build unit
+
+package http
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestServeReverseProxy_HeaderForwarding(t *testing.T) {
+ t.Parallel()
+
+ var receivedHost string
+ var receivedForwardedHost string
+ var receivedForwardedFor string
+ var receivedForwardedProto string
+ var receivedForwarded string
+ var receivedRealIP string
+ var receivedAuthorization string
+ var receivedCookie string
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ receivedHost = r.Host
+ receivedForwardedHost = r.Header.Get("X-Forwarded-Host")
+ receivedForwardedFor = r.Header.Get("X-Forwarded-For")
+ receivedForwardedProto = r.Header.Get("X-Forwarded-Proto")
+ receivedForwarded = r.Header.Get("Forwarded")
+ receivedRealIP = r.Header.Get("X-Real-Ip")
+ receivedAuthorization = r.Header.Get("Authorization")
+ receivedCookie = r.Header.Get("Cookie")
+ _, _ = w.Write([]byte("headers checked"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://original-host.local/proxy", nil)
+ req.Header.Set("X-Forwarded-For", "203.0.113.10")
+ req.Header.Set("X-Forwarded-Proto", "https")
+ req.Header.Set("Forwarded", "for=203.0.113.10;proto=https")
+ req.Header.Set("X-Real-Ip", "203.0.113.10")
+ req.Header.Set("Authorization", "Bearer test-token")
+ req.Header.Set("Cookie", "session=abc123")
+ rr := httptest.NewRecorder()
+
+ host := requestHostFromURL(t, target.URL)
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{host},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, "headers checked", string(body))
+ assert.Contains(t, receivedHost, host)
+ assert.Equal(t, "original-host.local", receivedForwardedHost)
+ assert.Equal(t, "https", receivedForwardedProto)
+ assert.Contains(t, receivedForwardedFor, "203.0.113.10")
+ assert.Equal(t, "for=203.0.113.10;proto=https", receivedForwarded)
+ assert.Equal(t, "203.0.113.10", receivedRealIP)
+ assert.Equal(t, "Bearer test-token", receivedAuthorization)
+ assert.Equal(t, "session=abc123", receivedCookie)
+}
+
+func TestServeReverseProxy_ProxyPassesResponseBody(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ _, _ = w.Write([]byte(`{"status":"created"}`))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ host := requestHostFromURL(t, target.URL)
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{host},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusCreated, resp.StatusCode)
+ assert.Equal(t, "application/json", resp.Header.Get("Content-Type"))
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.JSONEq(t, `{"status":"created"}`, string(body))
+}
+
+func TestServeReverseProxy_CaseInsensitiveScheme(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("ok"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ host := requestHostFromURL(t, target.URL)
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"HTTP"},
+ AllowedHosts: []string{host},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(t, readErr)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "ok", string(body))
+}
+
+func TestServeReverseProxy_MultipleAllowedSchemes(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("multi-scheme"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ host := requestHostFromURL(t, target.URL)
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"https", "http"},
+ AllowedHosts: []string{host},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(t, readErr)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "multi-scheme", string(body))
+}
diff --git a/commons/net/http/proxy_ssrf_test.go b/commons/net/http/proxy_ssrf_test.go
new file mode 100644
index 00000000..8cdd4189
--- /dev/null
+++ b/commons/net/http/proxy_ssrf_test.go
@@ -0,0 +1,303 @@
+//go:build unit
+
+package http
+
+import (
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestServeReverseProxy_SSRF_LoopbackIPv4(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://127.0.0.1:8080/admin", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"127.0.0.1"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+}
+
+func TestServeReverseProxy_SSRF_LoopbackIPv4_AltAddresses(t *testing.T) {
+ t.Parallel()
+
+ loopbacks := []string{"127.0.0.1", "127.0.0.2", "127.255.255.255"}
+
+ for _, ip := range loopbacks {
+ t.Run(ip, func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://"+ip+":8080", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{ip},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+ })
+ }
+}
+
+func TestServeReverseProxy_SSRF_PrivateClassA(t *testing.T) {
+ t.Parallel()
+
+ privateIPs := []string{"10.0.0.1", "10.0.0.0", "10.255.255.255", "10.1.2.3"}
+
+ for _, ip := range privateIPs {
+ t.Run(ip, func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://"+ip, ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{ip},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+ })
+ }
+}
+
+func TestServeReverseProxy_SSRF_PrivateClassB(t *testing.T) {
+ t.Parallel()
+
+ privateIPs := []string{"172.16.0.1", "172.16.0.0", "172.31.255.255", "172.20.10.1"}
+
+ for _, ip := range privateIPs {
+ t.Run(ip, func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://"+ip, ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{ip},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+ })
+ }
+}
+
+func TestServeReverseProxy_SSRF_PrivateClassC(t *testing.T) {
+ t.Parallel()
+
+ privateIPs := []string{"192.168.0.1", "192.168.0.0", "192.168.255.255", "192.168.1.1"}
+
+ for _, ip := range privateIPs {
+ t.Run(ip, func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://"+ip, ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{ip},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+ })
+ }
+}
+
+func TestServeReverseProxy_SSRF_LinkLocal(t *testing.T) {
+ t.Parallel()
+
+ linkLocalIPs := []string{"169.254.0.1", "169.254.169.254", "169.254.255.255"}
+
+ for _, ip := range linkLocalIPs {
+ t.Run(ip, func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://"+ip, ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{ip},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+ })
+ }
+}
+
+func TestServeReverseProxy_SSRF_IPv6Loopback(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://[::1]:8080", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"::1"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+}
+
+func TestServeReverseProxy_SSRF_UnspecifiedAddress(t *testing.T) {
+ t.Parallel()
+
+ t.Run("0.0.0.0", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://0.0.0.0", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"0.0.0.0"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+ })
+
+ t.Run("IPv6 unspecified [::]", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://[::]:8080", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"::"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+ })
+}
+
+func TestServeReverseProxy_SSRF_AllowUnsafeOverride(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("ok"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{requestHostFromURL(t, target.URL)},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, "ok", string(body))
+}
+
+func TestServeReverseProxy_SSRF_LocalhostAllowedWhenUnsafe(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("localhost-ok"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{requestHostFromURL(t, target.URL)},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(t, readErr)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "localhost-ok", string(body))
+}
+
+func TestIsUnsafeIP(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ ip string
+ unsafe bool
+ }{
+ {"IPv4 loopback 127.0.0.1", "127.0.0.1", true},
+ {"IPv4 loopback 127.0.0.2", "127.0.0.2", true},
+ {"IPv6 loopback ::1", "::1", true},
+ {"10.0.0.1", "10.0.0.1", true},
+ {"10.255.255.255", "10.255.255.255", true},
+ {"172.16.0.1", "172.16.0.1", true},
+ {"172.31.255.255", "172.31.255.255", true},
+ {"192.168.0.1", "192.168.0.1", true},
+ {"192.168.255.255", "192.168.255.255", true},
+ {"169.254.0.1", "169.254.0.1", true},
+ {"169.254.169.254 AWS metadata", "169.254.169.254", true},
+ {"0.0.0.0", "0.0.0.0", true},
+ {"IPv6 unspecified ::", "::", true},
+ {"IPv4-mapped loopback ::ffff:127.0.0.1", "::ffff:127.0.0.1", true},
+ {"IPv4-mapped private ::ffff:10.0.0.1", "::ffff:10.0.0.1", true},
+ {"Documentation 192.0.0.1", "192.0.0.1", true},
+ {"Documentation 192.0.2.1", "192.0.2.1", true},
+ {"IPv4-mapped documentation ::ffff:198.51.100.10", "::ffff:198.51.100.10", true},
+ {"Documentation 198.51.100.10", "198.51.100.10", true},
+ {"Documentation 203.0.113.10", "203.0.113.10", true},
+ {"224.0.0.1", "224.0.0.1", true},
+ {"239.255.255.255", "239.255.255.255", true},
+ {"8.8.8.8 Google DNS", "8.8.8.8", false},
+ {"1.1.1.1 Cloudflare DNS", "1.1.1.1", false},
+ {"93.184.216.34 example.com", "93.184.216.34", false},
+ {"CGNAT 100.64.0.1", "100.64.0.1", true},
+ {"Benchmark 198.18.0.1", "198.18.0.1", true},
+ {"Reserved 240.0.0.1", "240.0.0.1", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ ip := parseTestIP(t, tt.ip)
+ assert.Equal(t, tt.unsafe, isUnsafeIP(ip))
+ })
+ }
+}
+
+func parseTestIP(t *testing.T, s string) net.IP {
+ t.Helper()
+
+ ip := net.ParseIP(s)
+ require.NotNil(t, ip, "failed to parse IP: %s", s)
+
+ return ip
+}
diff --git a/commons/net/http/proxy_test.go b/commons/net/http/proxy_test.go
new file mode 100644
index 00000000..632cde6b
--- /dev/null
+++ b/commons/net/http/proxy_test.go
@@ -0,0 +1,326 @@
+//go:build unit
+
+package http
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestServeReverseProxy(t *testing.T) {
+ t.Parallel()
+
+ t.Run("rejects untrusted scheme", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("http://api.partner.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"api.partner.com"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.True(t, errors.Is(err, ErrUntrustedProxyScheme))
+ })
+
+ t.Run("rejects untrusted host", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://api.partner.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"trusted.partner.com"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.True(t, errors.Is(err, ErrUntrustedProxyHost))
+ })
+
+ t.Run("rejects localhost destination", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://localhost", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"localhost"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.True(t, errors.Is(err, ErrUnsafeProxyDestination))
+ })
+
+ t.Run("proxies request when policy allows target", func(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("proxied"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{requestHostFromURL(t, target.URL)},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(t, readErr)
+ assert.Equal(t, "proxied", string(body))
+ })
+}
+
+func requestHostFromURL(t *testing.T, rawURL string) string {
+ t.Helper()
+
+ req, err := http.NewRequest(http.MethodGet, rawURL, nil)
+ require.NoError(t, err)
+
+ return req.URL.Hostname()
+}
+
+func TestDefaultReverseProxyPolicy(t *testing.T) {
+ t.Parallel()
+
+ policy := DefaultReverseProxyPolicy()
+
+ assert.Equal(t, []string{"https"}, policy.AllowedSchemes)
+ assert.Nil(t, policy.AllowedHosts)
+ assert.False(t, policy.AllowUnsafeDestinations)
+}
+
+func TestIsAllowedScheme(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ scheme string
+ allowed []string
+ want bool
+ }{
+ {"https in https list", "https", []string{"https"}, true},
+ {"http in http/https list", "http", []string{"http", "https"}, true},
+ {"ftp not in http/https list", "ftp", []string{"http", "https"}, false},
+ {"case insensitive", "HTTPS", []string{"https"}, true},
+ {"empty allowed list", "https", []string{}, false},
+ {"nil allowed list", "https", nil, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tt.want, isAllowedScheme(tt.scheme, tt.allowed))
+ })
+ }
+}
+
+func TestIsAllowedHost(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ host string
+ allowed []string
+ want bool
+ }{
+ {"exact match", "example.com", []string{"example.com"}, true},
+ {"case insensitive", "Example.COM", []string{"example.com"}, true},
+ {"not in list", "evil.com", []string{"good.com"}, false},
+ {"empty list", "example.com", []string{}, false},
+ {"nil list", "example.com", nil, false},
+ {"multiple hosts", "api.example.com", []string{"web.example.com", "api.example.com"}, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tt.want, isAllowedHost(tt.host, tt.allowed))
+ })
+ }
+}
+
+func TestServeReverseProxy_NilRequest(t *testing.T) {
+ t.Parallel()
+
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://example.com", DefaultReverseProxyPolicy(), rr, nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilProxyRequest)
+}
+
+func TestServeReverseProxy_NilResponseWriter(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+
+ err := ServeReverseProxy("https://example.com", DefaultReverseProxyPolicy(), nil, req)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilProxyResponse)
+}
+
+func TestServeReverseProxy_InvalidTargetURL(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("://invalid", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"invalid"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidProxyTarget)
+}
+
+func TestServeReverseProxy_EmptyTarget(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"example.com"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidProxyTarget)
+}
+
+func TestServeReverseProxy_SchemeValidation(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ target string
+ schemes []string
+ hosts []string
+ wantErr error
+ }{
+ {name: "file scheme rejected (no host)", target: "file:///etc/passwd", schemes: []string{"https"}, hosts: []string{""}, wantErr: ErrInvalidProxyTarget},
+ {name: "gopher scheme rejected", target: "gopher://evil.com", schemes: []string{"https"}, hosts: []string{"evil.com"}, wantErr: ErrUntrustedProxyScheme},
+ {name: "ftp scheme rejected", target: "ftp://files.example.com", schemes: []string{"https"}, hosts: []string{"files.example.com"}, wantErr: ErrUntrustedProxyScheme},
+ {name: "data scheme rejected", target: "data:text/html,Hello
", schemes: []string{"https"}, hosts: []string{""}, wantErr: ErrInvalidProxyTarget},
+ {name: "empty allowed schemes rejects everything", target: "https://example.com", schemes: []string{}, hosts: []string{"example.com"}, wantErr: ErrUntrustedProxyScheme},
+ {name: "javascript scheme rejected", target: "javascript://evil.com", schemes: []string{"https"}, hosts: []string{"evil.com"}, wantErr: ErrUntrustedProxyScheme},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy(tt.target, ReverseProxyPolicy{
+ AllowedSchemes: tt.schemes,
+ AllowedHosts: tt.hosts,
+ }, rr, req)
+
+ if tt.wantErr != nil {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestServeReverseProxy_AllowedHostEnforcement(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty allowed hosts rejects all", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://example.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUntrustedProxyHost)
+ })
+
+ t.Run("nil allowed hosts rejects all", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://example.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: nil,
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUntrustedProxyHost)
+ })
+
+ t.Run("case insensitive host matching", func(t *testing.T) {
+ t.Parallel()
+
+ target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("ok"))
+ }))
+ defer target.Close()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ host := requestHostFromURL(t, target.URL)
+
+ err := ServeReverseProxy(target.URL, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{host},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(t, readErr)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "ok", string(body))
+ })
+
+ t.Run("host not in list is rejected", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err := ServeReverseProxy("https://evil.com", ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"trusted.com", "also-trusted.com"},
+ }, rr, req)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUntrustedProxyHost)
+ })
+}
diff --git a/commons/net/http/proxy_transport.go b/commons/net/http/proxy_transport.go
new file mode 100644
index 00000000..8beea9b3
--- /dev/null
+++ b/commons/net/http/proxy_transport.go
@@ -0,0 +1,131 @@
+package http
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+// ssrfSafeTransport wraps an http.Transport with a DialContext that validates
+// resolved IP addresses against the SSRF policy at connection time.
+// This prevents DNS rebinding attacks where a hostname resolves to a safe IP
+// during validation but a private IP at connection time.
+//
+// It also implements http.RoundTripper so each outbound request is re-validated
+// immediately before dialing with the current proxy policy.
+type ssrfSafeTransport struct {
+ policy ReverseProxyPolicy
+ base *http.Transport
+}
+
+// newSSRFSafeTransport creates a transport that enforces the given proxy policy
+// on DNS resolution (via DialContext) and on each outbound request validated by RoundTrip.
+func newSSRFSafeTransport(policy ReverseProxyPolicy) *ssrfSafeTransport {
+ return newSSRFSafeTransportWithDeps(policy, net.DefaultResolver.LookupIPAddr)
+}
+
+func newSSRFSafeTransportWithDeps(
+ policy ReverseProxyPolicy,
+ lookupIPAddr func(context.Context, string) ([]net.IPAddr, error),
+) *ssrfSafeTransport {
+ dialer := &net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+
+ transport := &http.Transport{
+ TLSHandshakeTimeout: 10 * time.Second,
+ }
+
+ if !policy.AllowUnsafeDestinations {
+ policyLogger := policy.Logger
+
+ transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(addr)
+ if err != nil {
+ host = addr
+ }
+
+ ips, err := lookupIPAddr(ctx, host)
+ if err != nil {
+ if !nilcheck.Interface(policyLogger) {
+ policyLogger.Log(ctx, log.LevelWarn, "proxy DNS resolution failed",
+ log.String("host", host),
+ log.Err(err),
+ )
+ }
+
+ return nil, fmt.Errorf("%w: %w", ErrDNSResolutionFailed, err)
+ }
+
+ safeIP, err := validateResolvedIPs(ctx, ips, host, policyLogger)
+ if err != nil {
+ return nil, err
+ }
+
+ if safeIP != nil && port != "" {
+ addr = net.JoinHostPort(safeIP.String(), port)
+ } else if safeIP != nil {
+ addr = safeIP.String()
+ }
+
+ return dialer.DialContext(ctx, network, addr)
+ }
+ } else {
+ transport.DialContext = dialer.DialContext
+ }
+
+ return &ssrfSafeTransport{
+ policy: policy,
+ base: transport,
+ }
+}
+
+// RoundTrip validates each outbound request against the proxy policy before forwarding.
+func (t *ssrfSafeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ if err := validateProxyTarget(req.URL, t.policy); err != nil {
+ return nil, err
+ }
+
+ return t.base.RoundTrip(req)
+}
+
+// validateResolvedIPs checks all resolved IPs against the SSRF policy.
+// Returns the first safe IP for use in the connection, or an error if any IP
+// is unsafe or if no IPs were resolved.
+func validateResolvedIPs(ctx context.Context, ips []net.IPAddr, host string, logger log.Logger) (net.IP, error) {
+ if len(ips) == 0 {
+ if !nilcheck.Interface(logger) {
+ logger.Log(ctx, log.LevelWarn, "proxy target resolved to no IPs",
+ log.String("host", host),
+ )
+ }
+
+ return nil, ErrNoResolvedIPs
+ }
+
+ var safeIP net.IP
+
+ for _, ipAddr := range ips {
+ if isUnsafeIP(ipAddr.IP) {
+ if !nilcheck.Interface(logger) {
+ logger.Log(ctx, log.LevelWarn, "proxy target resolved to unsafe IP",
+ log.String("host", host),
+ )
+ }
+
+ return nil, ErrUnsafeProxyDestination
+ }
+
+ if safeIP == nil {
+ safeIP = ipAddr.IP
+ }
+ }
+
+ return safeIP, nil
+}
diff --git a/commons/net/http/proxy_transport_test.go b/commons/net/http/proxy_transport_test.go
new file mode 100644
index 00000000..2bf09dd6
--- /dev/null
+++ b/commons/net/http/proxy_transport_test.go
@@ -0,0 +1,140 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestServeReverseProxy_UpstreamTransportFailureReturns502(t *testing.T) {
+ t.Parallel()
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ addr := listener.Addr().String()
+ require.NoError(t, listener.Close())
+
+ req := httptest.NewRequest(http.MethodGet, "http://gateway.local/proxy", nil)
+ rr := httptest.NewRecorder()
+
+ err = ServeReverseProxy("http://"+addr, ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{"127.0.0.1"},
+ AllowUnsafeDestinations: true,
+ }, rr, req)
+ require.NoError(t, err)
+
+ resp := rr.Result()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
+}
+
+func TestSSRFSafeTransport_DialContext_RejectsPrivateIP(t *testing.T) {
+ t.Parallel()
+
+ transport := newSSRFSafeTransport(ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{"localhost"},
+ AllowUnsafeDestinations: false,
+ })
+
+ require.NotNil(t, transport)
+ require.NotNil(t, transport.base)
+ require.NotNil(t, transport.base.DialContext, "DialContext should be set when AllowUnsafeDestinations is false")
+
+ _, err := transport.base.DialContext(context.Background(), "tcp", "localhost:80")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+}
+
+func TestSSRFSafeTransport_DialContext_AllowsWhenUnsafeEnabled(t *testing.T) {
+ t.Parallel()
+
+ transport := newSSRFSafeTransport(ReverseProxyPolicy{
+ AllowedSchemes: []string{"http"},
+ AllowedHosts: []string{"localhost"},
+ AllowUnsafeDestinations: true,
+ })
+
+ require.NotNil(t, transport)
+ require.NotNil(t, transport.base)
+ require.NotNil(t, transport.base.DialContext)
+}
+
+func TestSSRFSafeTransport_RoundTrip_RejectsUntrustedScheme(t *testing.T) {
+ t.Parallel()
+
+ transport := newSSRFSafeTransport(ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"example.com"},
+ AllowUnsafeDestinations: false,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "http://example.com/path", nil)
+
+ _, err := transport.RoundTrip(req)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUntrustedProxyScheme)
+}
+
+func TestSSRFSafeTransport_RoundTrip_RejectsUntrustedHost(t *testing.T) {
+ t.Parallel()
+
+ transport := newSSRFSafeTransport(ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"trusted.com"},
+ AllowUnsafeDestinations: false,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "https://evil.com/path", nil)
+
+ _, err := transport.RoundTrip(req)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUntrustedProxyHost)
+}
+
+func TestSSRFSafeTransport_RoundTrip_RejectsPrivateIPInRedirect(t *testing.T) {
+ t.Parallel()
+
+ transport := newSSRFSafeTransport(ReverseProxyPolicy{
+ AllowedSchemes: []string{"https"},
+ AllowedHosts: []string{"127.0.0.1"},
+ AllowUnsafeDestinations: false,
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "https://127.0.0.1/admin", nil)
+
+ _, err := transport.RoundTrip(req)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrUnsafeProxyDestination)
+}
+
+func TestNewSSRFSafeTransport_PolicyIsStored(t *testing.T) {
+ t.Parallel()
+
+ policy := ReverseProxyPolicy{
+ AllowedSchemes: []string{"https", "http"},
+ AllowedHosts: []string{"api.example.com"},
+ AllowUnsafeDestinations: false,
+ }
+
+ transport := newSSRFSafeTransport(policy)
+
+ assert.Equal(t, policy.AllowedSchemes, transport.policy.AllowedSchemes)
+ assert.Equal(t, policy.AllowedHosts, transport.policy.AllowedHosts)
+ assert.Equal(t, policy.AllowUnsafeDestinations, transport.policy.AllowUnsafeDestinations)
+}
+
+func TestErrDNSResolutionFailed_Exists(t *testing.T) {
+ t.Parallel()
+
+ assert.NotNil(t, ErrDNSResolutionFailed)
+ assert.Contains(t, ErrDNSResolutionFailed.Error(), "DNS resolution failed")
+}
diff --git a/commons/net/http/proxy_validation.go b/commons/net/http/proxy_validation.go
new file mode 100644
index 00000000..2cc5f8ed
--- /dev/null
+++ b/commons/net/http/proxy_validation.go
@@ -0,0 +1,105 @@
+package http
+
+import (
+ "net"
+ "net/netip"
+ "net/url"
+ "strings"
+)
+
+var blockedProxyPrefixes = []netip.Prefix{
+ netip.MustParsePrefix("0.0.0.0/8"),
+ netip.MustParsePrefix("100.64.0.0/10"),
+ netip.MustParsePrefix("192.0.0.0/24"),
+ netip.MustParsePrefix("192.0.2.0/24"),
+ netip.MustParsePrefix("198.18.0.0/15"),
+ netip.MustParsePrefix("198.51.100.0/24"),
+ netip.MustParsePrefix("203.0.113.0/24"),
+ netip.MustParsePrefix("240.0.0.0/4"),
+}
+
+// validateProxyTarget checks a parsed URL against the reverse proxy policy.
+func validateProxyTarget(targetURL *url.URL, policy ReverseProxyPolicy) error {
+ if targetURL == nil || targetURL.Scheme == "" || targetURL.Host == "" {
+ return ErrInvalidProxyTarget
+ }
+
+ if !isAllowedScheme(targetURL.Scheme, policy.AllowedSchemes) {
+ return ErrUntrustedProxyScheme
+ }
+
+ hostname := targetURL.Hostname()
+ if hostname == "" {
+ return ErrInvalidProxyTarget
+ }
+
+ if strings.EqualFold(hostname, "localhost") && !policy.AllowUnsafeDestinations {
+ return ErrUnsafeProxyDestination
+ }
+
+ if !isAllowedHost(hostname, policy.AllowedHosts) {
+ return ErrUntrustedProxyHost
+ }
+
+ if ip := net.ParseIP(hostname); ip != nil && isUnsafeIP(ip) && !policy.AllowUnsafeDestinations {
+ return ErrUnsafeProxyDestination
+ }
+
+ return nil
+}
+
+// isAllowedScheme reports whether scheme is in the allowed list (case-insensitive).
+func isAllowedScheme(scheme string, allowed []string) bool {
+ if len(allowed) == 0 {
+ return false
+ }
+
+ for _, candidate := range allowed {
+ if strings.EqualFold(scheme, candidate) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// isAllowedHost reports whether host is in the allowed list (case-insensitive).
+func isAllowedHost(host string, allowedHosts []string) bool {
+ if len(allowedHosts) == 0 {
+ return false
+ }
+
+ for _, candidate := range allowedHosts {
+ if strings.EqualFold(host, candidate) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// isUnsafeIP reports whether ip is a loopback, private, or otherwise non-routable address.
+func isUnsafeIP(ip net.IP) bool {
+ if ip == nil {
+ return true
+ }
+
+ if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() {
+ return true
+ }
+
+ addr, ok := netip.AddrFromSlice(ip)
+ if !ok {
+ return true
+ }
+
+ addr = addr.Unmap()
+
+ for _, prefix := range blockedProxyPrefixes {
+ if prefix.Contains(addr) {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/commons/net/http/ratelimit/doc.go b/commons/net/http/ratelimit/doc.go
new file mode 100644
index 00000000..1b5859be
--- /dev/null
+++ b/commons/net/http/ratelimit/doc.go
@@ -0,0 +1,42 @@
+// Package ratelimit provides distributed rate limiting for Fiber HTTP servers backed
+// by Redis. It uses a fixed-window counter implemented as an atomic Lua script
+// (INCR + PEXPIRE) to guarantee that no key is left without a TTL even under
+// concurrent load or connection failures.
+//
+// # Quick start
+//
+// conn, _ := redis.New(ctx, cfg)
+//
+// rl := ratelimit.New(conn,
+// ratelimit.WithKeyPrefix("my-service"),
+// ratelimit.WithLogger(logger),
+// )
+//
+// // Fixed tier — applied globally
+// app.Use(rl.WithRateLimit(ratelimit.DefaultTier()))
+//
+// // Dynamic tier — write operations are rate-limited more aggressively
+// app.Use(rl.WithDynamicRateLimit(ratelimit.MethodTierSelector(
+// ratelimit.AggressiveTier(), // POST, PUT, PATCH, DELETE
+// ratelimit.DefaultTier(), // GET, HEAD, OPTIONS
+// )))
+//
+// # Nil-safe usage
+//
+// New returns nil when the rate limiter is disabled (RATE_LIMIT_ENABLED=false)
+// or when the Redis connection is nil. A nil *RateLimiter is always safe to use:
+// WithRateLimit and WithDynamicRateLimit return a pass-through handler that calls
+// c.Next() without enforcing any limit.
+//
+// # Identity functions
+//
+// The identity function determines how clients are grouped for rate limiting.
+// IdentityFromIPAndHeader combines the client IP with an HTTP header value using
+// a # separator — not : — to avoid ambiguity with IPv6 addresses (e.g.
+// "2001:db8::1#tenant-abc" instead of "2001:db8::1:tenant-abc").
+//
+// # Redis key format
+//
+// Keys follow the pattern: [prefix:]ratelimit::
+// Example: "my-service:ratelimit:default:192.168.1.1"
+package ratelimit
diff --git a/commons/net/http/ratelimit/middleware.go b/commons/net/http/ratelimit/middleware.go
new file mode 100644
index 00000000..bb26d9d0
--- /dev/null
+++ b/commons/net/http/ratelimit/middleware.go
@@ -0,0 +1,447 @@
+package ratelimit
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "net/http"
+ "strconv"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ chttp "github.com/LerianStudio/lib-commons/v4/commons/net/http"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis"
+ "github.com/gofiber/fiber/v2"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
+)
+
+const (
+ // headerRetryAfter is the standard HTTP Retry-After header.
+ headerRetryAfter = "Retry-After"
+
+ // fallback values when environment variables are not set.
+ fallbackDefaultMax = 500
+ fallbackAggressiveMax = 100
+ fallbackRelaxedMax = 1000
+ fallbackWindowSec = 60
+ fallbackRedisTimeoutMS = 500
+
+ // rateLimitTitle is the error title returned when rate limit is exceeded.
+ rateLimitTitle = "rate_limit_exceeded"
+ // rateLimitMessage is the error message returned when rate limit is exceeded.
+ rateLimitMessage = "rate limit exceeded"
+
+ // serviceUnavailableTitle is the error title returned when Redis is unavailable and fail-closed.
+ serviceUnavailableTitle = "service_unavailable"
+ // serviceUnavailableMessage is the error message returned when Redis is unavailable and fail-closed.
+ serviceUnavailableMessage = "rate limiter temporarily unavailable"
+
+ // maxReasonableTierMax is the threshold above which a configuration warning is logged.
+ maxReasonableTierMax = 100_000
+
+ // invalidWindowTitle is the error title returned when a tier has a zero or sub-millisecond window.
+ invalidWindowTitle = "misconfigured_rate_limiter"
+ // invalidWindowMessage is the error message returned when a tier has a zero or sub-millisecond window.
+ invalidWindowMessage = "rate limiter tier window is zero; contact the service operator"
+
+ // luaIncrExpire is an atomic Lua script that increments the counter, sets expiry on the
+ // first request in a window, and returns both the current count and the remaining TTL in
+ // milliseconds. Executed atomically by Redis — no other command can interleave, eliminating
+ // the race condition present in sequential INCR + EXPIRE calls. Returning the TTL from the
+ // same script avoids an extra PTTL roundtrip and ensures the value is consistent with the
+ // counter read above.
+ luaIncrExpire = `
+local count = redis.call('INCR', KEYS[1])
+if count == 1 then
+ redis.call('PEXPIRE', KEYS[1], tonumber(ARGV[1]))
+ return {count, tonumber(ARGV[1])}
+end
+local pttl = redis.call('PTTL', KEYS[1])
+return {count, pttl}
+`
+)
+
+// hashKey returns the first 16 hex characters of the SHA-256 hash of key (64-bit prefix).
+// Used in logs and traces instead of the raw key to avoid leaking client identifiers
+// (IP addresses, tenant IDs) and to keep telemetry cardinality low.
+func hashKey(key string) string {
+ h := sha256.Sum256([]byte(key))
+ return hex.EncodeToString(h[:8])
+}
+
+// Tier defines a rate limiting level with its own limits and window.
+type Tier struct {
+ // Name is a human-readable identifier for the tier (e.g., "default", "export", "dispatch").
+ Name string
+ // Max is the maximum number of requests allowed within the window.
+ Max int
+ // Window is the duration of the rate limit window.
+ Window time.Duration
+}
+
+// DefaultTier returns a tier configured via environment variables with sensible defaults.
+//
+// Environment variables:
+// - RATE_LIMIT_MAX: maximum requests (default: 500)
+// - RATE_LIMIT_WINDOW_SEC: window duration in seconds (default: 60)
+func DefaultTier() Tier {
+ return Tier{
+ Name: "default",
+ Max: int(commons.GetenvIntOrDefault("RATE_LIMIT_MAX", fallbackDefaultMax)),
+ Window: time.Duration(commons.GetenvIntOrDefault("RATE_LIMIT_WINDOW_SEC", fallbackWindowSec)) * time.Second,
+ }
+}
+
+// AggressiveTier returns a stricter tier configured via environment variables.
+//
+// Environment variables:
+// - AGGRESSIVE_RATE_LIMIT_MAX: maximum requests (default: 100)
+// - AGGRESSIVE_RATE_LIMIT_WINDOW_SEC: window duration in seconds (default: 60)
+func AggressiveTier() Tier {
+ return Tier{
+ Name: "aggressive",
+ Max: int(commons.GetenvIntOrDefault("AGGRESSIVE_RATE_LIMIT_MAX", fallbackAggressiveMax)),
+ Window: time.Duration(commons.GetenvIntOrDefault("AGGRESSIVE_RATE_LIMIT_WINDOW_SEC", fallbackWindowSec)) * time.Second,
+ }
+}
+
+// RelaxedTier returns a more permissive tier configured via environment variables.
+//
+// Environment variables:
+// - RELAXED_RATE_LIMIT_MAX: maximum requests (default: 1000)
+// - RELAXED_RATE_LIMIT_WINDOW_SEC: window duration in seconds (default: 60)
+func RelaxedTier() Tier {
+ return Tier{
+ Name: "relaxed",
+ Max: int(commons.GetenvIntOrDefault("RELAXED_RATE_LIMIT_MAX", fallbackRelaxedMax)),
+ Window: time.Duration(commons.GetenvIntOrDefault("RELAXED_RATE_LIMIT_WINDOW_SEC", fallbackWindowSec)) * time.Second,
+ }
+}
+
+// RateLimiter provides distributed rate limiting via Redis.
+// It uses a fixed window counter pattern with an atomic Lua script (INCR + PEXPIRE)
+// to prevent keys from being left without TTL on connection failures.
+//
+// A nil RateLimiter is safe to use: WithRateLimit returns a pass-through handler.
+type RateLimiter struct {
+ conn *libRedis.Client
+ logger log.Logger
+ keyPrefix string
+ identityFunc IdentityFunc
+ failOpen bool
+ onLimited func(c *fiber.Ctx, tier Tier)
+ redisTimeout time.Duration
+}
+
+// New creates a RateLimiter. Returns nil when:
+// - conn is nil
+// - RATE_LIMIT_ENABLED environment variable is set to "false"
+//
+// A nil RateLimiter is safe to use: WithRateLimit returns a pass-through handler.
+func New(conn *libRedis.Client, opts ...Option) *RateLimiter {
+ timeoutMS := commons.GetenvIntOrDefault("RATE_LIMIT_REDIS_TIMEOUT_MS", fallbackRedisTimeoutMS)
+ if timeoutMS <= 0 {
+ timeoutMS = fallbackRedisTimeoutMS
+ }
+
+ rl := &RateLimiter{
+ logger: log.NewNop(),
+ identityFunc: IdentityFromIP(),
+ failOpen: true,
+ redisTimeout: time.Duration(timeoutMS) * time.Millisecond,
+ }
+
+ for _, opt := range opts {
+ opt(rl)
+ }
+
+ if commons.GetenvOrDefault("RATE_LIMIT_ENABLED", "true") == "false" {
+ rl.logger.Log(context.Background(), log.LevelInfo,
+ "rate limiter disabled via RATE_LIMIT_ENABLED=false; all requests will pass through")
+
+ return nil
+ }
+
+ if conn == nil {
+ asserter := assert.New(context.Background(), rl.logger, "http.ratelimit", "New")
+ _ = asserter.Never(context.Background(), "redis connection is nil; rate limiter disabled")
+
+ return nil
+ }
+
+ rl.conn = conn
+
+ return rl
+}
+
+// WithRateLimit returns a fiber.Handler that applies rate limiting for the given tier.
+// If the RateLimiter is nil, it returns a pass-through handler that calls c.Next().
+func (rl *RateLimiter) WithRateLimit(tier Tier) fiber.Handler {
+ if rl == nil {
+ return func(c *fiber.Ctx) error {
+ return c.Next()
+ }
+ }
+
+ if tier.Window <= 0 || tier.Window.Milliseconds() == 0 {
+ rl.logger.Log(context.Background(), log.LevelError,
+ "rate limit tier has invalid window; all requests will be rejected",
+ log.String("tier", tier.Name),
+ log.Int("max", tier.Max),
+ )
+
+ return func(c *fiber.Ctx) error {
+ return chttp.Respond(c, http.StatusInternalServerError, chttp.ErrorResponse{
+ Code: http.StatusInternalServerError,
+ Title: invalidWindowTitle,
+ Message: invalidWindowMessage,
+ })
+ }
+ }
+
+ if tier.Max > maxReasonableTierMax {
+ rl.logger.Log(context.Background(), log.LevelWarn,
+ "rate limit tier max is unusually high; verify configuration",
+ log.String("tier", tier.Name),
+ log.Int("max", tier.Max),
+ log.Int("threshold", maxReasonableTierMax),
+ )
+ }
+
+ return func(c *fiber.Ctx) error {
+ return rl.check(c, tier)
+ }
+}
+
+// WithDefaultRateLimit is a convenience function that creates a RateLimiter and returns
+// a fiber.Handler with the default tier (500 req/60s).
+func WithDefaultRateLimit(conn *libRedis.Client, opts ...Option) fiber.Handler {
+ return New(conn, opts...).WithRateLimit(DefaultTier())
+}
+
+// WithDynamicRateLimit returns a fiber.Handler that selects the rate limit tier per
+// request using the provided TierFunc. This allows applying different limits based on
+// request attributes such as HTTP method, path, or identity.
+//
+// If the RateLimiter is nil, it returns a pass-through handler that calls c.Next().
+//
+// Example — stricter limits for write operations:
+//
+// app.Use(rl.WithDynamicRateLimit(ratelimit.MethodTierSelector(
+// ratelimit.AggressiveTier(),
+// ratelimit.DefaultTier(),
+// )))
+func (rl *RateLimiter) WithDynamicRateLimit(fn TierFunc) fiber.Handler {
+ if rl == nil || fn == nil {
+ return func(c *fiber.Ctx) error {
+ return c.Next()
+ }
+ }
+
+ return func(c *fiber.Ctx) error {
+ tier := fn(c)
+
+ if tier.Window <= 0 || tier.Window.Milliseconds() == 0 {
+ ctx := c.UserContext()
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ rl.logger.Log(ctx, log.LevelError,
+ "rate limit tier has invalid window; request rejected",
+ log.String("tier", tier.Name),
+ log.Int("max", tier.Max),
+ )
+
+ return chttp.Respond(c, http.StatusInternalServerError, chttp.ErrorResponse{
+ Code: http.StatusInternalServerError,
+ Title: invalidWindowTitle,
+ Message: invalidWindowMessage,
+ })
+ }
+
+ return rl.check(c, tier)
+ }
+}
+
+// check is the shared core of WithRateLimit and WithDynamicRateLimit. It runs the rate
+// limit check for the given tier and either passes the request through or returns an
+// appropriate error response.
+func (rl *RateLimiter) check(c *fiber.Ctx, tier Tier) error {
+ ctx := c.UserContext()
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ _, tracer, _, _ := commons.NewTrackingFromContext(ctx) //nolint:dogsled
+
+ ctx, span := tracer.Start(ctx, "middleware.ratelimit.check")
+ defer span.End()
+
+ identity := rl.identityFunc(c)
+ key := rl.buildKey(tier, identity)
+ keyHash := hashKey(key)
+
+ span.SetAttributes(
+ attribute.String("ratelimit.tier", tier.Name),
+ attribute.String("ratelimit.key_hash", keyHash),
+ )
+
+ count, ttl, err := rl.incrementCounter(ctx, key, tier)
+ if err != nil {
+ return rl.handleRedisError(c, ctx, span, tier, keyHash, err)
+ }
+
+ allowed := count <= int64(tier.Max)
+ span.SetAttributes(attribute.Bool("ratelimit.allowed", allowed))
+
+ if !allowed {
+ return rl.handleLimitExceeded(c, ctx, span, tier, keyHash, ttl)
+ }
+
+ remaining := max(int64(tier.Max)-count, 0)
+ resetAt := time.Now().Add(ttl).Unix()
+
+ c.Set(constant.RateLimitLimit, strconv.Itoa(tier.Max))
+ c.Set(constant.RateLimitRemaining, strconv.FormatInt(remaining, 10))
+ c.Set(constant.RateLimitReset, strconv.FormatInt(resetAt, 10))
+
+ return c.Next()
+}
+
+// buildKey constructs the Redis key for the rate limit counter.
+// Format: {keyPrefix}:ratelimit:{tier.Name}:{identity} (with prefix)
+// Format: ratelimit:{tier.Name}:{identity} (without prefix)
+func (rl *RateLimiter) buildKey(tier Tier, identity string) string {
+ if rl.keyPrefix != "" {
+ return fmt.Sprintf("%s:ratelimit:%s:%s", rl.keyPrefix, tier.Name, identity)
+ }
+
+ return fmt.Sprintf("ratelimit:%s:%s", tier.Name, identity)
+}
+
+// incrementCounter atomically increments the counter and sets expiry using a Lua script.
+// Returns the current count and the remaining TTL of the key. On the first request of a
+// window the TTL equals the full window; on subsequent requests it reflects the actual
+// remaining time, which is used for accurate Retry-After and X-RateLimit-Reset headers.
+func (rl *RateLimiter) incrementCounter(ctx context.Context, key string, tier Tier) (count int64, ttl time.Duration, err error) {
+ client, err := rl.conn.GetClient(ctx)
+ if err != nil {
+ return 0, 0, fmt.Errorf("get redis client: %w", err)
+ }
+
+ timeoutCtx, cancel := context.WithTimeout(ctx, rl.redisTimeout)
+ defer cancel()
+
+ vals, err := client.Eval(timeoutCtx, luaIncrExpire, []string{key}, tier.Window.Milliseconds()).Slice()
+ if err != nil {
+ return 0, 0, fmt.Errorf("redis eval failed for tier %s: %w", tier.Name, err)
+ }
+
+ if len(vals) < 2 {
+ return 0, 0, fmt.Errorf("unexpected lua result length %d for tier %s", len(vals), tier.Name)
+ }
+
+ count, ok := vals[0].(int64)
+ if !ok {
+ return 0, 0, fmt.Errorf("unexpected lua result type %T for count (tier %s)", vals[0], tier.Name)
+ }
+
+ ttlMs, ok := vals[1].(int64)
+ if !ok {
+ return 0, 0, fmt.Errorf("unexpected lua result type %T for ttl (tier %s)", vals[1], tier.Name)
+ }
+
+ // Guard against -1 (no expiry) or -2 (key not found) from PTTL; fall back to full window.
+ if ttlMs <= 0 {
+ ttlMs = tier.Window.Milliseconds()
+ }
+
+ return count, time.Duration(ttlMs) * time.Millisecond, nil
+}
+
+// handleRedisError handles a Redis communication failure during rate limit check.
+func (rl *RateLimiter) handleRedisError(
+ c *fiber.Ctx,
+ ctx context.Context,
+ span trace.Span,
+ tier Tier,
+ keyHash string,
+ err error,
+) error {
+ rl.logger.Log(ctx, log.LevelWarn, "rate limiter redis error",
+ log.String("tier", tier.Name),
+ log.String("key_hash", keyHash),
+ log.Err(err),
+ )
+
+ libOpentelemetry.HandleSpanError(span, "rate limiter redis error", err)
+
+ if rl.failOpen {
+ return c.Next()
+ }
+
+ return chttp.Respond(c, http.StatusServiceUnavailable, chttp.ErrorResponse{
+ Code: http.StatusServiceUnavailable,
+ Title: serviceUnavailableTitle,
+ Message: serviceUnavailableMessage,
+ })
+}
+
+// handleLimitExceeded handles the case when the rate limit has been exceeded.
+// ttl is the actual remaining TTL of the Redis key, used for accurate Retry-After
+// and X-RateLimit-Reset headers instead of the full window duration.
+func (rl *RateLimiter) handleLimitExceeded(
+ c *fiber.Ctx,
+ ctx context.Context,
+ span trace.Span,
+ tier Tier,
+ keyHash string,
+ ttl time.Duration,
+) error {
+ rl.logger.Log(ctx, log.LevelWarn, "rate limit exceeded",
+ log.String("tier", tier.Name),
+ log.String("key_hash", keyHash),
+ log.Int("max", tier.Max),
+ )
+
+ libOpentelemetry.HandleSpanBusinessErrorEvent(
+ span,
+ "rate limit exceeded",
+ fiber.NewError(http.StatusTooManyRequests, rateLimitMessage),
+ )
+
+ if rl.onLimited != nil {
+ rl.onLimited(c, tier)
+ }
+
+ // Ceiling division: round up to the nearest second so the client never receives a
+ // Retry-After value that has already elapsed by the time they read the response.
+ retryAfterSec := int(ttl / time.Second)
+ if ttl%time.Second > 0 {
+ retryAfterSec++
+ }
+
+ if retryAfterSec < 1 {
+ retryAfterSec = 1
+ }
+
+ resetAt := time.Now().Add(ttl).Unix()
+
+ c.Set(headerRetryAfter, strconv.Itoa(retryAfterSec))
+ c.Set(constant.RateLimitLimit, strconv.Itoa(tier.Max))
+ c.Set(constant.RateLimitRemaining, "0")
+ c.Set(constant.RateLimitReset, strconv.FormatInt(resetAt, 10))
+
+ return chttp.Respond(c, http.StatusTooManyRequests, chttp.ErrorResponse{
+ Code: http.StatusTooManyRequests,
+ Title: rateLimitTitle,
+ Message: rateLimitMessage,
+ })
+}
diff --git a/commons/net/http/ratelimit/middleware_options.go b/commons/net/http/ratelimit/middleware_options.go
new file mode 100644
index 00000000..74946204
--- /dev/null
+++ b/commons/net/http/ratelimit/middleware_options.go
@@ -0,0 +1,145 @@
+package ratelimit
+
+import (
+ "net/url"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/gofiber/fiber/v2"
+)
+
+// IdentityFunc extracts the client identity from a Fiber request context.
+// The returned string is used as part of the Redis key for rate limiting.
+type IdentityFunc func(c *fiber.Ctx) string
+
+// Option configures a RateLimiter via functional options.
+type Option func(*RateLimiter)
+
+// WithLogger provides a structured logger for rate limiter warnings and errors.
+// When not provided, a no-op logger is used.
+func WithLogger(l log.Logger) Option {
+ return func(rl *RateLimiter) {
+ if l != nil {
+ rl.logger = l
+ }
+ }
+}
+
+// WithKeyPrefix sets a service-specific prefix for Redis keys.
+// For example, WithKeyPrefix("tenant-manager") produces keys like
+// "tenant-manager:ratelimit:default:192.168.1.1".
+// When not provided, keys have no prefix: "ratelimit:default:192.168.1.1".
+func WithKeyPrefix(prefix string) Option {
+ return func(rl *RateLimiter) {
+ rl.keyPrefix = prefix
+ }
+}
+
+// WithIdentityFunc sets a custom identity extractor for rate limiting.
+// The identity function determines how clients are grouped for rate limiting.
+// When not provided, IdentityFromIP() is used.
+func WithIdentityFunc(fn IdentityFunc) Option {
+ return func(rl *RateLimiter) {
+ if fn != nil {
+ rl.identityFunc = fn
+ }
+ }
+}
+
+// WithFailOpen controls the behavior when Redis is unavailable.
+// When true (default), requests are allowed through on Redis failures.
+// When false, requests receive a 503 Service Unavailable response on Redis failures.
+func WithFailOpen(failOpen bool) Option {
+ return func(rl *RateLimiter) {
+ rl.failOpen = failOpen
+ }
+}
+
+// WithOnLimited sets an optional callback that is invoked when a request exceeds the rate limit.
+// This can be used for custom metrics, alerting, or logging beyond the built-in behavior.
+func WithOnLimited(fn func(c *fiber.Ctx, tier Tier)) Option {
+ return func(rl *RateLimiter) {
+ rl.onLimited = fn
+ }
+}
+
+// IdentityFromIP returns an IdentityFunc that extracts the client IP address.
+// This is the default identity function.
+func IdentityFromIP() IdentityFunc {
+ return func(c *fiber.Ctx) string {
+ return c.IP()
+ }
+}
+
+// IdentityFromHeader returns an IdentityFunc that extracts the value of the given
+// HTTP header, returned as "hdr:". If the header is empty, it falls
+// back to the client IP address encoded as "ip:". The type prefix
+// prevents a header value that happens to equal an IP address from colliding with the
+// IP-based fallback identity.
+func IdentityFromHeader(header string) IdentityFunc {
+ return func(c *fiber.Ctx) string {
+ if val := c.Get(header); val != "" {
+ return "hdr:" + url.QueryEscape(val)
+ }
+
+ return "ip:" + url.QueryEscape(c.IP())
+ }
+}
+
+// IdentityFromIPAndHeader returns an IdentityFunc that combines the client IP address
+// with the value of the given HTTP header. The resulting identity has the form
+// "ip:#hdr:". Both components are URL-encoded so that
+// IPv6 colons (encoded as %3A) and '#' characters (encoded as %23) cannot appear as
+// raw values, making '#' an unambiguous inter-component separator. If the header is
+// empty, only the encoded IP is returned: "ip:".
+func IdentityFromIPAndHeader(header string) IdentityFunc {
+ return func(c *fiber.Ctx) string {
+ encodedIP := url.QueryEscape(c.IP())
+ if val := c.Get(header); val != "" {
+ return "ip:" + encodedIP + "#hdr:" + url.QueryEscape(val)
+ }
+
+ return "ip:" + encodedIP
+ }
+}
+
+// WithRedisTimeout sets the timeout for Redis operations in the rate limiter.
+// If a Redis operation does not complete within the timeout, it is treated as a Redis
+// error and handled according to the fail-open/fail-closed policy (WithFailOpen).
+// Default is 500ms. Values <= 0 are ignored.
+func WithRedisTimeout(d time.Duration) Option {
+ return func(rl *RateLimiter) {
+ if d > 0 {
+ rl.redisTimeout = d
+ }
+ }
+}
+
+// TierFunc selects a rate limit Tier for the incoming request.
+// It is used with WithDynamicRateLimit to apply different limits per request attribute
+// (e.g., HTTP method, path, or authenticated identity).
+type TierFunc func(c *fiber.Ctx) Tier
+
+// MethodTierSelector returns a TierFunc that applies different tiers based on HTTP method:
+// - write: applied to POST, PUT, PATCH, DELETE (state-mutating methods)
+// - read: applied to GET, HEAD, OPTIONS and all other methods
+//
+// This mirrors the pattern where write operations are rate-limited more aggressively
+// than read operations on the same endpoint group.
+//
+// Example:
+//
+// rl.WithDynamicRateLimit(ratelimit.MethodTierSelector(
+// ratelimit.AggressiveTier(), // write
+// ratelimit.DefaultTier(), // read
+// ))
+func MethodTierSelector(write, read Tier) TierFunc {
+ return func(c *fiber.Ctx) Tier {
+ switch c.Method() {
+ case fiber.MethodPost, fiber.MethodPut, fiber.MethodPatch, fiber.MethodDelete:
+ return write
+ default:
+ return read
+ }
+ }
+}
diff --git a/commons/net/http/ratelimit/middleware_test.go b/commons/net/http/ratelimit/middleware_test.go
new file mode 100644
index 00000000..6e46df0a
--- /dev/null
+++ b/commons/net/http/ratelimit/middleware_test.go
@@ -0,0 +1,1503 @@
+//go:build unit
+
+package ratelimit
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ chttp "github.com/LerianStudio/lib-commons/v4/commons/net/http"
+ libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis"
+ "github.com/alicebob/miniredis/v2"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// warnSpy is a minimal log.Logger that captures warning messages for assertions.
+type warnSpy struct {
+ mu sync.Mutex
+ msgs []string
+}
+
+func (s *warnSpy) Log(_ context.Context, level libLog.Level, msg string, _ ...libLog.Field) {
+ if level == libLog.LevelWarn {
+ s.mu.Lock()
+ s.msgs = append(s.msgs, msg)
+ s.mu.Unlock()
+ }
+}
+
+func (s *warnSpy) With(_ ...libLog.Field) libLog.Logger { return s }
+func (s *warnSpy) WithGroup(_ string) libLog.Logger { return s }
+func (s *warnSpy) Enabled(_ libLog.Level) bool { return true }
+func (s *warnSpy) Sync(_ context.Context) error { return nil }
+
+func (s *warnSpy) hasWarn(substr string) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for _, m := range s.msgs {
+ if strings.Contains(m, substr) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func newTestMiddlewareRedisConnection(t *testing.T, mr *miniredis.Miniredis) *libRedis.Client {
+ t.Helper()
+
+ conn, err := libRedis.New(t.Context(), libRedis.Config{
+ Topology: libRedis.Topology{
+ Standalone: &libRedis.StandaloneTopology{Address: mr.Addr()},
+ },
+ Logger: &libLog.NopLogger{},
+ })
+ require.NoError(t, err)
+
+ t.Cleanup(func() { _ = conn.Close() })
+
+ return conn
+}
+
+func newTestApp(handler fiber.Handler) *fiber.App {
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Use(handler)
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ return app
+}
+
+// newTestAppWithProxyHeader creates a Fiber app that reads the client IP from
+// X-Forwarded-For. This lets tests inject any address — including IPv6 — without
+// depending on the synthetic RemoteAddr assigned by app.Test().
+func newTestAppWithProxyHeader(handler fiber.Handler) *fiber.App {
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Use(handler)
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ return app
+}
+
+func doRequest(t *testing.T, app *fiber.App) *http.Response {
+ t.Helper()
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("X-Forwarded-For", "10.0.0.1")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+}
+
+func doRequestWithHeader(t *testing.T, app *fiber.App, header, value string) *http.Response {
+ t.Helper()
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set(header, value)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+}
+
+func TestNew(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ conn *libRedis.Client
+ opts []Option
+ wantNil bool
+ checkFn func(t *testing.T, rl *RateLimiter)
+ }{
+ {
+ name: "nil connection returns nil",
+ conn: nil,
+ wantNil: true,
+ },
+ {
+ name: "valid connection returns non-nil",
+ conn: func() *libRedis.Client {
+ mr := miniredis.RunT(t)
+ return newTestMiddlewareRedisConnection(t, mr)
+ }(),
+ wantNil: false,
+ },
+ {
+ name: "with options applied",
+ conn: func() *libRedis.Client {
+ mr := miniredis.RunT(t)
+ return newTestMiddlewareRedisConnection(t, mr)
+ }(),
+ opts: []Option{
+ WithKeyPrefix("test"),
+ WithFailOpen(false),
+ },
+ wantNil: false,
+ checkFn: func(t *testing.T, rl *RateLimiter) {
+ t.Helper()
+ assert.Equal(t, "test", rl.keyPrefix)
+ assert.False(t, rl.failOpen)
+ },
+ },
+ {
+ name: "with logger option",
+ conn: func() *libRedis.Client {
+ mr := miniredis.RunT(t)
+ return newTestMiddlewareRedisConnection(t, mr)
+ }(),
+ opts: []Option{
+ WithLogger(&libLog.NopLogger{}),
+ },
+ wantNil: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ rl := New(tt.conn, tt.opts...)
+
+ if tt.wantNil {
+ assert.Nil(t, rl)
+ return
+ }
+
+ require.NotNil(t, rl)
+
+ if tt.checkFn != nil {
+ tt.checkFn(t, rl)
+ }
+ })
+ }
+}
+
+func TestMiddleware_NilRateLimiter(t *testing.T) {
+ t.Parallel()
+
+ var rl *RateLimiter
+
+ handler := rl.WithRateLimit(DefaultTier())
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestMiddleware_AllowsWithinLimit(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test", Max: 5, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ for range 5 {
+ resp := doRequest(t, app)
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+}
+
+func TestMiddleware_BlocksExceedingLimit(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-block", Max: 3, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // Use all allowed requests
+ for range 3 {
+ resp := doRequest(t, app)
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+
+ // Fourth request should be blocked
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+}
+
+func TestMiddleware_RetryAfterHeader(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-retry", Max: 1, Window: 120 * time.Second}
+ rl := New(conn)
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // First request passes
+ resp := doRequest(t, app)
+ resp.Body.Close()
+
+ // Second request is blocked
+ resp = doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+ assert.Equal(t, "120", resp.Header.Get("Retry-After"))
+}
+
+func TestMiddleware_RateLimitHeaders(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-headers", Max: 10, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "10", resp.Header.Get("X-RateLimit-Limit"))
+ assert.Equal(t, "9", resp.Header.Get("X-RateLimit-Remaining"))
+ assert.NotEmpty(t, resp.Header.Get("X-RateLimit-Reset"))
+
+ // Verify reset is a valid unix timestamp in the future
+ resetStr := resp.Header.Get("X-RateLimit-Reset")
+ resetUnix, err := strconv.ParseInt(resetStr, 10, 64)
+ require.NoError(t, err)
+ assert.Greater(t, resetUnix, time.Now().Unix()-1)
+}
+
+func TestMiddleware_RateLimitHeadersOnBlocked(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-headers-block", Max: 1, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // First request passes
+ resp := doRequest(t, app)
+ resp.Body.Close()
+
+ // Second request is blocked — check headers
+ resp = doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+ assert.Equal(t, "1", resp.Header.Get("X-RateLimit-Limit"))
+ assert.Equal(t, "0", resp.Header.Get("X-RateLimit-Remaining"))
+ assert.NotEmpty(t, resp.Header.Get("X-RateLimit-Reset"))
+ assert.Equal(t, "60", resp.Header.Get("Retry-After"))
+}
+
+func TestMiddleware_ResponseBody(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-body", Max: 1, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // First request passes
+ resp := doRequest(t, app)
+ resp.Body.Close()
+
+ // Second request is blocked — check body
+ resp = doRequest(t, app)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var errResp chttp.ErrorResponse
+ require.NoError(t, json.Unmarshal(body, &errResp))
+
+ assert.Equal(t, http.StatusTooManyRequests, errResp.Code)
+ assert.Equal(t, "rate_limit_exceeded", errResp.Title)
+ assert.Equal(t, "rate limit exceeded", errResp.Message)
+}
+
+func TestMiddleware_TierIsolation(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tierA := Tier{Name: "tier-a", Max: 2, Window: 60 * time.Second}
+ tierB := Tier{Name: "tier-b", Max: 2, Window: 60 * time.Second}
+ rl := New(conn)
+
+ appA := fiber.New(fiber.Config{DisableStartupMessage: true})
+ appA.Use(rl.WithRateLimit(tierA))
+ appA.Get("/test", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ appB := fiber.New(fiber.Config{DisableStartupMessage: true})
+ appB.Use(rl.WithRateLimit(tierB))
+ appB.Get("/test", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ // Exhaust tier A
+ for range 2 {
+ resp := doRequest(t, appA)
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+
+ // Tier A is now blocked
+ resp := doRequest(t, appA)
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+
+ // Tier B should still allow requests
+ resp = doRequest(t, appB)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestMiddleware_FailOpen(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-failopen", Max: 10, Window: 60 * time.Second}
+ rl := New(conn, WithFailOpen(true))
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // Close miniredis to simulate Redis failure
+ mr.Close()
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ // Should pass through (fail-open)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestMiddleware_FailClosed(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-failclosed", Max: 10, Window: 60 * time.Second}
+ rl := New(conn, WithFailOpen(false))
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // Close miniredis to simulate Redis failure
+ mr.Close()
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ // Should return 503 (fail-closed)
+ assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ var errResp chttp.ErrorResponse
+ require.NoError(t, json.Unmarshal(body, &errResp))
+
+ assert.Equal(t, http.StatusServiceUnavailable, errResp.Code)
+ assert.Equal(t, "service_unavailable", errResp.Title)
+}
+
+func TestMiddleware_CustomIdentityFunc(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-custom-id", Max: 2, Window: 60 * time.Second}
+ rl := New(conn, WithIdentityFunc(IdentityFromHeader("X-User-ID")))
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // User A: 2 requests allowed
+ for range 2 {
+ resp := doRequestWithHeader(t, app, "X-User-ID", "user-a")
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+
+ // User A: 3rd request blocked
+ resp := doRequestWithHeader(t, app, "X-User-ID", "user-a")
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+
+ // User B: should still be allowed (different identity)
+ resp = doRequestWithHeader(t, app, "X-User-ID", "user-b")
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestMiddleware_KeyPrefix(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-prefix", Max: 1, Window: 60 * time.Second}
+ rl := New(conn, WithKeyPrefix("my-svc"))
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ // Verify the key was created with the expected prefix in Redis
+ keys := mr.Keys()
+ require.Len(t, keys, 1)
+ assert.Contains(t, keys[0], "my-svc:ratelimit:test-prefix:")
+}
+
+func TestMiddleware_MultipleTiers(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ globalTier := Tier{Name: "global", Max: 10, Window: 60 * time.Second}
+ strictTier := Tier{Name: "strict", Max: 2, Window: 60 * time.Second}
+
+ rl := New(conn)
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Use(rl.WithRateLimit(globalTier))
+
+ strict := app.Group("/strict")
+ strict.Use(rl.WithRateLimit(strictTier))
+ strict.Get("/endpoint", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ app.Get("/normal", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ // Strict endpoint: 2 requests allowed, 3rd blocked by strict tier
+ for range 2 {
+ req := httptest.NewRequest(http.MethodGet, "/strict/endpoint", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+
+ req := httptest.NewRequest(http.MethodGet, "/strict/endpoint", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+
+ // Normal endpoint should still be allowed under global tier
+ req = httptest.NewRequest(http.MethodGet, "/normal", nil)
+
+ resp, err = app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestIdentityFromIP(t *testing.T) {
+ t.Parallel()
+
+ fn := IdentityFromIP()
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Get("/test", func(c *fiber.Ctx) error {
+ identity := fn(c)
+ return c.SendString(identity)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // Fiber returns "0.0.0.0" for test requests without a real connection
+ assert.NotEmpty(t, string(body))
+}
+
+func TestIdentityFromHeader(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ header string
+ headerVal string
+ wantPrefix string
+ }{
+ {
+ name: "header present",
+ header: "X-User-ID",
+ headerVal: "user-123",
+ wantPrefix: "hdr:user-123",
+ },
+ {
+ name: "header absent falls back to IP",
+ header: "X-User-ID",
+ headerVal: "",
+ wantPrefix: "", // will be "ip:", just check non-empty
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ fn := IdentityFromHeader(tt.header)
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString(fn(c))
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ if tt.headerVal != "" {
+ req.Header.Set(tt.header, tt.headerVal)
+ }
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ if tt.wantPrefix != "" {
+ assert.Equal(t, tt.wantPrefix, string(body))
+ } else {
+ assert.NotEmpty(t, string(body))
+ }
+ })
+ }
+}
+
+func TestIdentityFromIPAndHeader(t *testing.T) {
+ t.Parallel()
+
+ fn := IdentityFromIPAndHeader("X-Tenant-ID")
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString(fn(c))
+ })
+
+ t.Run("with header", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("X-Tenant-ID", "tenant-abc")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // Should contain the URL-encoded, prefixed form of the tenant header.
+ assert.Contains(t, string(body), "hdr:tenant-abc")
+ })
+
+ t.Run("without header", func(t *testing.T) {
+ t.Parallel()
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // Should not contain the tenant ID — only the IP is used as identity.
+ assert.NotContains(t, string(body), "tenant-abc")
+ })
+}
+
+func TestBuildKey(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ prefix string
+ tier Tier
+ identity string
+ wantKey string
+ }{
+ {
+ name: "no prefix",
+ prefix: "",
+ tier: Tier{Name: "global"},
+ identity: "192.168.1.1",
+ wantKey: "ratelimit:global:192.168.1.1",
+ },
+ {
+ name: "with prefix",
+ prefix: "tenant-manager",
+ tier: Tier{Name: "export"},
+ identity: "10.0.0.1",
+ wantKey: "tenant-manager:ratelimit:export:10.0.0.1",
+ },
+ {
+ name: "with service prefix",
+ prefix: "my-svc",
+ tier: Tier{Name: "dispatch"},
+ identity: "user-123",
+ wantKey: "my-svc:ratelimit:dispatch:user-123",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ rl := New(conn, WithKeyPrefix(tt.prefix))
+ require.NotNil(t, rl)
+
+ key := rl.buildKey(tt.tier, tt.identity)
+ assert.Equal(t, tt.wantKey, key)
+ })
+ }
+}
+
+func TestWithDefaultRateLimit(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ handler := WithDefaultRateLimit(conn)
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "500", resp.Header.Get("X-RateLimit-Limit"))
+}
+
+func TestWithDefaultRateLimit_NilConnection(t *testing.T) {
+ t.Parallel()
+
+ // WithDefaultRateLimit with nil conn should return a pass-through handler
+ handler := WithDefaultRateLimit(nil)
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestMiddleware_OnLimitedCallback(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ var callbackCalled atomic.Bool
+
+ var (
+ mu sync.Mutex
+ callbackTier Tier
+ )
+
+ tier := Tier{Name: "test-callback", Max: 1, Window: 60 * time.Second}
+ rl := New(conn, WithOnLimited(func(_ *fiber.Ctx, t Tier) {
+ callbackCalled.Store(true)
+ mu.Lock()
+ callbackTier = t
+ mu.Unlock()
+ }))
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ // First request passes
+ resp := doRequest(t, app)
+ resp.Body.Close()
+
+ // Second request triggers callback
+ resp = doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+ assert.True(t, callbackCalled.Load())
+ mu.Lock()
+ tierName := callbackTier.Name
+ mu.Unlock()
+ assert.Equal(t, "test-callback", tierName)
+}
+
+func TestTierPresets(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tier Tier
+ wantName string
+ wantMax int
+ wantWindow time.Duration
+ }{
+ {
+ name: "DefaultTier",
+ tier: DefaultTier(),
+ wantName: "default",
+ wantMax: 500,
+ wantWindow: 60 * time.Second,
+ },
+ {
+ name: "AggressiveTier",
+ tier: AggressiveTier(),
+ wantName: "aggressive",
+ wantMax: 100,
+ wantWindow: 60 * time.Second,
+ },
+ {
+ name: "RelaxedTier",
+ tier: RelaxedTier(),
+ wantName: "relaxed",
+ wantMax: 1000,
+ wantWindow: 60 * time.Second,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tt.wantName, tt.tier.Name)
+ assert.Equal(t, tt.wantMax, tt.tier.Max)
+ assert.Equal(t, tt.wantWindow, tt.tier.Window)
+ })
+ }
+}
+
+func TestMiddleware_RemainingDecrementsCorrectly(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "test-remaining", Max: 5, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestApp(rl.WithRateLimit(tier))
+
+ for i := range 5 {
+ resp := doRequest(t, app)
+
+ expectedRemaining := strconv.Itoa(4 - i)
+ assert.Equal(t, expectedRemaining, resp.Header.Get("X-RateLimit-Remaining"),
+ "request %d should have remaining=%s", i+1, expectedRemaining)
+
+ resp.Body.Close()
+ }
+}
+
+func TestMiddleware_NilIdentityFuncIgnored(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ // WithIdentityFunc(nil) should keep the default (IP-based)
+ rl := New(conn, WithIdentityFunc(nil))
+ require.NotNil(t, rl)
+ require.NotNil(t, rl.identityFunc)
+}
+
+func TestMiddleware_NilLoggerIgnored(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ // WithLogger(nil) should keep the default (nop logger)
+ rl := New(conn, WithLogger(nil))
+ require.NotNil(t, rl)
+ require.NotNil(t, rl.logger)
+}
+
+func TestNew_RateLimitEnabledEnv(t *testing.T) {
+ tests := []struct {
+ name string
+ envVal string
+ wantNil bool
+ }{
+ {
+ name: "disabled when RATE_LIMIT_ENABLED=false",
+ envVal: "false",
+ wantNil: true,
+ },
+ {
+ name: "enabled when RATE_LIMIT_ENABLED=true",
+ envVal: "true",
+ wantNil: false,
+ },
+ {
+ name: "enabled when RATE_LIMIT_ENABLED is empty",
+ envVal: "",
+ wantNil: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Setenv("RATE_LIMIT_ENABLED", tt.envVal)
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ rl := New(conn)
+
+ if tt.wantNil {
+ assert.Nil(t, rl)
+ } else {
+ assert.NotNil(t, rl)
+ }
+ })
+ }
+}
+
+func TestNew_RateLimitDisabled_PassThrough(t *testing.T) {
+ t.Setenv("RATE_LIMIT_ENABLED", "false")
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ rl := New(conn)
+ require.Nil(t, rl)
+
+ // nil receiver should return pass-through handler
+ handler := rl.WithRateLimit(DefaultTier())
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestIncrementCounter_TTLSetOnFirstIncrement(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "ttl-test", Max: 10, Window: 60 * time.Second}
+ rl := New(conn, WithKeyPrefix("svc"))
+
+ // Make one request to trigger INCR
+ app := newTestApp(rl.WithRateLimit(tier))
+ resp := doRequest(t, app)
+ resp.Body.Close()
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ // Verify that the key has a TTL set (Lua script atomicity guarantee)
+ keys := mr.Keys()
+ require.Len(t, keys, 1)
+ ttl := mr.TTL(keys[0])
+ assert.Greater(t, ttl, time.Duration(0), "key must have TTL after first increment")
+}
+
+func TestWithRedisTimeout_Applied(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ rl := New(conn, WithRedisTimeout(200*time.Millisecond))
+ require.NotNil(t, rl)
+ assert.Equal(t, 200*time.Millisecond, rl.redisTimeout)
+}
+
+func TestWithRedisTimeout_ZeroIgnored(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ rl := New(conn, WithRedisTimeout(0))
+ require.NotNil(t, rl)
+ assert.Equal(t, 500*time.Millisecond, rl.redisTimeout, "zero value should keep default timeout")
+}
+
+func TestMiddleware_DefaultRedisTimeout(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ rl := New(conn)
+ require.NotNil(t, rl)
+ assert.Equal(t, 500*time.Millisecond, rl.redisTimeout)
+}
+
+func TestMethodTierSelector_WriteMethods(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ writeTier := Tier{Name: "write", Max: 2, Window: 60 * time.Second}
+ readTier := Tier{Name: "read", Max: 10, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Use(rl.WithDynamicRateLimit(MethodTierSelector(writeTier, readTier)))
+ app.Post("/test", func(c *fiber.Ctx) error { return c.SendString("ok") })
+ app.Get("/test", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ // POST uses write tier (max 2)
+ for range 2 {
+ req := httptest.NewRequest(http.MethodPost, "/test", nil)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+
+ // 3rd POST is blocked by write tier
+ req := httptest.NewRequest(http.MethodPost, "/test", nil)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ resp.Body.Close()
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+
+ // GET uses read tier (max 10) — still allowed
+ req = httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err = app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestMethodTierSelector_ReadMethods(t *testing.T) {
+ t.Parallel()
+
+ writeTier := Tier{Name: "write", Max: 5, Window: 60 * time.Second}
+ readTier := Tier{Name: "read", Max: 100, Window: 60 * time.Second}
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+ rl := New(conn)
+
+ for _, method := range []string{
+ fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions,
+ } {
+ t.Run(method, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Use(rl.WithDynamicRateLimit(MethodTierSelector(writeTier, readTier)))
+ app.Add(method, "/test", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ req := httptest.NewRequest(method, "/test", nil)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // read tier (max 100) — first request must be allowed
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ // X-RateLimit-Limit header reflects the read tier max
+ assert.Equal(t, "100", resp.Header.Get("X-RateLimit-Limit"),
+ "method %s should use read tier (max 100)", method)
+ })
+ }
+}
+
+func TestWithDynamicRateLimit_NilTierFunc(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ // Non-nil receiver with nil TierFunc should return a pass-through handler,
+ // not panic. This differs from the nil-receiver test below.
+ rl := New(conn)
+ require.NotNil(t, rl)
+
+ handler := rl.WithDynamicRateLimit(nil)
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ // No rate-limit headers should be set — the request passed through without counting.
+ assert.Empty(t, resp.Header.Get("X-RateLimit-Limit"))
+}
+
+func TestWithDynamicRateLimit_NilRateLimiter(t *testing.T) {
+ t.Parallel()
+
+ var rl *RateLimiter
+
+ fn := MethodTierSelector(DefaultTier(), RelaxedTier())
+ handler := rl.WithDynamicRateLimit(fn)
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestTierPresets_FromEnv(t *testing.T) {
+ tests := []struct {
+ name string
+ envVars map[string]string
+ tierFn func() Tier
+ wantMax int
+ wantWindow time.Duration
+ }{
+ {
+ name: "DefaultTier reads RATE_LIMIT_MAX",
+ envVars: map[string]string{"RATE_LIMIT_MAX": "200"},
+ tierFn: DefaultTier,
+ wantMax: 200,
+ wantWindow: 60 * time.Second,
+ },
+ {
+ name: "DefaultTier reads RATE_LIMIT_WINDOW_SEC",
+ envVars: map[string]string{"RATE_LIMIT_WINDOW_SEC": "30"},
+ tierFn: DefaultTier,
+ wantMax: 500,
+ wantWindow: 30 * time.Second,
+ },
+ {
+ name: "AggressiveTier reads AGGRESSIVE_RATE_LIMIT_MAX",
+ envVars: map[string]string{"AGGRESSIVE_RATE_LIMIT_MAX": "50"},
+ tierFn: AggressiveTier,
+ wantMax: 50,
+ wantWindow: 60 * time.Second,
+ },
+ {
+ name: "AggressiveTier reads AGGRESSIVE_RATE_LIMIT_WINDOW_SEC",
+ envVars: map[string]string{"AGGRESSIVE_RATE_LIMIT_WINDOW_SEC": "120"},
+ tierFn: AggressiveTier,
+ wantMax: 100,
+ wantWindow: 120 * time.Second,
+ },
+ {
+ name: "RelaxedTier reads RELAXED_RATE_LIMIT_MAX",
+ envVars: map[string]string{"RELAXED_RATE_LIMIT_MAX": "5000"},
+ tierFn: RelaxedTier,
+ wantMax: 5000,
+ wantWindow: 60 * time.Second,
+ },
+ {
+ name: "RelaxedTier reads RELAXED_RATE_LIMIT_WINDOW_SEC",
+ envVars: map[string]string{"RELAXED_RATE_LIMIT_WINDOW_SEC": "300"},
+ tierFn: RelaxedTier,
+ wantMax: 1000,
+ wantWindow: 300 * time.Second,
+ },
+ {
+ name: "invalid env falls back to default",
+ envVars: map[string]string{"RATE_LIMIT_MAX": "not-a-number"},
+ tierFn: DefaultTier,
+ wantMax: 500,
+ wantWindow: 60 * time.Second,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ for k, v := range tt.envVars {
+ t.Setenv(k, v)
+ }
+
+ tier := tt.tierFn()
+
+ assert.Equal(t, tt.wantMax, tier.Max)
+ assert.Equal(t, tt.wantWindow, tier.Window)
+ })
+ }
+}
+
+// ── IPv6 tests ────────────────────────────────────────────────────────────────
+//
+// These tests verify that identity extractors and the rate limit middleware handle
+// IPv6 client addresses correctly. IPv6 addresses contain colons (e.g. "2001:db8::1"),
+// which is why the previous assertion in TestIdentityFromIPAndHeader ("without header"
+// sub-test) used NotContains(":") — it would have incorrectly failed for IPv6 clients.
+
+func TestIdentityFromIP_IPv6(t *testing.T) {
+ t.Parallel()
+
+ fn := IdentityFromIP()
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString(fn(c))
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("X-Forwarded-For", "2001:db8::1")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Equal(t, "2001:db8::1", string(body))
+}
+
+func TestIdentityFromIPAndHeader_IPv6_WithoutHeader(t *testing.T) {
+ t.Parallel()
+
+ fn := IdentityFromIPAndHeader("X-Tenant-ID")
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString(fn(c))
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("X-Forwarded-For", "2001:db8::1")
+ // No X-Tenant-ID — only the IPv6 address is used.
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ identity := string(body)
+
+ // With URL encoding, the IPv6 address becomes "2001%3Adb8%3A%3A1" and the identity
+ // is prefixed with "ip:". No tenant header is present so there is no ":hdr:" segment.
+ assert.Equal(t, "ip:2001%3Adb8%3A%3A1", identity)
+ assert.NotContains(t, identity, "tenant-abc")
+}
+
+func TestIdentityFromIPAndHeader_IPv6_WithHeader(t *testing.T) {
+ t.Parallel()
+
+ fn := IdentityFromIPAndHeader("X-Tenant-ID")
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString(fn(c))
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("X-Forwarded-For", "2001:db8::1")
+ req.Header.Set("X-Tenant-ID", "tenant-abc")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ // Combined identity: "ip:#hdr:" — # is the inter-component
+ // separator; IPv6 colons are URL-encoded to %3A so they can't be confused with it.
+ assert.Equal(t, "ip:2001%3Adb8%3A%3A1#hdr:tenant-abc", string(body))
+}
+
+func TestMiddleware_IPv6_RateLimiting(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "ipv6-test", Max: 2, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestAppWithProxyHeader(rl.WithRateLimit(tier))
+
+ doIPv6Req := func() *http.Response {
+ t.Helper()
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("X-Forwarded-For", "2001:db8::1")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+ }
+
+ // First two requests are allowed.
+ for range 2 {
+ resp := doIPv6Req()
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+
+ // Third request is blocked.
+ resp := doIPv6Req()
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+
+ // IdentityFromIP() returns the raw IP without encoding, so the Redis key embeds
+ // the IPv6 address as-is. URL encoding only applies to IdentityFromHeader and
+ // IdentityFromIPAndHeader.
+ keys := mr.Keys()
+ require.Len(t, keys, 1)
+ assert.Contains(t, keys[0], "2001:db8::1")
+}
+
+func TestMiddleware_IPv6_Isolation(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ tier := Tier{Name: "ipv6-isolation", Max: 1, Window: 60 * time.Second}
+ rl := New(conn)
+
+ app := newTestAppWithProxyHeader(rl.WithRateLimit(tier))
+
+ doReq := func(ip string) *http.Response {
+ t.Helper()
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("X-Forwarded-For", ip)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+ }
+
+ // IPv6 client exhausts its quota.
+ resp := doReq("2001:db8::1")
+ resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ resp = doReq("2001:db8::1")
+ resp.Body.Close()
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+
+ // A different IPv6 address has its own independent counter.
+ resp = doReq("2001:db8::2")
+ resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ // An IPv4 client also has its own independent counter.
+ resp = doReq("192.168.1.1")
+ defer resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+// TestWithRateLimit_HighTierWarning verifies that configuring a tier with Max above
+// maxReasonableTierMax causes a warning to be logged at setup time (not per request).
+func TestWithRateLimit_ZeroWindow(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+ rl := New(conn)
+ require.NotNil(t, rl)
+
+ // A zero window rounds down to PEXPIRE 0, immediately expiring all keys.
+ // The middleware must reject all requests rather than silently bypassing the limit.
+ zeroTier := Tier{Name: "bad-window", Max: 100, Window: 0}
+ handler := rl.WithRateLimit(zeroTier)
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+}
+
+func TestWithRateLimit_SubMillisecondWindow(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+ rl := New(conn)
+ require.NotNil(t, rl)
+
+ // A window smaller than 1ms truncates to 0 when converted via .Milliseconds() — also invalid.
+ subMsTier := Tier{Name: "subms-window", Max: 100, Window: 999 * time.Microsecond}
+ handler := rl.WithRateLimit(subMsTier)
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+}
+
+func TestWithDynamicRateLimit_ZeroWindow(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+ rl := New(conn)
+ require.NotNil(t, rl)
+
+ // TierFunc returns a zero-window tier on every request — must be rejected per request.
+ handler := rl.WithDynamicRateLimit(func(_ *fiber.Ctx) Tier {
+ return Tier{Name: "dynamic-bad-window", Max: 100, Window: 0}
+ })
+ app := newTestApp(handler)
+
+ resp := doRequest(t, app)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+}
+
+func TestNew_RedisTimeoutNonPositiveEnv(t *testing.T) {
+ tests := []struct {
+ name string
+ envVal string
+ }{
+ {name: "zero", envVal: "0"},
+ {name: "negative", envVal: "-100"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Setenv("RATE_LIMIT_REDIS_TIMEOUT_MS", tt.envVal)
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ rl := New(conn)
+ require.NotNil(t, rl)
+
+ assert.Equal(t, 500*time.Millisecond, rl.redisTimeout,
+ "non-positive env value should clamp to fallback timeout")
+ })
+ }
+}
+
+func TestWithRateLimit_HighTierWarning(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+
+ spy := &warnSpy{}
+ rl := New(conn, WithLogger(spy))
+ require.NotNil(t, rl)
+
+ highTier := Tier{Name: "high", Max: maxReasonableTierMax + 1, Window: 60 * time.Second}
+ handler := rl.WithRateLimit(highTier)
+ require.NotNil(t, handler)
+
+ assert.True(t, spy.hasWarn("rate limit tier max is unusually high"),
+ "expected warning when tier.Max exceeds %d", maxReasonableTierMax)
+}
+
+// TestMethodTierSelector_OtherWriteMethods verifies that PUT, PATCH, and DELETE are
+// treated as write-tier methods, consistent with POST.
+func TestMethodTierSelector_OtherWriteMethods(t *testing.T) {
+ t.Parallel()
+
+ writeTier := Tier{Name: "write", Max: 5, Window: 60 * time.Second}
+ readTier := Tier{Name: "read", Max: 100, Window: 60 * time.Second}
+
+ mr := miniredis.RunT(t)
+ conn := newTestMiddlewareRedisConnection(t, mr)
+ rl := New(conn)
+
+ for _, method := range []string{
+ fiber.MethodPut, fiber.MethodPatch, fiber.MethodDelete,
+ } {
+ m := method
+ t.Run(m, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Use(rl.WithDynamicRateLimit(MethodTierSelector(writeTier, readTier)))
+ app.Add(m, "/test", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ req := httptest.NewRequest(m, "/test", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "5", resp.Header.Get("X-RateLimit-Limit"),
+ "method %s should use write tier (max 5)", m)
+ })
+ }
+}
diff --git a/commons/net/http/ratelimit/redis_storage.go b/commons/net/http/ratelimit/redis_storage.go
new file mode 100644
index 00000000..4687322b
--- /dev/null
+++ b/commons/net/http/ratelimit/redis_storage.go
@@ -0,0 +1,258 @@
+package ratelimit
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/redis/go-redis/v9"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+
+ libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis"
+)
+
+const (
+ keyPrefix = "ratelimit:"
+ scanBatchSize = 100
+)
+
+// ErrStorageUnavailable is returned when Redis storage is nil or not initialized.
+var ErrStorageUnavailable = errors.New("ratelimit redis storage is unavailable")
+
+// RedisStorageOption is a functional option for configuring RedisStorage.
+type RedisStorageOption func(*RedisStorage)
+
+// WithRedisStorageLogger provides a structured logger for assertion and error logging.
+func WithRedisStorageLogger(l log.Logger) RedisStorageOption {
+ return func(s *RedisStorage) {
+ if l != nil {
+ s.logger = l
+ }
+ }
+}
+
+func (storage *RedisStorage) unavailableStorageError(operation string) error {
+ var logger log.Logger
+ if storage != nil {
+ logger = storage.logger
+ }
+
+ asserter := assert.New(context.Background(), logger, "http.ratelimit", operation)
+ _ = asserter.Never(context.Background(), "ratelimit redis storage is unavailable")
+
+ return ErrStorageUnavailable
+}
+
+// RedisStorage implements fiber.Storage interface using lib-commons Redis connection.
+// This enables distributed rate limiting across multiple application instances.
+type RedisStorage struct {
+ conn *libRedis.Client
+ logger log.Logger
+}
+
+// NewRedisStorage creates a new Redis-backed storage for Fiber rate limiting.
+// Returns nil if the Redis connection is nil. Options can configure a logger.
+func NewRedisStorage(conn *libRedis.Client, opts ...RedisStorageOption) *RedisStorage {
+ storage := &RedisStorage{}
+
+ for _, opt := range opts {
+ opt(storage)
+ }
+
+ if conn == nil {
+ asserter := assert.New(context.Background(), storage.logger, "http.ratelimit", "NewRedisStorage")
+ _ = asserter.Never(context.Background(), "redis connection is nil; ratelimit storage disabled")
+
+ return nil
+ }
+
+ storage.conn = conn
+
+ return storage
+}
+
+// Get retrieves the value for the given key.
+// Returns nil, nil when the key does not exist.
+func (storage *RedisStorage) Get(key string) ([]byte, error) {
+ if storage == nil || storage.conn == nil {
+ return nil, storage.unavailableStorageError("Get")
+ }
+
+ ctx := context.Background()
+ tracer := otel.Tracer("ratelimit")
+
+ ctx, span := tracer.Start(ctx, "ratelimit.get")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ client, err := storage.conn.GetClient(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to get redis client for ratelimit", err)
+ return nil, fmt.Errorf("get redis client: %w", err)
+ }
+
+ val, err := client.Get(ctx, keyPrefix+key).Bytes()
+ if errors.Is(err, redis.Nil) {
+ return nil, nil
+ }
+
+ if err != nil {
+ storage.logError(ctx, "redis get failed", err, "key", key)
+ libOpentelemetry.HandleSpanError(span, "Ratelimit redis get failed", err)
+
+ return nil, fmt.Errorf("redis get: %w", err)
+ }
+
+ return val, nil
+}
+
+// Set stores the given value for the given key with an expiration.
+// 0 expiration means no expiration. Empty key or value will be ignored.
+func (storage *RedisStorage) Set(key string, val []byte, exp time.Duration) error {
+ if storage == nil || storage.conn == nil {
+ return storage.unavailableStorageError("Set")
+ }
+
+ if key == "" || len(val) == 0 {
+ return nil
+ }
+
+ ctx := context.Background()
+ tracer := otel.Tracer("ratelimit")
+
+ ctx, span := tracer.Start(ctx, "ratelimit.set")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ client, err := storage.conn.GetClient(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to get redis client for ratelimit", err)
+ return fmt.Errorf("get redis client: %w", err)
+ }
+
+ if err := client.Set(ctx, keyPrefix+key, val, exp).Err(); err != nil {
+ storage.logError(ctx, "redis set failed", err, "key", key)
+ libOpentelemetry.HandleSpanError(span, "Ratelimit redis set failed", err)
+
+ return fmt.Errorf("redis set: %w", err)
+ }
+
+ return nil
+}
+
+// Delete removes the value for the given key.
+// Returns no error if the key does not exist.
+func (storage *RedisStorage) Delete(key string) error {
+ if storage == nil || storage.conn == nil {
+ return storage.unavailableStorageError("Delete")
+ }
+
+ ctx := context.Background()
+ tracer := otel.Tracer("ratelimit")
+
+ ctx, span := tracer.Start(ctx, "ratelimit.delete")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ client, err := storage.conn.GetClient(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to get redis client for ratelimit", err)
+ return fmt.Errorf("get redis client: %w", err)
+ }
+
+ if err := client.Del(ctx, keyPrefix+key).Err(); err != nil {
+ storage.logError(ctx, "redis delete failed", err, "key", key)
+ libOpentelemetry.HandleSpanError(span, "Ratelimit redis delete failed", err)
+
+ return fmt.Errorf("redis delete: %w", err)
+ }
+
+ return nil
+}
+
+// Reset clears all rate limit keys from the storage.
+// This uses SCAN to find and delete keys with the rate limit prefix.
+func (storage *RedisStorage) Reset() error {
+ if storage == nil || storage.conn == nil {
+ return storage.unavailableStorageError("Reset")
+ }
+
+ ctx := context.Background()
+ tracer := otel.Tracer("ratelimit")
+
+ ctx, span := tracer.Start(ctx, "ratelimit.reset")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ client, err := storage.conn.GetClient(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to get redis client for ratelimit", err)
+ return fmt.Errorf("get redis client: %w", err)
+ }
+
+ var cursor uint64
+
+ for {
+ keys, nextCursor, err := client.Scan(ctx, cursor, keyPrefix+"*", scanBatchSize).Result()
+ if err != nil {
+ storage.logError(ctx, "redis scan failed during reset", err)
+ libOpentelemetry.HandleSpanError(span, "Ratelimit redis scan failed", err)
+
+ return fmt.Errorf("redis scan: %w", err)
+ }
+
+ if len(keys) > 0 {
+ if err := client.Del(ctx, keys...).Err(); err != nil {
+ storage.logError(ctx, "redis batch delete failed during reset", err)
+ libOpentelemetry.HandleSpanError(span, "Ratelimit redis batch delete failed", err)
+
+ return fmt.Errorf("redis batch delete: %w", err)
+ }
+ }
+
+ cursor = nextCursor
+ if cursor == 0 {
+ break
+ }
+ }
+
+ return nil
+}
+
+// logError logs a Redis operation error if a logger is configured.
+func (storage *RedisStorage) logError(_ context.Context, msg string, err error, kv ...string) {
+ if storage == nil || storage.logger == nil {
+ return
+ }
+
+ fields := make([]log.Field, 0, 1+(len(kv)+1)/2)
+ fields = append(fields, log.Err(err))
+
+ for i := 0; i+1 < len(kv); i += 2 {
+ fields = append(fields, log.String(kv[i], kv[i+1]))
+ }
+
+ // Defensively handle odd-length kv: use a sentinel so missing values are obvious in logs.
+ if len(kv)%2 != 0 {
+ const missingValue = ""
+
+ fields = append(fields, log.String(kv[len(kv)-1], missingValue))
+ }
+
+ storage.logger.Log(context.Background(), log.LevelWarn, msg, fields...)
+}
+
+// Close is a no-op as the Redis connection is managed by the application lifecycle.
+func (*RedisStorage) Close() error {
+ return nil
+}
diff --git a/commons/net/http/ratelimit/redis_storage_integration_test.go b/commons/net/http/ratelimit/redis_storage_integration_test.go
new file mode 100644
index 00000000..2c774c35
--- /dev/null
+++ b/commons/net/http/ratelimit/redis_storage_integration_test.go
@@ -0,0 +1,249 @@
+//go:build integration
+
+package ratelimit
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go"
+ tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
+)
+
+// setupRedisContainer starts a disposable Redis container via testcontainers
+// and returns a connected libRedis.Client plus a teardown function.
+// The container is terminated when the returned cleanup is invoked (typically
+// via t.Cleanup).
+func setupRedisContainer(t *testing.T) (*libRedis.Client, func()) {
+ t.Helper()
+
+ ctx := context.Background()
+
+ container, err := tcredis.Run(ctx, "redis:7-alpine")
+ require.NoError(t, err, "failed to start Redis container")
+
+ // Endpoint returns "host:port" which is exactly what StandaloneTopology expects.
+ endpoint, err := container.Endpoint(ctx, "")
+ require.NoError(t, err, "failed to get Redis container endpoint")
+
+ client, err := libRedis.New(ctx, libRedis.Config{
+ Topology: libRedis.Topology{
+ Standalone: &libRedis.StandaloneTopology{Address: endpoint},
+ },
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err, "failed to create libRedis.Client")
+
+ cleanup := func() {
+ _ = client.Close()
+
+ if err := testcontainers.TerminateContainer(container); err != nil {
+ t.Logf("warning: failed to terminate Redis container: %v", err)
+ }
+ }
+
+ return client, cleanup
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_RateLimitStorage_SetAndGet
+// ---------------------------------------------------------------------------
+
+func TestIntegration_RateLimitStorage_SetAndGet(t *testing.T) {
+ client, cleanup := setupRedisContainer(t)
+ t.Cleanup(cleanup)
+
+ storage := NewRedisStorage(client, WithRedisStorageLogger(log.NewNop()))
+ require.NotNil(t, storage, "storage must not be nil with a valid connection")
+
+ key := "integration-test-key"
+ value := []byte("integration-test-value")
+
+ // Verify key does not exist before Set.
+ got, err := storage.Get(key)
+ require.NoError(t, err, "Get on non-existent key should not error")
+ assert.Nil(t, got, "Get on non-existent key should return nil")
+
+ // Set the key with a reasonable TTL.
+ err = storage.Set(key, value, 30*time.Second)
+ require.NoError(t, err, "Set should succeed")
+
+ // Get it back and verify the round-trip.
+ got, err = storage.Get(key)
+ require.NoError(t, err, "Get after Set should not error")
+ assert.Equal(t, value, got, "Get should return the exact value that was Set")
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_RateLimitStorage_Expiration
+// ---------------------------------------------------------------------------
+
+func TestIntegration_RateLimitStorage_Expiration(t *testing.T) {
+ client, cleanup := setupRedisContainer(t)
+ t.Cleanup(cleanup)
+
+ storage := NewRedisStorage(client, WithRedisStorageLogger(log.NewNop()))
+ require.NotNil(t, storage)
+
+ key := "expiring-key"
+ value := []byte("temporary-value")
+
+ // Set with a 1-second TTL.
+ err := storage.Set(key, value, 1*time.Second)
+ require.NoError(t, err, "Set with short TTL should succeed")
+
+ // Verify it exists immediately.
+ got, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, value, got, "key should exist immediately after Set")
+
+ // Wait for the key to expire. We use 1.5s to give real Redis enough
+ // headroom for its lazy/active expiry cycle.
+ time.Sleep(1500 * time.Millisecond)
+
+ // Key should now be gone.
+ got, err = storage.Get(key)
+ require.NoError(t, err, "Get on expired key should not error")
+ assert.Nil(t, got, "key should have expired and return nil")
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_RateLimitStorage_Delete
+// ---------------------------------------------------------------------------
+
+func TestIntegration_RateLimitStorage_Delete(t *testing.T) {
+ client, cleanup := setupRedisContainer(t)
+ t.Cleanup(cleanup)
+
+ storage := NewRedisStorage(client, WithRedisStorageLogger(log.NewNop()))
+ require.NotNil(t, storage)
+
+ key := "delete-me"
+ value := []byte("soon-to-be-gone")
+
+ // Set the key.
+ err := storage.Set(key, value, 30*time.Second)
+ require.NoError(t, err, "Set should succeed")
+
+ // Confirm it exists.
+ got, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, value, got, "key should exist after Set")
+
+ // Delete it.
+ err = storage.Delete(key)
+ require.NoError(t, err, "Delete should succeed")
+
+ // Confirm it is gone.
+ got, err = storage.Get(key)
+ require.NoError(t, err, "Get after Delete should not error")
+ assert.Nil(t, got, "key should be nil after Delete")
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_RateLimitStorage_Reset
+// ---------------------------------------------------------------------------
+
+func TestIntegration_RateLimitStorage_Reset(t *testing.T) {
+ client, cleanup := setupRedisContainer(t)
+ t.Cleanup(cleanup)
+
+ storage := NewRedisStorage(client, WithRedisStorageLogger(log.NewNop()))
+ require.NotNil(t, storage)
+
+ // Populate multiple keys.
+ keys := []string{"reset-a", "reset-b", "reset-c", "reset-d", "reset-e"}
+ for i, k := range keys {
+ err := storage.Set(k, []byte(fmt.Sprintf("value-%d", i)), 30*time.Second)
+ require.NoError(t, err, "Set(%s) should succeed", k)
+ }
+
+ // Verify all keys exist before Reset.
+ for _, k := range keys {
+ got, err := storage.Get(k)
+ require.NoError(t, err)
+ assert.NotNil(t, got, "key %s should exist before Reset", k)
+ }
+
+ // Reset all ratelimit keys.
+ err := storage.Reset()
+ require.NoError(t, err, "Reset should succeed")
+
+ // Verify all keys are gone.
+ for _, k := range keys {
+ got, err := storage.Get(k)
+ require.NoError(t, err, "Get(%s) after Reset should not error", k)
+ assert.Nil(t, got, "key %s should be nil after Reset", k)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_RateLimitStorage_ConcurrentAccess
+// ---------------------------------------------------------------------------
+
+func TestIntegration_RateLimitStorage_ConcurrentAccess(t *testing.T) {
+ client, cleanup := setupRedisContainer(t)
+ t.Cleanup(cleanup)
+
+ storage := NewRedisStorage(client, WithRedisStorageLogger(log.NewNop()))
+ require.NotNil(t, storage)
+
+ const goroutines = 20
+
+ var (
+ wg sync.WaitGroup
+ errCount atomic.Int32
+ )
+
+ wg.Add(goroutines)
+
+ // Each goroutine writes its own key and reads it back, exercising
+ // concurrent Set/Get against a real Redis server.
+ for i := range goroutines {
+ go func(idx int) {
+ defer wg.Done()
+
+ key := "concurrent-" + strconv.Itoa(idx)
+ value := []byte("value-" + strconv.Itoa(idx))
+
+ if err := storage.Set(key, value, 30*time.Second); err != nil {
+ errCount.Add(1)
+ return
+ }
+
+ got, err := storage.Get(key)
+ if err != nil {
+ errCount.Add(1)
+ return
+ }
+
+ if string(got) != string(value) {
+ errCount.Add(1)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+
+ assert.Equal(t, int32(0), errCount.Load(),
+ "no errors should occur during concurrent Set/Get operations")
+
+ // Verify all keys are readable after the concurrent burst.
+ for i := range goroutines {
+ key := "concurrent-" + strconv.Itoa(i)
+ expected := []byte("value-" + strconv.Itoa(i))
+
+ got, err := storage.Get(key)
+ require.NoError(t, err, "Get(%s) should succeed after concurrent writes", key)
+ assert.Equal(t, expected, got, "key %s should hold the correct value", key)
+ }
+}
diff --git a/commons/net/http/ratelimit/redis_storage_test.go b/commons/net/http/ratelimit/redis_storage_test.go
new file mode 100644
index 00000000..4ebcc77a
--- /dev/null
+++ b/commons/net/http/ratelimit/redis_storage_test.go
@@ -0,0 +1,661 @@
+//go:build unit
+
+package ratelimit
+
+import (
+ "context"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis"
+ "github.com/alicebob/miniredis/v2"
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func newTestRedisConnection(t *testing.T, mr *miniredis.Miniredis) *libRedis.Client {
+ t.Helper()
+
+ conn, err := libRedis.New(context.Background(), libRedis.Config{
+ Topology: libRedis.Topology{
+ Standalone: &libRedis.StandaloneTopology{Address: mr.Addr()},
+ },
+ Logger: &libLog.NopLogger{},
+ })
+ require.NoError(t, err)
+
+ t.Cleanup(func() { _ = conn.Close() })
+
+ return conn
+}
+
+func TestNewRedisStorage_NilConnection(t *testing.T) {
+ t.Parallel()
+
+ storage := NewRedisStorage(nil)
+ assert.Nil(t, storage)
+}
+
+func TestNewRedisStorage_ValidConnection(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+ require.NotNil(t, storage.conn)
+}
+
+func TestRedisStorage_GetSetDelete(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ key := "test-key"
+ value := []byte("test-value")
+
+ val, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Nil(t, val)
+
+ err = storage.Set(key, value, time.Minute)
+ require.NoError(t, err)
+
+ val, err = storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, value, val)
+
+ err = storage.Delete(key)
+ require.NoError(t, err)
+
+ val, err = storage.Get(key)
+ require.NoError(t, err)
+ assert.Nil(t, val)
+}
+
+func TestRedisStorage_SetEmptyKeyIgnored(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ err := storage.Set("", []byte("value"), time.Minute)
+ require.NoError(t, err)
+
+ err = storage.Set("key", nil, time.Minute)
+ require.NoError(t, err)
+}
+
+func TestRedisStorage_Reset(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ require.NoError(t, storage.Set("key1", []byte("val1"), time.Minute))
+ require.NoError(t, storage.Set("key2", []byte("val2"), time.Minute))
+
+ err := storage.Reset()
+ require.NoError(t, err)
+
+ val, err := storage.Get("key1")
+ require.NoError(t, err)
+ assert.Nil(t, val)
+
+ val, err = storage.Get("key2")
+ require.NoError(t, err)
+ assert.Nil(t, val)
+}
+
+func TestRedisStorage_Close(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ err := storage.Close()
+ require.NoError(t, err)
+}
+
+func TestRedisStorage_NilStorageOperations(t *testing.T) {
+ t.Parallel()
+
+ var storage *RedisStorage
+
+ val, err := storage.Get("key")
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+ assert.Nil(t, val)
+
+ err = storage.Set("key", []byte("value"), time.Minute)
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+
+ err = storage.Delete("key")
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+
+ err = storage.Reset()
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+
+ err = storage.Close()
+ require.NoError(t, err)
+}
+
+func TestRedisStorage_KeyPrefix(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ require.NoError(t, storage.Set("test", []byte("value"), time.Minute))
+
+ client := redis.NewClient(&redis.Options{
+ Addr: mr.Addr(),
+ })
+
+ t.Cleanup(func() { _ = client.Close() })
+
+ val, err := client.Get(t.Context(), "ratelimit:test").Bytes()
+ require.NoError(t, err)
+ assert.Equal(t, []byte("value"), val)
+}
+
+// --- New comprehensive test coverage below ---
+
+func TestRedisStorage_ConcurrentIncrementOperations(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ const workers = 50
+ const key = "concurrent-counter"
+
+ var wg sync.WaitGroup
+
+ wg.Add(workers)
+
+ var errCount int32
+
+ for range workers {
+ go func() {
+ defer wg.Done()
+
+ // Simulate incrementing a counter: read, parse, increment, write
+ val, err := storage.Get(key)
+ if err != nil {
+ atomic.AddInt32(&errCount, 1)
+ return
+ }
+
+ counter := 0
+ if val != nil {
+ counter, _ = strconv.Atoi(string(val))
+ }
+
+ counter++
+
+ if err := storage.Set(key, []byte(strconv.Itoa(counter)), time.Minute); err != nil {
+ atomic.AddInt32(&errCount, 1)
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ // Verify no errors occurred (no panics, no crashes)
+ assert.Equal(t, int32(0), atomic.LoadInt32(&errCount))
+
+ // Verify the key exists and has a value (exact value depends on race ordering,
+ // which is expected - this test validates no crashes under contention)
+ val, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.NotNil(t, val)
+
+ counter, err := strconv.Atoi(string(val))
+ require.NoError(t, err)
+ assert.Greater(t, counter, 0)
+}
+
+func TestRedisStorage_ConcurrentSetGet(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ const workers = 20
+ var wg sync.WaitGroup
+
+ wg.Add(workers * 2) // writers + readers
+
+ // Writers
+ for i := range workers {
+ go func(idx int) {
+ defer wg.Done()
+
+ key := "concurrent-key-" + strconv.Itoa(idx)
+ val := []byte("value-" + strconv.Itoa(idx))
+
+ _ = storage.Set(key, val, time.Minute)
+ }(i)
+ }
+
+ // Readers (concurrent with writers)
+ for i := range workers {
+ go func(idx int) {
+ defer wg.Done()
+
+ key := "concurrent-key-" + strconv.Itoa(idx)
+
+ _, _ = storage.Get(key)
+ }(i)
+ }
+
+ wg.Wait()
+
+ // Verify all keys were written
+ for i := range workers {
+ key := "concurrent-key-" + strconv.Itoa(i)
+ val, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, []byte("value-"+strconv.Itoa(i)), val)
+ }
+}
+
+func TestRedisStorage_TTLExpiration(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ key := "expiring-key"
+ value := []byte("temporary")
+
+ // Set with short TTL
+ err := storage.Set(key, value, time.Second)
+ require.NoError(t, err)
+
+ // Verify it exists
+ val, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, value, val)
+
+ // Fast-forward miniredis time past the TTL
+ mr.FastForward(2 * time.Second)
+
+ // Now the key should be expired
+ val, err = storage.Get(key)
+ require.NoError(t, err)
+ assert.Nil(t, val, "key should have expired after TTL")
+}
+
+func TestRedisStorage_ZeroTTL(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ key := "no-expiry-key"
+ value := []byte("persistent")
+
+ // Set with 0 TTL (no expiration)
+ err := storage.Set(key, value, 0)
+ require.NoError(t, err)
+
+ // Fast-forward time significantly
+ mr.FastForward(24 * time.Hour)
+
+ // Key should still exist
+ val, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, value, val)
+}
+
+func TestRedisStorage_MultipleKeySimultaneous(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ keys := map[string][]byte{
+ "key-alpha": []byte("value-alpha"),
+ "key-beta": []byte("value-beta"),
+ "key-gamma": []byte("value-gamma"),
+ "key-delta": []byte("value-delta"),
+ "key-epsilon": []byte("value-epsilon"),
+ }
+
+ // Set all keys
+ for k, v := range keys {
+ require.NoError(t, storage.Set(k, v, time.Minute))
+ }
+
+ // Verify all keys exist with correct values
+ for k, expected := range keys {
+ val, err := storage.Get(k)
+ require.NoError(t, err)
+ assert.Equal(t, expected, val, "key %s should have correct value", k)
+ }
+
+ // Delete one key
+ require.NoError(t, storage.Delete("key-gamma"))
+
+ // Verify deleted key returns nil
+ val, err := storage.Get("key-gamma")
+ require.NoError(t, err)
+ assert.Nil(t, val)
+
+ // Verify other keys still exist
+ val, err = storage.Get("key-alpha")
+ require.NoError(t, err)
+ assert.Equal(t, []byte("value-alpha"), val)
+}
+
+func TestRedisStorage_LargeCounterValues(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Store a very large counter value
+ largeValue := strconv.Itoa(999999999)
+ err := storage.Set("large-counter", []byte(largeValue), time.Minute)
+ require.NoError(t, err)
+
+ val, err := storage.Get("large-counter")
+ require.NoError(t, err)
+ assert.Equal(t, []byte(largeValue), val)
+}
+
+func TestRedisStorage_LargeByteValue(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Store a large byte slice
+ largeVal := make([]byte, 1024*10) // 10KB
+ for i := range largeVal {
+ largeVal[i] = byte(i % 256)
+ }
+
+ err := storage.Set("large-value", largeVal, time.Minute)
+ require.NoError(t, err)
+
+ val, err := storage.Get("large-value")
+ require.NoError(t, err)
+ assert.Equal(t, largeVal, val)
+}
+
+func TestRedisStorage_SetOverwrite(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ key := "overwrite-key"
+
+ // Set initial value
+ require.NoError(t, storage.Set(key, []byte("original"), time.Minute))
+
+ val, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, []byte("original"), val)
+
+ // Overwrite with new value
+ require.NoError(t, storage.Set(key, []byte("updated"), time.Minute))
+
+ val, err = storage.Get(key)
+ require.NoError(t, err)
+ assert.Equal(t, []byte("updated"), val)
+}
+
+func TestRedisStorage_DeleteNonExistentKey(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Delete a key that doesn't exist should not error
+ err := storage.Delete("non-existent-key")
+ require.NoError(t, err)
+}
+
+func TestRedisStorage_GetNonExistentKey(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ val, err := storage.Get("non-existent-key")
+ require.NoError(t, err)
+ assert.Nil(t, val)
+}
+
+func TestRedisStorage_ResetOnlyRateLimitKeys(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Set rate limit keys through storage (these get the ratelimit: prefix)
+ require.NoError(t, storage.Set("limit-key-1", []byte("1"), time.Minute))
+ require.NoError(t, storage.Set("limit-key-2", []byte("2"), time.Minute))
+
+ // Set a non-rate-limit key directly via Redis (no ratelimit: prefix)
+ client := redis.NewClient(&redis.Options{Addr: mr.Addr()})
+ t.Cleanup(func() { _ = client.Close() })
+
+ require.NoError(t, client.Set(t.Context(), "other:key", "non-ratelimit", time.Minute).Err())
+
+ // Reset should only clear ratelimit: prefixed keys
+ err := storage.Reset()
+ require.NoError(t, err)
+
+ // Rate limit keys should be gone
+ val, err := storage.Get("limit-key-1")
+ require.NoError(t, err)
+ assert.Nil(t, val)
+
+ val, err = storage.Get("limit-key-2")
+ require.NoError(t, err)
+ assert.Nil(t, val)
+
+ // Non-rate-limit key should still exist
+ otherVal, err := client.Get(t.Context(), "other:key").Result()
+ require.NoError(t, err)
+ assert.Equal(t, "non-ratelimit", otherVal)
+}
+
+func TestRedisStorage_ResetEmptyStorage(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Reset on empty storage should not error
+ err := storage.Reset()
+ require.NoError(t, err)
+}
+
+func TestRedisStorage_NilConnOperations(t *testing.T) {
+ t.Parallel()
+
+ // Storage with nil conn field (manually constructed)
+ storage := &RedisStorage{conn: nil}
+
+ val, err := storage.Get("key")
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+ assert.Nil(t, val)
+
+ err = storage.Set("key", []byte("value"), time.Minute)
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+
+ err = storage.Delete("key")
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+
+ err = storage.Reset()
+ require.ErrorIs(t, err, ErrStorageUnavailable)
+}
+
+func TestRedisStorage_SetEmptyValue(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Empty byte slice should be ignored (same as nil)
+ err := storage.Set("key", []byte{}, time.Minute)
+ require.NoError(t, err)
+
+ // Key should not exist since empty value is ignored
+ val, err := storage.Get("key")
+ require.NoError(t, err)
+ assert.Nil(t, val)
+}
+
+func TestRedisStorage_CloseIsNoop(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Close should be a no-op and return nil
+ err := storage.Close()
+ require.NoError(t, err)
+
+ // Storage should still work after Close (since Close is a no-op)
+ err = storage.Set("after-close", []byte("value"), time.Minute)
+ require.NoError(t, err)
+
+ val, err := storage.Get("after-close")
+ require.NoError(t, err)
+ assert.Equal(t, []byte("value"), val)
+}
+
+func TestRedisStorage_ResetManyKeys(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Set more keys than scanBatchSize (100) to test pagination in SCAN
+ const numKeys = 150
+ for i := range numKeys {
+ key := "batch-key-" + strconv.Itoa(i)
+ require.NoError(t, storage.Set(key, []byte(strconv.Itoa(i)), time.Minute))
+ }
+
+ // Reset should clear all of them
+ err := storage.Reset()
+ require.NoError(t, err)
+
+ // Verify they're all gone
+ for i := range numKeys {
+ key := "batch-key-" + strconv.Itoa(i)
+ val, err := storage.Get(key)
+ require.NoError(t, err)
+ assert.Nil(t, val, "key %s should be deleted after reset", key)
+ }
+}
+
+func TestRedisStorage_SetWithDifferentTTLs(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newTestRedisConnection(t, mr)
+
+ storage := NewRedisStorage(conn)
+ require.NotNil(t, storage)
+
+ // Set keys with different TTLs
+ require.NoError(t, storage.Set("short-ttl", []byte("short"), 1*time.Second))
+ require.NoError(t, storage.Set("long-ttl", []byte("long"), 1*time.Hour))
+
+ // Both should exist initially
+ val, err := storage.Get("short-ttl")
+ require.NoError(t, err)
+ assert.Equal(t, []byte("short"), val)
+
+ val, err = storage.Get("long-ttl")
+ require.NoError(t, err)
+ assert.Equal(t, []byte("long"), val)
+
+ // Fast-forward past short TTL but before long TTL
+ mr.FastForward(5 * time.Second)
+
+ // Short TTL should be gone
+ val, err = storage.Get("short-ttl")
+ require.NoError(t, err)
+ assert.Nil(t, val, "short-ttl should have expired")
+
+ // Long TTL should still exist
+ val, err = storage.Get("long-ttl")
+ require.NoError(t, err)
+ assert.Equal(t, []byte("long"), val, "long-ttl should still exist")
+}
diff --git a/commons/net/http/ratelimit/server_test.go b/commons/net/http/ratelimit/server_test.go
new file mode 100644
index 00000000..199ba441
--- /dev/null
+++ b/commons/net/http/ratelimit/server_test.go
@@ -0,0 +1,439 @@
+//go:build unit
+
+// Package ratelimit_test demonstrates the rate limit middleware in realistic Fiber
+// server configurations. Unlike middleware_test.go (which tests individual behaviors
+// in isolation using white-box access), this file uses only the public API and builds
+// complete API servers that mirror production usage patterns.
+package ratelimit_test
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "testing"
+ "time"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/net/http/ratelimit"
+ libRedis "github.com/LerianStudio/lib-commons/v4/commons/redis"
+ "github.com/alicebob/miniredis/v2"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ── Helpers ───────────────────────────────────────────────────────────────────
+
+func newServerTestConn(t *testing.T, mr *miniredis.Miniredis) *libRedis.Client {
+ t.Helper()
+
+ conn, err := libRedis.New(t.Context(), libRedis.Config{
+ Topology: libRedis.Topology{
+ Standalone: &libRedis.StandaloneTopology{Address: mr.Addr()},
+ },
+ Logger: &libLog.NopLogger{},
+ })
+ require.NoError(t, err)
+
+ t.Cleanup(func() { _ = conn.Close() })
+
+ return conn
+}
+
+// buildAPIServer returns a Fiber app that resembles a production multi-tier API:
+//
+// GET /public/info → relaxed tier (max = 20 req / 60 s)
+// GET /admin/config → strict tier (max = 3 req / 60 s)
+// GET /api/items → read tier (max = 10 req / 60 s)
+// POST /api/items → write tier (max = 3 req / 60 s)
+//
+// Write/read tiers are applied via WithDynamicRateLimit + MethodTierSelector,
+// demonstrating method-sensitive rate limiting on the same route group.
+func buildAPIServer(rl *ratelimit.RateLimiter) *fiber.App {
+ relaxed := ratelimit.Tier{Name: "public", Max: 20, Window: 60 * time.Second}
+ strict := ratelimit.Tier{Name: "admin", Max: 3, Window: 60 * time.Second}
+ write := ratelimit.Tier{Name: "write", Max: 3, Window: 60 * time.Second}
+ read := ratelimit.Tier{Name: "read", Max: 10, Window: 60 * time.Second}
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ // ProxyHeader lets tests inject any IP (including IPv6) via X-Forwarded-For.
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+
+ public := app.Group("/public")
+ public.Use(rl.WithRateLimit(relaxed))
+ public.Get("/info", func(c *fiber.Ctx) error {
+ return c.JSON(fiber.Map{"status": "ok"})
+ })
+
+ admin := app.Group("/admin")
+ admin.Use(rl.WithRateLimit(strict))
+ admin.Get("/config", func(c *fiber.Ctx) error {
+ return c.JSON(fiber.Map{"config": "redacted"})
+ })
+
+ api := app.Group("/api")
+ api.Use(rl.WithDynamicRateLimit(ratelimit.MethodTierSelector(write, read)))
+ api.Get("/items", func(c *fiber.Ctx) error {
+ return c.JSON(fiber.Map{"items": []string{"a", "b", "c"}})
+ })
+ api.Post("/items", func(c *fiber.Ctx) error {
+ return c.JSON(fiber.Map{"created": true})
+ })
+
+ return app
+}
+
+func serverGet(t *testing.T, app *fiber.App, path, ip string) *http.Response {
+ t.Helper()
+
+ req := httptest.NewRequest(http.MethodGet, path, nil)
+ if ip != "" {
+ req.Header.Set("X-Forwarded-For", ip)
+ }
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+}
+
+func serverPost(t *testing.T, app *fiber.App, path, ip string) *http.Response {
+ t.Helper()
+
+ req := httptest.NewRequest(http.MethodPost, path, nil)
+ if ip != "" {
+ req.Header.Set("X-Forwarded-For", ip)
+ }
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+}
+
+// ── Tests ─────────────────────────────────────────────────────────────────────
+
+// TestServer_MultiTierRouteGroups verifies that different route groups enforce
+// their own independent limits. Exhausting the /admin tier must not affect /public.
+func TestServer_MultiTierRouteGroups(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newServerTestConn(t, mr)
+
+ rl := ratelimit.New(conn, ratelimit.WithKeyPrefix("svc"))
+ app := buildAPIServer(rl)
+
+ const ip = "10.0.0.1"
+
+ // Exhaust the strict admin tier (max = 3).
+ for i := 1; i <= 3; i++ {
+ resp := serverGet(t, app, "/admin/config", ip)
+ resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "admin request %d should pass", i)
+ }
+
+ resp := serverGet(t, app, "/admin/config", ip)
+ resp.Body.Close()
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "4th admin request should be blocked")
+
+ // The public tier (max = 20) must be completely unaffected.
+ resp = serverGet(t, app, "/public/info", ip)
+ defer resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "public route should remain accessible")
+ assert.Equal(t, "20", resp.Header.Get("X-RateLimit-Limit"))
+}
+
+// TestServer_WindowReset verifies that the rate limit counter resets automatically
+// after the window elapses. miniredis.FastForward is used to advance the internal
+// clock without real wall-clock waiting.
+func TestServer_WindowReset(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newServerTestConn(t, mr)
+
+ shortWindow := ratelimit.Tier{Name: "short", Max: 2, Window: 5 * time.Second}
+ rl := ratelimit.New(conn)
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Use(rl.WithRateLimit(shortWindow))
+ app.Get("/ping", func(c *fiber.Ctx) error { return c.SendString("pong") })
+
+ doReq := func() *http.Response {
+ req := httptest.NewRequest(http.MethodGet, "/ping", nil)
+ req.Header.Set("X-Forwarded-For", "10.0.0.3")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+ }
+
+ // Exhaust the window (max = 2).
+ for i := 1; i <= 2; i++ {
+ resp := doReq()
+ resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "request %d within window should pass", i)
+ }
+
+ // Third request is blocked.
+ resp := doReq()
+ resp.Body.Close()
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "3rd request should be blocked")
+
+ // Advance miniredis clock past the 5-second window.
+ mr.FastForward(6 * time.Second)
+
+ // First request of the new window must pass, remaining resets to max-1.
+ resp = doReq()
+ defer resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "request after window reset should pass")
+ assert.Equal(t, "1", resp.Header.Get("X-RateLimit-Remaining"), "counter should have reset")
+}
+
+// TestServer_ClientIsolation verifies that each client IP (IPv4 and IPv6) maintains
+// its own independent counter, so one client exhausting its quota does not affect others.
+func TestServer_ClientIsolation(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newServerTestConn(t, mr)
+
+ tier := ratelimit.Tier{Name: "iso", Max: 2, Window: 60 * time.Second}
+ rl := ratelimit.New(conn)
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Use(rl.WithRateLimit(tier))
+ app.Get("/data", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ clients := []string{
+ "10.1.1.1", // IPv4
+ "10.1.1.2", // IPv4 different subnet
+ "2001:db8::1", // IPv6
+ "2001:db8::2", // IPv6 different host
+ }
+
+ type result struct {
+ ip string
+ blocked bool
+ err error
+ }
+
+ results := make([]result, len(clients))
+
+ var wg sync.WaitGroup
+
+ for i, ip := range clients {
+ wg.Add(1)
+
+ go func() {
+ defer wg.Done()
+
+ // Each client fires 2 requests within quota.
+ for range 2 {
+ req := httptest.NewRequest(http.MethodGet, "/data", nil)
+ req.Header.Set("X-Forwarded-For", ip)
+
+ resp, err := app.Test(req, -1)
+ if err != nil {
+ results[i] = result{ip: ip, err: err}
+ return
+ }
+
+ resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ results[i] = result{ip: ip, blocked: true}
+ return
+ }
+ }
+
+ // 3rd request should be blocked for this client only.
+ req := httptest.NewRequest(http.MethodGet, "/data", nil)
+ req.Header.Set("X-Forwarded-For", ip)
+
+ resp, err := app.Test(req, -1)
+ if err != nil {
+ results[i] = result{ip: ip, err: err}
+ return
+ }
+
+ resp.Body.Close()
+ results[i] = result{ip: ip, blocked: resp.StatusCode == http.StatusTooManyRequests}
+ }()
+ }
+
+ wg.Wait()
+
+ for _, r := range results {
+ require.NoError(t, r.err, "client %s: unexpected request error", r.ip)
+ assert.True(t, r.blocked, "client %s: 3rd request should have been blocked", r.ip)
+ }
+}
+
+// TestServer_TenantIsolation verifies that IdentityFromIPAndHeader creates per-tenant
+// buckets. The same IP with different X-Tenant-ID values gets independent counters.
+func TestServer_TenantIsolation(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newServerTestConn(t, mr)
+
+ tier := ratelimit.Tier{Name: "tenant", Max: 2, Window: 60 * time.Second}
+ rl := ratelimit.New(conn,
+ ratelimit.WithIdentityFunc(ratelimit.IdentityFromIPAndHeader("X-Tenant-ID")),
+ )
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ app.Use(rl.WithRateLimit(tier))
+ app.Get("/resource", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ doTenantReq := func(tenantID string) *http.Response {
+ req := httptest.NewRequest(http.MethodGet, "/resource", nil)
+ req.Header.Set("X-Tenant-ID", tenantID)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+ }
+
+ // Tenant A exhausts its quota (max = 2).
+ for range 2 {
+ resp := doTenantReq("tenant-a")
+ resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ }
+
+ resp := doTenantReq("tenant-a")
+ resp.Body.Close()
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode, "tenant-a should be blocked")
+
+ // Tenant B has its own counter — completely unaffected.
+ resp = doTenantReq("tenant-b")
+ defer resp.Body.Close()
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "tenant-b should not be affected")
+ assert.Equal(t, "1", resp.Header.Get("X-RateLimit-Remaining"))
+}
+
+// TestServer_HeadersProgression verifies that X-RateLimit-Remaining decrements
+// accurately across a full sequence of requests and that Retry-After is set on
+// the blocking response with a ceiling-rounded value.
+func TestServer_HeadersProgression(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newServerTestConn(t, mr)
+
+ tier := ratelimit.Tier{Name: "progress", Max: 5, Window: 60 * time.Second}
+ rl := ratelimit.New(conn)
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Use(rl.WithRateLimit(tier))
+ app.Get("/count", func(c *fiber.Ctx) error { return c.SendString("ok") })
+
+ type snapshot struct{ limit, remaining string }
+
+ expected := []snapshot{
+ {"5", "4"},
+ {"5", "3"},
+ {"5", "2"},
+ {"5", "1"},
+ {"5", "0"},
+ }
+
+ for i, want := range expected {
+ req := httptest.NewRequest(http.MethodGet, "/count", nil)
+ req.Header.Set("X-Forwarded-For", "172.16.0.1")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode, "request %d", i+1)
+ assert.Equal(t, want.limit, resp.Header.Get("X-RateLimit-Limit"), "request %d: limit", i+1)
+ assert.Equal(t, want.remaining, resp.Header.Get("X-RateLimit-Remaining"), "request %d: remaining", i+1)
+ assert.NotEmpty(t, resp.Header.Get("X-RateLimit-Reset"), "request %d: reset timestamp", i+1)
+ }
+
+ // 6th request — verify the blocking response headers.
+ req := httptest.NewRequest(http.MethodGet, "/count", nil)
+ req.Header.Set("X-Forwarded-For", "172.16.0.1")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+ assert.Equal(t, "5", resp.Header.Get("X-RateLimit-Limit"))
+ assert.Equal(t, "0", resp.Header.Get("X-RateLimit-Remaining"))
+ assert.NotEmpty(t, resp.Header.Get("X-RateLimit-Reset"))
+ // Retry-After must be ≥ 1 (ceiling division guarantees this).
+ retryAfter := resp.Header.Get("Retry-After")
+ assert.NotEmpty(t, retryAfter)
+ assert.NotEqual(t, "0", retryAfter, "Retry-After must be at least 1 second")
+}
+
+// TestServer_RetryAfter_CeilingDivision verifies that when the remaining TTL is
+// sub-second (e.g. 100ms), the Retry-After header is 1, not 0.
+//
+// This exercises the ceiling division in handleLimitExceeded:
+//
+// retryAfterSec := int(ttl / time.Second) // = 0 for 100ms
+// if ttl%time.Second > 0 { retryAfterSec++ } // ceil → 1
+func TestServer_RetryAfter_CeilingDivision(t *testing.T) {
+ t.Parallel()
+
+ mr := miniredis.RunT(t)
+ conn := newServerTestConn(t, mr)
+
+ // 3-second window so FastForward can leave a sub-second TTL remainder.
+ tier := ratelimit.Tier{Name: "ceiling", Max: 1, Window: 3 * time.Second}
+ rl := ratelimit.New(conn)
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ ProxyHeader: fiber.HeaderXForwardedFor,
+ })
+ app.Use(rl.WithRateLimit(tier))
+ app.Get("/ping", func(c *fiber.Ctx) error { return c.SendString("pong") })
+
+ doReq := func() *http.Response {
+ req := httptest.NewRequest(http.MethodGet, "/ping", nil)
+ req.Header.Set("X-Forwarded-For", "10.99.0.1")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ return resp
+ }
+
+ // First request exhausts the quota and sets TTL = 3s.
+ resp := doReq()
+ resp.Body.Close()
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ // Advance to ~100ms before expiry: TTL drops to sub-second.
+ mr.FastForward(2900 * time.Millisecond)
+
+ // Blocked response: TTL ≈ 100ms → ceiling division must yield 1, not 0.
+ resp = doReq()
+ defer resp.Body.Close()
+
+ require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
+ assert.Equal(t, "1", resp.Header.Get("Retry-After"),
+ "Retry-After should be 1 (ceiling of sub-second TTL), not 0")
+}
+
diff --git a/commons/net/http/response.go b/commons/net/http/response.go
index d53fc1e2..1c7df28f 100644
--- a/commons/net/http/response.go
+++ b/commons/net/http/response.go
@@ -1,124 +1,33 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
- "github.com/LerianStudio/lib-commons/v2/commons"
- "github.com/gofiber/fiber/v2"
"net/http"
- "strconv"
-)
-
-const NotImplementedMessage = "Not implemented yet"
-
-// Unauthorized sends an HTTP 401 Unauthorized response with a custom code, title and message.
-func Unauthorized(c *fiber.Ctx, code, title, message string) error {
- return c.Status(http.StatusUnauthorized).JSON(commons.Response{
- Code: code,
- Title: title,
- Message: message,
- })
-}
-
-// Forbidden sends an HTTP 403 Forbidden response with a custom code, title and message.
-func Forbidden(c *fiber.Ctx, code, title, message string) error {
- return c.Status(http.StatusForbidden).JSON(commons.Response{
- Code: code,
- Title: title,
- Message: message,
- })
-}
-
-// BadRequest sends an HTTP 400 Bad Request response with a custom body.
-func BadRequest(c *fiber.Ctx, s any) error {
- return c.Status(http.StatusBadRequest).JSON(s)
-}
-
-// Created sends an HTTP 201 Created response with a custom body.
-func Created(c *fiber.Ctx, s any) error {
- return c.Status(http.StatusCreated).JSON(s)
-}
-
-// OK sends an HTTP 200 OK response with a custom body.
-func OK(c *fiber.Ctx, s any) error {
- return c.Status(http.StatusOK).JSON(s)
-}
-
-// NoContent sends an HTTP 204 No Content response without anybody.
-func NoContent(c *fiber.Ctx) error {
- return c.SendStatus(http.StatusNoContent)
-}
-
-// Accepted sends an HTTP 202 Accepted response with a custom body.
-func Accepted(c *fiber.Ctx, s any) error {
- return c.Status(http.StatusAccepted).JSON(s)
-}
-
-// PartialContent sends an HTTP 206 Partial Content response with a custom body.
-func PartialContent(c *fiber.Ctx, s any) error {
- return c.Status(http.StatusPartialContent).JSON(s)
-}
-
-// RangeNotSatisfiable sends an HTTP 416 Requested Range Not Satisfiable response.
-func RangeNotSatisfiable(c *fiber.Ctx) error {
- return c.SendStatus(http.StatusRequestedRangeNotSatisfiable)
-}
-// NotFound sends an HTTP 404 Not Found response with a custom code, title and message.
-func NotFound(c *fiber.Ctx, code, title, message string) error {
- return c.Status(http.StatusNotFound).JSON(commons.Response{
- Code: code,
- Title: title,
- Message: message,
- })
-}
+ "github.com/gofiber/fiber/v2"
+)
-// Conflict sends an HTTP 409 Conflict response with a custom code, title and message.
-func Conflict(c *fiber.Ctx, code, title, message string) error {
- return c.Status(http.StatusConflict).JSON(commons.Response{
- Code: code,
- Title: title,
- Message: message,
- })
-}
+// Respond sends a JSON response with explicit status.
+func Respond(c *fiber.Ctx, status int, payload any) error {
+ if c == nil {
+ return ErrContextNotFound
+ }
-// NotImplemented sends an HTTP 501 Not Implemented response with a custom message.
-func NotImplemented(c *fiber.Ctx, message string) error {
- return c.Status(http.StatusNotImplemented).JSON(commons.Response{
- Code: strconv.Itoa(http.StatusNotImplemented),
- Title: NotImplementedMessage,
- Message: message,
- })
-}
+ if status < http.StatusContinue || status > 599 {
+ status = http.StatusInternalServerError
+ }
-// UnprocessableEntity sends an HTTP 422 Unprocessable Entity response with a custom code, title and message.
-func UnprocessableEntity(c *fiber.Ctx, code, title, message string) error {
- return c.Status(http.StatusUnprocessableEntity).JSON(commons.Response{
- Code: code,
- Title: title,
- Message: message,
- })
+ return c.Status(status).JSON(payload)
}
-// InternalServerError sends an HTTP 500 Internal Server Response response
-func InternalServerError(c *fiber.Ctx, code, title, message string) error {
- return c.Status(http.StatusInternalServerError).JSON(commons.Response{
- Code: code,
- Title: title,
- Message: message,
- })
-}
-
-// JSONResponseError sends a JSON formatted error response with a custom error struct.
-func JSONResponseError(c *fiber.Ctx, err commons.Response) error {
- code, _ := strconv.Atoi(err.Code)
+// RespondStatus sends a status-only response with no body.
+func RespondStatus(c *fiber.Ctx, status int) error {
+ if c == nil {
+ return ErrContextNotFound
+ }
- return c.Status(code).JSON(err)
-}
+ if status < http.StatusContinue || status > 599 {
+ status = http.StatusInternalServerError
+ }
-// JSONResponse sends a custom status code and body as a JSON response.
-func JSONResponse(c *fiber.Ctx, status int, s any) error {
- return c.Status(status).JSON(s)
+ return c.SendStatus(status)
}
diff --git a/commons/net/http/response_test.go b/commons/net/http/response_test.go
new file mode 100644
index 00000000..d2a72caa
--- /dev/null
+++ b/commons/net/http/response_test.go
@@ -0,0 +1,135 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestRespond_NegativeStatus(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", func(c *fiber.Ctx) error {
+ return Respond(c, -1, fiber.Map{"ok": true})
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+}
+
+func TestRespond_Status599IsValid(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", func(c *fiber.Ctx) error {
+ return Respond(c, 599, fiber.Map{"ok": true})
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, 599, resp.StatusCode)
+}
+
+func TestRespond_Status100IsValid(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", func(c *fiber.Ctx) error {
+ return Respond(c, http.StatusContinue, fiber.Map{"data": "x"})
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusContinue, resp.StatusCode)
+}
+
+func TestRespond_EmptyPayload(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", func(c *fiber.Ctx) error {
+ return Respond(c, http.StatusOK, fiber.Map{})
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+
+ var result map[string]any
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+ assert.Empty(t, result)
+}
+
+func TestRespondStatus_Status600ClampedTo500(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", func(c *fiber.Ctx) error {
+ return RespondStatus(c, 600)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+}
+
+// ---------------------------------------------------------------------------
+// Nil guard tests
+// ---------------------------------------------------------------------------
+
+func TestRespond_NilContext(t *testing.T) {
+ t.Parallel()
+
+ err := Respond(nil, 200, fiber.Map{"ok": true})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestRespondStatus_NilContext(t *testing.T) {
+ t.Parallel()
+
+ err := RespondStatus(nil, 200)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
+
+func TestRespondStatus_NoContent(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Delete("/", func(c *fiber.Ctx) error {
+ return RespondStatus(c, http.StatusNoContent)
+ })
+
+ req := httptest.NewRequest(http.MethodDelete, "/", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusNoContent, resp.StatusCode)
+}
diff --git a/commons/net/http/validation.go b/commons/net/http/validation.go
new file mode 100644
index 00000000..7c3e1a22
--- /dev/null
+++ b/commons/net/http/validation.go
@@ -0,0 +1,295 @@
+package http
+
+import (
+ "errors"
+ "fmt"
+ "mime"
+ "strings"
+ "sync"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/go-playground/validator/v10"
+ "github.com/gofiber/fiber/v2"
+ "github.com/shopspring/decimal"
+)
+
+// Validation errors.
+var (
+ // ErrValidationFailed is returned when struct validation fails.
+ ErrValidationFailed = errors.New("validation failed")
+ // ErrFieldRequired is returned when a required field is missing.
+ ErrFieldRequired = errors.New("field is required")
+ // ErrFieldMaxLength is returned when a field exceeds maximum length.
+ ErrFieldMaxLength = errors.New("field exceeds maximum length")
+ // ErrQueryParamTooLong is returned when a query parameter exceeds its maximum length.
+ ErrQueryParamTooLong = errors.New("query parameter exceeds maximum length")
+ // ErrFieldMinLength is returned when a field is below minimum length.
+ ErrFieldMinLength = errors.New("field below minimum length")
+ // ErrFieldGreaterThan is returned when a field must be greater than a value.
+ ErrFieldGreaterThan = errors.New("field must be greater than constraint")
+ // ErrFieldGreaterThanOrEqual is returned when a field must be greater than or equal to a value.
+ ErrFieldGreaterThanOrEqual = errors.New("field must be greater than or equal to constraint")
+ // ErrFieldLessThan is returned when a field must be less than a value.
+ ErrFieldLessThan = errors.New("field must be less than constraint")
+ // ErrFieldLessThanOrEqual is returned when a field must be less than or equal to a value.
+ ErrFieldLessThanOrEqual = errors.New("field must be less than or equal to constraint")
+ // ErrFieldOneOf is returned when a field must be one of allowed values.
+ ErrFieldOneOf = errors.New("field must be one of allowed values")
+ // ErrFieldEmail is returned when a field must be a valid email.
+ ErrFieldEmail = errors.New("field must be a valid email")
+ // ErrFieldURL is returned when a field must be a valid URL.
+ ErrFieldURL = errors.New("field must be a valid URL")
+ // ErrFieldUUID is returned when a field must be a valid UUID.
+ ErrFieldUUID = errors.New("field must be a valid UUID")
+ // ErrFieldPositiveAmount is returned when a field must be a positive amount.
+ ErrFieldPositiveAmount = errors.New("field must be a positive amount")
+ // ErrFieldNonNegativeAmount is returned when a field must be a non-negative amount.
+ ErrFieldNonNegativeAmount = errors.New("field must be a non-negative amount")
+ // ErrBodyParseFailed is returned when request body parsing fails.
+ ErrBodyParseFailed = errors.New("failed to parse request body")
+ // ErrUnsupportedContentType is returned when the Content-Type is not application/json.
+ ErrUnsupportedContentType = errors.New("Content-Type must be application/json")
+)
+
+// ErrValidatorInit is returned when custom validator registration fails during initialization.
+var ErrValidatorInit = errors.New("validator initialization failed")
+
+var (
+ validate *validator.Validate
+ validateOnce sync.Once
+ errValidate error
+)
+
+// initValidators creates and configures the validator with custom validation rules.
+// Returns an error if any custom validator registration fails.
+func initValidators() (*validator.Validate, error) {
+ vld := validator.New(validator.WithRequiredStructEnabled())
+
+ // Note: We do NOT register a custom type function for decimal.Decimal
+ // because returning the same type causes an infinite loop in the validator.
+ // Instead, custom validators like positive_decimal access the field directly.
+
+ // Register custom validator for decimal amounts that must be positive
+ if err := vld.RegisterValidation("positive_decimal", func(fl validator.FieldLevel) bool {
+ value, ok := fl.Field().Interface().(decimal.Decimal)
+ if !ok {
+ return false
+ }
+
+ return value.IsPositive()
+ }); err != nil {
+ return nil, fmt.Errorf("%w: failed to register 'positive_decimal': %w", ErrValidatorInit, err)
+ }
+
+ // Register custom validator for string amounts that must be positive
+ if err := vld.RegisterValidation("positive_amount", func(fl validator.FieldLevel) bool {
+ str := fl.Field().String()
+ if str == "" {
+ return true // Let required tag handle empty strings
+ }
+
+ d, parseErr := decimal.NewFromString(str)
+ if parseErr != nil {
+ return false
+ }
+
+ return d.IsPositive()
+ }); err != nil {
+ return nil, fmt.Errorf("%w: failed to register 'positive_amount': %w", ErrValidatorInit, err)
+ }
+
+ // Register custom validator for string amounts that must be non-negative
+ if err := vld.RegisterValidation("nonnegative_amount", func(fl validator.FieldLevel) bool {
+ str := fl.Field().String()
+ if str == "" {
+ return true // Let required tag handle empty strings
+ }
+
+ d, parseErr := decimal.NewFromString(str)
+ if parseErr != nil {
+ return false
+ }
+
+ return !d.IsNegative()
+ }); err != nil {
+ return nil, fmt.Errorf("%w: failed to register 'nonnegative_amount': %w", ErrValidatorInit, err)
+ }
+
+ return vld, nil
+}
+
+// GetValidator returns the singleton validator instance.
+// Returns the validator and any initialization error that may have occurred.
+func GetValidator() (*validator.Validate, error) {
+ validateOnce.Do(func() {
+ validate, errValidate = initValidators()
+ })
+
+ return validate, errValidate
+}
+
+// ValidateStruct validates a struct using the go-playground/validator tags.
+// Returns nil if validation passes, or the first validation error.
+func ValidateStruct(payload any) error {
+ vld, initErr := GetValidator()
+ if initErr != nil {
+ return fmt.Errorf("%w: %w", ErrValidationFailed, initErr)
+ }
+
+ if err := vld.Struct(payload); err != nil {
+ var validationErrors validator.ValidationErrors
+ if errors.As(err, &validationErrors) && len(validationErrors) > 0 {
+ return formatValidationError(validationErrors[0])
+ }
+
+ return fmt.Errorf("%w: %w", ErrValidationFailed, err)
+ }
+
+ return nil
+}
+
+// validationErrorFormatters maps validation tags to their error formatting functions.
+// Using a map-based approach reduces cyclomatic complexity compared to a large switch.
+var validationErrorFormatters = map[string]func(field, param string) error{
+ "required": func(field, _ string) error {
+ return fmt.Errorf("%w: '%s'", ErrFieldRequired, field)
+ },
+ "max": func(field, param string) error {
+ return fmt.Errorf("%w: '%s' must be at most %s", ErrFieldMaxLength, field, param)
+ },
+ "min": func(field, param string) error {
+ return fmt.Errorf("%w: '%s' must be at least %s", ErrFieldMinLength, field, param)
+ },
+ "gt": func(field, param string) error {
+ return fmt.Errorf("%w: '%s' must be greater than %s", ErrFieldGreaterThan, field, param)
+ },
+ "gte": func(field, param string) error {
+ return fmt.Errorf("%w: '%s' must be at least %s", ErrFieldGreaterThanOrEqual, field, param)
+ },
+ "lt": func(field, param string) error {
+ return fmt.Errorf("%w: '%s' must be less than %s", ErrFieldLessThan, field, param)
+ },
+ "lte": func(field, param string) error {
+ return fmt.Errorf("%w: '%s' must be at most %s", ErrFieldLessThanOrEqual, field, param)
+ },
+ "oneof": func(field, param string) error {
+ return fmt.Errorf("%w: '%s' must be one of [%s]", ErrFieldOneOf, field, param)
+ },
+ "email": func(field, _ string) error {
+ return fmt.Errorf("%w: '%s'", ErrFieldEmail, field)
+ },
+ "url": func(field, _ string) error {
+ return fmt.Errorf("%w: '%s'", ErrFieldURL, field)
+ },
+ "uuid": func(field, _ string) error {
+ return fmt.Errorf("%w: '%s'", ErrFieldUUID, field)
+ },
+ "positive_amount": func(field, _ string) error {
+ return fmt.Errorf("%w: '%s'", ErrFieldPositiveAmount, field)
+ },
+ "positive_decimal": func(field, _ string) error {
+ return fmt.Errorf("%w: '%s'", ErrFieldPositiveAmount, field)
+ },
+ "nonnegative_amount": func(field, _ string) error {
+ return fmt.Errorf("%w: '%s'", ErrFieldNonNegativeAmount, field)
+ },
+}
+
+// formatValidationError creates a user-friendly error message from a validation error.
+func formatValidationError(fe validator.FieldError) error {
+ field := toSnakeCase(fe.Field())
+
+ if formatter, ok := validationErrorFormatters[fe.Tag()]; ok {
+ return formatter(field, fe.Param())
+ }
+
+ return fmt.Errorf("%w: '%s' failed '%s' check", ErrValidationFailed, field, fe.Tag())
+}
+
+// toSnakeCase converts a PascalCase or camelCase string to snake_case.
+func toSnakeCase(s string) string {
+ var result strings.Builder
+
+ for i, r := range s {
+ if i > 0 && r >= 'A' && r <= 'Z' {
+ result.WriteByte('_')
+ }
+
+ result.WriteRune(r)
+ }
+
+ return strings.ToLower(result.String())
+}
+
+// ParseBodyAndValidate parses the request body into the given struct and validates it.
+// Returns a bad request error if parsing or validation fails.
+// Rejects requests with explicit non-JSON Content-Type headers to provide clear
+// error messages while preserving existing parser behavior when the header is absent.
+func ParseBodyAndValidate(fiberCtx *fiber.Ctx, payload any) error {
+ if fiberCtx == nil {
+ return ErrContextNotFound
+ }
+
+ ct := strings.TrimSpace(fiberCtx.Get(fiber.HeaderContentType))
+ if ct != "" {
+ mediaType, _, err := mime.ParseMediaType(ct)
+ if err != nil {
+ mediaType = strings.TrimSpace(strings.SplitN(ct, ";", 2)[0])
+ }
+
+ if !strings.EqualFold(mediaType, fiber.MIMEApplicationJSON) {
+ return ErrUnsupportedContentType
+ }
+
+ fiberCtx.Request().Header.SetContentType(fiber.MIMEApplicationJSON)
+ }
+
+ if err := fiberCtx.BodyParser(payload); err != nil {
+ return fmt.Errorf("%w: %w", ErrBodyParseFailed, err)
+ }
+
+ return ValidateStruct(payload)
+}
+
+// ValidateSortDirection validates and normalizes a sort direction string.
+// Only "ASC" and "DESC" (case-insensitive) are allowed.
+// Returns "ASC" as the safe default for any invalid input.
+func ValidateSortDirection(dir string) string {
+ upper := strings.ToUpper(strings.TrimSpace(dir))
+ if upper == cn.SortDirDESC {
+ return cn.SortDirDESC
+ }
+
+ return cn.SortDirASC
+}
+
+// ValidateLimit validates and normalizes a pagination limit.
+// It ensures the limit is within the allowed range [1, maxLimit].
+// If limit is <= 0, returns defaultLimit. If limit > maxLimit, returns maxLimit.
+func ValidateLimit(limit, defaultLimit, maxLimit int) int {
+ if limit <= 0 {
+ return defaultLimit
+ }
+
+ if limit > maxLimit {
+ return maxLimit
+ }
+
+ return limit
+}
+
+// MaxQueryParamLengthShort is the maximum length for short query parameters (action, entity_type, status).
+const MaxQueryParamLengthShort = 50
+
+// MaxQueryParamLengthLong is the maximum length for long query parameters (actor, assigned_to).
+const MaxQueryParamLengthLong = 255
+
+// ValidateQueryParamLength checks that a query parameter value does not exceed maxLen.
+// Returns nil if the value is within bounds, or a descriptive error if it exceeds the limit.
+func ValidateQueryParamLength(value, name string, maxLen int) error {
+ if len(value) > maxLen {
+ return fmt.Errorf("%w: '%s' must be at most %d characters", ErrQueryParamTooLong, name, maxLen)
+ }
+
+ return nil
+}
diff --git a/commons/net/http/validation_parse_test.go b/commons/net/http/validation_parse_test.go
new file mode 100644
index 00000000..792e87c2
--- /dev/null
+++ b/commons/net/http/validation_parse_test.go
@@ -0,0 +1,179 @@
+//go:build unit
+
+package http
+
+import (
+ "bytes"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseBodyAndValidate(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ body string
+ contentType string
+ payload any
+ wantErr bool
+ errContains string
+ errIs error
+ }{
+ {
+ name: "valid JSON payload",
+ body: `{"name":"test","email":"test@example.com","priority":1}`,
+ contentType: "application/json",
+ payload: &testPayload{},
+ wantErr: false,
+ },
+ {
+ name: "invalid JSON",
+ body: `{"name": invalid}`,
+ contentType: "application/json",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "failed to parse request body",
+ errIs: ErrBodyParseFailed,
+ },
+ {
+ name: "valid JSON but validation fails",
+ body: `{"name":"","email":"test@example.com","priority":1}`,
+ contentType: "application/json",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "field is required: 'name'",
+ errIs: ErrFieldRequired,
+ },
+ {
+ name: "empty body",
+ body: "",
+ contentType: "application/json",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "failed to parse request body",
+ errIs: ErrBodyParseFailed,
+ },
+ {
+ name: "application/json with charset is accepted",
+ body: `{"name":"test","email":"test@example.com","priority":1}`,
+ contentType: "application/json; charset=utf-8",
+ payload: &testPayload{},
+ wantErr: false,
+ },
+ {
+ name: "empty Content-Type falls through to body parser",
+ body: `{"name":"test","email":"test@example.com","priority":1}`,
+ contentType: "",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "failed to parse request body",
+ errIs: ErrBodyParseFailed,
+ },
+ {
+ name: "JSON Content-Type with surrounding whitespace is accepted",
+ body: `{"name":"test","email":"test@example.com","priority":1}`,
+ contentType: " application/json ; charset=utf-8 ",
+ payload: &testPayload{},
+ wantErr: false,
+ },
+ {
+ name: "JSON Content-Type is case-insensitive",
+ body: `{"name":"test","email":"test@example.com","priority":1}`,
+ contentType: "Application/JSON",
+ payload: &testPayload{},
+ wantErr: false,
+ },
+ {
+ name: "text/plain Content-Type is rejected",
+ body: `{"name":"test","email":"test@example.com","priority":1}`,
+ contentType: "text/plain",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "Content-Type must be application/json",
+ errIs: ErrUnsupportedContentType,
+ },
+ {
+ name: "text/xml Content-Type is rejected",
+ body: ``,
+ contentType: "text/xml",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "Content-Type must be application/json",
+ errIs: ErrUnsupportedContentType,
+ },
+ {
+ name: "multipart/form-data Content-Type is rejected",
+ body: `{"name":"test"}`,
+ contentType: "multipart/form-data",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "Content-Type must be application/json",
+ errIs: ErrUnsupportedContentType,
+ },
+ {
+ name: "application/jsonx is rejected",
+ body: `{"name":"test","email":"test@example.com","priority":1}`,
+ contentType: "application/jsonx",
+ payload: &testPayload{},
+ wantErr: true,
+ errContains: "Content-Type must be application/json",
+ errIs: ErrUnsupportedContentType,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var gotErr error
+ app.Post("/test", func(c *fiber.Ctx) error {
+ gotErr = ParseBodyAndValidate(c, tc.payload)
+ if gotErr != nil {
+ return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": gotErr.Error()})
+ }
+
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewBufferString(tc.body))
+ if tc.contentType != "" {
+ req.Header.Set("Content-Type", tc.contentType)
+ }
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+
+ defer func() {
+ require.NoError(t, resp.Body.Close())
+ }()
+
+ if tc.wantErr {
+ require.Error(t, gotErr)
+ assert.Contains(t, gotErr.Error(), tc.errContains)
+ if tc.errIs != nil {
+ assert.ErrorIs(t, gotErr, tc.errIs)
+ }
+ assert.Equal(t, fiber.StatusBadRequest, resp.StatusCode)
+ } else {
+ require.NoError(t, gotErr)
+ assert.Equal(t, fiber.StatusOK, resp.StatusCode)
+ }
+ })
+ }
+}
+
+func TestParseBodyAndValidate_NilContext(t *testing.T) {
+ t.Parallel()
+
+ payload := &testPayload{}
+ err := ParseBodyAndValidate(nil, payload)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrContextNotFound)
+}
diff --git a/commons/net/http/validation_query_test.go b/commons/net/http/validation_query_test.go
new file mode 100644
index 00000000..33a39474
--- /dev/null
+++ b/commons/net/http/validation_query_test.go
@@ -0,0 +1,140 @@
+//go:build unit
+
+package http
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateSortDirection(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {name: "uppercase ASC", input: "ASC", want: "ASC"},
+ {name: "uppercase DESC", input: "DESC", want: "DESC"},
+ {name: "lowercase asc", input: "asc", want: "ASC"},
+ {name: "lowercase desc", input: "desc", want: "DESC"},
+ {name: "mixed case Asc", input: "Asc", want: "ASC"},
+ {name: "mixed case Desc", input: "Desc", want: "DESC"},
+ {name: "empty string defaults to ASC", input: "", want: "ASC"},
+ {name: "whitespace only defaults to ASC", input: " ", want: "ASC"},
+ {name: "with leading whitespace", input: " DESC", want: "DESC"},
+ {name: "with trailing whitespace", input: "ASC ", want: "ASC"},
+ {name: "invalid value defaults to ASC", input: "INVALID", want: "ASC"},
+ {name: "SQL injection attempt defaults to ASC", input: "ASC; DROP TABLE users;--", want: "ASC"},
+ {name: "partial match defaults to ASC", input: "ASCENDING", want: "ASC"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got := ValidateSortDirection(tt.input)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestValidateLimit(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ limit int
+ defaultLimit int
+ maxLimit int
+ expected int
+ }{
+ {"zero uses default", 0, 20, 100, 20},
+ {"negative uses default", -5, 20, 100, 20},
+ {"valid limit unchanged", 50, 20, 100, 50},
+ {"exceeds max capped", 150, 20, 100, 100},
+ {"equals max unchanged", 100, 20, 100, 100},
+ {"equals default", 20, 20, 100, 20},
+ {"min valid (1)", 1, 20, 100, 1},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := ValidateLimit(tc.limit, tc.defaultLimit, tc.maxLimit)
+ assert.Equal(t, tc.expected, result)
+ })
+ }
+}
+
+func TestValidateQueryParamLength(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ value string
+ paramName string
+ maxLen int
+ wantErr bool
+ errContains string
+ }{
+ {
+ name: "value within limit",
+ value: "CREATE",
+ paramName: "action",
+ maxLen: 50,
+ wantErr: false,
+ },
+ {
+ name: "value at exact limit",
+ value: strings.Repeat("a", 50),
+ paramName: "action",
+ maxLen: 50,
+ wantErr: false,
+ },
+ {
+ name: "value exceeds limit",
+ value: strings.Repeat("a", 51),
+ paramName: "action",
+ maxLen: 50,
+ wantErr: true,
+ errContains: "'action' must be at most 50 characters",
+ },
+ {
+ name: "empty value always valid",
+ value: "",
+ paramName: "actor",
+ maxLen: 255,
+ wantErr: false,
+ },
+ {
+ name: "long value exceeds short limit",
+ value: strings.Repeat("x", 256),
+ paramName: "entity_type",
+ maxLen: 255,
+ wantErr: true,
+ errContains: "'entity_type' must be at most 255 characters",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := ValidateQueryParamLength(tc.value, tc.paramName, tc.maxLen)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrQueryParamTooLong)
+ assert.Contains(t, err.Error(), tc.errContains)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/commons/net/http/validation_rules_amount_test.go b/commons/net/http/validation_rules_amount_test.go
new file mode 100644
index 00000000..d6d14775
--- /dev/null
+++ b/commons/net/http/validation_rules_amount_test.go
@@ -0,0 +1,106 @@
+//go:build unit
+
+package http
+
+import (
+ "testing"
+
+ "github.com/shopspring/decimal"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPositiveDecimalValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ amount decimal.Decimal
+ wantErr bool
+ }{
+ {name: "positive amount is valid", amount: decimal.NewFromFloat(100.50), wantErr: false},
+ {name: "zero is invalid", amount: decimal.Zero, wantErr: true},
+ {name: "negative is invalid", amount: decimal.NewFromFloat(-50.00), wantErr: true},
+ {name: "small positive is valid", amount: decimal.NewFromFloat(0.01), wantErr: false},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testPositiveDecimalPayload{Amount: tc.amount}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "amount")
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestPositiveAmountValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ amount string
+ wantErr bool
+ }{
+ {name: "positive amount is valid", amount: "100.50", wantErr: false},
+ {name: "zero is invalid", amount: "0", wantErr: true},
+ {name: "negative is invalid", amount: "-50.00", wantErr: true},
+ {name: "empty string is valid (let required handle it)", amount: "", wantErr: false},
+ {name: "invalid decimal string", amount: "not-a-number", wantErr: true},
+ {name: "small positive is valid", amount: "0.01", wantErr: false},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testPositiveAmountPayload{Amount: tc.amount}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestNonNegativeAmountValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ amount string
+ wantErr bool
+ }{
+ {name: "positive amount is valid", amount: "100.50", wantErr: false},
+ {name: "zero is valid", amount: "0", wantErr: false},
+ {name: "negative is invalid", amount: "-50.00", wantErr: true},
+ {name: "empty string is valid (let required handle it)", amount: "", wantErr: false},
+ {name: "invalid decimal string", amount: "not-a-number", wantErr: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testNonNegativeAmountPayload{Amount: tc.amount}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "amount")
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/commons/net/http/validation_rules_field_test.go b/commons/net/http/validation_rules_field_test.go
new file mode 100644
index 00000000..eada85ab
--- /dev/null
+++ b/commons/net/http/validation_rules_field_test.go
@@ -0,0 +1,159 @@
+//go:build unit
+
+package http
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestURLValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ website string
+ wantErr bool
+ }{
+ {name: "valid HTTP URL", website: "http://example.com", wantErr: false},
+ {name: "valid HTTPS URL", website: "https://example.com/path", wantErr: false},
+ {name: "invalid URL", website: "not-a-url", wantErr: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testURLPayload{Website: tc.website}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrFieldURL)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestUUIDValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ id string
+ wantErr bool
+ }{
+ {name: "valid UUID", id: "550e8400-e29b-41d4-a716-446655440000", wantErr: false},
+ {name: "invalid UUID", id: "not-a-uuid", wantErr: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testUUIDPayload{ID: tc.id}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrFieldUUID)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestLteValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ value int
+ wantErr bool
+ }{
+ {name: "value less than constraint is valid", value: 50, wantErr: false},
+ {name: "value equal to constraint is valid", value: 100, wantErr: false},
+ {name: "value greater than constraint is invalid", value: 101, wantErr: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testLtePayload{Value: tc.value}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrFieldLessThanOrEqual)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestLtValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ value int
+ wantErr bool
+ }{
+ {name: "value less than constraint is valid", value: 50, wantErr: false},
+ {name: "value equal to constraint is invalid", value: 100, wantErr: true},
+ {name: "value greater than constraint is invalid", value: 101, wantErr: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testLtPayload{Value: tc.value}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrFieldLessThan)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestMinValidator(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ value string
+ wantErr bool
+ }{
+ {name: "value at minimum is valid", value: "hello", wantErr: false},
+ {name: "value above minimum is valid", value: "hello world", wantErr: false},
+ {name: "value below minimum is invalid", value: "hi", wantErr: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ payload := testMinPayload{Name: tc.value}
+ err := ValidateStruct(&payload)
+
+ if tc.wantErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrFieldMinLength)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/commons/net/http/validation_test.go b/commons/net/http/validation_test.go
new file mode 100644
index 00000000..cdea2150
--- /dev/null
+++ b/commons/net/http/validation_test.go
@@ -0,0 +1,362 @@
+//go:build unit
+
+package http
+
+import (
+ "testing"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/shopspring/decimal"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type testPayload struct {
+ Name string `json:"name" validate:"required,max=50"`
+ Email string `json:"email" validate:"required,email"`
+ Priority int `json:"priority" validate:"required,gt=0"`
+}
+
+type testOptionalPayload struct {
+ Name string `json:"name" validate:"omitempty,max=50"`
+ Value int `json:"value" validate:"omitempty,gte=0"`
+}
+
+type testPositiveDecimalPayload struct {
+ Amount decimal.Decimal `json:"amount" validate:"positive_decimal"`
+}
+
+type testPositiveAmountPayload struct {
+ Amount string `json:"amount" validate:"positive_amount"`
+}
+
+type testNonNegativeAmountPayload struct {
+ Amount string `json:"amount" validate:"nonnegative_amount"`
+}
+
+type testURLPayload struct {
+ Website string `json:"website" validate:"required,url"`
+}
+
+type testUUIDPayload struct {
+ ID string `json:"id" validate:"required,uuid"`
+}
+
+type testLtePayload struct {
+ Value int `json:"value" validate:"lte=100"`
+}
+
+type testLtPayload struct {
+ Value int `json:"value" validate:"lt=100"`
+}
+
+type testMinPayload struct {
+ Name string `json:"name" validate:"min=5"`
+}
+
+func TestGetValidator(t *testing.T) {
+ t.Parallel()
+
+ v1, err1 := GetValidator()
+ v2, err2 := GetValidator()
+
+ require.NoError(t, err1)
+ require.NoError(t, err2)
+ assert.NotNil(t, v1)
+ assert.Same(t, v1, v2, "GetValidator should return singleton")
+}
+
+func TestValidateStruct(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ payload any
+ wantErr bool
+ errContains string
+ }{
+ {
+ name: "valid payload",
+ payload: &testPayload{
+ Name: "test",
+ Email: "test@example.com",
+ Priority: 1,
+ },
+ wantErr: false,
+ },
+ {
+ name: "missing required name",
+ payload: &testPayload{
+ Email: "test@example.com",
+ Priority: 1,
+ },
+ wantErr: true,
+ errContains: "field is required: 'name'",
+ },
+ {
+ name: "invalid email",
+ payload: &testPayload{
+ Name: "test",
+ Email: "not-an-email",
+ Priority: 1,
+ },
+ wantErr: true,
+ errContains: "field must be a valid email: 'email'",
+ },
+ {
+ name: "priority must be greater than 0",
+ payload: &testPayload{
+ Name: "test",
+ Email: "test@example.com",
+ Priority: 0,
+ },
+ wantErr: true,
+ errContains: "'priority'",
+ },
+ {
+ name: "name exceeds max length",
+ payload: &testPayload{
+ Name: "this is a very long name that exceeds the maximum allowed length of fifty characters",
+ Email: "test@example.com",
+ Priority: 1,
+ },
+ wantErr: true,
+ errContains: "field exceeds maximum length: 'name'",
+ },
+ {
+ name: "optional fields can be empty",
+ payload: &testOptionalPayload{},
+ wantErr: false,
+ },
+ {
+ name: "optional field with valid value",
+ payload: &testOptionalPayload{
+ Name: "test",
+ Value: 10,
+ },
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := ValidateStruct(tt.payload)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errContains)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestToSnakeCase(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"Name", "name"},
+ {"FirstName", "first_name"},
+ {"HTMLParser", "h_t_m_l_parser"},
+ {"userID", "user_i_d"},
+ {"simple", "simple"},
+ {"", ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ t.Parallel()
+
+ got := toSnakeCase(tt.input)
+ assert.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestFormatValidationError(t *testing.T) {
+ t.Parallel()
+
+ type testStruct struct {
+ Required string `validate:"required"`
+ Max string `validate:"max=10"`
+ Min string `validate:"min=5"`
+ Gt int `validate:"gt=0"`
+ Gte int `validate:"gte=10"`
+ Lt int `validate:"lt=100"`
+ Lte int `validate:"lte=50"`
+ OneOf string `validate:"oneof=a b c"`
+ Email string `validate:"email"`
+ URL string `validate:"url"`
+ UUID string `validate:"uuid"`
+ }
+
+ tests := []struct {
+ name string
+ payload testStruct
+ errTag string
+ }{
+ {
+ name: "required tag",
+ payload: testStruct{},
+ errTag: "required",
+ },
+ {
+ name: "max tag",
+ payload: testStruct{Required: "x", Max: "this is too long"},
+ errTag: "max",
+ },
+ {
+ name: "oneof tag",
+ payload: testStruct{Required: "x", OneOf: "invalid"},
+ errTag: "oneof",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := ValidateStruct(&tt.payload)
+ require.Error(t, err)
+ })
+ }
+}
+
+func TestValidationSentinelErrors(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ expected string
+ }{
+ {
+ name: "ErrValidationFailed",
+ err: ErrValidationFailed,
+ expected: "validation failed",
+ },
+ {
+ name: "ErrFieldRequired",
+ err: ErrFieldRequired,
+ expected: "field is required",
+ },
+ {
+ name: "ErrFieldMaxLength",
+ err: ErrFieldMaxLength,
+ expected: "field exceeds maximum length",
+ },
+ {
+ name: "ErrFieldMinLength",
+ err: ErrFieldMinLength,
+ expected: "field below minimum length",
+ },
+ {
+ name: "ErrFieldGreaterThan",
+ err: ErrFieldGreaterThan,
+ expected: "field must be greater than constraint",
+ },
+ {
+ name: "ErrFieldGreaterThanOrEqual",
+ err: ErrFieldGreaterThanOrEqual,
+ expected: "field must be greater than or equal to constraint",
+ },
+ {
+ name: "ErrFieldLessThan",
+ err: ErrFieldLessThan,
+ expected: "field must be less than constraint",
+ },
+ {
+ name: "ErrFieldLessThanOrEqual",
+ err: ErrFieldLessThanOrEqual,
+ expected: "field must be less than or equal to constraint",
+ },
+ {
+ name: "ErrFieldOneOf",
+ err: ErrFieldOneOf,
+ expected: "field must be one of allowed values",
+ },
+ {
+ name: "ErrFieldEmail",
+ err: ErrFieldEmail,
+ expected: "field must be a valid email",
+ },
+ {
+ name: "ErrFieldURL",
+ err: ErrFieldURL,
+ expected: "field must be a valid URL",
+ },
+ {
+ name: "ErrFieldUUID",
+ err: ErrFieldUUID,
+ expected: "field must be a valid UUID",
+ },
+ {
+ name: "ErrFieldPositiveAmount",
+ err: ErrFieldPositiveAmount,
+ expected: "field must be a positive amount",
+ },
+ {
+ name: "ErrFieldNonNegativeAmount",
+ err: ErrFieldNonNegativeAmount,
+ expected: "field must be a non-negative amount",
+ },
+ {
+ name: "ErrBodyParseFailed",
+ err: ErrBodyParseFailed,
+ expected: "failed to parse request body",
+ },
+ {
+ name: "ErrQueryParamTooLong",
+ err: ErrQueryParamTooLong,
+ expected: "query parameter exceeds maximum length",
+ },
+ {
+ name: "ErrUnsupportedContentType",
+ err: ErrUnsupportedContentType,
+ expected: "Content-Type must be application/json",
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tc.expected, tc.err.Error())
+ })
+ }
+}
+
+func TestPaginationConstants_Validation(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, 20, cn.DefaultLimit)
+ assert.Equal(t, 200, cn.MaxLimit)
+}
+
+func TestQueryParamLengthConstants(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, 50, MaxQueryParamLengthShort)
+ assert.Equal(t, 255, MaxQueryParamLengthLong)
+}
+
+func TestUnknownValidationTag(t *testing.T) {
+ t.Parallel()
+
+ type customPayload struct {
+ Value string `validate:"alphanum"`
+ }
+
+ payload := customPayload{Value: "hello@world"}
+ err := ValidateStruct(&payload)
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrValidationFailed)
+ assert.Contains(t, err.Error(), "failed 'alphanum' check")
+}
diff --git a/commons/net/http/withBasicAuth.go b/commons/net/http/withBasicAuth.go
index e18b3a55..e2490843 100644
--- a/commons/net/http/withBasicAuth.go
+++ b/commons/net/http/withBasicAuth.go
@@ -1,17 +1,13 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
"crypto/subtle"
"encoding/base64"
- "github.com/LerianStudio/lib-commons/v2/commons"
- "github.com/LerianStudio/lib-commons/v2/commons/constants"
"net/http"
"strings"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+
"github.com/gofiber/fiber/v2"
)
@@ -32,41 +28,51 @@ func FixedBasicAuthFunc(username, password string) BasicAuthFunc {
// WithBasicAuth creates a basic authentication middleware.
func WithBasicAuth(f BasicAuthFunc, realm string) fiber.Handler {
+ safeRealm := sanitizeBasicAuthRealm(realm)
+
return func(c *fiber.Ctx) error {
+ if f == nil {
+ return unauthorizedResponse(c, safeRealm)
+ }
+
auth := c.Get(constant.Authorization)
if auth == "" {
- return unauthorizedResponse(c, realm)
+ return unauthorizedResponse(c, safeRealm)
}
parts := strings.SplitN(auth, " ", 2)
- if len(parts) != 2 || parts[0] != constant.Basic {
- return unauthorizedResponse(c, realm)
+ if len(parts) != 2 || !strings.EqualFold(parts[0], constant.Basic) {
+ return unauthorizedResponse(c, safeRealm)
}
cred, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
- return unauthorizedResponse(c, realm)
+ return unauthorizedResponse(c, safeRealm)
}
pair := strings.SplitN(string(cred), ":", 2)
if len(pair) != 2 {
- return unauthorizedResponse(c, realm)
+ return unauthorizedResponse(c, safeRealm)
}
if f(pair[0], pair[1]) {
return c.Next()
}
- return unauthorizedResponse(c, realm)
+ return unauthorizedResponse(c, safeRealm)
}
}
+// sanitizeBasicAuthRealm strips CR, LF, and double-quote characters from the realm string.
+func sanitizeBasicAuthRealm(realm string) string {
+ realm = strings.TrimSpace(realm)
+
+ return strings.NewReplacer("\r", "", "\n", "", "\"", "").Replace(realm)
+}
+
+// unauthorizedResponse sends a 401 response with a WWW-Authenticate header.
func unauthorizedResponse(c *fiber.Ctx, realm string) error {
c.Set(constant.WWWAuthenticate, `Basic realm="`+realm+`"`)
- return c.Status(http.StatusUnauthorized).JSON(commons.Response{
- Code: "401",
- Title: "Invalid Credentials",
- Message: "The provided credentials are invalid. Please provide valid credentials and try again.",
- })
+ return RespondError(c, http.StatusUnauthorized, "invalid_credentials", "The provided credentials are invalid. Please provide valid credentials and try again.")
}
diff --git a/commons/net/http/withBasicAuth_test.go b/commons/net/http/withBasicAuth_test.go
new file mode 100644
index 00000000..ae40e94e
--- /dev/null
+++ b/commons/net/http/withBasicAuth_test.go
@@ -0,0 +1,100 @@
+//go:build unit
+
+package http
+
+import (
+ "encoding/base64"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestWithBasicAuth_NilAuthFunc(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", WithBasicAuth(nil, "realm"), func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ cred := base64.StdEncoding.EncodeToString([]byte("user:pass"))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(constant.Authorization, "Basic "+cred)
+
+ res, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, res.Body.Close()) }()
+
+ assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
+}
+
+func TestWithBasicAuth_SanitizesRealmHeader(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", WithBasicAuth(FixedBasicAuthFunc("user", "pass"), "safe\r\nrealm\"name"), func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ res, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, res.Body.Close()) }()
+
+ assert.Equal(t, `Basic realm="saferealmname"`, res.Header.Get(constant.WWWAuthenticate))
+}
+
+func TestWithBasicAuth_AllowsValidCredentials(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", WithBasicAuth(FixedBasicAuthFunc("user", "pass"), "realm"), func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ cred := base64.StdEncoding.EncodeToString([]byte("user:pass"))
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(constant.Authorization, "Basic "+cred)
+
+ res, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, res.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, res.StatusCode)
+}
+
+func TestWithBasicAuth_RejectsMalformedAuthorization(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/", WithBasicAuth(FixedBasicAuthFunc("user", "pass"), "realm"), func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ testCases := []struct {
+ name string
+ header string
+ }{
+ {name: "wrong scheme", header: "Bearer token"},
+ {name: "invalid base64", header: "Basic !!!"},
+ {name: "missing colon", header: "Basic " + base64.StdEncoding.EncodeToString([]byte("userpass"))},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(constant.Authorization, tc.header)
+
+ res, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, res.Body.Close()) }()
+
+ assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
+ })
+ }
+}
diff --git a/commons/net/http/withCORS.go b/commons/net/http/withCORS.go
index 3c97269a..a52f2894 100644
--- a/commons/net/http/withCORS.go
+++ b/commons/net/http/withCORS.go
@@ -1,37 +1,101 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
- "github.com/LerianStudio/lib-commons/v2/commons"
+ "context"
+ "strconv"
+
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
)
const (
- defaultAccessControlAllowOrigin = "*"
- defaultAccessControlAllowMethods = "POST, GET, OPTIONS, PUT, DELETE, PATCH"
- defaultAccessControlAllowHeaders = "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization"
+ // defaultAccessControlAllowOrigin is the default value for the Access-Control-Allow-Origin header.
+ defaultAccessControlAllowOrigin = "*"
+ // defaultAccessControlAllowMethods is the default value for the Access-Control-Allow-Methods header.
+ defaultAccessControlAllowMethods = "POST, GET, OPTIONS, PUT, DELETE, PATCH"
+ // defaultAccessControlAllowHeaders is the default value for the Access-Control-Allow-Headers header.
+ defaultAccessControlAllowHeaders = "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization"
+ // defaultAccessControlExposeHeaders is the default value for the Access-Control-Expose-Headers header.
defaultAccessControlExposeHeaders = ""
+ // defaultAllowCredentials is the default value for the Access-Control-Allow-Credentials header.
+ defaultAllowCredentials = false
)
+// CORSOption is a functional option for CORS middleware configuration.
+type CORSOption func(*corsConfig)
+
+type corsConfig struct {
+ logger libLog.Logger
+}
+
+// WithCORSLogger provides a structured logger for CORS security warnings.
+// When not provided, warnings are logged via stdlib log.
+func WithCORSLogger(logger libLog.Logger) CORSOption {
+ return func(c *corsConfig) {
+ if logger != nil {
+ c.logger = logger
+ }
+ }
+}
+
// WithCORS is a middleware that enables CORS.
-// Replace it with a real CORS middleware implementation.
-func WithCORS() fiber.Handler {
+// Reads configuration from environment variables with sensible defaults.
+//
+// WARNING: The default AllowOrigins is "*" (wildcard). For financial services,
+// configure ACCESS_CONTROL_ALLOW_ORIGIN to specific trusted origins.
+func WithCORS(opts ...CORSOption) fiber.Handler {
+ cfg := &corsConfig{}
+
+ for _, opt := range opts {
+ opt(cfg)
+ }
+
+ // Default to GoLogger so CORS warnings are always emitted, even without explicit logger.
+ if cfg.logger == nil {
+ cfg.logger = &libLog.GoLogger{Level: libLog.LevelWarn}
+ }
+
+ allowCredentials := defaultAllowCredentials
+
+ if parsed, err := strconv.ParseBool(commons.GetenvOrDefault("ACCESS_CONTROL_ALLOW_CREDENTIALS", "false")); err == nil {
+ allowCredentials = parsed
+ }
+
+ origins := commons.GetenvOrDefault("ACCESS_CONTROL_ALLOW_ORIGIN", defaultAccessControlAllowOrigin)
+
+ if origins == "*" || origins == "" {
+ cfg.logger.Log(context.Background(), libLog.LevelWarn,
+ "CORS: AllowOrigins is set to wildcard (*); "+
+ "this allows ANY website to make cross-origin requests to your API; "+
+ "for financial services, set ACCESS_CONTROL_ALLOW_ORIGIN to specific trusted origins",
+ )
+ }
+
+ if origins == "*" && allowCredentials {
+ cfg.logger.Log(context.Background(), libLog.LevelWarn,
+ "CORS: AllowOrigins=* with AllowCredentials=true is REJECTED by browsers per the CORS spec; "+
+ "credentials will NOT work; configure specific origins via ACCESS_CONTROL_ALLOW_ORIGIN",
+ )
+ }
+
return cors.New(cors.Config{
- AllowOrigins: commons.GetenvOrDefault("ACCESS_CONTROL_ALLOW_ORIGIN", defaultAccessControlAllowOrigin),
+ AllowOrigins: origins,
AllowMethods: commons.GetenvOrDefault("ACCESS_CONTROL_ALLOW_METHODS", defaultAccessControlAllowMethods),
AllowHeaders: commons.GetenvOrDefault("ACCESS_CONTROL_ALLOW_HEADERS", defaultAccessControlAllowHeaders),
ExposeHeaders: commons.GetenvOrDefault("ACCESS_CONTROL_EXPOSE_HEADERS", defaultAccessControlExposeHeaders),
- AllowCredentials: true,
+ AllowCredentials: allowCredentials,
})
}
// AllowFullOptionsWithCORS set r.Use(WithCORS) and allow every request to use OPTION method.
-func AllowFullOptionsWithCORS(app *fiber.App) {
- app.Use(WithCORS())
+func AllowFullOptionsWithCORS(app *fiber.App, opts ...CORSOption) {
+ if app == nil {
+ return
+ }
+
+ app.Use(WithCORS(opts...))
app.Options("/*", func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
diff --git a/commons/net/http/withCORS_test.go b/commons/net/http/withCORS_test.go
new file mode 100644
index 00000000..720e2ee2
--- /dev/null
+++ b/commons/net/http/withCORS_test.go
@@ -0,0 +1,197 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestAllowFullOptionsWithCORS_NilApp(t *testing.T) {
+ t.Parallel()
+
+ require.NotPanics(t, func() {
+ AllowFullOptionsWithCORS(nil)
+ })
+}
+
+func TestWithCORS_UsesEnvironmentConfiguration(t *testing.T) {
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_ALLOW_ORIGIN", "https://example.com"))
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_ALLOW_METHODS", "GET,POST,OPTIONS"))
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_ALLOW_HEADERS", "Authorization,Content-Type"))
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_EXPOSE_HEADERS", "X-Trace-ID"))
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_ALLOW_CREDENTIALS", "true"))
+ t.Cleanup(func() {
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_ALLOW_ORIGIN"))
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_ALLOW_METHODS"))
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_ALLOW_HEADERS"))
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_EXPOSE_HEADERS"))
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_ALLOW_CREDENTIALS"))
+ })
+
+ app := fiber.New()
+ app.Use(WithCORS())
+ app.Get("/", func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ req.Header.Set(fiber.HeaderOrigin, "https://example.com")
+ req.Header.Set(fiber.HeaderAccessControlRequestMethod, http.MethodGet)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Equal(t, http.StatusNoContent, resp.StatusCode)
+ require.Equal(t, "https://example.com", resp.Header.Get(fiber.HeaderAccessControlAllowOrigin))
+ require.Equal(t, "true", resp.Header.Get(fiber.HeaderAccessControlAllowCredentials))
+ require.Contains(t, resp.Header.Get(fiber.HeaderAccessControlAllowMethods), http.MethodGet)
+ require.Contains(t, resp.Header.Get(fiber.HeaderAccessControlAllowHeaders), constant.Authorization)
+}
+
+func TestAllowFullOptionsWithCORS_RegistersOptionsRoute(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ AllowFullOptionsWithCORS(app)
+
+ req := httptest.NewRequest(http.MethodOptions, "/health", nil)
+ req.Header.Set(fiber.HeaderOrigin, "https://example.com")
+ req.Header.Set(fiber.HeaderAccessControlRequestMethod, http.MethodGet)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Equal(t, http.StatusNoContent, resp.StatusCode)
+}
+
+func TestWithCORS_ExplicitFalseCredentials(t *testing.T) {
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_ALLOW_CREDENTIALS", "false"))
+ t.Cleanup(func() {
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_ALLOW_CREDENTIALS"))
+ })
+
+ app := fiber.New()
+ app.Use(WithCORS())
+ app.Get("/", func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ req.Header.Set(fiber.HeaderOrigin, "https://example.com")
+ req.Header.Set(fiber.HeaderAccessControlRequestMethod, http.MethodGet)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Equal(t, http.StatusNoContent, resp.StatusCode)
+ require.Equal(t, "", resp.Header.Get(fiber.HeaderAccessControlAllowCredentials))
+}
+
+// ---------------------------------------------------------------------------
+// WithCORSLogger option
+// ---------------------------------------------------------------------------
+
+func TestWithCORSLogger_NilDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ // WithCORSLogger(nil) should not override the default (nil logger stays nil)
+ cfg := &corsConfig{}
+ opt := WithCORSLogger(nil)
+ opt(cfg)
+ assert.Nil(t, cfg.logger)
+}
+
+func TestWithCORSLogger_SetsLogger(t *testing.T) {
+ t.Parallel()
+
+ logger := &testCORSLogger{}
+ cfg := &corsConfig{}
+ opt := WithCORSLogger(logger)
+ opt(cfg)
+ assert.Equal(t, logger, cfg.logger)
+}
+
+func TestWithCORS_WithLoggerOption(t *testing.T) {
+ // This test verifies that WithCORS accepts the WithCORSLogger option
+ // and uses it for the wildcard warning.
+ // Not parallel because it sets env vars.
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_ALLOW_ORIGIN", "*"))
+ t.Cleanup(func() {
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_ALLOW_ORIGIN"))
+ })
+
+ logger := &testCORSLogger{}
+
+ app := fiber.New()
+ app.Use(WithCORS(WithCORSLogger(logger)))
+ app.Get("/", func(c *fiber.Ctx) error {
+ return c.SendStatus(200)
+ })
+
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ req.Header.Set(fiber.HeaderOrigin, "https://example.com")
+ req.Header.Set(fiber.HeaderAccessControlRequestMethod, http.MethodGet)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ // The logger should have received at least the wildcard warning
+ assert.True(t, logger.logCalled, "expected the logger to be called with wildcard warning")
+ assert.Equal(t, libLog.LevelWarn, logger.lastLevel)
+ assert.Contains(t, logger.lastMessage, "AllowOrigins is set to wildcard")
+}
+
+// testCORSLogger is a test logger that records whether Log was called.
+type testCORSLogger struct {
+ logCalled bool
+ lastLevel libLog.Level
+ lastMessage string
+}
+
+func (l *testCORSLogger) Log(_ context.Context, level libLog.Level, msg string, _ ...libLog.Field) {
+ l.logCalled = true
+ l.lastLevel = level
+ l.lastMessage = msg
+}
+func (l *testCORSLogger) With(_ ...libLog.Field) libLog.Logger { return l }
+func (l *testCORSLogger) WithGroup(string) libLog.Logger { return l }
+func (l *testCORSLogger) Enabled(libLog.Level) bool { return true }
+func (l *testCORSLogger) Sync(context.Context) error { return nil }
+
+func TestWithCORS_InvalidAllowCredentialsFallsBackToDefault(t *testing.T) {
+ require.NoError(t, os.Setenv("ACCESS_CONTROL_ALLOW_CREDENTIALS", "not-a-bool"))
+ t.Cleanup(func() {
+ require.NoError(t, os.Unsetenv("ACCESS_CONTROL_ALLOW_CREDENTIALS"))
+ })
+
+ app := fiber.New()
+ app.Use(WithCORS())
+ app.Get("/", func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodOptions, "/", nil)
+ req.Header.Set(fiber.HeaderOrigin, "https://example.com")
+ req.Header.Set(fiber.HeaderAccessControlRequestMethod, http.MethodGet)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.Equal(t, http.StatusNoContent, resp.StatusCode)
+ require.Equal(t, "", resp.Header.Get(fiber.HeaderAccessControlAllowCredentials))
+}
diff --git a/commons/net/http/withLogging.go b/commons/net/http/withLogging.go
index e05ff4cd..71ee09b8 100644
--- a/commons/net/http/withLogging.go
+++ b/commons/net/http/withLogging.go
@@ -1,27 +1,16 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
- "context"
"encoding/json"
+ stdlog "log"
"net/url"
"os"
"strconv"
"strings"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons"
- cn "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
- "github.com/LerianStudio/lib-commons/v2/commons/security"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
"github.com/gofiber/fiber/v2"
- "github.com/google/uuid"
- "go.opentelemetry.io/otel/attribute"
- "google.golang.org/grpc"
- "google.golang.org/grpc/metadata"
)
// maxObfuscationDepth limits recursion depth when obfuscating nested JSON structures
@@ -32,6 +21,12 @@ const maxObfuscationDepth = 32
// to avoid repeated syscalls on every request.
var logObfuscationDisabled = os.Getenv("LOG_OBFUSCATION_DISABLED") == "true"
+func init() {
+ if logObfuscationDisabled {
+ stdlog.Println("[WARN] LOG_OBFUSCATION_DISABLED is set to true. Sensitive data may appear in logs. Ensure this is not enabled in production.")
+ }
+}
+
// RequestInfo is a struct design to store http access log data.
type RequestInfo struct {
Method string
@@ -49,17 +44,23 @@ type RequestInfo struct {
Body string
}
-// ResponseMetricsWrapper is a Wrapper responsible for collect the response data such as status code and size
-// It implements built-in ResponseWriter interface.
+// ResponseMetricsWrapper is a Wrapper responsible for collecting the response data such as status code and size.
type ResponseMetricsWrapper struct {
Context *fiber.Ctx
StatusCode int
Size int
- Body string
}
// NewRequestInfo creates an instance of RequestInfo.
-func NewRequestInfo(c *fiber.Ctx) *RequestInfo {
+// The obfuscationDisabled parameter controls whether sensitive fields in the
+// request body are obfuscated. Pass the middleware's effective setting (which
+// combines the global LOG_OBFUSCATION_DISABLED env var with per-middleware
+// overrides via WithObfuscationDisabled) to honour per-middleware configuration.
+func NewRequestInfo(c *fiber.Ctx, obfuscationDisabled bool) *RequestInfo {
+ if c == nil {
+ return &RequestInfo{Date: time.Now().UTC()}
+ }
+
username, referer := "-", "-"
rawURL := string(c.Request().URI().FullURI())
@@ -70,16 +71,15 @@ func NewRequestInfo(c *fiber.Ctx) *RequestInfo {
}
}
- if c.Get("Referer") != "" {
- referer = c.Get("Referer")
+ if c.Get(cn.HeaderReferer) != "" {
+ referer = sanitizeReferer(c.Get(cn.HeaderReferer))
}
body := ""
if c.Request().Header.ContentLength() > 0 {
bodyBytes := c.Body()
-
- if !logObfuscationDisabled {
+ if !obfuscationDisabled {
body = getBodyObfuscatedString(c, bodyBytes)
} else {
body = string(bodyBytes)
@@ -89,10 +89,10 @@ func NewRequestInfo(c *fiber.Ctx) *RequestInfo {
return &RequestInfo{
TraceID: c.Get(cn.HeaderID),
Method: c.Method(),
- URI: c.OriginalURL(),
+ URI: sanitizeURL(c.OriginalURL()),
Username: username,
Referer: referer,
- UserAgent: c.Get(cn.HeaderUserAgent),
+ UserAgent: sanitizeLogValue(c.Get(cn.HeaderUserAgent)),
RemoteAddress: c.IP(),
Protocol: c.Protocol(),
Date: time.Now().UTC(),
@@ -104,16 +104,16 @@ func NewRequestInfo(c *fiber.Ctx) *RequestInfo {
// Ref: https://httpd.apache.org/docs/trunk/logs.html#common
func (r *RequestInfo) CLFString() string {
return strings.Join([]string{
- r.RemoteAddress,
+ sanitizeLogValue(r.RemoteAddress),
"-",
- r.Username,
- r.Protocol,
+ sanitizeLogValue(r.Username),
+ sanitizeLogValue(r.Protocol),
r.Date.Format("[02/Jan/2006:15:04:05 -0700]"),
- `"` + r.Method + " " + r.URI + `"`,
+ `"` + sanitizeLogValue(r.Method) + " " + sanitizeLogValue(r.URI) + `"`,
strconv.Itoa(r.Status),
strconv.Itoa(r.Size),
- r.Referer,
- r.UserAgent,
+ sanitizeLogValue(r.Referer),
+ sanitizeLogValue(r.UserAgent),
}, " ")
}
@@ -125,295 +125,36 @@ func (r *RequestInfo) String() string {
// FinishRequestInfo calculates the duration of RequestInfo automatically using time.Now()
// It also set StatusCode and Size of RequestInfo passed by ResponseMetricsWrapper.
func (r *RequestInfo) FinishRequestInfo(rw *ResponseMetricsWrapper) {
+ if rw == nil {
+ return
+ }
+
r.Duration = time.Now().UTC().Sub(r.Date)
r.Status = rw.StatusCode
r.Size = rw.Size
}
-type logMiddleware struct {
- Logger log.Logger
-}
-
-// LogMiddlewareOption represents the log middleware function as an implementation.
-type LogMiddlewareOption func(l *logMiddleware)
-
-// WithCustomLogger is a functional option for logMiddleware.
-func WithCustomLogger(logger log.Logger) LogMiddlewareOption {
- return func(l *logMiddleware) {
- l.Logger = logger
- }
-}
-
-// buildOpts creates an instance of logMiddleware with options.
-func buildOpts(opts ...LogMiddlewareOption) *logMiddleware {
- mid := &logMiddleware{
- Logger: &log.GoLogger{},
- }
-
- for _, opt := range opts {
- opt(mid)
- }
-
- return mid
-}
-
-// WithHTTPLogging is a middleware to log access to http server.
-// It logs access log according to Apache Standard Logs which uses Common Log Format (CLF)
-// Ref: https://httpd.apache.org/docs/trunk/logs.html#common
-func WithHTTPLogging(opts ...LogMiddlewareOption) fiber.Handler {
- return func(c *fiber.Ctx) error {
- if c.Path() == "/health" {
- return c.Next()
- }
-
- if strings.Contains(c.Path(), "swagger") && c.Path() != "/swagger/index.html" {
- return c.Next()
- }
-
- setRequestHeaderID(c)
-
- info := NewRequestInfo(c)
-
- headerID := c.Get(cn.HeaderID)
-
- mid := buildOpts(opts...)
- logger := mid.Logger.WithFields(
- cn.HeaderID, info.TraceID,
- ).WithDefaultMessageTemplate(headerID + cn.LoggerDefaultSeparator)
-
- ctx := commons.ContextWithLogger(c.UserContext(), logger)
- c.SetUserContext(ctx)
-
- err := c.Next()
-
- rw := ResponseMetricsWrapper{
- Context: c,
- StatusCode: c.Response().StatusCode(),
- Size: len(c.Response().Body()),
- Body: "",
- }
-
- info.FinishRequestInfo(&rw)
-
- logger.Info(info.CLFString())
-
- return err
- }
-}
-
-// WithGrpcLogging is a gRPC unary interceptor to log access to gRPC server.
-func WithGrpcLogging(opts ...LogMiddlewareOption) grpc.UnaryServerInterceptor {
- return func(
- ctx context.Context,
- req any,
- info *grpc.UnaryServerInfo,
- handler grpc.UnaryHandler,
- ) (any, error) {
- // Prefer request_id from the gRPC request body when available and valid.
- if rid, ok := getValidBodyRequestID(req); ok {
- // Emit a debug log if overriding a different metadata id
- if prev := getMetadataID(ctx); prev != "" && prev != rid {
- mid := buildOpts(opts...)
- mid.Logger.Debugf("Overriding correlation id from metadata (%s) with body request_id (%s)", prev, rid)
- }
- // Override correlation id to match the body-provided, validated UUID request_id
- ctx = commons.ContextWithHeaderID(ctx, rid)
- // Ensure standardized span attribute is present
- ctx = commons.ContextWithSpanAttributes(ctx, attribute.String("app.request.request_id", rid))
- } else {
- // Fallback to metadata path only if body is empty/invalid or accessor not present
- ctx = setGRPCRequestHeaderID(ctx)
- }
-
- _, _, reqId, _ := commons.NewTrackingFromContext(ctx)
-
- mid := buildOpts(opts...)
- logger := mid.Logger.
- WithFields(cn.HeaderID, reqId).
- WithDefaultMessageTemplate(reqId + cn.LoggerDefaultSeparator)
-
- ctx = commons.ContextWithLogger(ctx, logger)
-
- start := time.Now()
- resp, err := handler(ctx, req)
- duration := time.Since(start)
-
- logger.Infof("gRPC method: %s, Duration: %s, Error: %v", info.FullMethod, duration, err)
-
- return resp, err
- }
-}
-
-func setRequestHeaderID(c *fiber.Ctx) {
- headerID := c.Get(cn.HeaderID)
-
- if commons.IsNilOrEmpty(&headerID) {
- headerID = uuid.New().String()
- c.Set(cn.HeaderID, headerID)
- c.Request().Header.Set(cn.HeaderID, headerID)
- c.Response().Header.Set(cn.HeaderID, headerID)
- }
-
- ctx := commons.ContextWithHeaderID(c.UserContext(), headerID)
- c.SetUserContext(ctx)
-}
-
-func setGRPCRequestHeaderID(ctx context.Context) context.Context {
- md, ok := metadata.FromIncomingContext(ctx)
- if ok {
- headerID := md.Get(cn.MetadataID)
- if len(headerID) > 0 && !commons.IsNilOrEmpty(&headerID[0]) {
- return commons.ContextWithHeaderID(ctx, headerID[0])
- }
- }
-
- // If metadata is not present, or if the header ID is missing or empty, generate a new one.
- return commons.ContextWithHeaderID(ctx, uuid.New().String())
-}
-
-func getBodyObfuscatedString(c *fiber.Ctx, bodyBytes []byte) string {
- contentType := c.Get("Content-Type")
-
- var obfuscatedBody string
-
- if strings.Contains(contentType, "application/json") {
- obfuscatedBody = handleJSONBody(bodyBytes)
- } else if strings.Contains(contentType, "application/x-www-form-urlencoded") {
- obfuscatedBody = handleURLEncodedBody(bodyBytes)
- } else if strings.Contains(contentType, "multipart/form-data") {
- obfuscatedBody = handleMultipartBody(c)
- } else {
- obfuscatedBody = string(bodyBytes)
- }
-
- return obfuscatedBody
-}
-
+// handleJSONBody obfuscates sensitive fields in a JSON request body.
+// Handles both top-level objects and arrays.
func handleJSONBody(bodyBytes []byte) string {
- var bodyData map[string]any
+ var bodyData any
if err := json.Unmarshal(bodyBytes, &bodyData); err != nil {
return string(bodyBytes)
}
- obfuscateMapRecursively(bodyData, 0)
-
- updatedBody, err := json.Marshal(bodyData)
- if err != nil {
+ switch v := bodyData.(type) {
+ case map[string]any:
+ obfuscateMapRecursively(v, 0)
+ case []any:
+ obfuscateSliceRecursively(v, 0)
+ default:
return string(bodyBytes)
}
- return string(updatedBody)
-}
-
-func obfuscateMapRecursively(data map[string]any, depth int) {
- if depth >= maxObfuscationDepth {
- return
- }
-
- for key, value := range data {
- if security.IsSensitiveField(key) {
- data[key] = cn.ObfuscatedValue
- continue
- }
-
- switch v := value.(type) {
- case map[string]any:
- obfuscateMapRecursively(v, depth+1)
- case []any:
- obfuscateSliceRecursively(v, depth+1)
- }
- }
-}
-
-func obfuscateSliceRecursively(data []any, depth int) {
- if depth >= maxObfuscationDepth {
- return
- }
-
- for _, item := range data {
- switch v := item.(type) {
- case map[string]any:
- obfuscateMapRecursively(v, depth+1)
- case []any:
- obfuscateSliceRecursively(v, depth+1)
- }
- }
-}
-
-func handleURLEncodedBody(bodyBytes []byte) string {
- formData, err := url.ParseQuery(string(bodyBytes))
+ updatedBody, err := json.Marshal(bodyData)
if err != nil {
return string(bodyBytes)
}
- updatedBody := url.Values{}
-
- for key, values := range formData {
- if security.IsSensitiveField(key) {
- for range values {
- updatedBody.Add(key, cn.ObfuscatedValue)
- }
- } else {
- for _, value := range values {
- updatedBody.Add(key, value)
- }
- }
- }
-
- return updatedBody.Encode()
-}
-
-func handleMultipartBody(c *fiber.Ctx) string {
- form, err := c.MultipartForm()
- if err != nil {
- return "[multipart/form-data]"
- }
-
- result := url.Values{}
-
- for key, values := range form.Value {
- if security.IsSensitiveField(key) {
- for range values {
- result.Add(key, cn.ObfuscatedValue)
- }
- } else {
- for _, value := range values {
- result.Add(key, value)
- }
- }
- }
-
- for key := range form.File {
- if security.IsSensitiveField(key) {
- result.Add(key, cn.ObfuscatedValue)
- } else {
- result.Add(key, "[file]")
- }
- }
-
- return result.Encode()
-}
-
-// getValidBodyRequestID extracts and validates the request_id from the gRPC request body.
-// Returns (id, true) when present and valid UUID; otherwise ("", false).
-func getValidBodyRequestID(req any) (string, bool) {
- if r, ok := req.(interface{ GetRequestId() string }); ok {
- if rid := strings.TrimSpace(r.GetRequestId()); rid != "" && commons.IsUUID(rid) {
- return rid, true
- }
- }
-
- return "", false
-}
-
-// getMetadataID extracts a correlation id from incoming gRPC metadata if present.
-func getMetadataID(ctx context.Context) string {
- if md, ok := metadata.FromIncomingContext(ctx); ok && md != nil {
- headerID := md.Get(cn.MetadataID)
- if len(headerID) > 0 && !commons.IsNilOrEmpty(&headerID[0]) {
- return headerID[0]
- }
- }
-
- return ""
+ return string(updatedBody)
}
diff --git a/commons/net/http/withLogging_grpc_test.go b/commons/net/http/withLogging_grpc_test.go
new file mode 100644
index 00000000..a0fd93f3
--- /dev/null
+++ b/commons/net/http/withLogging_grpc_test.go
@@ -0,0 +1,324 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+type grpcRequestWithID struct {
+ requestID string
+}
+
+func (r grpcRequestWithID) GetRequestId() string {
+ return r.requestID
+}
+
+type grpcPointerRequestWithID struct {
+ requestID string
+}
+
+func (r *grpcPointerRequestWithID) GetRequestId() string {
+ return r.requestID
+}
+
+type capturedLogEntry struct {
+ level libLog.Level
+ msg string
+ fields []libLog.Field
+}
+
+type capturedLogState struct {
+ mu sync.Mutex
+ entries []capturedLogEntry
+}
+
+type captureLogger struct {
+ state *capturedLogState
+ bound []libLog.Field
+}
+
+func newCaptureLogger() *captureLogger {
+ return &captureLogger{state: &capturedLogState{}}
+}
+
+func (l *captureLogger) Log(_ context.Context, level libLog.Level, msg string, fields ...libLog.Field) {
+ merged := make([]libLog.Field, 0, len(l.bound)+len(fields))
+ merged = append(merged, l.bound...)
+ merged = append(merged, fields...)
+
+ l.state.mu.Lock()
+ defer l.state.mu.Unlock()
+ l.state.entries = append(l.state.entries, capturedLogEntry{level: level, msg: msg, fields: merged})
+}
+
+func (l *captureLogger) With(fields ...libLog.Field) libLog.Logger {
+ bound := make([]libLog.Field, 0, len(l.bound)+len(fields))
+ bound = append(bound, l.bound...)
+ bound = append(bound, fields...)
+
+ return &captureLogger{state: l.state, bound: bound}
+}
+
+func (l *captureLogger) WithGroup(string) libLog.Logger { return l }
+func (l *captureLogger) Enabled(libLog.Level) bool { return true }
+func (l *captureLogger) Sync(context.Context) error { return nil }
+
+func (l *captureLogger) entries() []capturedLogEntry {
+ l.state.mu.Lock()
+ defer l.state.mu.Unlock()
+
+ entries := make([]capturedLogEntry, len(l.state.entries))
+ copy(entries, l.state.entries)
+ return entries
+}
+
+func TestWithGrpcLogging_BodyRequestIDOverridesMetadata(t *testing.T) {
+ t.Parallel()
+
+ logger := newCaptureLogger()
+ interceptor := WithGrpcLogging(WithCustomLogger(logger))
+ bodyID := uuid.NewString()
+ metadataID := uuid.NewString()
+
+ ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(cn.MetadataID, metadataID))
+
+ var seenRequestID string
+ resp, err := interceptor(ctx, grpcRequestWithID{requestID: bodyID}, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, seenRequestID, _ = commons.NewTrackingFromContext(ctx)
+ return "ok", nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "ok", resp)
+ assert.Equal(t, bodyID, seenRequestID)
+
+ entries := logger.entries()
+ require.Len(t, entries, 2)
+ assert.Equal(t, libLog.LevelDebug, entries[0].level)
+ assert.Contains(t, entries[0].msg, "Overriding correlation id")
+ assert.Equal(t, libLog.LevelInfo, entries[1].level)
+ assert.Equal(t, "gRPC request finished", entries[1].msg)
+ assert.Contains(t, entries[1].fields, libLog.String(cn.HeaderID, bodyID))
+ assert.Contains(t, entries[1].fields, libLog.String("message_prefix", bodyID+cn.LoggerDefaultSeparator))
+}
+
+func TestWithGrpcLogging_InvalidBodyRequestIDFallsBackToMetadata(t *testing.T) {
+ t.Parallel()
+
+ logger := newCaptureLogger()
+ interceptor := WithGrpcLogging(WithCustomLogger(logger))
+ metadataID := uuid.NewString()
+ ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(cn.MetadataID, metadataID))
+
+ var seenRequestID string
+ _, err := interceptor(ctx, grpcRequestWithID{requestID: "not-a-uuid"}, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, seenRequestID, _ = commons.NewTrackingFromContext(ctx)
+ return nil, nil
+ })
+ require.NoError(t, err)
+ assert.Equal(t, metadataID, seenRequestID)
+
+ entries := logger.entries()
+ require.Len(t, entries, 1)
+ assert.Equal(t, libLog.LevelInfo, entries[0].level)
+ assert.Contains(t, entries[0].fields, libLog.String(cn.HeaderID, metadataID))
+}
+
+func TestWithGrpcLogging_GeneratesRequestIDWhenMissing(t *testing.T) {
+ t.Parallel()
+
+ interceptor := WithGrpcLogging()
+
+ var seenRequestID string
+ _, err := interceptor(context.Background(), struct{}{}, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, seenRequestID, _ = commons.NewTrackingFromContext(ctx)
+ return nil, nil
+ })
+ require.NoError(t, err)
+ assert.NotEmpty(t, seenRequestID)
+ _, parseErr := uuid.Parse(seenRequestID)
+ require.NoError(t, parseErr)
+}
+
+func TestGetValidBodyRequestID_TypedNilRequestReturnsFalse(t *testing.T) {
+ t.Parallel()
+
+ var req *grpcPointerRequestWithID
+
+ assert.NotPanics(t, func() {
+ requestID, ok := getValidBodyRequestID(req)
+ assert.False(t, ok)
+ assert.Empty(t, requestID)
+ })
+}
+
+func TestWithGrpcLogging_TypedNilBodyRequestIDFallsBackToMetadata(t *testing.T) {
+ t.Parallel()
+
+ logger := newCaptureLogger()
+ interceptor := WithGrpcLogging(WithCustomLogger(logger))
+ metadataID := uuid.NewString()
+ ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(cn.MetadataID, metadataID))
+
+ var req *grpcPointerRequestWithID
+ var seenRequestID string
+
+ assert.NotPanics(t, func() {
+ _, err := interceptor(ctx, req, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, seenRequestID, _ = commons.NewTrackingFromContext(ctx)
+ return nil, nil
+ })
+ require.NoError(t, err)
+ })
+
+ assert.Equal(t, metadataID, seenRequestID)
+}
+
+func TestWithGrpcLogging_LogsHandlerErrors(t *testing.T) {
+ t.Parallel()
+
+ logger := newCaptureLogger()
+ interceptor := WithGrpcLogging(WithCustomLogger(logger))
+ handlerErr := errors.New("boom")
+
+ _, err := interceptor(context.Background(), struct{}{}, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ return nil, handlerErr
+ })
+ require.ErrorIs(t, err, handlerErr)
+
+ entries := logger.entries()
+ require.Len(t, entries, 1)
+ assert.Equal(t, libLog.LevelInfo, entries[0].level)
+ assert.Contains(t, entries[0].fields, libLog.Err(handlerErr))
+}
+
+func TestWithGrpcLogging_NilContextDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ interceptor := WithGrpcLogging()
+
+ assert.NotPanics(t, func() {
+ var seenRequestID string
+ _, err := interceptor(nil, struct{}{}, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, seenRequestID, _ = commons.NewTrackingFromContext(ctx)
+ return nil, nil
+ })
+ require.NoError(t, err)
+ assert.NotEmpty(t, seenRequestID)
+ })
+}
+
+func TestWithGrpcLogging_NilInfoUsesUnknownMethod(t *testing.T) {
+ t.Parallel()
+
+ logger := newCaptureLogger()
+ interceptor := WithGrpcLogging(WithCustomLogger(logger))
+
+ _, err := interceptor(context.Background(), struct{}{}, nil, func(ctx context.Context, req any) (any, error) {
+ return nil, nil
+ })
+ require.NoError(t, err)
+
+ entries := logger.entries()
+ require.Len(t, entries, 1)
+ assert.Contains(t, entries[0].fields, libLog.String("method", "unknown"))
+}
+
+func TestWithTelemetryInterceptor_NilContextDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ tp, _ := setupTestTracer()
+ defer func() { _ = tp.Shutdown(context.Background()) }()
+
+ telemetry := &opentelemetry.Telemetry{
+ TelemetryConfig: opentelemetry.TelemetryConfig{LibraryName: "test-library", EnableTelemetry: true},
+ TracerProvider: tp,
+ }
+ interceptor := NewTelemetryMiddleware(telemetry).WithTelemetryInterceptor(telemetry)
+
+ assert.NotPanics(t, func() {
+ _, err := interceptor(nil, struct{}{}, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, requestID, _ := commons.NewTrackingFromContext(ctx)
+ assert.NotEmpty(t, requestID)
+ return nil, nil
+ })
+ require.NoError(t, err)
+ })
+}
+
+func TestWithTelemetryInterceptor_TypedNilBodyRequestIDFallsBackToMetadata(t *testing.T) {
+ t.Parallel()
+
+ tp, _ := setupTestTracer()
+ defer func() { _ = tp.Shutdown(context.Background()) }()
+
+ telemetry := &opentelemetry.Telemetry{
+ TelemetryConfig: opentelemetry.TelemetryConfig{LibraryName: "test-library", EnableTelemetry: true},
+ TracerProvider: tp,
+ }
+ interceptor := NewTelemetryMiddleware(telemetry).WithTelemetryInterceptor(telemetry)
+ metadataID := uuid.NewString()
+ ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(cn.MetadataID, metadataID))
+
+ var req *grpcPointerRequestWithID
+
+ assert.NotPanics(t, func() {
+ _, err := interceptor(ctx, req, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, requestID, _ := commons.NewTrackingFromContext(ctx)
+ assert.Equal(t, metadataID, requestID)
+ return nil, nil
+ })
+ require.NoError(t, err)
+ })
+}
+
+func TestWithGrpcLoggingAndTelemetryInterceptor_ShareResolvedRequestID(t *testing.T) {
+ t.Parallel()
+
+ tp, _ := setupTestTracer()
+ defer func() { _ = tp.Shutdown(context.Background()) }()
+
+ telemetry := &opentelemetry.Telemetry{
+ TelemetryConfig: opentelemetry.TelemetryConfig{LibraryName: "test-library", EnableTelemetry: true},
+ TracerProvider: tp,
+ }
+ telemetryInterceptor := NewTelemetryMiddleware(telemetry).WithTelemetryInterceptor(telemetry)
+ logger := newCaptureLogger()
+ loggingInterceptor := WithGrpcLogging(WithCustomLogger(logger))
+ bodyID := uuid.NewString()
+ metadataID := uuid.NewString()
+
+ ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(cn.MetadataID, metadataID, "user-agent", "midaz/1.0.0 LerianStudio"))
+
+ var seenRequestID string
+ var spanContext trace.SpanContext
+ resp, err := loggingInterceptor(ctx, grpcRequestWithID{requestID: bodyID}, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ return telemetryInterceptor(ctx, req, &grpc.UnaryServerInfo{FullMethod: "/svc.Method"}, func(ctx context.Context, req any) (any, error) {
+ _, _, seenRequestID, _ = commons.NewTrackingFromContext(ctx)
+ spanContext = trace.SpanContextFromContext(ctx)
+ return "ok", nil
+ })
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "ok", resp)
+ assert.Equal(t, bodyID, seenRequestID)
+ assert.True(t, spanContext.IsValid())
+
+ entries := logger.entries()
+ require.NotEmpty(t, entries)
+ assert.Contains(t, entries[len(entries)-1].fields, libLog.String(cn.HeaderID, bodyID))
+}
diff --git a/commons/net/http/withLogging_middleware.go b/commons/net/http/withLogging_middleware.go
new file mode 100644
index 00000000..0a6dfcc6
--- /dev/null
+++ b/commons/net/http/withLogging_middleware.go
@@ -0,0 +1,247 @@
+package http
+
+import (
+ "context"
+ "strings"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/gofiber/fiber/v2"
+ "github.com/google/uuid"
+ "go.opentelemetry.io/otel/attribute"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+// logMiddleware holds the logger and configuration used by HTTP and gRPC logging middleware.
+type logMiddleware struct {
+ Logger log.Logger
+ ObfuscationDisabled bool
+}
+
+// LogMiddlewareOption represents the log middleware function as an implementation.
+type LogMiddlewareOption func(l *logMiddleware)
+
+// WithCustomLogger is a functional option for logMiddleware.
+func WithCustomLogger(logger log.Logger) LogMiddlewareOption {
+ return func(l *logMiddleware) {
+ if !nilcheck.Interface(logger) {
+ l.Logger = logger
+ }
+ }
+}
+
+// WithObfuscationDisabled is a functional option that disables log body obfuscation.
+// This is primarily intended for testing and local development.
+// In production, use the LOG_OBFUSCATION_DISABLED environment variable.
+func WithObfuscationDisabled(disabled bool) LogMiddlewareOption {
+ return func(l *logMiddleware) {
+ l.ObfuscationDisabled = disabled
+ }
+}
+
+// buildOpts creates an instance of logMiddleware with options.
+func buildOpts(opts ...LogMiddlewareOption) *logMiddleware {
+ mid := &logMiddleware{
+ Logger: &log.GoLogger{},
+ ObfuscationDisabled: logObfuscationDisabled,
+ }
+
+ for _, opt := range opts {
+ opt(mid)
+ }
+
+ return mid
+}
+
+// WithHTTPLogging is a middleware to log access to http server.
+// It logs access log according to Apache Standard Logs which uses Common Log Format (CLF)
+// Ref: https://httpd.apache.org/docs/trunk/logs.html#common
+func WithHTTPLogging(opts ...LogMiddlewareOption) fiber.Handler {
+ return func(c *fiber.Ctx) error {
+ if c.Path() == "/health" {
+ return c.Next()
+ }
+
+ if strings.Contains(c.Path(), "swagger") && c.Path() != "/swagger/index.html" {
+ return c.Next()
+ }
+
+ setRequestHeaderID(c)
+
+ mid := buildOpts(opts...)
+ info := NewRequestInfo(c, mid.ObfuscationDisabled)
+
+ headerID := c.Get(cn.HeaderID)
+ logger := mid.Logger.
+ With(log.String(cn.HeaderID, info.TraceID)).
+ With(log.String("message_prefix", headerID+cn.LoggerDefaultSeparator))
+
+ ctx := commons.ContextWithLogger(c.UserContext(), logger)
+ c.SetUserContext(ctx)
+
+ err := c.Next()
+
+ rw := ResponseMetricsWrapper{
+ Context: c,
+ StatusCode: c.Response().StatusCode(),
+ Size: len(c.Response().Body()),
+ }
+
+ info.FinishRequestInfo(&rw)
+ logger.Log(c.UserContext(), log.LevelInfo, info.CLFString())
+
+ return err
+ }
+}
+
+// WithGrpcLogging is a gRPC unary interceptor to log access to gRPC server.
+func WithGrpcLogging(opts ...LogMiddlewareOption) grpc.UnaryServerInterceptor {
+ return func(
+ ctx context.Context,
+ req any,
+ info *grpc.UnaryServerInfo,
+ handler grpc.UnaryHandler,
+ ) (any, error) {
+ ctx = normalizeGRPCContext(ctx)
+ requestID := resolveGRPCRequestID(ctx, req)
+
+ if rid, ok := getValidBodyRequestID(req); ok {
+ if prev := getMetadataID(ctx); prev != "" && prev != rid {
+ mid := buildOpts(opts...)
+ mid.Logger.Log(ctx, log.LevelDebug, "Overriding correlation id from metadata with body request_id",
+ log.String("metadata_id", prev),
+ log.String("body_request_id", rid),
+ )
+ }
+ }
+
+ ctx = commons.ContextWithHeaderID(ctx, requestID)
+ ctx = commons.ContextWithSpanAttributes(ctx, attribute.String("app.request.request_id", requestID))
+
+ _, _, reqId, _ := commons.NewTrackingFromContext(ctx)
+
+ mid := buildOpts(opts...)
+ logger := mid.Logger.
+ With(log.String(cn.HeaderID, reqId)).
+ With(log.String("message_prefix", reqId+cn.LoggerDefaultSeparator))
+
+ ctx = commons.ContextWithLogger(ctx, logger)
+
+ start := time.Now()
+ resp, err := handler(ctx, req)
+ duration := time.Since(start)
+
+ methodName := "unknown"
+ if info != nil {
+ methodName = info.FullMethod
+ }
+
+ fields := []log.Field{
+ log.String("method", methodName),
+ log.String("duration", duration.String()),
+ }
+ if err != nil {
+ fields = append(fields, log.Err(err))
+ }
+
+ logger.Log(ctx, log.LevelInfo, "gRPC request finished", fields...)
+
+ return resp, err
+ }
+}
+
+func normalizeGRPCContext(ctx context.Context) context.Context {
+ if ctx == nil {
+ return context.Background()
+ }
+
+ return ctx
+}
+
+func getContextHeaderID(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+
+ values, ok := ctx.Value(commons.CustomContextKey).(*commons.CustomContextKeyValue)
+ if !ok || values == nil {
+ return ""
+ }
+
+ return normalizeRequestID(values.HeaderID)
+}
+
+func normalizeRequestID(raw string) string {
+ return strings.TrimSpace(sanitizeLogValue(raw))
+}
+
+func resolveGRPCRequestID(ctx context.Context, req any) string {
+ if rid, ok := getValidBodyRequestID(req); ok {
+ return rid
+ }
+
+ if existing := getContextHeaderID(ctx); existing != "" {
+ return existing
+ }
+
+ if rid := getMetadataID(ctx); rid != "" {
+ return rid
+ }
+
+ return uuid.New().String()
+}
+
+// setRequestHeaderID ensures the Fiber request carries a unique correlation ID header.
+// The effective ID is always echoed back on the response so that callers can
+// correlate their request regardless of whether the ID was client-supplied or
+// server-generated.
+func setRequestHeaderID(c *fiber.Ctx) {
+ headerID := normalizeRequestID(c.Get(cn.HeaderID))
+
+ if commons.IsNilOrEmpty(&headerID) {
+ headerID = uuid.New().String()
+ }
+
+ c.Request().Header.Set(cn.HeaderID, headerID)
+ c.Set(cn.HeaderID, headerID)
+ c.Response().Header.Set(cn.HeaderID, headerID)
+
+ ctx := commons.ContextWithHeaderID(c.UserContext(), headerID)
+ c.SetUserContext(ctx)
+}
+
+// getValidBodyRequestID extracts and validates the request_id from the gRPC request body.
+// Returns (id, true) when present and valid UUID; otherwise ("", false).
+func getValidBodyRequestID(req any) (string, bool) {
+ if r, ok := req.(interface{ GetRequestId() string }); ok {
+ if nilcheck.Interface(r) {
+ return "", false
+ }
+
+ if rid := strings.TrimSpace(r.GetRequestId()); rid != "" && commons.IsUUID(rid) {
+ return rid, true
+ }
+ }
+
+ return "", false
+}
+
+// getMetadataID extracts a correlation id from incoming gRPC metadata if present.
+func getMetadataID(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+
+ if md, ok := metadata.FromIncomingContext(ctx); ok && md != nil {
+ headerID := md.Get(cn.MetadataID)
+ if len(headerID) > 0 && !commons.IsNilOrEmpty(&headerID[0]) {
+ return normalizeRequestID(headerID[0])
+ }
+ }
+
+ return ""
+}
diff --git a/commons/net/http/withLogging_obfuscation.go b/commons/net/http/withLogging_obfuscation.go
new file mode 100644
index 00000000..cba43856
--- /dev/null
+++ b/commons/net/http/withLogging_obfuscation.go
@@ -0,0 +1,123 @@
+package http
+
+import (
+ "net/url"
+ "strings"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/security"
+ "github.com/gofiber/fiber/v2"
+)
+
+// getBodyObfuscatedString returns the request body with sensitive fields obfuscated.
+func getBodyObfuscatedString(c *fiber.Ctx, bodyBytes []byte) string {
+ contentType := c.Get(cn.HeaderContentType)
+
+ var obfuscatedBody string
+
+ switch {
+ case strings.Contains(contentType, "application/json"):
+ obfuscatedBody = handleJSONBody(bodyBytes)
+ case strings.Contains(contentType, "application/x-www-form-urlencoded"):
+ obfuscatedBody = handleURLEncodedBody(bodyBytes)
+ case strings.Contains(contentType, "multipart/form-data"):
+ obfuscatedBody = handleMultipartBody(c)
+ default:
+ obfuscatedBody = string(bodyBytes)
+ }
+
+ return obfuscatedBody
+}
+
+// obfuscateMapRecursively replaces sensitive map values up to maxObfuscationDepth levels.
+func obfuscateMapRecursively(data map[string]any, depth int) {
+ if depth >= maxObfuscationDepth {
+ return
+ }
+
+ for key, value := range data {
+ if security.IsSensitiveField(key) {
+ data[key] = cn.ObfuscatedValue
+ continue
+ }
+
+ switch v := value.(type) {
+ case map[string]any:
+ obfuscateMapRecursively(v, depth+1)
+ case []any:
+ obfuscateSliceRecursively(v, depth+1)
+ }
+ }
+}
+
+// obfuscateSliceRecursively walks slice elements and obfuscates nested sensitive fields.
+func obfuscateSliceRecursively(data []any, depth int) {
+ if depth >= maxObfuscationDepth {
+ return
+ }
+
+ for _, item := range data {
+ switch v := item.(type) {
+ case map[string]any:
+ obfuscateMapRecursively(v, depth+1)
+ case []any:
+ obfuscateSliceRecursively(v, depth+1)
+ }
+ }
+}
+
+// handleURLEncodedBody obfuscates sensitive fields in a URL-encoded request body.
+func handleURLEncodedBody(bodyBytes []byte) string {
+ formData, err := url.ParseQuery(string(bodyBytes))
+ if err != nil {
+ return string(bodyBytes)
+ }
+
+ updatedBody := url.Values{}
+
+ for key, values := range formData {
+ if security.IsSensitiveField(key) {
+ for range values {
+ updatedBody.Add(key, cn.ObfuscatedValue)
+ }
+ } else {
+ for _, value := range values {
+ updatedBody.Add(key, value)
+ }
+ }
+ }
+
+ return updatedBody.Encode()
+}
+
+// handleMultipartBody obfuscates sensitive fields in a multipart/form-data request body.
+func handleMultipartBody(c *fiber.Ctx) string {
+ form, err := c.MultipartForm()
+ if err != nil {
+ return "[multipart/form-data]"
+ }
+
+ result := url.Values{}
+
+ for key, values := range form.Value {
+ if security.IsSensitiveField(key) {
+ for range values {
+ result.Add(key, cn.ObfuscatedValue)
+ }
+ } else {
+ for _, value := range values {
+ result.Add(key, value)
+ }
+ }
+ }
+
+ for key := range form.File {
+ if security.IsSensitiveField(key) {
+ result.Add(key, cn.ObfuscatedValue)
+ } else {
+ result.Add(key, "[file]")
+ }
+ }
+
+ return result.Encode()
+}
diff --git a/commons/net/http/withLogging_sanitize.go b/commons/net/http/withLogging_sanitize.go
new file mode 100644
index 00000000..35f4a61e
--- /dev/null
+++ b/commons/net/http/withLogging_sanitize.go
@@ -0,0 +1,26 @@
+package http
+
+import (
+ "net/url"
+ "strings"
+)
+
+// sanitizeReferer strips query parameters and userinfo from a Referer header value
+// before it is written into logs, preventing credential/token leakage.
+func sanitizeReferer(raw string) string {
+ parsed, err := url.Parse(raw)
+ if err != nil {
+ return "-"
+ }
+
+ parsed.User = nil
+ parsed.RawQuery = ""
+ parsed.Fragment = ""
+
+ return parsed.String()
+}
+
+func sanitizeLogValue(raw string) string {
+ replacer := strings.NewReplacer("\r", "", "\n", "", "\x00", "")
+ return replacer.Replace(raw)
+}
diff --git a/commons/net/http/withLogging_test.go b/commons/net/http/withLogging_test.go
new file mode 100644
index 00000000..36528b95
--- /dev/null
+++ b/commons/net/http/withLogging_test.go
@@ -0,0 +1,722 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// NewRequestInfo
+// ---------------------------------------------------------------------------
+
+func TestNewRequestInfo_Basic(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var info *RequestInfo
+
+ app.Get("/api/test", func(c *fiber.Ctx) error {
+ info = NewRequestInfo(c, false)
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
+ req.Header.Set(cn.HeaderID, "trace-123")
+ req.Header.Set(cn.HeaderUserAgent, "test-agent")
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.NotNil(t, info)
+ assert.Equal(t, http.MethodGet, info.Method)
+ assert.Equal(t, "/api/test", info.URI)
+ assert.Equal(t, "trace-123", info.TraceID)
+ assert.Equal(t, "test-agent", info.UserAgent)
+ assert.Equal(t, "-", info.Username)
+ assert.Equal(t, "-", info.Referer)
+ assert.False(t, info.Date.IsZero())
+}
+
+func TestNewRequestInfo_WithReferer(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var info *RequestInfo
+
+ app.Get("/", func(c *fiber.Ctx) error {
+ info = NewRequestInfo(c, false)
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set("Referer", "https://example.com")
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, "https://example.com", info.Referer)
+}
+
+func TestSanitizeReferer_StripsCredentialsQueryAndFragment(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "https://example.com/path", sanitizeReferer("https://user:pass@example.com/path?token=123#frag"))
+}
+
+func TestSanitizeReferer_InvalidValueFallsBackToDash(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "-", sanitizeReferer("://bad-url"))
+}
+
+func TestNewRequestInfo_SanitizesUserAgentControlCharacters(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var info *RequestInfo
+
+ app.Get("/", func(c *fiber.Ctx) error {
+ info = NewRequestInfo(c, false)
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ req.Header.Set(cn.HeaderUserAgent, "good-agent\r\nforged")
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ require.NotNil(t, info)
+ assert.NotContains(t, info.UserAgent, "\r")
+ assert.NotContains(t, info.UserAgent, "\n")
+ assert.Contains(t, info.UserAgent, "good-agent")
+ assert.Contains(t, info.UserAgent, "forged")
+}
+
+// ---------------------------------------------------------------------------
+// CLFString
+// ---------------------------------------------------------------------------
+
+func TestCLFString(t *testing.T) {
+ t.Parallel()
+
+ info := &RequestInfo{
+ RemoteAddress: "192.168.1.1",
+ Username: "admin",
+ Protocol: "http",
+ Date: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
+ Method: "POST",
+ URI: "/api/v1/resource",
+ Status: 200,
+ Size: 1024,
+ Referer: "-",
+ UserAgent: "curl/7.68.0",
+ }
+
+ clf := info.CLFString()
+
+ assert.Contains(t, clf, "192.168.1.1")
+ assert.Contains(t, clf, "admin")
+ assert.Contains(t, clf, `"POST /api/v1/resource"`)
+ assert.Contains(t, clf, "200")
+ assert.Contains(t, clf, "1024")
+ assert.Contains(t, clf, "curl/7.68.0")
+}
+
+func TestCLFString_DoesNotIncludeControlCharactersFromUserAgent(t *testing.T) {
+ t.Parallel()
+
+ info := &RequestInfo{
+ RemoteAddress: "192.168.1.1",
+ Username: "admin",
+ Protocol: "http",
+ Date: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
+ Method: "POST",
+ URI: "/api/v1/resource",
+ Status: 200,
+ Size: 1024,
+ Referer: "-",
+ UserAgent: "curl/7.68.0\r\nforged\x00",
+ }
+
+ clf := info.CLFString()
+ assert.NotContains(t, clf, "\r")
+ assert.NotContains(t, clf, "\n")
+ assert.NotContains(t, clf, "\x00")
+ assert.Contains(t, clf, "curl/7.68.0forged")
+}
+
+func TestSanitizeLogValue_RemovesNullByte(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "abcdef", sanitizeLogValue("abc\x00def"))
+}
+
+func TestStringImplementsStringer(t *testing.T) {
+ t.Parallel()
+
+ info := &RequestInfo{
+ RemoteAddress: "127.0.0.1",
+ Username: "-",
+ Protocol: "http",
+ Date: time.Now(),
+ Method: "GET",
+ URI: "/",
+ Referer: "-",
+ UserAgent: "-",
+ }
+
+ assert.Equal(t, info.CLFString(), info.String())
+}
+
+// ---------------------------------------------------------------------------
+// FinishRequestInfo
+// ---------------------------------------------------------------------------
+
+func TestFinishRequestInfo(t *testing.T) {
+ t.Parallel()
+
+ info := &RequestInfo{
+ Date: time.Now().Add(-100 * time.Millisecond),
+ }
+
+ rw := &ResponseMetricsWrapper{
+ StatusCode: 201,
+ Size: 512,
+ }
+
+ info.FinishRequestInfo(rw)
+
+ assert.Equal(t, 201, info.Status)
+ assert.Equal(t, 512, info.Size)
+ assert.True(t, info.Duration >= 90*time.Millisecond, "expected duration >= 90ms, got %v", info.Duration)
+}
+
+// ---------------------------------------------------------------------------
+// buildOpts / WithCustomLogger
+// ---------------------------------------------------------------------------
+
+func TestBuildOpts_Default(t *testing.T) {
+ t.Parallel()
+
+ mid := buildOpts()
+ assert.NotNil(t, mid.Logger)
+ assert.IsType(t, &log.GoLogger{}, mid.Logger)
+}
+
+func TestBuildOpts_WithCustomLogger(t *testing.T) {
+ t.Parallel()
+
+ custom := &mockLogger{}
+ mid := buildOpts(WithCustomLogger(custom))
+ assert.Equal(t, custom, mid.Logger)
+}
+
+func TestWithCustomLogger_NilDoesNotOverride(t *testing.T) {
+ t.Parallel()
+
+ mid := buildOpts(WithCustomLogger(nil))
+ assert.NotNil(t, mid.Logger)
+ assert.IsType(t, &log.GoLogger{}, mid.Logger)
+}
+
+func TestWithCustomLogger_TypedNilDoesNotOverride(t *testing.T) {
+ t.Parallel()
+
+ var typedNil *mockLogger
+ mid := buildOpts(WithCustomLogger(typedNil))
+ assert.NotNil(t, mid.Logger)
+ assert.IsType(t, &log.GoLogger{}, mid.Logger)
+}
+
+func TestWithHTTPLogging_TypedNilCustomLoggerFallsBackToDefault(t *testing.T) {
+ t.Parallel()
+
+ var typedNil *mockLogger
+ app := fiber.New()
+ app.Use(WithHTTPLogging(WithCustomLogger(typedNil)))
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestNormalizeRequestID_TrimsWhitespaceAndControlCharacters(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "trace-123", normalizeRequestID(" \r\ntrace-123\x00 "))
+ assert.Empty(t, normalizeRequestID(" \r\n\x00 "))
+}
+
+// ---------------------------------------------------------------------------
+// Body obfuscation
+// ---------------------------------------------------------------------------
+
+func TestHandleJSONBody_SensitiveFields(t *testing.T) {
+ t.Parallel()
+
+ input := `{"username":"admin","password":"secret123","email":"a@b.com"}`
+ result := handleJSONBody([]byte(input))
+
+ assert.NotContains(t, result, "secret123")
+ assert.Contains(t, result, cn.ObfuscatedValue)
+ assert.Contains(t, result, "admin")
+}
+
+func TestHandleJSONBody_InvalidJSON(t *testing.T) {
+ t.Parallel()
+
+ input := `not json`
+ result := handleJSONBody([]byte(input))
+ assert.Equal(t, input, result)
+}
+
+func TestHandleJSONBody_NestedSensitive(t *testing.T) {
+ t.Parallel()
+
+ input := `{"user":{"name":"alice","password":"pw"},"items":[{"secret_key":"abc"}]}`
+ result := handleJSONBody([]byte(input))
+
+ assert.NotContains(t, result, "pw")
+ assert.Contains(t, result, "alice")
+}
+
+func TestHandleURLEncodedBody_SensitiveFields(t *testing.T) {
+ t.Parallel()
+
+ input := "username=admin&password=secret123&name=test"
+ result := handleURLEncodedBody([]byte(input))
+
+ assert.NotContains(t, result, "secret123")
+ // ObfuscatedValue gets URL-encoded by url.Values.Encode()
+ assert.Contains(t, result, "password=")
+ assert.Contains(t, result, "admin")
+}
+
+func TestHandleURLEncodedBody_InvalidForm(t *testing.T) {
+ t.Parallel()
+
+ input := "%ZZinvalid"
+ result := handleURLEncodedBody([]byte(input))
+ assert.Equal(t, input, result)
+}
+
+// ---------------------------------------------------------------------------
+// WithHTTPLogging middleware integration
+// ---------------------------------------------------------------------------
+
+func TestWithHTTPLogging_SkipsHealthEndpoint(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Use(WithHTTPLogging())
+ app.Get("/health", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestWithHTTPLogging_SetsHeaderID(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Use(WithHTTPLogging())
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ headerID := resp.Header.Get(cn.HeaderID)
+ assert.NotEmpty(t, headerID)
+}
+
+func TestWithHTTPLogging_NormalizesIncomingHeaderID(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Use(WithHTTPLogging())
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set(cn.HeaderID, " trace-123 ")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "trace-123", resp.Header.Get(cn.HeaderID))
+}
+
+func TestWithHTTPLogging_SkipsSwagger(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Use(WithHTTPLogging())
+ app.Get("/swagger/doc.json", func(c *fiber.Ctx) error {
+ return c.SendString("{}")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/swagger/doc.json", nil)
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestWithHTTPLogging_PostWithJSONBody(t *testing.T) {
+ t.Parallel()
+
+ logger := newCaptureLogger()
+ app := fiber.New()
+ app.Use(WithHTTPLogging(WithCustomLogger(logger)))
+ app.Post("/api", func(c *fiber.Ctx) error {
+ return c.SendStatus(http.StatusCreated)
+ })
+
+ body := strings.NewReader(`{"username":"admin","password":"secret"}`)
+ req := httptest.NewRequest(http.MethodPost, "/api?token=abc123", body)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Referer", "https://user:pass@example.com/path?token=abc123#frag")
+ req.Header.Set(cn.HeaderUserAgent, "good-agent\r\nforged")
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, http.StatusCreated, resp.StatusCode)
+
+ entries := logger.entries()
+ require.Len(t, entries, 1)
+ assert.Equal(t, log.LevelInfo, entries[0].level)
+ assert.NotContains(t, entries[0].msg, "secret")
+ assert.NotContains(t, entries[0].msg, "abc123")
+ assert.NotContains(t, entries[0].msg, "\r")
+ assert.NotContains(t, entries[0].msg, "\n")
+ assert.Contains(t, entries[0].msg, "https://example.com/path")
+}
+
+func TestGetBodyObfuscatedString_DispatchesByContentType(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var got string
+
+ app.Post("/api", func(c *fiber.Ctx) error {
+ got = getBodyObfuscatedString(c, c.Body())
+ return c.SendStatus(http.StatusOK)
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/api", strings.NewReader(`{"password":"secret"}`))
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.NotContains(t, got, "secret")
+ assert.Contains(t, got, cn.ObfuscatedValue)
+}
+
+func TestGetBodyObfuscatedString_UnknownContentTypeReturnsRawBody(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ var got string
+
+ app.Post("/api", func(c *fiber.Ctx) error {
+ got = getBodyObfuscatedString(c, c.Body())
+ return c.SendStatus(http.StatusOK)
+ })
+
+ body := "plain text body"
+ req := httptest.NewRequest(http.MethodPost, "/api", strings.NewReader(body))
+ req.Header.Set("Content-Type", "text/plain")
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Equal(t, body, got)
+}
+
+// ---------------------------------------------------------------------------
+// handleJSONBody: array support
+// ---------------------------------------------------------------------------
+
+func TestHandleJSONBody_ArrayTopLevel(t *testing.T) {
+ t.Parallel()
+
+ input := `[{"name":"alice","password":"secret"},{"name":"bob","api_key":"key123"}]`
+ result := handleJSONBody([]byte(input))
+
+ assert.NotContains(t, result, "secret")
+ assert.NotContains(t, result, "key123")
+ assert.Contains(t, result, "alice")
+ assert.Contains(t, result, "bob")
+ assert.Contains(t, result, cn.ObfuscatedValue)
+}
+
+func TestHandleJSONBody_ArrayOfPrimitives(t *testing.T) {
+ t.Parallel()
+
+ input := `[1, 2, 3]`
+ result := handleJSONBody([]byte(input))
+ assert.Equal(t, `[1,2,3]`, result)
+}
+
+func TestHandleJSONBody_EmptyArray(t *testing.T) {
+ t.Parallel()
+
+ input := `[]`
+ result := handleJSONBody([]byte(input))
+ assert.Equal(t, `[]`, result)
+}
+
+// ---------------------------------------------------------------------------
+// Obfuscation depth limit
+// ---------------------------------------------------------------------------
+
+func nestedMapWithPassword(levels int, password string) map[string]any {
+ node := map[string]any{"password": password}
+
+ for i := 0; i < levels; i++ {
+ node = map[string]any{"level": node}
+ }
+
+ return node
+}
+
+func nestedMapPassword(data map[string]any, levels int) string {
+ current := data
+ for i := 0; i < levels; i++ {
+ next, ok := current["level"].(map[string]any)
+ if !ok {
+ return ""
+ }
+
+ current = next
+ }
+
+ password, _ := current["password"].(string)
+
+ return password
+}
+
+func nestedSliceWithPassword(wrappers int, password string) []any {
+ var node any = map[string]any{"password": password}
+
+ for i := 0; i < wrappers; i++ {
+ node = []any{node}
+ }
+
+ data, _ := node.([]any)
+
+ return data
+}
+
+func nestedSlicePassword(data []any, wrappers int) string {
+ var current any = data
+ for i := 0; i < wrappers; i++ {
+ next, ok := current.([]any)
+ if !ok || len(next) == 0 {
+ return ""
+ }
+
+ current = next[0]
+ }
+
+ node, ok := current.(map[string]any)
+ if !ok {
+ return ""
+ }
+
+ password, _ := node["password"].(string)
+
+ return password
+}
+
+func TestObfuscateMapRecursively_DepthLimit(t *testing.T) {
+ t.Parallel()
+
+ t.Run("obfuscates before boundary", func(t *testing.T) {
+ t.Parallel()
+
+ levels := maxObfuscationDepth - 1 // password at depth 31 when max is 32
+ data := nestedMapWithPassword(levels, "deep-secret")
+
+ obfuscateMapRecursively(data, 0)
+ assert.Equal(t, cn.ObfuscatedValue, nestedMapPassword(data, levels))
+ })
+
+ t.Run("does not obfuscate at boundary", func(t *testing.T) {
+ t.Parallel()
+
+ levels := maxObfuscationDepth // password at depth 32 when max is 32
+ data := nestedMapWithPassword(levels, "deep-secret")
+
+ obfuscateMapRecursively(data, 0)
+ assert.Equal(t, "deep-secret", nestedMapPassword(data, levels))
+ })
+}
+
+func TestObfuscateSliceRecursively_DepthLimit(t *testing.T) {
+ t.Parallel()
+
+ t.Run("obfuscates before boundary", func(t *testing.T) {
+ t.Parallel()
+
+ wrappers := maxObfuscationDepth - 1 // map processed at depth 31
+ data := nestedSliceWithPassword(wrappers, "deep-secret")
+
+ obfuscateSliceRecursively(data, 0)
+ assert.Equal(t, cn.ObfuscatedValue, nestedSlicePassword(data, wrappers))
+ })
+
+ t.Run("does not obfuscate at boundary", func(t *testing.T) {
+ t.Parallel()
+
+ wrappers := maxObfuscationDepth // map reached at depth 32
+ data := nestedSliceWithPassword(wrappers, "deep-secret")
+
+ obfuscateSliceRecursively(data, 0)
+ assert.Equal(t, "deep-secret", nestedSlicePassword(data, wrappers))
+ })
+}
+
+// ---------------------------------------------------------------------------
+// handleMultipartBody
+// ---------------------------------------------------------------------------
+
+func TestHandleMultipartBody_ViaMiddleware(t *testing.T) {
+ t.Parallel()
+
+ // We test multipart by going through the middleware stack.
+ // The handleMultipartBody function requires a fiber.Ctx with a parsed multipart form.
+ boundary := "testboundary"
+ body := "--" + boundary + "\r\n" +
+ "Content-Disposition: form-data; name=\"username\"\r\n\r\n" +
+ "admin\r\n" +
+ "--" + boundary + "\r\n" +
+ "Content-Disposition: form-data; name=\"password\"\r\n\r\n" +
+ "my-secret\r\n" +
+ "--" + boundary + "--\r\n"
+
+ app := fiber.New()
+
+ var capturedBody string
+
+ app.Post("/test", func(c *fiber.Ctx) error {
+ capturedBody = handleMultipartBody(c)
+ return c.SendStatus(200)
+ })
+
+ req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader(body))
+ req.Header.Set("Content-Type", "multipart/form-data; boundary="+boundary)
+
+ resp, err := app.Test(req)
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.NotContains(t, capturedBody, "my-secret")
+ assert.Contains(t, capturedBody, "username=admin")
+}
+
+// ---------------------------------------------------------------------------
+// NewRequestInfo and FinishRequestInfo nil guards
+// ---------------------------------------------------------------------------
+
+func TestNewRequestInfo_NilContext(t *testing.T) {
+ t.Parallel()
+
+ info := NewRequestInfo(nil, false)
+ require.NotNil(t, info)
+ assert.False(t, info.Date.IsZero(), "should set Date even with nil context")
+}
+
+func TestFinishRequestInfo_NilWrapper(t *testing.T) {
+ t.Parallel()
+
+ info := &RequestInfo{Date: time.Now().Add(-50 * time.Millisecond)}
+
+ // Should not panic
+ info.FinishRequestInfo(nil)
+
+ // Status and Size should remain zero
+ assert.Equal(t, 0, info.Status)
+ assert.Equal(t, 0, info.Size)
+}
+
+// ---------------------------------------------------------------------------
+// WithObfuscationDisabled option
+// ---------------------------------------------------------------------------
+
+func TestWithObfuscationDisabled_True(t *testing.T) {
+ t.Parallel()
+
+ mid := buildOpts(WithObfuscationDisabled(true))
+ assert.True(t, mid.ObfuscationDisabled)
+}
+
+func TestWithObfuscationDisabled_False(t *testing.T) {
+ t.Parallel()
+
+ mid := buildOpts(WithObfuscationDisabled(false))
+ assert.False(t, mid.ObfuscationDisabled)
+}
+
+func TestWithObfuscationDisabled_OverridesEnvDefault(t *testing.T) {
+ t.Parallel()
+
+ // Default value comes from env var (logObfuscationDisabled).
+ // WithObfuscationDisabled should override it.
+ mid := buildOpts(WithObfuscationDisabled(true))
+ assert.True(t, mid.ObfuscationDisabled)
+
+ mid2 := buildOpts(WithObfuscationDisabled(false))
+ assert.False(t, mid2.ObfuscationDisabled)
+}
+
+// ---------------------------------------------------------------------------
+// mockLogger for WithCustomLogger tests
+// ---------------------------------------------------------------------------
+
+type mockLogger struct{}
+
+func (m *mockLogger) Log(context.Context, log.Level, string, ...log.Field) {}
+func (m *mockLogger) With(...log.Field) log.Logger { return m }
+func (m *mockLogger) WithGroup(string) log.Logger { return m }
+func (m *mockLogger) Enabled(log.Level) bool { return true }
+func (m *mockLogger) Sync(context.Context) error { return nil }
diff --git a/commons/net/http/withTelemetry.go b/commons/net/http/withTelemetry.go
index eb425446..6842026b 100644
--- a/commons/net/http/withTelemetry.go
+++ b/commons/net/http/withTelemetry.go
@@ -1,25 +1,16 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package http
import (
"context"
- "net/url"
- "os"
- "strings"
- "sync"
+ "fmt"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons"
- cn "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry"
- "github.com/LerianStudio/lib-commons/v2/commons/security"
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
"github.com/gofiber/fiber/v2"
- "go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
- "go.opentelemetry.io/otel/metric"
+ "go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@@ -30,14 +21,7 @@ import (
// Can be overridden via METRICS_COLLECTION_INTERVAL environment variable.
const DefaultMetricsCollectionInterval = 5 * time.Second
-var (
- metricsCollectorOnce = &sync.Once{}
- metricsCollectorShutdown chan struct{}
- metricsCollectorMu sync.Mutex
- metricsCollectorStarted bool
- metricsCollectorInitErr error
-)
-
+// TelemetryMiddleware wraps HTTP and gRPC handlers with tracing and metrics setup.
type TelemetryMiddleware struct {
Telemetry *opentelemetry.Telemetry
}
@@ -50,72 +34,95 @@ func NewTelemetryMiddleware(tl *opentelemetry.Telemetry) *TelemetryMiddleware {
// WithTelemetry is a middleware that adds tracing to the context.
func (tm *TelemetryMiddleware) WithTelemetry(tl *opentelemetry.Telemetry, excludedRoutes ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
- if len(excludedRoutes) > 0 && tm.isRouteExcluded(c, excludedRoutes) {
+ effectiveTelemetry := tl
+ if effectiveTelemetry == nil && tm != nil {
+ effectiveTelemetry = tm.Telemetry
+ }
+
+ if effectiveTelemetry == nil {
+ return c.Next()
+ }
+
+ if len(excludedRoutes) > 0 && isRouteExcludedFromList(c, excludedRoutes) {
return c.Next()
}
setRequestHeaderID(c)
ctx := c.UserContext()
-
_, _, reqId, _ := commons.NewTrackingFromContext(ctx)
c.SetUserContext(commons.ContextWithSpanAttributes(ctx,
attribute.String("app.request.request_id", reqId),
))
- tracer := otel.Tracer(tl.LibraryName)
+ if effectiveTelemetry.TracerProvider == nil {
+ return c.Next()
+ }
+
+ tracer := effectiveTelemetry.TracerProvider.Tracer(effectiveTelemetry.LibraryName)
routePathWithMethod := c.Method() + " " + commons.ReplaceUUIDWithPlaceholder(c.Path())
traceCtx := c.UserContext()
+ // Compatibility note: trace extraction currently trusts the internal-service
+ // User-Agent heuristic. This is an interoperability hint, not an authenticated
+ // trust boundary, and is preserved to avoid changing existing caller behavior.
if commons.IsInternalLerianService(c.Get(cn.HeaderUserAgent)) {
- traceCtx = opentelemetry.ExtractHTTPContext(c)
+ traceCtx = opentelemetry.ExtractHTTPContext(traceCtx, c)
}
ctx, span := tracer.Start(traceCtx, routePathWithMethod, trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
- span.SetAttributes(
- attribute.String("http.method", c.Method()),
- attribute.String("http.url", sanitizeURL(c.OriginalURL())),
- attribute.String("http.route", c.Route().Path),
- attribute.String("http.scheme", c.Protocol()),
- attribute.String("http.host", c.Hostname()),
- attribute.String("http.user_agent", c.Get("User-Agent")),
- )
-
ctx = commons.ContextWithTracer(ctx, tracer)
- ctx = commons.ContextWithMetricFactory(ctx, tl.MetricsFactory)
-
+ ctx = commons.ContextWithMetricFactory(ctx, effectiveTelemetry.MetricsFactory)
c.SetUserContext(ctx)
err := tm.collectMetrics(ctx)
if err != nil {
- opentelemetry.HandleSpanError(&span, "Failed to collect metrics", err)
+ opentelemetry.HandleSpanError(span, "Failed to collect metrics", err)
}
err = c.Next()
+ statusCode := c.Response().StatusCode()
span.SetAttributes(
- attribute.Int("http.status_code", c.Response().StatusCode()),
+ attribute.String("http.request.method", c.Method()),
+ attribute.String("url.path", sanitizeURL(c.OriginalURL())),
+ attribute.String("http.route", c.Route().Path),
+ attribute.String("url.scheme", c.Protocol()),
+ attribute.String("server.address", c.Hostname()),
+ attribute.String("user_agent.original", c.Get(cn.HeaderUserAgent)),
+ attribute.Int("http.response.status_code", statusCode),
)
+ if err != nil {
+ opentelemetry.HandleSpanError(span, "handler error", err)
+ } else if statusCode >= 500 {
+ span.SetStatus(codes.Error, fmt.Sprintf("HTTP %d", statusCode))
+ }
+
return err
}
}
// EndTracingSpans is a middleware that ends the tracing spans.
func (tm *TelemetryMiddleware) EndTracingSpans(c *fiber.Ctx) error {
- ctx := c.UserContext()
- if ctx == nil {
- return nil
+ if c == nil {
+ return ErrContextNotFound
}
+ originalCtx := c.UserContext()
err := c.Next()
- go func() {
- trace.SpanFromContext(ctx).End()
- }()
+ endCtx := c.UserContext()
+ if endCtx == nil {
+ endCtx = originalCtx
+ }
+
+ if endCtx != nil {
+ trace.SpanFromContext(endCtx).End()
+ }
return err
}
@@ -128,38 +135,68 @@ func (tm *TelemetryMiddleware) WithTelemetryInterceptor(tl *opentelemetry.Teleme
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
- ctx = setGRPCRequestHeaderID(ctx)
+ ctx = normalizeGRPCContext(ctx)
- _, _, reqId, _ := commons.NewTrackingFromContext(ctx)
- tracer := otel.Tracer(tl.LibraryName)
+ effectiveTelemetry := tl
+ if effectiveTelemetry == nil && tm != nil {
+ effectiveTelemetry = tm.Telemetry
+ }
+
+ if effectiveTelemetry == nil {
+ return handler(ctx, req)
+ }
+
+ requestID := resolveGRPCRequestID(ctx, req)
+ ctx = commons.ContextWithHeaderID(ctx, requestID)
+
+ if effectiveTelemetry.TracerProvider == nil {
+ return handler(ctx, req)
+ }
+
+ tracer := effectiveTelemetry.TracerProvider.Tracer(effectiveTelemetry.LibraryName)
+
+ methodName := "unknown"
+ if info != nil {
+ methodName = info.FullMethod
+ }
ctx = commons.ContextWithSpanAttributes(ctx,
- attribute.String("app.request.request_id", reqId),
- attribute.String("grpc.method", info.FullMethod),
+ attribute.String("app.request.request_id", requestID),
+ attribute.String("grpc.method", methodName),
)
traceCtx := ctx
+ // Compatibility note: trace extraction currently trusts the internal-service
+ // User-Agent heuristic. This is an interoperability hint, not an authenticated
+ // trust boundary, and is preserved to avoid changing existing caller behavior.
if commons.IsInternalLerianService(getGRPCUserAgent(ctx)) {
- traceCtx = opentelemetry.ExtractGRPCContext(ctx)
+ md, _ := metadata.FromIncomingContext(ctx)
+ traceCtx = opentelemetry.ExtractGRPCContext(ctx, md)
}
- ctx, span := tracer.Start(traceCtx, info.FullMethod, trace.WithSpanKind(trace.SpanKindServer))
+ ctx, span := tracer.Start(traceCtx, methodName, trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
ctx = commons.ContextWithTracer(ctx, tracer)
- ctx = commons.ContextWithMetricFactory(ctx, tl.MetricsFactory)
+ ctx = commons.ContextWithMetricFactory(ctx, effectiveTelemetry.MetricsFactory)
err := tm.collectMetrics(ctx)
if err != nil {
- opentelemetry.HandleSpanError(&span, "Failed to collect metrics", err)
+ opentelemetry.HandleSpanError(span, "Failed to collect metrics", err)
}
resp, err := handler(ctx, req)
+ grpcStatusCode := status.Code(err)
span.SetAttributes(
- attribute.Int("grpc.status_code", int(status.Code(err))),
+ attribute.String("rpc.method", methodName),
+ attribute.Int("rpc.grpc.status_code", int(grpcStatusCode)),
)
+ if err != nil {
+ opentelemetry.HandleSpanError(span, "gRPC handler error", err)
+ }
+
return resp, err
}
}
@@ -173,160 +210,8 @@ func (tm *TelemetryMiddleware) EndTracingSpansInterceptor() grpc.UnaryServerInte
handler grpc.UnaryHandler,
) (any, error) {
resp, err := handler(ctx, req)
-
- go func() {
- trace.SpanFromContext(ctx).End()
- }()
+ trace.SpanFromContext(ctx).End()
return resp, err
}
}
-
-func (tm *TelemetryMiddleware) collectMetrics(_ context.Context) error {
- return tm.ensureMetricsCollector()
-}
-
-// getMetricsCollectionInterval returns the metrics collection interval.
-// Can be configured via METRICS_COLLECTION_INTERVAL environment variable.
-// Accepts Go duration format (e.g., "10s", "1m", "500ms").
-// Falls back to DefaultMetricsCollectionInterval if not set or invalid.
-func getMetricsCollectionInterval() time.Duration {
- if envInterval := os.Getenv("METRICS_COLLECTION_INTERVAL"); envInterval != "" {
- if parsed, err := time.ParseDuration(envInterval); err == nil && parsed > 0 {
- return parsed
- }
- }
-
- return DefaultMetricsCollectionInterval
-}
-
-func (tm *TelemetryMiddleware) ensureMetricsCollector() error {
- metricsCollectorMu.Lock()
- defer metricsCollectorMu.Unlock()
-
- if metricsCollectorStarted {
- return nil
- }
-
- if metricsCollectorInitErr != nil {
- // Reset to allow retry after transient init failures
- metricsCollectorOnce = &sync.Once{}
- metricsCollectorInitErr = nil
- }
-
- metricsCollectorOnce.Do(func() {
- cpuGauge, err := otel.Meter(tm.Telemetry.ServiceName).Int64Gauge("system.cpu.usage", metric.WithUnit("percentage"))
- if err != nil {
- metricsCollectorInitErr = err
- return
- }
-
- memGauge, err := otel.Meter(tm.Telemetry.ServiceName).Int64Gauge("system.mem.usage", metric.WithUnit("percentage"))
- if err != nil {
- metricsCollectorInitErr = err
- return
- }
-
- metricsCollectorShutdown = make(chan struct{})
- ticker := time.NewTicker(getMetricsCollectionInterval())
-
- go func() {
- commons.GetCPUUsage(context.Background(), cpuGauge)
- commons.GetMemUsage(context.Background(), memGauge)
-
- for {
- select {
- case <-metricsCollectorShutdown:
- ticker.Stop()
- return
- case <-ticker.C:
- commons.GetCPUUsage(context.Background(), cpuGauge)
- commons.GetMemUsage(context.Background(), memGauge)
- }
- }
- }()
-
- metricsCollectorStarted = true
- })
-
- return metricsCollectorInitErr
-}
-
-// StopMetricsCollector stops the background metrics collector goroutine.
-// Should be called during application shutdown for graceful cleanup.
-// After calling this function, the collector can be restarted by new requests.
-//
-// Implementation note: This function intentionally resets sync.Once to a new instance
-// to allow the collector to be restarted after being stopped. This is an unusual but
-// intentional pattern - the mutex ensures thread-safety during the reset operation,
-// preventing race conditions between Stop and subsequent Start calls.
-func StopMetricsCollector() {
- metricsCollectorMu.Lock()
- defer metricsCollectorMu.Unlock()
-
- if metricsCollectorStarted && metricsCollectorShutdown != nil {
- close(metricsCollectorShutdown)
-
- metricsCollectorStarted = false
- metricsCollectorOnce = &sync.Once{}
- metricsCollectorInitErr = nil
- }
-}
-
-func (tm *TelemetryMiddleware) isRouteExcluded(c *fiber.Ctx, excludedRoutes []string) bool {
- for _, route := range excludedRoutes {
- if strings.HasPrefix(c.Path(), route) {
- return true
- }
- }
-
- return false
-}
-
-// sanitizeURL removes or obfuscates sensitive query parameters from URLs
-// to prevent exposing tokens, API keys, and other sensitive data in telemetry.
-func sanitizeURL(rawURL string) string {
- parsed, err := url.Parse(rawURL)
- if err != nil {
- return rawURL
- }
-
- if parsed.RawQuery == "" {
- return rawURL
- }
-
- query := parsed.Query()
- modified := false
-
- for key := range query {
- if security.IsSensitiveField(key) {
- query.Set(key, cn.ObfuscatedValue)
-
- modified = true
- }
- }
-
- if !modified {
- return rawURL
- }
-
- parsed.RawQuery = query.Encode()
-
- return parsed.String()
-}
-
-// getGRPCUserAgent extracts the User-Agent from incoming gRPC metadata.
-// Returns empty string if the metadata is not present or doesn't contain user-agent.
-func getGRPCUserAgent(ctx context.Context) string {
- md, ok := metadata.FromIncomingContext(ctx)
- if !ok || md == nil {
- return ""
- }
-
- userAgents := md.Get("user-agent")
- if len(userAgents) == 0 {
- return ""
- }
-
- return userAgents[0]
-}
diff --git a/commons/net/http/withTelemetry_helpers.go b/commons/net/http/withTelemetry_helpers.go
new file mode 100644
index 00000000..d7751f6c
--- /dev/null
+++ b/commons/net/http/withTelemetry_helpers.go
@@ -0,0 +1,86 @@
+package http
+
+import (
+ "context"
+ "net/url"
+ "strings"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/security"
+ "github.com/gofiber/fiber/v2"
+ "google.golang.org/grpc/metadata"
+)
+
+// isRouteExcludedFromList reports whether the request path matches any excluded route prefix.
+// This standalone function is used to evaluate route exclusions independently of whether
+// the TelemetryMiddleware receiver is nil.
+func isRouteExcludedFromList(c *fiber.Ctx, excludedRoutes []string) bool {
+ for _, route := range excludedRoutes {
+ if strings.HasPrefix(c.Path(), route) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// sanitizeURL removes or obfuscates sensitive query parameters from URLs
+// to prevent exposing tokens, API keys, and other sensitive data in telemetry.
+func sanitizeURL(rawURL string) string {
+ parsed, err := url.Parse(rawURL)
+ if err != nil {
+ return sanitizeMalformedURL(rawURL)
+ }
+
+ if parsed.RawQuery == "" {
+ return rawURL
+ }
+
+ query := parsed.Query()
+ modified := false
+
+ for key := range query {
+ if security.IsSensitiveField(key) {
+ query.Set(key, cn.ObfuscatedValue)
+
+ modified = true
+ }
+ }
+
+ if !modified {
+ return rawURL
+ }
+
+ parsed.RawQuery = query.Encode()
+
+ return parsed.String()
+}
+
+func sanitizeMalformedURL(rawURL string) string {
+ sanitized := sanitizeLogValue(rawURL)
+ if before, _, ok := strings.Cut(sanitized, "?"); ok {
+ return before + "?redacted"
+ }
+
+ return sanitized
+}
+
+// getGRPCUserAgent extracts the User-Agent from incoming gRPC metadata.
+// Returns empty string if the metadata is not present or doesn't contain user-agent.
+func getGRPCUserAgent(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok || md == nil {
+ return ""
+ }
+
+ userAgents := md.Get(strings.ToLower(cn.HeaderUserAgent))
+ if len(userAgents) == 0 {
+ return ""
+ }
+
+ return userAgents[0]
+}
diff --git a/commons/net/http/withTelemetry_metrics.go b/commons/net/http/withTelemetry_metrics.go
new file mode 100644
index 00000000..e7c8c609
--- /dev/null
+++ b/commons/net/http/withTelemetry_metrics.go
@@ -0,0 +1,131 @@
+package http
+
+import (
+ "context"
+ "errors"
+ "os"
+ "sync"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+)
+
+// Metrics collector singleton state.
+var (
+ metricsCollectorOnce = &sync.Once{}
+ metricsCollectorShutdown chan struct{}
+ metricsCollectorMu sync.Mutex
+ metricsCollectorStarted bool
+ metricsCollectorInitErr error
+)
+
+// telemetryRuntimeLogger returns the runtime logger from the telemetry middleware, or nil.
+func telemetryRuntimeLogger(tm *TelemetryMiddleware) runtime.Logger {
+ if tm == nil || tm.Telemetry == nil {
+ return nil
+ }
+
+ return tm.Telemetry.Logger
+}
+
+// collectMetrics ensures the background metrics collector goroutine is running.
+func (tm *TelemetryMiddleware) collectMetrics(_ context.Context) error {
+ return tm.ensureMetricsCollector()
+}
+
+// getMetricsCollectionInterval returns the metrics collection interval.
+// Can be configured via METRICS_COLLECTION_INTERVAL environment variable.
+// Accepts Go duration format (e.g., "10s", "1m", "500ms").
+// Falls back to DefaultMetricsCollectionInterval if not set or invalid.
+func getMetricsCollectionInterval() time.Duration {
+ if envInterval := os.Getenv("METRICS_COLLECTION_INTERVAL"); envInterval != "" {
+ if parsed, err := time.ParseDuration(envInterval); err == nil && parsed > 0 {
+ return parsed
+ }
+ }
+
+ return DefaultMetricsCollectionInterval
+}
+
+// ensureMetricsCollector lazily starts the background metrics collector singleton.
+func (tm *TelemetryMiddleware) ensureMetricsCollector() error {
+ if tm == nil || tm.Telemetry == nil {
+ return nil
+ }
+
+ if tm.Telemetry.MeterProvider == nil {
+ return nil
+ }
+
+ metricsCollectorMu.Lock()
+ defer metricsCollectorMu.Unlock()
+
+ if metricsCollectorStarted {
+ return nil
+ }
+
+ if metricsCollectorInitErr != nil {
+ metricsCollectorOnce = &sync.Once{}
+ metricsCollectorInitErr = nil
+ }
+
+ metricsCollectorOnce.Do(func() {
+ factory := tm.Telemetry.MetricsFactory
+ if factory == nil {
+ metricsCollectorInitErr = errors.New("telemetry MetricsFactory is nil, cannot start system metrics collector")
+ return
+ }
+
+ metricsCollectorShutdown = make(chan struct{})
+ ticker := time.NewTicker(getMetricsCollectionInterval())
+
+ runtime.SafeGoWithContextAndComponent(
+ context.Background(),
+ telemetryRuntimeLogger(tm),
+ "http",
+ "metrics_collector",
+ runtime.KeepRunning,
+ func(_ context.Context) {
+ commons.GetCPUUsage(context.Background(), factory)
+ commons.GetMemUsage(context.Background(), factory)
+
+ for {
+ select {
+ case <-metricsCollectorShutdown:
+ ticker.Stop()
+ return
+ case <-ticker.C:
+ commons.GetCPUUsage(context.Background(), factory)
+ commons.GetMemUsage(context.Background(), factory)
+ }
+ }
+ },
+ )
+
+ metricsCollectorStarted = true
+ })
+
+ return metricsCollectorInitErr
+}
+
+// StopMetricsCollector stops the background metrics collector goroutine.
+// Should be called during application shutdown for graceful cleanup.
+// After calling this function, the collector can be restarted by new requests.
+//
+// Implementation note: This function intentionally resets sync.Once to a new instance
+// to allow the collector to be restarted after being stopped. This is an unusual but
+// intentional pattern - the mutex ensures thread-safety during the reset operation,
+// preventing race conditions between Stop and subsequent Start calls.
+func StopMetricsCollector() {
+ metricsCollectorMu.Lock()
+ defer metricsCollectorMu.Unlock()
+
+ if metricsCollectorStarted && metricsCollectorShutdown != nil {
+ close(metricsCollectorShutdown)
+
+ metricsCollectorStarted = false
+ metricsCollectorOnce = &sync.Once{}
+ metricsCollectorInitErr = nil
+ }
+}
diff --git a/commons/net/http/withTelemetry_route_test.go b/commons/net/http/withTelemetry_route_test.go
new file mode 100644
index 00000000..972026ab
--- /dev/null
+++ b/commons/net/http/withTelemetry_route_test.go
@@ -0,0 +1,39 @@
+//go:build unit
+
+package http
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestWithTelemetry_UnmatchedRouteDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ tp, spanRecorder := setupTestTracer()
+ defer func() { _ = tp.Shutdown(context.Background()) }()
+
+ telemetry := &opentelemetry.Telemetry{
+ TelemetryConfig: opentelemetry.TelemetryConfig{LibraryName: "test-library", EnableTelemetry: true},
+ TracerProvider: tp,
+ }
+
+ app := fiber.New()
+ app.Use(NewTelemetryMiddleware(telemetry).WithTelemetry(telemetry))
+
+ assert.NotPanics(t, func() {
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/missing", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+ assert.Equal(t, http.StatusNotFound, resp.StatusCode)
+ })
+
+ assert.NotEmpty(t, spanRecorder.Ended())
+}
diff --git a/commons/net/http/withTelemetry_test.go b/commons/net/http/withTelemetry_test.go
index d6834cde..b3b31b40 100644
--- a/commons/net/http/withTelemetry_test.go
+++ b/commons/net/http/withTelemetry_test.go
@@ -1,6 +1,4 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package http
@@ -10,16 +8,19 @@ import (
"net/http"
"net/http/httptest"
"strings"
+ "sync"
"testing"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ otelmetrics "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
+ sdkmetric "go.opentelemetry.io/otel/sdk/metric"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"go.opentelemetry.io/otel/trace"
@@ -33,10 +34,10 @@ func setupTestTracer() (*sdktrace.TracerProvider, *tracetest.SpanRecorder) {
tracerProvider := sdktrace.NewTracerProvider(
sdktrace.WithSpanProcessor(spanRecorder),
)
-
+
// Set the global propagator to TraceContext
otel.SetTextMapPropagator(propagation.TraceContext{})
-
+
return tracerProvider, spanRecorder
}
@@ -109,18 +110,18 @@ func TestWithTelemetry(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
-
+
// Setup test tracer
tp, spanRecorder := setupTestTracer()
defer func() {
_ = tp.Shutdown(ctx)
}()
-
+
// Replace the global tracer provider for this test
oldTracerProvider := otel.GetTracerProvider()
otel.SetTracerProvider(tp)
defer otel.SetTracerProvider(oldTracerProvider)
-
+
// Setup telemetry
var telemetry *opentelemetry.Telemetry
if !tt.nilTelemetry {
@@ -171,24 +172,24 @@ func TestWithTelemetry(t *testing.T) {
// Execute request
resp, err := app.Test(req)
require.NoError(t, err)
- defer resp.Body.Close()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
// Check status code
assert.Equal(t, tt.expectedStatusCode, resp.StatusCode)
-
+
// Check spans
spans := spanRecorder.Ended()
-
+
if tt.expectSpan && !tt.nilTelemetry && !tt.swaggerPath {
// Should have created a span
require.GreaterOrEqual(t, len(spans), 1, "Expected at least one span to be created")
-
+
// Check span name
expectedPath := tt.path
if strings.Contains(tt.path, "123e4567-e89b-12d3-a456-426614174000") {
expectedPath = commons.ReplaceUUIDWithPlaceholder(tt.path)
}
-
+
spanFound := false
for _, span := range spans {
if span.Name() == tt.method+" "+expectedPath {
@@ -256,18 +257,18 @@ func TestWithTelemetryExcludedRoutes(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
-
+
// Setup test tracer
tp, spanRecorder := setupTestTracer()
defer func() {
_ = tp.Shutdown(ctx)
}()
-
+
// Replace the global tracer provider for this test
oldTracerProvider := otel.GetTracerProvider()
otel.SetTracerProvider(tp)
defer otel.SetTracerProvider(oldTracerProvider)
-
+
// Setup telemetry
telemetry := &opentelemetry.Telemetry{
TelemetryConfig: opentelemetry.TelemetryConfig{
@@ -298,18 +299,18 @@ func TestWithTelemetryExcludedRoutes(t *testing.T) {
// Execute request
resp, err := app.Test(req)
require.NoError(t, err)
- defer resp.Body.Close()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
// Check status code
assert.Equal(t, http.StatusOK, resp.StatusCode)
-
+
// Check spans
spans := spanRecorder.Ended()
-
+
if tt.expectSpan {
// Should have created a span
require.GreaterOrEqual(t, len(spans), 1, "Expected at least one span to be created")
-
+
// Check span name
expectedSpanName := tt.method + " " + commons.ReplaceUUIDWithPlaceholder(tt.path)
spanFound := false
@@ -410,7 +411,7 @@ func TestEndTracingSpans(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
resp, err := app.Test(req)
require.NoError(t, err)
- defer resp.Body.Close()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
// Verify error propagation via status code
if tt.handlerErr != nil {
@@ -440,6 +441,51 @@ func TestEndTracingSpans(t *testing.T) {
}
}
+func TestEndTracingSpans_CallsNextWithoutInitialContext(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ middleware := &TelemetryMiddleware{}
+ handlerCalled := false
+
+ app.Get("/test", middleware.EndTracingSpans, func(c *fiber.Ctx) error {
+ handlerCalled = true
+ return c.SendStatus(http.StatusNoContent)
+ })
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/test", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.True(t, handlerCalled)
+ assert.Equal(t, http.StatusNoContent, resp.StatusCode)
+}
+
+func TestEndTracingSpans_EndsFinalContextSpan(t *testing.T) {
+ t.Parallel()
+
+ tp, spanRecorder := setupTestTracer()
+ defer func() { _ = tp.Shutdown(context.Background()) }()
+
+ app := fiber.New()
+ middleware := &TelemetryMiddleware{}
+
+ app.Get("/test", middleware.EndTracingSpans, func(c *fiber.Ctx) error {
+ ctx, _ := tp.Tracer("test").Start(context.Background(), "handler-span")
+ c.SetUserContext(ctx)
+ return c.SendStatus(http.StatusNoContent)
+ })
+
+ resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/test", nil))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, resp.Body.Close()) }()
+
+ assert.Eventually(t, func() bool {
+ return len(spanRecorder.Ended()) == 1
+ }, time.Second, 10*time.Millisecond)
+ assert.Equal(t, "handler-span", spanRecorder.Ended()[0].Name())
+}
+
// TestGetMetricsCollectionInterval tests the getMetricsCollectionInterval function
func TestGetMetricsCollectionInterval(t *testing.T) {
tests := []struct {
@@ -498,6 +544,65 @@ func TestGetMetricsCollectionInterval(t *testing.T) {
}
}
+func resetMetricsCollectorState() {
+ metricsCollectorMu.Lock()
+ defer metricsCollectorMu.Unlock()
+
+ if metricsCollectorStarted && metricsCollectorShutdown != nil {
+ close(metricsCollectorShutdown)
+ time.Sleep(50 * time.Millisecond)
+ }
+
+ metricsCollectorShutdown = nil
+ metricsCollectorStarted = false
+ metricsCollectorOnce = &sync.Once{}
+ metricsCollectorInitErr = nil
+}
+
+func TestEnsureMetricsCollector_ReturnsErrorWhenMetricsFactoryNil(t *testing.T) {
+ resetMetricsCollectorState()
+ t.Cleanup(resetMetricsCollectorState)
+
+ middleware := &TelemetryMiddleware{Telemetry: &opentelemetry.Telemetry{
+ TelemetryConfig: opentelemetry.TelemetryConfig{LibraryName: "test-library", EnableTelemetry: true},
+ MeterProvider: sdkmetric.NewMeterProvider(),
+ }}
+
+ err := middleware.ensureMetricsCollector()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "MetricsFactory is nil")
+ assert.False(t, metricsCollectorStarted)
+}
+
+func TestEnsureMetricsCollector_NoMeterProviderReturnsNil(t *testing.T) {
+ resetMetricsCollectorState()
+ t.Cleanup(resetMetricsCollectorState)
+
+ middleware := &TelemetryMiddleware{Telemetry: &opentelemetry.Telemetry{}}
+ require.NoError(t, middleware.ensureMetricsCollector())
+ assert.False(t, metricsCollectorStarted)
+}
+
+func TestStopMetricsCollector_AllowsRestart(t *testing.T) {
+ resetMetricsCollectorState()
+ t.Cleanup(resetMetricsCollectorState)
+
+ middleware := &TelemetryMiddleware{Telemetry: &opentelemetry.Telemetry{
+ TelemetryConfig: opentelemetry.TelemetryConfig{LibraryName: "test-library", EnableTelemetry: true},
+ MeterProvider: sdkmetric.NewMeterProvider(),
+ MetricsFactory: otelmetrics.NewNopFactory(),
+ }}
+
+ require.NoError(t, middleware.ensureMetricsCollector())
+ assert.True(t, metricsCollectorStarted)
+
+ StopMetricsCollector()
+ assert.False(t, metricsCollectorStarted)
+
+ require.NoError(t, middleware.ensureMetricsCollector())
+ assert.True(t, metricsCollectorStarted)
+}
+
// TestExtractHTTPContext tests the ExtractHTTPContext function
func TestExtractHTTPContext(t *testing.T) {
ctx := context.Background()
@@ -522,7 +627,7 @@ func TestExtractHTTPContext(t *testing.T) {
// Add test route
app.Get("/test", func(c *fiber.Ctx) error {
// Extract context
- ctx := opentelemetry.ExtractHTTPContext(c)
+ ctx := opentelemetry.ExtractHTTPContext(c.UserContext(), c)
// Check if span info was extracted
spanCtx := trace.SpanContextFromContext(ctx)
@@ -546,7 +651,7 @@ func TestExtractHTTPContext(t *testing.T) {
resp1, err := app.Test(req1)
require.NoError(t, err)
- defer resp1.Body.Close()
+ defer func() { require.NoError(t, resp1.Body.Close()) }()
assert.Equal(t, http.StatusOK, resp1.StatusCode)
// Test without traceparent header
@@ -555,18 +660,18 @@ func TestExtractHTTPContext(t *testing.T) {
resp2, err := app.Test(req2)
require.NoError(t, err)
- defer resp2.Body.Close()
+ defer func() { require.NoError(t, resp2.Body.Close()) }()
assert.Equal(t, http.StatusOK, resp2.StatusCode)
}
// TestWithTelemetryConditionalTracePropagation tests the conditional trace propagation based on UserAgent
func TestWithTelemetryConditionalTracePropagation(t *testing.T) {
tests := []struct {
- name string
- userAgent string
- traceparent string
+ name string
+ userAgent string
+ traceparent string
shouldPropagateTrace bool
- description string
+ description string
}{
{
name: "Internal Lerian service - should propagate trace",
@@ -665,7 +770,7 @@ func TestWithTelemetryConditionalTracePropagation(t *testing.T) {
// Execute request
resp, err := app.Test(req)
require.NoError(t, err)
- defer resp.Body.Close()
+ defer func() { require.NoError(t, resp.Body.Close()) }()
// Check status code
assert.Equal(t, http.StatusOK, resp.StatusCode)
@@ -695,10 +800,10 @@ func TestWithTelemetryConditionalTracePropagation(t *testing.T) {
// TestGetGRPCUserAgent tests the getGRPCUserAgent helper function
func TestGetGRPCUserAgent(t *testing.T) {
tests := []struct {
- name string
- setupMetadata func() context.Context
- expectedUA string
- description string
+ name string
+ setupMetadata func() context.Context
+ expectedUA string
+ description string
}{
{
name: "Valid user-agent in metadata",
@@ -758,6 +863,80 @@ func TestGetGRPCUserAgent(t *testing.T) {
}
}
+// ---------------------------------------------------------------------------
+// sanitizeURL tests
+// ---------------------------------------------------------------------------
+
+func TestSanitizeURL_NoQueryParams(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeURL("https://example.com/api/v1/users")
+ assert.Equal(t, "https://example.com/api/v1/users", result)
+}
+
+func TestSanitizeURL_NoSensitiveParams(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeURL("https://example.com/api?page=1&limit=20")
+ assert.Equal(t, "https://example.com/api?page=1&limit=20", result)
+}
+
+func TestSanitizeURL_SensitiveTokenParam(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeURL("https://example.com/callback?token=secret123&state=abc")
+ assert.NotContains(t, result, "secret123")
+ assert.Contains(t, result, "state=abc")
+}
+
+func TestSanitizeURL_SensitivePasswordParam(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeURL("https://example.com/auth?password=hunter2&username=admin")
+ assert.NotContains(t, result, "hunter2")
+ assert.Contains(t, result, "username=admin")
+}
+
+func TestSanitizeURL_SensitiveAPIKeyParam(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeURL("https://example.com/api?api_key=my-secret-key&format=json")
+ assert.NotContains(t, result, "my-secret-key")
+ assert.Contains(t, result, "format=json")
+}
+
+func TestSanitizeURL_InvalidURL_ReturnedAsIs(t *testing.T) {
+ t.Parallel()
+
+ // A URL that cannot be parsed should be returned as-is
+ invalidURL := "://missing-scheme"
+ result := sanitizeURL(invalidURL)
+ assert.Equal(t, invalidURL, result)
+}
+
+func TestSanitizeURL_InvalidURLWithSensitiveQuery_RedactsFallback(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeURL("://missing-scheme?token=secret123")
+ assert.NotContains(t, result, "secret123")
+ assert.Contains(t, result, "?redacted")
+}
+
+func TestSanitizeURL_EmptyQueryReturnsOriginal(t *testing.T) {
+ t.Parallel()
+
+ original := "https://example.com/path"
+ result := sanitizeURL(original)
+ assert.Equal(t, original, result)
+}
+
+func TestSanitizeURL_RelativePath(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeURL("/api/v1/users?token=abc123")
+ assert.NotContains(t, result, "abc123")
+}
+
// TestWithTelemetryInterceptorConditionalTracePropagation tests conditional trace propagation in gRPC interceptor
func TestWithTelemetryInterceptorConditionalTracePropagation(t *testing.T) {
tests := []struct {
diff --git a/commons/opentelemetry/README.md b/commons/opentelemetry/README.md
index 4f99caa4..bc8faa93 100644
--- a/commons/opentelemetry/README.md
+++ b/commons/opentelemetry/README.md
@@ -1,322 +1,73 @@
-# OpenTelemetry Package
+# OpenTelemetry v2
-This package provides OpenTelemetry integration for the LerianStudio commons library, including advanced struct obfuscation capabilities for secure telemetry data.
+This package now exposes a strict v2 API with deliberate breakage from v1.
-## Features
+## Breaking changes
-- **OpenTelemetry Integration**: Complete setup and configuration for tracing, metrics, and logging
-- **Struct Obfuscation**: Advanced field obfuscation for sensitive data in telemetry spans
-- **Flexible Configuration**: Support for custom obfuscation rules and business logic
-- **Backward Compatibility**: Maintains existing API while adding new security features
+- No fatal initializer. Use `NewTelemetry` and handle returned errors.
+- No implicit global mutation during initialization.
+- Span helpers use `trace.Span` (value) instead of `*trace.Span`.
+- Struct-to-single-JSON-attribute helpers were removed.
+- Obfuscation is explicit and deterministic via `Redactor` rules.
+- Metrics factory and metric builders return errors (no silent no-op).
+- High-cardinality label helpers were removed.
-## Quick Start
-
-### Basic Usage (Without Obfuscation)
-
-```go
-import (
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry"
- "go.opentelemetry.io/otel"
-)
-
-// Create a span and add struct data
-tracer := otel.Tracer("my-service")
-_, span := tracer.Start(ctx, "operation")
-defer span.End()
-
-// Add struct data to span (original behavior)
-err := opentelemetry.SetSpanAttributesFromStruct(&span, "user_data", userStruct)
-```
-
-### With Default Obfuscation
-
-```go
-// Create default obfuscator (covers common sensitive fields)
-obfuscator := opentelemetry.NewDefaultObfuscator()
-
-// Add obfuscated struct data to span
-err := opentelemetry.SetSpanAttributesFromStructWithObfuscation(
- &span, "user_data", userStruct, obfuscator)
-```
-
-### With Custom Obfuscation
-
-```go
-// Define custom sensitive fields
-customFields := []string{"email", "phone", "address"}
-customObfuscator := opentelemetry.NewCustomObfuscator(customFields)
-
-// Apply custom obfuscation
-err := opentelemetry.SetSpanAttributesFromStructWithObfuscation(
- &span, "user_data", userStruct, customObfuscator)
-```
-
-## Struct Obfuscation Examples
-
-### Example Data Structure
-
-```go
-type UserLoginRequest struct {
- Username string `json:"username"`
- Password string `json:"password"`
- Email string `json:"email"`
- RememberMe bool `json:"rememberMe"`
- DeviceInfo DeviceInfo `json:"deviceInfo"`
- Credentials AuthCredentials `json:"credentials"`
- Metadata map[string]any `json:"metadata"`
-}
-
-type DeviceInfo struct {
- UserAgent string `json:"userAgent"`
- IPAddress string `json:"ipAddress"`
- DeviceID string `json:"deviceId"`
- SessionToken string `json:"token"` // Will be obfuscated
-}
-
-type AuthCredentials struct {
- APIKey string `json:"apikey"` // Will be obfuscated
- RefreshToken string `json:"refresh_token"` // Will be obfuscated
- ClientSecret string `json:"secret"` // Will be obfuscated
-}
-```
-
-### Example 1: Default Obfuscation
-
-```go
-loginRequest := UserLoginRequest{
- Username: "john.doe",
- Password: "super_secret_password_123",
- Email: "john.doe@example.com",
- RememberMe: true,
- DeviceInfo: DeviceInfo{
- UserAgent: "Mozilla/5.0...",
- IPAddress: "192.168.1.100",
- DeviceID: "device_12345",
- SessionToken: "session_token_abc123xyz",
- },
- Credentials: AuthCredentials{
- APIKey: "api_key_secret_789",
- RefreshToken: "refresh_token_xyz456",
- ClientSecret: "client_secret_ultra_secure",
- },
- Metadata: map[string]any{
- "theme": "dark",
- "language": "en-US",
- "private_key": "private_key_should_be_hidden",
- "public_info": "this is safe to show",
- },
-}
-
-// Apply default obfuscation
-defaultObfuscator := opentelemetry.NewDefaultObfuscator()
-err := opentelemetry.SetSpanAttributesFromStructWithObfuscation(
- &span, "login_request", loginRequest, defaultObfuscator)
-
-// Result: password, token, secret, apikey, private_key fields become "***"
-```
-
-### Example 2: Custom Field Selection
-
-```go
-// Only obfuscate specific fields
-customFields := []string{"username", "email", "deviceId", "ipAddress"}
-customObfuscator := opentelemetry.NewCustomObfuscator(customFields)
-
-err := opentelemetry.SetSpanAttributesFromStructWithObfuscation(
- &span, "login_request", loginRequest, customObfuscator)
-
-// Result: Only username, email, deviceId, ipAddress become "***"
-```
-
-### Example 3: Custom Business Logic Obfuscator
+## Create telemetry instance
```go
-// Implement custom obfuscation logic
-type BusinessLogicObfuscator struct {
- companyPolicy map[string]bool
-}
-
-func (b *BusinessLogicObfuscator) ShouldObfuscate(fieldName string) bool {
- return b.companyPolicy[strings.ToLower(fieldName)]
-}
-
-func (b *BusinessLogicObfuscator) GetObfuscatedValue() string {
- return "[COMPANY_POLICY_REDACTED]"
+cfg := opentelemetry.TelemetryConfig{
+ LibraryName: "payments",
+ ServiceName: "payments-api",
+ ServiceVersion: "2.0.0",
+ DeploymentEnv: "prod",
+ CollectorExporterEndpoint: "otel-collector:4317",
+ EnableTelemetry: true,
+ InsecureExporter: false,
+ Logger: log.NewNop(),
}
-// Use custom obfuscator
-businessObfuscator := &BusinessLogicObfuscator{
- companyPolicy: map[string]bool{
- "email": true,
- "ipaddress": true,
- "deviceinfo": true, // Obfuscates entire nested object
- },
-}
-
-err := opentelemetry.SetSpanAttributesFromStructWithObfuscation(
- &span, "login_request", loginRequest, businessObfuscator)
-```
-
-### Example 4: Standalone Obfuscation Utility
-
-```go
-// Use obfuscation without OpenTelemetry spans
-obfuscator := opentelemetry.NewDefaultObfuscator()
-obfuscatedData, err := opentelemetry.ObfuscateStruct(loginRequest, obfuscator)
+tl, err := opentelemetry.NewTelemetry(cfg)
if err != nil {
- log.Printf("Obfuscation failed: %v", err)
+ return err
}
-
-// obfuscatedData now contains the struct with sensitive fields replaced
+defer tl.ShutdownTelemetry()
```
-## Default Sensitive Fields
-
-The `NewDefaultObfuscator()` uses a shared list of sensitive field names from the `commons/security` package, ensuring consistent obfuscation behavior across HTTP logging, OpenTelemetry spans, and other components.
-
-The following common sensitive field names are automatically obfuscated (case-insensitive):
-
-### Authentication & Security
-- `password`
-- `token`
-- `secret`
-- `key`
-- `authorization`
-- `auth`
-- `credential`
-- `credentials`
-- `apikey`
-- `api_key`
-- `access_token`
-- `refresh_token`
-- `private_key`
-- `privatekey`
-
-> **Note**: This list is shared with HTTP logging middleware and other components via `security.DefaultSensitiveFields` to ensure consistent behavior across the entire commons library.
-
-## API Reference
-
-### Core Functions
-
-#### `SetSpanAttributesFromStruct(span *trace.Span, key string, valueStruct any) error`
-Original function for backward compatibility. Adds struct data to span without obfuscation.
-
-#### `SetSpanAttributesFromStructWithObfuscation(span *trace.Span, key string, valueStruct any, obfuscator FieldObfuscator) error`
-Enhanced function that applies obfuscation before adding struct data to span. If `obfuscator` is `nil`, behaves like the original function.
+If you still want global providers, call it explicitly:
-#### `ObfuscateStruct(valueStruct any, obfuscator FieldObfuscator) (any, error)`
-Standalone utility function that obfuscates a struct and returns the result. Can be used independently of OpenTelemetry spans.
-
-### Obfuscator Constructors
-
-#### `NewDefaultObfuscator() *DefaultObfuscator`
-Creates an obfuscator with predefined common sensitive field names.
-
-#### `NewCustomObfuscator(sensitiveFields []string) *CustomObfuscator`
-Creates an obfuscator with custom sensitive field names. Field matching is case-insensitive and uses exact matching (not word-boundary matching like `DefaultObfuscator`).
-
-### Interface
-
-#### `FieldObfuscator`
```go
-type FieldObfuscator interface {
- // ShouldObfuscate returns true if the given field name should be obfuscated
- ShouldObfuscate(fieldName string) bool
- // GetObfuscatedValue returns the value to use for obfuscated fields
- GetObfuscatedValue() string
-}
+tl.ApplyGlobals()
```
-### Constants
-
-#### `ObfuscatedValue = "***"`
-Default value used to replace sensitive fields. Can be referenced for consistency.
-
-## Advanced Features
+## Span attributes from objects
-### Recursive Obfuscation
-The obfuscation system works recursively on:
-- **Nested structs**: Processes all nested object fields
-- **Arrays and slices**: Processes each element in collections
-- **Maps**: Processes all key-value pairs
-
-### Case-Insensitive Matching
-Field name matching is case-insensitive for flexibility:
```go
-// All these variations will be obfuscated if "password" is in the sensitive list
-"password", "Password", "PASSWORD", "PaSsWoRd"
+err := opentelemetry.SetSpanAttributesFromValue(span, "request", payload, opentelemetry.NewDefaultRedactor())
```
-### Performance Considerations
-- **Efficient processing**: Uses pre-allocated maps for field lookups
-- **Memory conscious**: Minimal allocations during recursive processing
-- **JSON conversion**: Leverages Go's efficient JSON marshaling/unmarshaling
-
-## Best Practices
+This flattens nested values into typed attributes (`request.user.id`, `request.amount`, etc.).
-### Security
-- **Always use obfuscation** for production telemetry data
-- **Review sensitive field lists** regularly to ensure comprehensive coverage
-- **Implement custom obfuscators** for business-specific sensitive data
-- **Test obfuscation rules** to verify sensitive data is properly hidden
-
-### Performance
-- **Reuse obfuscator instances** instead of creating new ones for each call
-- **Use appropriate obfuscation level** - don't over-obfuscate if not needed
-- **Consider caching** obfuscated results for frequently used structs
-
-### Maintainability
-- **Use `NewDefaultObfuscator()`** for most common use cases
-- **Document custom obfuscation rules** in your business logic
-- **Centralize obfuscation policies** for consistency across services
-- **Test obfuscation behavior** in your unit tests
-
-### Migration
-- **Backward compatibility**: Existing code using `SetSpanAttributesFromStruct()` continues to work
-- **Gradual adoption**: Add obfuscation incrementally to existing telemetry code
-- **Monitoring**: Verify obfuscated telemetry data meets security requirements
-
-## Error Handling
-
-The obfuscation functions return errors in these cases:
-- **Invalid JSON**: When the input struct cannot be marshaled to JSON
-- **Malformed data**: When JSON unmarshaling fails during processing
+## Redaction
```go
-err := opentelemetry.SetSpanAttributesFromStructWithObfuscation(
- &span, "data", invalidStruct, obfuscator)
-if err != nil {
- log.Printf("Obfuscation failed: %v", err)
- // Handle error appropriately
-}
+redactor, err := opentelemetry.NewRedactor([]opentelemetry.RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: opentelemetry.RedactionMask},
+ {FieldPattern: `(?i)^document$`, Action: opentelemetry.RedactionHash},
+ {PathPattern: `(?i)^session\.token$`, FieldPattern: `(?i)^token$`, Action: opentelemetry.RedactionDrop},
+}, "***")
```
-## Testing
+Available actions:
-The package includes comprehensive tests covering:
-- **Default obfuscator behavior**
-- **Custom obfuscator functionality**
-- **Recursive obfuscation of nested structures**
-- **Error handling for invalid data**
-- **Integration with OpenTelemetry spans**
-- **Custom obfuscator interface implementations**
-
-Run tests with:
-```bash
-go test ./commons/opentelemetry -v
-```
+- `RedactionMask`
+- `RedactionHash`
+- `RedactionDrop`
-## Examples
+## Propagation
-For complete working examples, see:
-- `obfuscation_test.go` - Comprehensive test cases
-- `examples/opentelemetry_obfuscation_example.go` - Runnable example application
+Use carrier-first APIs:
-## Contributing
+- `InjectTraceContext(ctx, carrier)`
+- `ExtractTraceContext(ctx, carrier)`
-When adding new features:
-1. **Follow the interface pattern** for extensibility
-2. **Add comprehensive tests** for new functionality
-3. **Update documentation** with examples
-4. **Maintain backward compatibility** with existing APIs
-5. **Follow Go best practices** and the project's coding standards
+Transport adapters remain available for HTTP/gRPC/queue integration.
diff --git a/commons/opentelemetry/doc.go b/commons/opentelemetry/doc.go
new file mode 100644
index 00000000..63ebcb6c
--- /dev/null
+++ b/commons/opentelemetry/doc.go
@@ -0,0 +1,8 @@
+// Package opentelemetry provides tracing, metrics, propagation, and redaction helpers.
+//
+// NewTelemetry builds providers/exporters and can run in disabled mode for local/dev
+// environments while preserving API compatibility.
+//
+// The package also includes carrier utilities for HTTP, gRPC, and queue headers, plus
+// redaction-aware attribute extraction for safe span enrichment.
+package opentelemetry
diff --git a/commons/opentelemetry/extract_queue_test.go b/commons/opentelemetry/extract_queue_test.go
deleted file mode 100644
index 13f9dd1e..00000000
--- a/commons/opentelemetry/extract_queue_test.go
+++ /dev/null
@@ -1,147 +0,0 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
-package opentelemetry
-
-import (
- "context"
- "testing"
-
- "go.opentelemetry.io/otel"
- "go.opentelemetry.io/otel/propagation"
- "go.opentelemetry.io/otel/sdk/trace"
-)
-
-func TestExtractTraceContextFromQueueHeaders(t *testing.T) {
- // Setup OpenTelemetry with proper propagator and real tracer
- tp := trace.NewTracerProvider()
- otel.SetTracerProvider(tp)
- otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
- propagation.TraceContext{},
- propagation.Baggage{},
- ))
- tracer := tp.Tracer("extract-queue-test")
-
- // Create a root span and inject headers (simulating producer)
- rootCtx, rootSpan := tracer.Start(context.Background(), "producer-span")
- defer rootSpan.End()
-
- // Inject trace headers (what producer would do)
- traceHeaders := InjectQueueTraceContext(rootCtx)
-
- // Convert to amqp.Table format (simulating RabbitMQ headers)
- amqpHeaders := make(map[string]any)
- for k, v := range traceHeaders {
- amqpHeaders[k] = v
- }
- // Add some non-trace headers
- amqpHeaders["X-Request-Id"] = "test-123"
- amqpHeaders["Content-Type"] = "application/json"
-
- // Test extraction (what consumer would do)
- baseCtx := context.Background()
- extractedCtx := ExtractTraceContextFromQueueHeaders(baseCtx, amqpHeaders)
-
- // Verify trace context was extracted correctly
- originalTraceID := GetTraceIDFromContext(rootCtx)
- extractedTraceID := GetTraceIDFromContext(extractedCtx)
-
- if originalTraceID == "" {
- t.Error("Expected original trace ID to be non-empty")
- }
-
- if extractedTraceID == "" {
- t.Error("Expected extracted trace ID to be non-empty")
- }
-
- if originalTraceID != extractedTraceID {
- t.Errorf("Trace ID mismatch: original=%s, extracted=%s", originalTraceID, extractedTraceID)
- }
-
- t.Logf("✅ Trace ID successfully propagated: %s", extractedTraceID)
-}
-
-func TestExtractTraceContextFromQueueHeadersWithEmptyHeaders(t *testing.T) {
- baseCtx := context.Background()
-
- // Test with nil headers
- extractedCtx := ExtractTraceContextFromQueueHeaders(baseCtx, nil)
- if extractedCtx != baseCtx {
- t.Error("Expected same context when headers are nil")
- }
-
- // Test with empty headers
- extractedCtx = ExtractTraceContextFromQueueHeaders(baseCtx, map[string]any{})
- if extractedCtx != baseCtx {
- t.Error("Expected same context when headers are empty")
- }
-}
-
-func TestExtractTraceContextFromQueueHeadersWithNonStringValues(t *testing.T) {
- baseCtx := context.Background()
-
- // Test with headers containing non-string values
- amqpHeaders := map[string]any{
- "X-Request-Id": "test-123",
- "Retry-Count": 42, // int
- "Timestamp": 1234567890.5, // float
- "Enabled": true, // bool
- }
-
- extractedCtx := ExtractTraceContextFromQueueHeaders(baseCtx, amqpHeaders)
-
- // Should return base context since no valid trace headers
- if extractedCtx != baseCtx {
- t.Error("Expected same context when no valid trace headers present")
- }
-
- // Verify no trace ID extracted
- traceID := GetTraceIDFromContext(extractedCtx)
- if traceID != "" {
- t.Errorf("Expected empty trace ID, got: %s", traceID)
- }
-}
-
-func TestExtractTraceContextFromQueueHeadersWithMixedTypes(t *testing.T) {
- // Setup OpenTelemetry
- tp := trace.NewTracerProvider()
- otel.SetTracerProvider(tp)
- otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
- propagation.TraceContext{},
- propagation.Baggage{},
- ))
- tracer := tp.Tracer("mixed-types-test")
-
- // Create span and get trace headers
- rootCtx, rootSpan := tracer.Start(context.Background(), "test-span")
- defer rootSpan.End()
-
- traceHeaders := InjectQueueTraceContext(rootCtx)
-
- // Create mixed-type headers (simulating real RabbitMQ scenario)
- amqpHeaders := map[string]any{
- "X-Request-Id": "test-123",
- "Retry-Count": 42,
- "Enabled": true,
- }
-
- // Add trace headers as strings
- for k, v := range traceHeaders {
- amqpHeaders[k] = v
- }
-
- // Test extraction
- baseCtx := context.Background()
- extractedCtx := ExtractTraceContextFromQueueHeaders(baseCtx, amqpHeaders)
-
- // Verify trace context was extracted despite mixed types
- originalTraceID := GetTraceIDFromContext(rootCtx)
- extractedTraceID := GetTraceIDFromContext(extractedCtx)
-
- if originalTraceID != extractedTraceID {
- t.Errorf("Trace ID mismatch with mixed types: original=%s, extracted=%s", originalTraceID, extractedTraceID)
- }
-
- t.Logf("✅ Trace extraction works with mixed header types: %s", extractedTraceID)
-}
diff --git a/commons/opentelemetry/inject_trace_test.go b/commons/opentelemetry/inject_trace_test.go
deleted file mode 100644
index 3d99162c..00000000
--- a/commons/opentelemetry/inject_trace_test.go
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
-package opentelemetry
-
-import (
- "context"
- "testing"
-
- "go.opentelemetry.io/otel"
- "go.opentelemetry.io/otel/propagation"
- "go.opentelemetry.io/otel/sdk/trace"
-)
-
-func TestInjectTraceHeadersIntoQueue(t *testing.T) {
- // Setup OpenTelemetry with proper propagator and real tracer
- tp := trace.NewTracerProvider()
- otel.SetTracerProvider(tp)
- otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
- propagation.TraceContext{},
- propagation.Baggage{},
- ))
- tracer := tp.Tracer("inject-trace-test")
-
- // Create a root span
- rootCtx, rootSpan := tracer.Start(context.Background(), "test-span")
- defer rootSpan.End()
-
- // Create initial headers map
- headers := map[string]any{
- "X-Request-Id": "test-request-123",
- "Content-Type": "application/json",
- }
-
- // Test injection into existing headers
- InjectTraceHeadersIntoQueue(rootCtx, &headers)
-
- // Verify original headers are preserved
- if headers["X-Request-Id"] != "test-request-123" {
- t.Error("Original headers should be preserved")
- }
-
- if headers["Content-Type"] != "application/json" {
- t.Error("Original headers should be preserved")
- }
-
- // Verify trace headers were added
- if _, exists := headers["Traceparent"]; !exists {
- t.Errorf("Expected 'Traceparent' header to be added. Got headers: %v", headers)
- }
-
- // Verify we have more headers than we started with
- if len(headers) <= 2 {
- t.Errorf("Expected headers to be added. Original: 2, Final: %d", len(headers))
- }
-
- t.Logf("Final headers: %+v", headers)
-}
-
-func TestInjectTraceHeadersIntoQueueWithNilPointer(t *testing.T) {
- // Test with nil pointer - should not panic
- InjectTraceHeadersIntoQueue(context.Background(), nil)
- // If we reach here, the function handled nil gracefully
-}
-
-func TestInjectTraceHeadersIntoQueueWithEmptyHeaders(t *testing.T) {
- // Setup OpenTelemetry
- tp := trace.NewTracerProvider()
- otel.SetTracerProvider(tp)
- otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
- propagation.TraceContext{},
- propagation.Baggage{},
- ))
- tracer := tp.Tracer("inject-trace-test")
-
- // Create a root span
- rootCtx, rootSpan := tracer.Start(context.Background(), "test-span")
- defer rootSpan.End()
-
- // Start with empty headers
- headers := map[string]any{}
-
- // Test injection
- InjectTraceHeadersIntoQueue(rootCtx, &headers)
-
- // Verify trace headers were added
- if len(headers) == 0 {
- t.Error("Expected trace headers to be added to empty map")
- }
-
- if _, exists := headers["Traceparent"]; !exists {
- t.Errorf("Expected 'Traceparent' header to be added. Got headers: %v", headers)
- }
-
- t.Logf("Headers added to empty map: %+v", headers)
-}
diff --git a/commons/opentelemetry/metrics/METRICS_USAGE.md b/commons/opentelemetry/metrics/METRICS_USAGE.md
index 053e5969..3647d6f8 100644
--- a/commons/opentelemetry/metrics/METRICS_USAGE.md
+++ b/commons/opentelemetry/metrics/METRICS_USAGE.md
@@ -40,7 +40,7 @@ Distribution of values with configurable buckets (e.g., response times, transact
```go
import (
"context"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
)
func basicMetricsExample(telemetry *opentelemetry.Telemetry, ctx context.Context) {
@@ -92,7 +92,7 @@ The metrics package provides pre-configured convenience methods for common busin
```go
import (
"context"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry/metrics"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
"go.opentelemetry.io/otel/attribute"
)
diff --git a/commons/opentelemetry/metrics/account.go b/commons/opentelemetry/metrics/account.go
index 24ecb6b4..3a2bbf3a 100644
--- a/commons/opentelemetry/metrics/account.go
+++ b/commons/opentelemetry/metrics/account.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package metrics
import (
@@ -10,9 +6,16 @@ import (
"go.opentelemetry.io/otel/attribute"
)
-func (f *MetricsFactory) RecordAccountCreated(ctx context.Context, organizationID, ledgerID string, attributes ...attribute.KeyValue) {
- f.Counter(MetricAccountsCreated).
- WithLabels(f.WithLedgerLabels(organizationID, ledgerID)).
- WithAttributes(attributes...).
- AddOne(ctx)
+// RecordAccountCreated increments the account-created counter.
+func (f *MetricsFactory) RecordAccountCreated(ctx context.Context, attributes ...attribute.KeyValue) error {
+ if f == nil {
+ return ErrNilFactory
+ }
+
+ b, err := f.Counter(MetricAccountsCreated)
+ if err != nil {
+ return err
+ }
+
+ return b.WithAttributes(attributes...).AddOne(ctx)
}
diff --git a/commons/opentelemetry/metrics/builders.go b/commons/opentelemetry/metrics/builders.go
index 06ba9e6c..e8650001 100644
--- a/commons/opentelemetry/metrics/builders.go
+++ b/commons/opentelemetry/metrics/builders.go
@@ -1,16 +1,28 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package metrics
import (
"context"
+ "errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
+var (
+ // ErrNilCounter is returned when a counter builder has no instrument.
+ ErrNilCounter = errors.New("counter instrument is nil")
+ // ErrNilGauge is returned when a gauge builder has no instrument.
+ ErrNilGauge = errors.New("gauge instrument is nil")
+ // ErrNilHistogram is returned when a histogram builder has no instrument.
+ ErrNilHistogram = errors.New("histogram instrument is nil")
+ // ErrNilCounterBuilder is returned when a CounterBuilder method is called on a nil receiver.
+ ErrNilCounterBuilder = errors.New("counter builder is nil")
+ // ErrNilGaugeBuilder is returned when a GaugeBuilder method is called on a nil receiver.
+ ErrNilGaugeBuilder = errors.New("gauge builder is nil")
+ // ErrNilHistogramBuilder is returned when a HistogramBuilder method is called on a nil receiver.
+ ErrNilHistogramBuilder = errors.New("histogram builder is nil")
+)
+
// CounterBuilder provides a fluent API for recording counter metrics with optional labels
type CounterBuilder struct {
factory *MetricsFactory
@@ -19,8 +31,13 @@ type CounterBuilder struct {
attrs []attribute.KeyValue
}
-// WithLabels adds labels/attributes to the counter metric
+// WithLabels adds labels/attributes to the counter metric.
+// Returns a nil-safe builder if the receiver is nil.
func (c *CounterBuilder) WithLabels(labels map[string]string) *CounterBuilder {
+ if c == nil {
+ return nil
+ }
+
builder := &CounterBuilder{
factory: c.factory,
counter: c.counter,
@@ -37,8 +54,13 @@ func (c *CounterBuilder) WithLabels(labels map[string]string) *CounterBuilder {
return builder
}
-// WithAttributes adds OpenTelemetry attributes to the counter metric
+// WithAttributes adds OpenTelemetry attributes to the counter metric.
+// Returns a nil-safe builder if the receiver is nil.
func (c *CounterBuilder) WithAttributes(attrs ...attribute.KeyValue) *CounterBuilder {
+ if c == nil {
+ return nil
+ }
+
builder := &CounterBuilder{
factory: c.factory,
counter: c.counter,
@@ -53,22 +75,33 @@ func (c *CounterBuilder) WithAttributes(attrs ...attribute.KeyValue) *CounterBui
return builder
}
-// Add records a counter increment
-func (c *CounterBuilder) Add(ctx context.Context, value int64) {
+// Add records a counter increment.
+// Returns an error if the value is negative (counters are monotonically increasing).
+func (c *CounterBuilder) Add(ctx context.Context, value int64) error {
+ if c == nil {
+ return ErrNilCounterBuilder
+ }
+
if c.counter == nil {
- return
+ return ErrNilCounter
+ }
+
+ if value < 0 {
+ return ErrNegativeCounterValue
}
- // Use only the builder attributes (no trace correlation to avoid high cardinality)
c.counter.Add(ctx, value, metric.WithAttributes(c.attrs...))
+
+ return nil
}
-func (c *CounterBuilder) AddOne(ctx context.Context) {
- if c.counter == nil {
- return
+// AddOne increments the counter by one.
+func (c *CounterBuilder) AddOne(ctx context.Context) error {
+ if c == nil {
+ return ErrNilCounterBuilder
}
- c.Add(ctx, 1)
+ return c.Add(ctx, 1)
}
// GaugeBuilder provides a fluent API for recording gauge metrics with optional labels
@@ -79,8 +112,13 @@ type GaugeBuilder struct {
attrs []attribute.KeyValue
}
-// WithLabels adds labels/attributes to the gauge metric
+// WithLabels adds labels/attributes to the gauge metric.
+// Returns a nil-safe builder if the receiver is nil.
func (g *GaugeBuilder) WithLabels(labels map[string]string) *GaugeBuilder {
+ if g == nil {
+ return nil
+ }
+
builder := &GaugeBuilder{
factory: g.factory,
gauge: g.gauge,
@@ -97,8 +135,13 @@ func (g *GaugeBuilder) WithLabels(labels map[string]string) *GaugeBuilder {
return builder
}
-// WithAttributes adds OpenTelemetry attributes to the gauge metric
+// WithAttributes adds OpenTelemetry attributes to the gauge metric.
+// Returns a nil-safe builder if the receiver is nil.
func (g *GaugeBuilder) WithAttributes(attrs ...attribute.KeyValue) *GaugeBuilder {
+ if g == nil {
+ return nil
+ }
+
builder := &GaugeBuilder{
factory: g.factory,
gauge: g.gauge,
@@ -113,27 +156,23 @@ func (g *GaugeBuilder) WithAttributes(attrs ...attribute.KeyValue) *GaugeBuilder
return builder
}
-// Record sets the gauge to the provided value.
-//
-// Deprecated: use Set for application code. This method is kept for
-// parity with OpenTelemetry's instrument API (metric.Int64Gauge.Record)
-// to ease portability from raw OTEL usage. It delegates to Set.
-func (g *GaugeBuilder) Record(ctx context.Context, value int64) {
- g.Set(ctx, value)
-}
-
// Set sets the current value of a gauge (recommended for application code).
//
// This is the primary implementation for recording gauge values and is
// idiomatic for instantaneous state (e.g., queue length, in-flight operations).
// It uses only the builder attributes to avoid high-cardinality labels.
-func (g *GaugeBuilder) Set(ctx context.Context, value int64) {
+func (g *GaugeBuilder) Set(ctx context.Context, value int64) error {
+ if g == nil {
+ return ErrNilGaugeBuilder
+ }
+
if g.gauge == nil {
- return
+ return ErrNilGauge
}
- // Use only the builder attributes (no trace correlation to avoid high cardinality)
g.gauge.Record(ctx, value, metric.WithAttributes(g.attrs...))
+
+ return nil
}
// HistogramBuilder provides a fluent API for recording histogram metrics with optional labels
@@ -144,8 +183,13 @@ type HistogramBuilder struct {
attrs []attribute.KeyValue
}
-// WithLabels adds labels/attributes to the histogram metric
+// WithLabels adds labels/attributes to the histogram metric.
+// Returns a nil-safe builder if the receiver is nil.
func (h *HistogramBuilder) WithLabels(labels map[string]string) *HistogramBuilder {
+ if h == nil {
+ return nil
+ }
+
builder := &HistogramBuilder{
factory: h.factory,
histogram: h.histogram,
@@ -162,8 +206,13 @@ func (h *HistogramBuilder) WithLabels(labels map[string]string) *HistogramBuilde
return builder
}
-// WithAttributes adds OpenTelemetry attributes to the histogram metric
+// WithAttributes adds OpenTelemetry attributes to the histogram metric.
+// Returns a nil-safe builder if the receiver is nil.
func (h *HistogramBuilder) WithAttributes(attrs ...attribute.KeyValue) *HistogramBuilder {
+ if h == nil {
+ return nil
+ }
+
builder := &HistogramBuilder{
factory: h.factory,
histogram: h.histogram,
@@ -179,11 +228,16 @@ func (h *HistogramBuilder) WithAttributes(attrs ...attribute.KeyValue) *Histogra
}
// Record records a histogram value
-func (h *HistogramBuilder) Record(ctx context.Context, value int64) {
+func (h *HistogramBuilder) Record(ctx context.Context, value int64) error {
+ if h == nil {
+ return ErrNilHistogramBuilder
+ }
+
if h.histogram == nil {
- return
+ return ErrNilHistogram
}
- // Use only the builder attributes (no trace correlation to avoid high cardinality)
h.histogram.Record(ctx, value, metric.WithAttributes(h.attrs...))
+
+ return nil
}
diff --git a/commons/opentelemetry/metrics/doc.go b/commons/opentelemetry/metrics/doc.go
new file mode 100644
index 00000000..d0fb1e39
--- /dev/null
+++ b/commons/opentelemetry/metrics/doc.go
@@ -0,0 +1,8 @@
+// Package metrics provides a fluent factory for OpenTelemetry metric instruments.
+//
+// MetricsFactory caches instruments and exposes builder-style APIs for counters,
+// gauges, and histograms with low-overhead attribute composition.
+//
+// Convenience methods (for example RecordTransactionProcessed) are provided for
+// common domain metrics used across Lerian services.
+package metrics
diff --git a/commons/opentelemetry/metrics/labels.go b/commons/opentelemetry/metrics/labels.go
deleted file mode 100644
index 63f5fc01..00000000
--- a/commons/opentelemetry/metrics/labels.go
+++ /dev/null
@@ -1,20 +0,0 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
-package metrics
-
-// WithOrganizationLabels generates a map of labels with the organization ID
-func (f *MetricsFactory) WithOrganizationLabels(organizationID string) map[string]string {
- return map[string]string{
- "organization_id": organizationID,
- }
-}
-
-// WithLedgerLabels generates a map of labels with the organization ID and ledger ID
-func (f *MetricsFactory) WithLedgerLabels(organizationID, ledgerID string) map[string]string {
- labels := f.WithOrganizationLabels(organizationID)
- labels["ledger_id"] = ledgerID
-
- return labels
-}
diff --git a/commons/opentelemetry/metrics/metrics.go b/commons/opentelemetry/metrics/metrics.go
index 7664d60d..41cdbc85 100644
--- a/commons/opentelemetry/metrics/metrics.go
+++ b/commons/opentelemetry/metrics/metrics.go
@@ -1,18 +1,17 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package metrics
import (
+ "context"
+ "errors"
"fmt"
"sort"
"strconv"
"strings"
"sync"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
"go.opentelemetry.io/otel/metric"
+ "go.opentelemetry.io/otel/metric/noop"
)
// MetricsFactory provides a thread-safe factory for creating and managing OpenTelemetry metrics
@@ -25,6 +24,17 @@ type MetricsFactory struct {
logger log.Logger
}
+var (
+ // ErrNilMeter indicates that a nil OTEL meter was provided.
+ ErrNilMeter = errors.New("metric meter cannot be nil")
+ // ErrNilFactory is returned when a MetricsFactory method is called on a nil receiver.
+ ErrNilFactory = errors.New("metrics factory is nil")
+ // ErrNegativeCounterValue is returned when a negative value is passed to Counter.Add.
+ ErrNegativeCounterValue = errors.New("counter value must not be negative")
+ // ErrPercentageOutOfRange is returned when a percentage value is outside [0, 100].
+ ErrPercentageOutOfRange = errors.New("percentage value must be between 0 and 100")
+)
+
// Metric represents a metric that can be collected by the server.
type Metric struct {
Name string
@@ -78,50 +88,84 @@ var (
DefaultTransactionBuckets = []float64{1, 10, 50, 100, 500, 1000, 2500, 5000, 8000, 10000}
)
-// NewMetricsFactory creates a new MetricsFactory instance
-func NewMetricsFactory(meter metric.Meter, logger log.Logger) *MetricsFactory {
+// NewMetricsFactory creates a new MetricsFactory instance.
+func NewMetricsFactory(meter metric.Meter, logger log.Logger) (*MetricsFactory, error) {
+ if meter == nil {
+ return nil, ErrNilMeter
+ }
+
return &MetricsFactory{
meter: meter,
logger: logger,
+ }, nil
+}
+
+// NewNopFactory returns a MetricsFactory backed by OpenTelemetry's no-op meter.
+// It is safe for use as a fallback when a real meter is unavailable.
+func NewNopFactory() *MetricsFactory {
+ return &MetricsFactory{
+ meter: noop.NewMeterProvider().Meter("nop"),
+ logger: log.NewNop(),
}
}
// Counter creates or retrieves a counter metric and returns a builder for fluent API usage
-func (f *MetricsFactory) Counter(m Metric) *CounterBuilder {
- counter := f.getOrCreateCounter(m)
+func (f *MetricsFactory) Counter(m Metric) (*CounterBuilder, error) {
+ if f == nil {
+ return nil, ErrNilFactory
+ }
+
+ counter, err := f.getOrCreateCounter(m)
+ if err != nil {
+ return nil, err
+ }
return &CounterBuilder{
factory: f,
counter: counter,
name: m.Name,
- }
+ }, nil
}
// Gauge creates or retrieves a gauge metric and returns a builder for fluent API usage
-func (f *MetricsFactory) Gauge(m Metric) *GaugeBuilder {
- gauge := f.getOrCreateGauge(m)
+func (f *MetricsFactory) Gauge(m Metric) (*GaugeBuilder, error) {
+ if f == nil {
+ return nil, ErrNilFactory
+ }
+
+ gauge, err := f.getOrCreateGauge(m)
+ if err != nil {
+ return nil, err
+ }
return &GaugeBuilder{
factory: f,
gauge: gauge,
name: m.Name,
- }
+ }, nil
}
// Histogram creates or retrieves a histogram metric and returns a builder for fluent API usage
-func (f *MetricsFactory) Histogram(m Metric) *HistogramBuilder {
+func (f *MetricsFactory) Histogram(m Metric) (*HistogramBuilder, error) {
+ if f == nil {
+ return nil, ErrNilFactory
+ }
+
// Set default buckets if not provided
if m.Buckets == nil {
m.Buckets = selectDefaultBuckets(m.Name)
}
- histogram := f.getOrCreateHistogram(m)
+ histogram, err := f.getOrCreateHistogram(m)
+ if err != nil {
+ return nil, err
+ }
return &HistogramBuilder{
factory: f,
histogram: histogram,
name: m.Name,
- }
+ }, nil
}
// selectDefaultBuckets chooses default buckets based on metric name.
@@ -129,17 +173,18 @@ func (f *MetricsFactory) Histogram(m Metric) *HistogramBuilder {
func selectDefaultBuckets(name string) []float64 {
nameL := strings.ToLower(name)
- // Check substrings in deterministic priority order
- // Domain-specific patterns first, general time patterns last
+ // Check substrings in deterministic priority order.
+ // Latency/duration/time patterns first to avoid "transaction_latency"
+ // matching "transaction" instead of "latency".
patterns := []struct {
substr string
buckets []float64
}{
- {"account", DefaultAccountBuckets},
- {"transaction", DefaultTransactionBuckets},
{"latency", DefaultLatencyBuckets},
{"duration", DefaultLatencyBuckets},
{"time", DefaultLatencyBuckets},
+ {"account", DefaultAccountBuckets},
+ {"transaction", DefaultTransactionBuckets},
}
for _, p := range patterns {
@@ -152,9 +197,17 @@ func selectDefaultBuckets(name string) []float64 {
}
// getOrCreateCounter lazily creates or retrieves an existing counter
-func (f *MetricsFactory) getOrCreateCounter(m Metric) metric.Int64Counter {
+func (f *MetricsFactory) getOrCreateCounter(m Metric) (metric.Int64Counter, error) {
+ if f == nil {
+ return nil, ErrNilFactory
+ }
+
if counter, exists := f.counters.Load(m.Name); exists {
- return counter.(metric.Int64Counter)
+ if c, ok := counter.(metric.Int64Counter); ok {
+ return c, nil
+ }
+
+ return nil, fmt.Errorf("counter cache contains invalid type for %q", m.Name)
}
// Create new counter with proper options
@@ -163,25 +216,37 @@ func (f *MetricsFactory) getOrCreateCounter(m Metric) metric.Int64Counter {
counter, err := f.meter.Int64Counter(m.Name, counterOpts...)
if err != nil {
if f.logger != nil {
- f.logger.Errorf("Failed to create counter metric '%s': %v", m.Name, err)
+ f.logger.Log(context.Background(), log.LevelError, "failed to create counter metric", log.String("metric_name", m.Name), log.Err(err))
}
- // Return nil - builders will handle nil gracefully
- return nil
+
+ return nil, fmt.Errorf("create counter %q: %w", m.Name, err)
}
// Store in sync.Map for future use
if actual, loaded := f.counters.LoadOrStore(m.Name, counter); loaded {
// Another goroutine created it first, use that one
- return actual.(metric.Int64Counter)
+ if c, ok := actual.(metric.Int64Counter); ok {
+ return c, nil
+ }
+
+ return nil, fmt.Errorf("counter cache contains invalid type for %q", m.Name)
}
- return counter
+ return counter, nil
}
// getOrCreateGauge lazily creates or retrieves an existing gauge
-func (f *MetricsFactory) getOrCreateGauge(m Metric) metric.Int64Gauge {
+func (f *MetricsFactory) getOrCreateGauge(m Metric) (metric.Int64Gauge, error) {
+ if f == nil {
+ return nil, ErrNilFactory
+ }
+
if gauge, exists := f.gauges.Load(m.Name); exists {
- return gauge.(metric.Int64Gauge)
+ if g, ok := gauge.(metric.Int64Gauge); ok {
+ return g, nil
+ }
+
+ return nil, fmt.Errorf("gauge cache contains invalid type for %q", m.Name)
}
// Create new gauge with proper options
@@ -190,29 +255,50 @@ func (f *MetricsFactory) getOrCreateGauge(m Metric) metric.Int64Gauge {
gauge, err := f.meter.Int64Gauge(m.Name, gaugeOpts...)
if err != nil {
if f.logger != nil {
- f.logger.Errorf("Failed to create gauge metric '%s': %v", m.Name, err)
+ f.logger.Log(context.Background(), log.LevelError, "failed to create gauge metric", log.String("metric_name", m.Name), log.Err(err))
}
- // Return nil - builders will handle nil gracefully
- return nil
+
+ return nil, fmt.Errorf("create gauge %q: %w", m.Name, err)
}
// Store in sync.Map for future use
if actual, loaded := f.gauges.LoadOrStore(m.Name, gauge); loaded {
// Another goroutine created it first, use that one
- return actual.(metric.Int64Gauge)
+ if g, ok := actual.(metric.Int64Gauge); ok {
+ return g, nil
+ }
+
+ return nil, fmt.Errorf("gauge cache contains invalid type for %q", m.Name)
}
- return gauge
+ return gauge, nil
}
// getOrCreateHistogram lazily creates or retrieves an existing histogram.
// Uses a composite key (name + buckets hash) to ensure different bucket configs
// result in different histograms.
-func (f *MetricsFactory) getOrCreateHistogram(m Metric) metric.Int64Histogram {
+func (f *MetricsFactory) getOrCreateHistogram(m Metric) (metric.Int64Histogram, error) {
+ if f == nil {
+ return nil, ErrNilFactory
+ }
+
+ // Sort buckets before both cache key computation and instrument creation
+ // to ensure the instrument configuration matches the cache key.
+ if len(m.Buckets) > 1 {
+ sorted := make([]float64, len(m.Buckets))
+ copy(sorted, m.Buckets)
+ sort.Float64s(sorted)
+ m.Buckets = sorted
+ }
+
cacheKey := histogramCacheKey(m.Name, m.Buckets)
if histogram, exists := f.histograms.Load(cacheKey); exists {
- return histogram.(metric.Int64Histogram)
+ if h, ok := histogram.(metric.Int64Histogram); ok {
+ return h, nil
+ }
+
+ return nil, fmt.Errorf("histogram cache contains invalid type for %q", cacheKey)
}
// Create new histogram with proper options
@@ -221,19 +307,23 @@ func (f *MetricsFactory) getOrCreateHistogram(m Metric) metric.Int64Histogram {
histogram, err := f.meter.Int64Histogram(m.Name, histogramOpts...)
if err != nil {
if f.logger != nil {
- f.logger.Errorf("Failed to create histogram metric '%s': %v", m.Name, err)
+ f.logger.Log(context.Background(), log.LevelError, "failed to create histogram metric", log.String("metric_name", m.Name), log.Err(err))
}
- // Return nil - builders will handle nil gracefully
- return nil
+
+ return nil, fmt.Errorf("create histogram %q: %w", m.Name, err)
}
// Store in sync.Map for future use
if actual, loaded := f.histograms.LoadOrStore(cacheKey, histogram); loaded {
// Another goroutine created it first, use that one
- return actual.(metric.Int64Histogram)
+ if h, ok := actual.(metric.Int64Histogram); ok {
+ return h, nil
+ }
+
+ return nil, fmt.Errorf("histogram cache contains invalid type for %q", cacheKey)
}
- return histogram
+ return histogram, nil
}
// histogramCacheKey generates a unique cache key based on name and bucket configuration.
@@ -255,7 +345,7 @@ func histogramCacheKey(name string, buckets []float64) string {
}
func (f *MetricsFactory) addCounterOptions(m Metric) []metric.Int64CounterOption {
- opts := []metric.Int64CounterOption{}
+ var opts []metric.Int64CounterOption
if m.Description != "" {
opts = append(opts, metric.WithDescription(m.Description))
}
@@ -268,7 +358,7 @@ func (f *MetricsFactory) addCounterOptions(m Metric) []metric.Int64CounterOption
}
func (f *MetricsFactory) addGaugeOptions(m Metric) []metric.Int64GaugeOption {
- opts := []metric.Int64GaugeOption{}
+ var opts []metric.Int64GaugeOption
if m.Description != "" {
opts = append(opts, metric.WithDescription(m.Description))
}
@@ -281,7 +371,7 @@ func (f *MetricsFactory) addGaugeOptions(m Metric) []metric.Int64GaugeOption {
}
func (f *MetricsFactory) addHistogramOptions(m Metric) []metric.Int64HistogramOption {
- opts := []metric.Int64HistogramOption{}
+ var opts []metric.Int64HistogramOption
if m.Description != "" {
opts = append(opts, metric.WithDescription(m.Description))
}
diff --git a/commons/opentelemetry/metrics/operation_routes.go b/commons/opentelemetry/metrics/operation_routes.go
index 022dee51..e9ea4ce3 100644
--- a/commons/opentelemetry/metrics/operation_routes.go
+++ b/commons/opentelemetry/metrics/operation_routes.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package metrics
import (
@@ -10,9 +6,16 @@ import (
"go.opentelemetry.io/otel/attribute"
)
-func (f *MetricsFactory) RecordOperationRouteCreated(ctx context.Context, organizationID, ledgerID string, attributes ...attribute.KeyValue) {
- f.Counter(MetricOperationRoutesCreated).
- WithLabels(f.WithLedgerLabels(organizationID, ledgerID)).
- WithAttributes(attributes...).
- AddOne(ctx)
+// RecordOperationRouteCreated increments the operation-route-created counter.
+func (f *MetricsFactory) RecordOperationRouteCreated(ctx context.Context, attributes ...attribute.KeyValue) error {
+ if f == nil {
+ return ErrNilFactory
+ }
+
+ b, err := f.Counter(MetricOperationRoutesCreated)
+ if err != nil {
+ return err
+ }
+
+ return b.WithAttributes(attributes...).AddOne(ctx)
}
diff --git a/commons/opentelemetry/metrics/system.go b/commons/opentelemetry/metrics/system.go
new file mode 100644
index 00000000..8b554b72
--- /dev/null
+++ b/commons/opentelemetry/metrics/system.go
@@ -0,0 +1,60 @@
+package metrics
+
+import (
+ "context"
+)
+
+// Pre-configured system metrics for infrastructure monitoring.
+var (
+ // MetricSystemCPUUsage is a gauge that records the current CPU usage percentage.
+ MetricSystemCPUUsage = Metric{
+ Name: "system.cpu.usage",
+ Unit: "percentage",
+ Description: "Current CPU usage percentage of the process host.",
+ }
+
+ // MetricSystemMemUsage is a gauge that records the current memory usage percentage.
+ MetricSystemMemUsage = Metric{
+ Name: "system.mem.usage",
+ Unit: "percentage",
+ Description: "Current memory usage percentage of the process host.",
+ }
+)
+
+// RecordSystemCPUUsage records the current CPU usage percentage via the factory's gauge.
+// The percentage must be in the range [0, 100].
+func (f *MetricsFactory) RecordSystemCPUUsage(ctx context.Context, percentage int64) error {
+ if f == nil {
+ return ErrNilFactory
+ }
+
+ if percentage < 0 || percentage > 100 {
+ return ErrPercentageOutOfRange
+ }
+
+ b, err := f.Gauge(MetricSystemCPUUsage)
+ if err != nil {
+ return err
+ }
+
+ return b.Set(ctx, percentage)
+}
+
+// RecordSystemMemUsage records the current memory usage percentage via the factory's gauge.
+// The percentage must be in the range [0, 100].
+func (f *MetricsFactory) RecordSystemMemUsage(ctx context.Context, percentage int64) error {
+ if f == nil {
+ return ErrNilFactory
+ }
+
+ if percentage < 0 || percentage > 100 {
+ return ErrPercentageOutOfRange
+ }
+
+ b, err := f.Gauge(MetricSystemMemUsage)
+ if err != nil {
+ return err
+ }
+
+ return b.Set(ctx, percentage)
+}
diff --git a/commons/opentelemetry/metrics/system_test.go b/commons/opentelemetry/metrics/system_test.go
new file mode 100644
index 00000000..28bc304c
--- /dev/null
+++ b/commons/opentelemetry/metrics/system_test.go
@@ -0,0 +1,195 @@
+//go:build unit
+
+package metrics
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// Test: System metric variable definitions
+// ---------------------------------------------------------------------------
+
+func TestSystemMetrics_MetricDefinitions(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ metric Metric
+ wantName string
+ wantUnit string
+ wantDescNE string // description must not be empty
+ }{
+ {
+ name: "CPU usage metric has correct name",
+ metric: MetricSystemCPUUsage,
+ wantName: "system.cpu.usage",
+ wantUnit: "percentage",
+ wantDescNE: "",
+ },
+ {
+ name: "Memory usage metric has correct name",
+ metric: MetricSystemMemUsage,
+ wantName: "system.mem.usage",
+ wantUnit: "percentage",
+ wantDescNE: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tt.wantName, tt.metric.Name)
+ assert.Equal(t, tt.wantUnit, tt.metric.Unit)
+ assert.NotEmpty(t, tt.metric.Description)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// Test: RecordSystemCPUUsage with valid factory
+// ---------------------------------------------------------------------------
+
+func TestRecordSystemCPUUsage_ValidFactory(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordSystemCPUUsage(context.Background(), 75)
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "system.cpu.usage")
+ require.NotNil(t, m, "system.cpu.usage metric must exist")
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(75), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// Test: RecordSystemMemUsage with valid factory
+// ---------------------------------------------------------------------------
+
+func TestRecordSystemMemUsage_ValidFactory(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordSystemMemUsage(context.Background(), 42)
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "system.mem.usage")
+ require.NotNil(t, m, "system.mem.usage metric must exist")
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(42), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// Test: RecordSystemCPUUsage — zero value
+// ---------------------------------------------------------------------------
+
+func TestRecordSystemCPUUsage_ZeroValue(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordSystemCPUUsage(context.Background(), 0)
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "system.cpu.usage")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(0), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// Test: RecordSystemMemUsage — zero value
+// ---------------------------------------------------------------------------
+
+func TestRecordSystemMemUsage_ZeroValue(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordSystemMemUsage(context.Background(), 0)
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "system.mem.usage")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(0), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// Test: RecordSystemCPUUsage — boundary value 100%
+// ---------------------------------------------------------------------------
+
+func TestRecordSystemCPUUsage_MaxPercentage(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordSystemCPUUsage(context.Background(), 100)
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "system.cpu.usage")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(100), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// Test: RecordSystemMemUsage — overwrite (gauge last-value semantics)
+// ---------------------------------------------------------------------------
+
+func TestRecordSystemMemUsage_Overwrite(t *testing.T) {
+ t.Parallel()
+
+ factory, reader := newTestFactory(t)
+
+ require.NoError(t, factory.RecordSystemMemUsage(context.Background(), 30))
+ require.NoError(t, factory.RecordSystemMemUsage(context.Background(), 85))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "system.mem.usage")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ // Gauge keeps last value
+ assert.Equal(t, int64(85), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// Test: Nop factory — system metrics don't error
+// ---------------------------------------------------------------------------
+
+func TestRecordSystemMetrics_NopFactory(t *testing.T) {
+ t.Parallel()
+
+ factory := NewNopFactory()
+
+ err := factory.RecordSystemCPUUsage(context.Background(), 50)
+ assert.NoError(t, err, "nop factory should not error for CPU usage")
+
+ err = factory.RecordSystemMemUsage(context.Background(), 60)
+ assert.NoError(t, err, "nop factory should not error for memory usage")
+}
diff --git a/commons/opentelemetry/metrics/transaction.go b/commons/opentelemetry/metrics/transaction.go
index a8799647..a2382385 100644
--- a/commons/opentelemetry/metrics/transaction.go
+++ b/commons/opentelemetry/metrics/transaction.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package metrics
import (
@@ -10,9 +6,16 @@ import (
"go.opentelemetry.io/otel/attribute"
)
-func (f *MetricsFactory) RecordTransactionProcessed(ctx context.Context, organizationID, ledgerID string, attributes ...attribute.KeyValue) {
- f.Counter(MetricTransactionsProcessed).
- WithLabels(f.WithLedgerLabels(organizationID, ledgerID)).
- WithAttributes(attributes...).
- AddOne(ctx)
+// RecordTransactionProcessed increments the transaction-processed counter.
+func (f *MetricsFactory) RecordTransactionProcessed(ctx context.Context, attributes ...attribute.KeyValue) error {
+ if f == nil {
+ return ErrNilFactory
+ }
+
+ b, err := f.Counter(MetricTransactionsProcessed)
+ if err != nil {
+ return err
+ }
+
+ return b.WithAttributes(attributes...).AddOne(ctx)
}
diff --git a/commons/opentelemetry/metrics/transaction_routes.go b/commons/opentelemetry/metrics/transaction_routes.go
index bfcaf2d2..bb2f3313 100644
--- a/commons/opentelemetry/metrics/transaction_routes.go
+++ b/commons/opentelemetry/metrics/transaction_routes.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package metrics
import (
@@ -10,9 +6,16 @@ import (
"go.opentelemetry.io/otel/attribute"
)
-func (f *MetricsFactory) RecordTransactionRouteCreated(ctx context.Context, organizationID, ledgerID string, attributes ...attribute.KeyValue) {
- f.Counter(MetricTransactionRoutesCreated).
- WithLabels(f.WithLedgerLabels(organizationID, ledgerID)).
- WithAttributes(attributes...).
- AddOne(ctx)
+// RecordTransactionRouteCreated increments the transaction-route-created counter.
+func (f *MetricsFactory) RecordTransactionRouteCreated(ctx context.Context, attributes ...attribute.KeyValue) error {
+ if f == nil {
+ return ErrNilFactory
+ }
+
+ b, err := f.Counter(MetricTransactionRoutesCreated)
+ if err != nil {
+ return err
+ }
+
+ return b.WithAttributes(attributes...).AddOne(ctx)
}
diff --git a/commons/opentelemetry/metrics/v2_test.go b/commons/opentelemetry/metrics/v2_test.go
new file mode 100644
index 00000000..47986ee7
--- /dev/null
+++ b/commons/opentelemetry/metrics/v2_test.go
@@ -0,0 +1,1798 @@
+//go:build unit
+
+package metrics
+
+import (
+ "context"
+ "sync"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/metric/noop"
+ sdkmetric "go.opentelemetry.io/otel/sdk/metric"
+ "go.opentelemetry.io/otel/sdk/metric/metricdata"
+)
+
+// ---------------------------------------------------------------------------
+// Test helpers
+// ---------------------------------------------------------------------------
+
+// newTestFactory creates a MetricsFactory backed by a real SDK meter provider
+// with a ManualReader. The ManualReader lets us export and inspect actual
+// metric data recorded by the instruments.
+func newTestFactory(t *testing.T) (*MetricsFactory, *sdkmetric.ManualReader) {
+ t.Helper()
+
+ reader := sdkmetric.NewManualReader()
+ provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader))
+ meter := provider.Meter("test")
+
+ factory, err := NewMetricsFactory(meter, &log.NopLogger{})
+ require.NoError(t, err)
+
+ return factory, reader
+}
+
+// collectMetrics is a convenience wrapper that calls reader.Collect and returns
+// the ResourceMetrics payload.
+func collectMetrics(t *testing.T, reader *sdkmetric.ManualReader) metricdata.ResourceMetrics {
+ t.Helper()
+
+ var rm metricdata.ResourceMetrics
+
+ err := reader.Collect(context.Background(), &rm)
+ require.NoError(t, err)
+
+ return rm
+}
+
+// findMetricByName walks the collected ResourceMetrics and returns the first
+// Metrics entry whose Name matches. Returns nil if not found.
+func findMetricByName(rm metricdata.ResourceMetrics, name string) *metricdata.Metrics {
+ for _, sm := range rm.ScopeMetrics {
+ for i := range sm.Metrics {
+ if sm.Metrics[i].Name == name {
+ return &sm.Metrics[i]
+ }
+ }
+ }
+
+ return nil
+}
+
+// sumDataPoints extracts data points from a Sum metric.
+func sumDataPoints(t *testing.T, m *metricdata.Metrics) []metricdata.DataPoint[int64] {
+ t.Helper()
+
+ sum, ok := m.Data.(metricdata.Sum[int64])
+ require.True(t, ok, "expected Sum[int64] data, got %T", m.Data)
+
+ return sum.DataPoints
+}
+
+// histDataPoints extracts data points from a Histogram metric.
+func histDataPoints(t *testing.T, m *metricdata.Metrics) []metricdata.HistogramDataPoint[int64] {
+ t.Helper()
+
+ hist, ok := m.Data.(metricdata.Histogram[int64])
+ require.True(t, ok, "expected Histogram[int64] data, got %T", m.Data)
+
+ return hist.DataPoints
+}
+
+// gaugeDataPoints extracts data points from a Gauge metric.
+func gaugeDataPoints(t *testing.T, m *metricdata.Metrics) []metricdata.DataPoint[int64] {
+ t.Helper()
+
+ gauge, ok := m.Data.(metricdata.Gauge[int64])
+ require.True(t, ok, "expected Gauge[int64] data, got %T", m.Data)
+
+ return gauge.DataPoints
+}
+
+// hasAttribute checks whether the attribute set contains a specific string key/value.
+func hasAttribute(attrs attribute.Set, key, value string) bool {
+ v, ok := attrs.Value(attribute.Key(key))
+ if !ok {
+ return false
+ }
+
+ return v.AsString() == value
+}
+
+// ---------------------------------------------------------------------------
+// 1. Factory creation
+// ---------------------------------------------------------------------------
+
+func TestNewMetricsFactory_NilMeter(t *testing.T) {
+ _, err := NewMetricsFactory(nil, &log.NopLogger{})
+ assert.ErrorIs(t, err, ErrNilMeter, "nil meter must be rejected")
+}
+
+func TestNewMetricsFactory_NilLogger(t *testing.T) {
+ // A nil logger is fine -- internal code guards against it.
+ meter := noop.NewMeterProvider().Meter("test")
+ factory, err := NewMetricsFactory(meter, nil)
+ require.NoError(t, err)
+ assert.NotNil(t, factory)
+}
+
+func TestNewMetricsFactory_ValidCreation(t *testing.T) {
+ factory, _ := newTestFactory(t)
+ assert.NotNil(t, factory)
+}
+
+// ---------------------------------------------------------------------------
+// 2. Counter recording and verification
+// ---------------------------------------------------------------------------
+
+func TestCounter_AddOne_RecordsValue(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{
+ Name: "requests_total",
+ Description: "Total number of requests",
+ Unit: "1",
+ })
+ require.NoError(t, err)
+
+ err = counter.AddOne(context.Background())
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "requests_total")
+ require.NotNil(t, m, "metric requests_total must exist")
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value)
+}
+
+func TestCounter_Add_RecordsArbitraryValue(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "bytes_sent"})
+ require.NoError(t, err)
+
+ require.NoError(t, counter.Add(context.Background(), 42))
+ require.NoError(t, counter.Add(context.Background(), 8))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "bytes_sent")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(50), dps[0].Value, "counter should accumulate 42+8=50")
+}
+
+func TestCounter_AddOne_MultipleIncrements(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "events_total"})
+ require.NoError(t, err)
+
+ for i := 0; i < 5; i++ {
+ require.NoError(t, counter.AddOne(context.Background()))
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "events_total")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(5), dps[0].Value)
+}
+
+func TestCounter_ZeroValue(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "zero_counter"})
+ require.NoError(t, err)
+
+ require.NoError(t, counter.Add(context.Background(), 0))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "zero_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(0), dps[0].Value)
+}
+
+func TestCounter_NilCounter_ReturnsError(t *testing.T) {
+ builder := &CounterBuilder{counter: nil}
+ err := builder.AddOne(context.Background())
+ assert.ErrorIs(t, err, ErrNilCounter)
+}
+
+// ---------------------------------------------------------------------------
+// 3. Gauge recording and verification
+// ---------------------------------------------------------------------------
+
+func TestGauge_Set_RecordsValue(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{
+ Name: "queue_length",
+ Description: "Current queue length",
+ Unit: "1",
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, gauge.Set(context.Background(), 42))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "queue_length")
+ require.NotNil(t, m, "metric queue_length must exist")
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(42), dps[0].Value)
+}
+
+func TestGauge_SetOverwrite(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{Name: "connections"})
+ require.NoError(t, err)
+
+ require.NoError(t, gauge.Set(context.Background(), 10))
+ require.NoError(t, gauge.Set(context.Background(), 25))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "connections")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ // Gauge keeps last value
+ assert.Equal(t, int64(25), dps[0].Value)
+}
+
+func TestGauge_NilGauge_ReturnsError(t *testing.T) {
+ builder := &GaugeBuilder{gauge: nil}
+ err := builder.Set(context.Background(), 1)
+ assert.ErrorIs(t, err, ErrNilGauge)
+}
+
+func TestGauge_ZeroValue(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{Name: "zero_gauge"})
+ require.NoError(t, err)
+
+ require.NoError(t, gauge.Set(context.Background(), 0))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "zero_gauge")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(0), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// 4. Histogram recording and verification
+// ---------------------------------------------------------------------------
+
+func TestHistogram_Record_RecordsValue(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{
+ Name: "request_duration",
+ Description: "Request duration in ms",
+ Unit: "ms",
+ Buckets: []float64{10, 50, 100, 250, 500, 1000},
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 75))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "request_duration")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, uint64(1), dps[0].Count)
+ assert.Equal(t, int64(75), dps[0].Sum)
+}
+
+func TestHistogram_MultipleRecords(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{
+ Name: "latency",
+ Buckets: []float64{1, 5, 10, 50, 100},
+ })
+ require.NoError(t, err)
+
+ values := []int64{3, 7, 15, 45, 90}
+ for _, v := range values {
+ require.NoError(t, hist.Record(context.Background(), v))
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "latency")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, uint64(5), dps[0].Count)
+ assert.Equal(t, int64(3+7+15+45+90), dps[0].Sum)
+}
+
+func TestHistogram_BucketBoundariesConfigured(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ customBuckets := []float64{10, 25, 50, 100}
+
+ hist, err := factory.Histogram(Metric{
+ Name: "custom_histogram",
+ Buckets: customBuckets,
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 30))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "custom_histogram")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, customBuckets, dps[0].Bounds, "bucket boundaries must match configured values")
+}
+
+func TestHistogram_NilHistogram_ReturnsError(t *testing.T) {
+ builder := &HistogramBuilder{histogram: nil}
+ err := builder.Record(context.Background(), 1)
+ assert.ErrorIs(t, err, ErrNilHistogram)
+}
+
+func TestHistogram_ZeroValue(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{Name: "zero_hist", Buckets: []float64{1, 10}})
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 0))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "zero_hist")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, uint64(1), dps[0].Count)
+ assert.Equal(t, int64(0), dps[0].Sum)
+}
+
+// ---------------------------------------------------------------------------
+// 5. Builder patterns: WithLabels, WithAttributes
+// ---------------------------------------------------------------------------
+
+func TestCounterBuilder_WithLabels(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "labeled_counter"})
+ require.NoError(t, err)
+
+ labeled := counter.WithLabels(map[string]string{
+ "env": "prod",
+ "service": "ledger",
+ })
+ require.NoError(t, labeled.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "labeled_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+
+ attrs := dps[0].Attributes
+ assert.True(t, hasAttribute(attrs, "env", "prod"), "must have env=prod attribute")
+ assert.True(t, hasAttribute(attrs, "service", "ledger"), "must have service=ledger attribute")
+}
+
+func TestCounterBuilder_WithAttributes(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "attr_counter"})
+ require.NoError(t, err)
+
+ withAttrs := counter.WithAttributes(
+ attribute.String("method", "POST"),
+ attribute.String("status", "200"),
+ )
+ require.NoError(t, withAttrs.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "attr_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.True(t, hasAttribute(dps[0].Attributes, "method", "POST"))
+ assert.True(t, hasAttribute(dps[0].Attributes, "status", "200"))
+}
+
+func TestCounterBuilder_WithLabels_EmptyMap(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "empty_labels_counter"})
+ require.NoError(t, err)
+
+ labeled := counter.WithLabels(map[string]string{})
+ require.NoError(t, labeled.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "empty_labels_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value)
+}
+
+func TestCounterBuilder_ChainedLabelsAndAttributes(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "chained_counter"})
+ require.NoError(t, err)
+
+ chained := counter.
+ WithLabels(map[string]string{"region": "us-east-1"}).
+ WithAttributes(attribute.String("version", "v2"))
+ require.NoError(t, chained.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "chained_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.True(t, hasAttribute(dps[0].Attributes, "region", "us-east-1"))
+ assert.True(t, hasAttribute(dps[0].Attributes, "version", "v2"))
+}
+
+func TestGaugeBuilder_WithLabels(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{Name: "labeled_gauge"})
+ require.NoError(t, err)
+
+ labeled := gauge.WithLabels(map[string]string{"pool": "primary"})
+ require.NoError(t, labeled.Set(context.Background(), 17))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "labeled_gauge")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.True(t, hasAttribute(dps[0].Attributes, "pool", "primary"))
+ assert.Equal(t, int64(17), dps[0].Value)
+}
+
+func TestGaugeBuilder_WithAttributes(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{Name: "attr_gauge"})
+ require.NoError(t, err)
+
+ withAttrs := gauge.WithAttributes(attribute.String("db", "postgres"))
+ require.NoError(t, withAttrs.Set(context.Background(), 100))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "attr_gauge")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.True(t, hasAttribute(dps[0].Attributes, "db", "postgres"))
+}
+
+func TestHistogramBuilder_WithLabels(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{Name: "labeled_hist", Buckets: []float64{10, 100}})
+ require.NoError(t, err)
+
+ labeled := hist.WithLabels(map[string]string{"endpoint": "/api/v1"})
+ require.NoError(t, labeled.Record(context.Background(), 55))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "labeled_hist")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.True(t, hasAttribute(dps[0].Attributes, "endpoint", "/api/v1"))
+}
+
+func TestHistogramBuilder_WithAttributes(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{Name: "attr_hist", Buckets: []float64{5, 50}})
+ require.NoError(t, err)
+
+ withAttrs := hist.WithAttributes(attribute.String("type", "batch"))
+ require.NoError(t, withAttrs.Record(context.Background(), 20))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "attr_hist")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.True(t, hasAttribute(dps[0].Attributes, "type", "batch"))
+}
+
+// ---------------------------------------------------------------------------
+// 6. Builder immutability -- WithLabels/WithAttributes must not mutate original
+// ---------------------------------------------------------------------------
+
+func TestCounterBuilder_Immutability(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "immut_counter"})
+ require.NoError(t, err)
+
+ branch1 := counter.WithLabels(map[string]string{"branch": "1"})
+ branch2 := counter.WithLabels(map[string]string{"branch": "2"})
+
+ require.NoError(t, branch1.AddOne(context.Background()))
+ require.NoError(t, branch2.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "immut_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ assert.Len(t, dps, 2, "two label sets must produce two separate data points")
+
+ foundBranch1, foundBranch2 := false, false
+ for _, dp := range dps {
+ if hasAttribute(dp.Attributes, "branch", "1") {
+ foundBranch1 = true
+ }
+ if hasAttribute(dp.Attributes, "branch", "2") {
+ foundBranch2 = true
+ }
+ }
+
+ assert.True(t, foundBranch1, "must find branch=1 data point")
+ assert.True(t, foundBranch2, "must find branch=2 data point")
+}
+
+func TestGaugeBuilder_Immutability(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{Name: "immut_gauge"})
+ require.NoError(t, err)
+
+ branch1 := gauge.WithLabels(map[string]string{"pool": "primary"})
+ branch2 := gauge.WithLabels(map[string]string{"pool": "replica"})
+
+ require.NoError(t, branch1.Set(context.Background(), 10))
+ require.NoError(t, branch2.Set(context.Background(), 20))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "immut_gauge")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ assert.Len(t, dps, 2, "two label sets must produce two separate data points")
+}
+
+func TestHistogramBuilder_Immutability(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{Name: "immut_hist", Buckets: []float64{10, 100}})
+ require.NoError(t, err)
+
+ branch1 := hist.WithLabels(map[string]string{"route": "/a"})
+ branch2 := hist.WithLabels(map[string]string{"route": "/b"})
+
+ require.NoError(t, branch1.Record(context.Background(), 5))
+ require.NoError(t, branch2.Record(context.Background(), 50))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "immut_hist")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ assert.Len(t, dps, 2, "two label sets must produce two separate data points")
+}
+
+// ---------------------------------------------------------------------------
+// 7. Distinct attribute sets create distinct data points
+// ---------------------------------------------------------------------------
+
+func TestCounter_DifferentLabels_SeparateDataPoints(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "http_requests"})
+ require.NoError(t, err)
+
+ success := counter.WithLabels(map[string]string{"status": "200"})
+ failure := counter.WithLabels(map[string]string{"status": "500"})
+
+ require.NoError(t, success.Add(context.Background(), 100))
+ require.NoError(t, failure.Add(context.Background(), 3))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "http_requests")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 2)
+
+ for _, dp := range dps {
+ if hasAttribute(dp.Attributes, "status", "200") {
+ assert.Equal(t, int64(100), dp.Value)
+ } else if hasAttribute(dp.Attributes, "status", "500") {
+ assert.Equal(t, int64(3), dp.Value)
+ } else {
+ t.Fatal("unexpected data point without status attribute")
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 8. Metric caching (getOrCreate*)
+// ---------------------------------------------------------------------------
+
+func TestCounter_CachesInstrument(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ m := Metric{Name: "cached_counter", Description: "test"}
+
+ counter1, err := factory.Counter(m)
+ require.NoError(t, err)
+
+ counter2, err := factory.Counter(m)
+ require.NoError(t, err)
+
+ // Both builders must share the same underlying counter instrument.
+ assert.Equal(t, counter1.counter, counter2.counter, "counter must be cached")
+}
+
+func TestGauge_CachesInstrument(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ m := Metric{Name: "cached_gauge"}
+
+ gauge1, err := factory.Gauge(m)
+ require.NoError(t, err)
+
+ gauge2, err := factory.Gauge(m)
+ require.NoError(t, err)
+
+ assert.Equal(t, gauge1.gauge, gauge2.gauge, "gauge must be cached")
+}
+
+func TestHistogram_CachesInstrument(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ m := Metric{Name: "cached_hist", Buckets: []float64{1, 10, 100}}
+
+ hist1, err := factory.Histogram(m)
+ require.NoError(t, err)
+
+ hist2, err := factory.Histogram(m)
+ require.NoError(t, err)
+
+ assert.Equal(t, hist1.histogram, hist2.histogram, "histogram must be cached")
+}
+
+func TestDuplicateRegistrations_ShareInstrument(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ m := Metric{Name: "shared_counter"}
+
+ counter1, err := factory.Counter(m)
+ require.NoError(t, err)
+
+ counter2, err := factory.Counter(m)
+ require.NoError(t, err)
+
+ require.NoError(t, counter1.AddOne(context.Background()))
+ require.NoError(t, counter2.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ met := findMetricByName(rm, "shared_counter")
+ require.NotNil(t, met)
+
+ dps := sumDataPoints(t, met)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(2), dps[0].Value, "both builders must write to same counter")
+}
+
+// ---------------------------------------------------------------------------
+// 9. selectDefaultBuckets
+// ---------------------------------------------------------------------------
+
+func TestSelectDefaultBuckets(t *testing.T) {
+ tests := []struct {
+ name string
+ expected []float64
+ }{
+ {"account_creation_rate", DefaultAccountBuckets},
+ {"AccountTotal", DefaultAccountBuckets},
+ {"transaction_volume", DefaultTransactionBuckets},
+ {"TransactionLatency", DefaultLatencyBuckets}, // "latency" checked before "transaction"
+ {"api_latency", DefaultLatencyBuckets},
+ {"request_duration", DefaultLatencyBuckets},
+ {"processing_time", DefaultLatencyBuckets},
+ {"unknown_metric", DefaultLatencyBuckets},
+ {"", DefaultLatencyBuckets},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := selectDefaultBuckets(tt.name)
+ assert.Equal(t, tt.expected, got)
+ })
+ }
+}
+
+func TestHistogram_DefaultBucketsApplied(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ // No Buckets specified -- should use default based on name
+ hist, err := factory.Histogram(Metric{Name: "transaction_processing"})
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 500))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "transaction_processing")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, DefaultTransactionBuckets, dps[0].Bounds,
+ "transaction-related histogram must use DefaultTransactionBuckets")
+}
+
+func TestHistogram_AccountDefaultBuckets(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{Name: "account_creation"})
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 10))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "account_creation")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, DefaultAccountBuckets, dps[0].Bounds)
+}
+
+func TestHistogram_LatencyDefaultBuckets(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{Name: "api_latency"})
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 1))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "api_latency")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, DefaultLatencyBuckets, dps[0].Bounds)
+}
+
+// ---------------------------------------------------------------------------
+// 10. histogramCacheKey
+// ---------------------------------------------------------------------------
+
+func TestHistogramCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ buckets []float64
+ expected string
+ }{
+ {"metric", nil, "metric"},
+ {"metric", []float64{}, "metric"},
+ {"metric", []float64{1, 5, 10}, "metric:1,5,10"},
+ {"metric", []float64{10, 1, 5}, "metric:1,5,10"}, // sorted
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.expected, func(t *testing.T) {
+ assert.Equal(t, tt.expected, histogramCacheKey(tt.name, tt.buckets))
+ })
+ }
+}
+
+func TestHistogram_DifferentBuckets_SeparateCacheEntries(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ hist1, err := factory.Histogram(Metric{
+ Name: "my_hist",
+ Buckets: []float64{1, 10, 100},
+ })
+ require.NoError(t, err)
+
+ hist2, err := factory.Histogram(Metric{
+ Name: "my_hist",
+ Buckets: []float64{5, 50, 500},
+ })
+ require.NoError(t, err)
+
+ // Different buckets => different cache entries => different histogram instruments.
+ // Note: OTel SDK may or may not return different instruments for the same name,
+ // but the cache key must be different.
+ assert.NotEqual(t,
+ histogramCacheKey("my_hist", []float64{1, 10, 100}),
+ histogramCacheKey("my_hist", []float64{5, 50, 500}),
+ )
+
+ // Both histograms should work without error.
+ require.NoError(t, hist1.Record(context.Background(), 5))
+ require.NoError(t, hist2.Record(context.Background(), 25))
+}
+
+// ---------------------------------------------------------------------------
+// 11. Domain metric recording helpers
+// ---------------------------------------------------------------------------
+
+func TestRecordAccountCreated(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordAccountCreated(context.Background(),
+ attribute.String("org_id", "org-123"),
+ )
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "accounts_created")
+ require.NotNil(t, m, "accounts_created metric must exist")
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value)
+ assert.True(t, hasAttribute(dps[0].Attributes, "org_id", "org-123"))
+}
+
+func TestRecordTransactionProcessed(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordTransactionProcessed(context.Background(),
+ attribute.String("type", "debit"),
+ )
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "transactions_processed")
+ require.NotNil(t, m, "transactions_processed metric must exist")
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value)
+ assert.True(t, hasAttribute(dps[0].Attributes, "type", "debit"))
+}
+
+func TestRecordOperationRouteCreated(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordOperationRouteCreated(context.Background(),
+ attribute.String("operation", "transfer"),
+ )
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "operation_routes_created")
+ require.NotNil(t, m, "operation_routes_created metric must exist")
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value)
+ assert.True(t, hasAttribute(dps[0].Attributes, "operation", "transfer"))
+}
+
+func TestRecordTransactionRouteCreated(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ err := factory.RecordTransactionRouteCreated(context.Background(),
+ attribute.String("route", "internal"),
+ )
+ require.NoError(t, err)
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "transaction_routes_created")
+ require.NotNil(t, m, "transaction_routes_created metric must exist")
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value)
+ assert.True(t, hasAttribute(dps[0].Attributes, "route", "internal"))
+}
+
+func TestRecordHelpers_NoAttributes(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ require.NoError(t, factory.RecordAccountCreated(context.Background()))
+ require.NoError(t, factory.RecordTransactionProcessed(context.Background()))
+ require.NoError(t, factory.RecordOperationRouteCreated(context.Background()))
+ require.NoError(t, factory.RecordTransactionRouteCreated(context.Background()))
+
+ rm := collectMetrics(t, reader)
+
+ for _, name := range []string{
+ "accounts_created",
+ "transactions_processed",
+ "operation_routes_created",
+ "transaction_routes_created",
+ } {
+ m := findMetricByName(rm, name)
+ require.NotNil(t, m, "metric %q must exist", name)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value)
+ }
+}
+
+func TestRecordHelpers_MultipleInvocations(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ for i := 0; i < 10; i++ {
+ require.NoError(t, factory.RecordAccountCreated(context.Background()))
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "accounts_created")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(10), dps[0].Value)
+}
+
+// ---------------------------------------------------------------------------
+// 12. Pre-configured Metric definitions
+// ---------------------------------------------------------------------------
+
+func TestPreConfiguredMetrics(t *testing.T) {
+ tests := []struct {
+ metric Metric
+ name string
+ }{
+ {MetricAccountsCreated, "accounts_created"},
+ {MetricTransactionsProcessed, "transactions_processed"},
+ {MetricTransactionRoutesCreated, "transaction_routes_created"},
+ {MetricOperationRoutesCreated, "operation_routes_created"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.name, tt.metric.Name)
+ assert.NotEmpty(t, tt.metric.Description)
+ assert.Equal(t, "1", tt.metric.Unit)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 13. Metric options (description, unit)
+// ---------------------------------------------------------------------------
+
+func TestCounter_DescriptionAndUnit(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{
+ Name: "desc_counter",
+ Description: "A test counter with description",
+ Unit: "requests",
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, counter.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "desc_counter")
+ require.NotNil(t, m)
+ assert.Equal(t, "A test counter with description", m.Description)
+ assert.Equal(t, "requests", m.Unit)
+}
+
+func TestGauge_DescriptionAndUnit(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{
+ Name: "desc_gauge",
+ Description: "A test gauge",
+ Unit: "connections",
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, gauge.Set(context.Background(), 5))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "desc_gauge")
+ require.NotNil(t, m)
+ assert.Equal(t, "A test gauge", m.Description)
+ assert.Equal(t, "connections", m.Unit)
+}
+
+func TestHistogram_DescriptionAndUnit(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{
+ Name: "desc_hist",
+ Description: "A test histogram",
+ Unit: "ms",
+ Buckets: []float64{10, 100},
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 50))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "desc_hist")
+ require.NotNil(t, m)
+ assert.Equal(t, "A test histogram", m.Description)
+ assert.Equal(t, "ms", m.Unit)
+}
+
+func TestCounter_NoDescriptionNoUnit(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "bare_counter"})
+ require.NoError(t, err)
+
+ require.NoError(t, counter.AddOne(context.Background()))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "bare_counter")
+ require.NotNil(t, m)
+ // SDK may set empty strings; the point is no error occurred.
+}
+
+// ---------------------------------------------------------------------------
+// 14. Concurrent metric recording (goroutine safety)
+// ---------------------------------------------------------------------------
+
+func TestCounter_ConcurrentAdd(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "concurrent_counter"})
+ require.NoError(t, err)
+
+ const goroutines = 100
+
+ errs := make(chan error, goroutines)
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+ if err := counter.AddOne(context.Background()); err != nil {
+ errs <- err
+ }
+ }()
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("concurrent AddOne error: %v", err)
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "concurrent_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(goroutines), dps[0].Value,
+ "all concurrent increments must be accounted for")
+}
+
+func TestGauge_ConcurrentSet(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{Name: "concurrent_gauge"})
+ require.NoError(t, err)
+
+ const goroutines = 50
+
+ errs := make(chan error, goroutines)
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(val int64) {
+ defer wg.Done()
+ if err := gauge.Set(context.Background(), val); err != nil {
+ errs <- err
+ }
+ }(int64(i))
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("concurrent Set error: %v", err)
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "concurrent_gauge")
+ require.NotNil(t, m)
+
+ dps := gaugeDataPoints(t, m)
+ require.NotEmpty(t, dps, "gauge must have at least one data point")
+}
+
+func TestHistogram_ConcurrentRecord(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{Name: "concurrent_hist", Buckets: []float64{10, 100, 1000}})
+ require.NoError(t, err)
+
+ const goroutines = 100
+
+ errs := make(chan error, goroutines)
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(val int64) {
+ defer wg.Done()
+ if err := hist.Record(context.Background(), val); err != nil {
+ errs <- err
+ }
+ }(int64(i))
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("concurrent Record error: %v", err)
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "concurrent_hist")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, uint64(goroutines), dps[0].Count)
+}
+
+func TestFactory_ConcurrentCounterCreation(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ const goroutines = 50
+ m := Metric{Name: "race_counter"}
+
+ errs := make(chan error, goroutines*2)
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+
+ counter, err := factory.Counter(m)
+ if err != nil {
+ errs <- err
+ return
+ }
+
+ if err := counter.AddOne(context.Background()); err != nil {
+ errs <- err
+ }
+ }()
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("concurrent counter creation error: %v", err)
+ }
+
+ rm := collectMetrics(t, reader)
+ met := findMetricByName(rm, "race_counter")
+ require.NotNil(t, met)
+
+ dps := sumDataPoints(t, met)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(goroutines), dps[0].Value,
+ "concurrent counter creation and recording must not lose data")
+}
+
+func TestFactory_ConcurrentGaugeCreation(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ const goroutines = 50
+ m := Metric{Name: "race_gauge"}
+
+ errs := make(chan error, goroutines*2)
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(val int64) {
+ defer wg.Done()
+
+ gauge, err := factory.Gauge(m)
+ if err != nil {
+ errs <- err
+ return
+ }
+
+ if err := gauge.Set(context.Background(), val); err != nil {
+ errs <- err
+ }
+ }(int64(i))
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("concurrent gauge creation error: %v", err)
+ }
+}
+
+func TestFactory_ConcurrentHistogramCreation(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ const goroutines = 50
+ m := Metric{Name: "race_hist", Buckets: []float64{10, 100}}
+
+ errs := make(chan error, goroutines*2)
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(val int64) {
+ defer wg.Done()
+
+ hist, err := factory.Histogram(m)
+ if err != nil {
+ errs <- err
+ return
+ }
+
+ if err := hist.Record(context.Background(), val); err != nil {
+ errs <- err
+ }
+ }(int64(i))
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("concurrent histogram creation error: %v", err)
+ }
+
+ rm := collectMetrics(t, reader)
+ met := findMetricByName(rm, "race_hist")
+ require.NotNil(t, met)
+
+ dps := histDataPoints(t, met)
+ require.Len(t, dps, 1)
+ assert.Equal(t, uint64(goroutines), dps[0].Count)
+}
+
+func TestFactory_ConcurrentMixedMetricTypes(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ const goroutines = 30
+
+ errs := make(chan error, goroutines*3)
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines * 3)
+
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+ if err := factory.RecordAccountCreated(context.Background()); err != nil {
+ errs <- err
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ if err := factory.RecordTransactionProcessed(context.Background()); err != nil {
+ errs <- err
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+ if err := factory.RecordOperationRouteCreated(context.Background()); err != nil {
+ errs <- err
+ }
+ }()
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("concurrent mixed metric error: %v", err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 15. Error sentinel values
+// ---------------------------------------------------------------------------
+
+func TestErrorSentinels(t *testing.T) {
+ assert.NotNil(t, ErrNilMeter)
+ assert.NotNil(t, ErrNilCounter)
+ assert.NotNil(t, ErrNilGauge)
+ assert.NotNil(t, ErrNilHistogram)
+
+ assert.EqualError(t, ErrNilMeter, "metric meter cannot be nil")
+ assert.EqualError(t, ErrNilCounter, "counter instrument is nil")
+ assert.EqualError(t, ErrNilGauge, "gauge instrument is nil")
+ assert.EqualError(t, ErrNilHistogram, "histogram instrument is nil")
+}
+
+// ---------------------------------------------------------------------------
+// 16. Default bucket configuration values
+// ---------------------------------------------------------------------------
+
+func TestDefaultBucketValues(t *testing.T) {
+ assert.NotEmpty(t, DefaultLatencyBuckets)
+ assert.NotEmpty(t, DefaultAccountBuckets)
+ assert.NotEmpty(t, DefaultTransactionBuckets)
+
+ // Verify they are sorted (required by OTel spec for histogram boundaries)
+ for i := 1; i < len(DefaultLatencyBuckets); i++ {
+ assert.Less(t, DefaultLatencyBuckets[i-1], DefaultLatencyBuckets[i],
+ "DefaultLatencyBuckets must be sorted")
+ }
+
+ for i := 1; i < len(DefaultAccountBuckets); i++ {
+ assert.Less(t, DefaultAccountBuckets[i-1], DefaultAccountBuckets[i],
+ "DefaultAccountBuckets must be sorted")
+ }
+
+ for i := 1; i < len(DefaultTransactionBuckets); i++ {
+ assert.Less(t, DefaultTransactionBuckets[i-1], DefaultTransactionBuckets[i],
+ "DefaultTransactionBuckets must be sorted")
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 17. addXxxOptions helpers
+// ---------------------------------------------------------------------------
+
+func TestAddCounterOptions(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ t.Run("with description and unit", func(t *testing.T) {
+ opts := factory.addCounterOptions(Metric{
+ Name: "test",
+ Description: "desc",
+ Unit: "bytes",
+ })
+ assert.Len(t, opts, 2)
+ })
+
+ t.Run("with description only", func(t *testing.T) {
+ opts := factory.addCounterOptions(Metric{
+ Name: "test",
+ Description: "desc",
+ })
+ assert.Len(t, opts, 1)
+ })
+
+ t.Run("with unit only", func(t *testing.T) {
+ opts := factory.addCounterOptions(Metric{
+ Name: "test",
+ Unit: "ms",
+ })
+ assert.Len(t, opts, 1)
+ })
+
+ t.Run("no options", func(t *testing.T) {
+ opts := factory.addCounterOptions(Metric{Name: "test"})
+ assert.Empty(t, opts)
+ })
+}
+
+func TestAddGaugeOptions(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ t.Run("with description and unit", func(t *testing.T) {
+ opts := factory.addGaugeOptions(Metric{
+ Name: "test",
+ Description: "desc",
+ Unit: "connections",
+ })
+ assert.Len(t, opts, 2)
+ })
+
+ t.Run("no options", func(t *testing.T) {
+ opts := factory.addGaugeOptions(Metric{Name: "test"})
+ assert.Empty(t, opts)
+ })
+}
+
+func TestAddHistogramOptions(t *testing.T) {
+ factory, _ := newTestFactory(t)
+
+ t.Run("with all options", func(t *testing.T) {
+ opts := factory.addHistogramOptions(Metric{
+ Name: "test",
+ Description: "desc",
+ Unit: "ms",
+ Buckets: []float64{1, 10, 100},
+ })
+ assert.Len(t, opts, 3) // description + unit + buckets
+ })
+
+ t.Run("with buckets only", func(t *testing.T) {
+ opts := factory.addHistogramOptions(Metric{
+ Name: "test",
+ Buckets: []float64{1, 10},
+ })
+ assert.Len(t, opts, 1)
+ })
+
+ t.Run("no options and nil buckets", func(t *testing.T) {
+ opts := factory.addHistogramOptions(Metric{Name: "test"})
+ assert.Empty(t, opts)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// 18. End-to-end: full recording pipeline
+// ---------------------------------------------------------------------------
+
+func TestEndToEnd_CounterPipeline(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ // 1. Create counter with full Metric definition
+ counter, err := factory.Counter(Metric{
+ Name: "e2e_counter",
+ Description: "End to end counter",
+ Unit: "ops",
+ })
+ require.NoError(t, err)
+
+ // 2. Record with labels
+ labeled := counter.WithLabels(map[string]string{
+ "service": "ledger",
+ "env": "staging",
+ })
+ require.NoError(t, labeled.Add(context.Background(), 10))
+
+ // 3. Record again with different labels
+ other := counter.WithLabels(map[string]string{
+ "service": "auth",
+ "env": "prod",
+ })
+ require.NoError(t, other.Add(context.Background(), 5))
+
+ // 4. Verify all data
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "e2e_counter")
+ require.NotNil(t, m)
+ assert.Equal(t, "End to end counter", m.Description)
+ assert.Equal(t, "ops", m.Unit)
+
+ dps := sumDataPoints(t, m)
+ assert.Len(t, dps, 2, "two label sets => two data points")
+
+ var totalValue int64
+ for _, dp := range dps {
+ totalValue += dp.Value
+ }
+
+ assert.Equal(t, int64(15), totalValue, "total value across all data points")
+}
+
+func TestEndToEnd_HistogramPipeline(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{
+ Name: "e2e_hist",
+ Description: "End to end histogram",
+ Unit: "ms",
+ Buckets: []float64{50, 100, 250, 500, 1000},
+ })
+ require.NoError(t, err)
+
+ // Record several values across different buckets
+ values := []int64{25, 75, 150, 300, 750, 1500}
+ for _, v := range values {
+ require.NoError(t, hist.WithLabels(map[string]string{"handler": "transfer"}).Record(context.Background(), v))
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "e2e_hist")
+ require.NotNil(t, m)
+ assert.Equal(t, "End to end histogram", m.Description)
+ assert.Equal(t, "ms", m.Unit)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, uint64(6), dps[0].Count)
+
+ var expectedSum int64
+ for _, v := range values {
+ expectedSum += v
+ }
+
+ assert.Equal(t, expectedSum, dps[0].Sum)
+ assert.Equal(t, []float64{50, 100, 250, 500, 1000}, dps[0].Bounds)
+ assert.True(t, hasAttribute(dps[0].Attributes, "handler", "transfer"))
+}
+
+func TestEndToEnd_GaugePipeline(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ gauge, err := factory.Gauge(Metric{
+ Name: "e2e_gauge",
+ Description: "End to end gauge",
+ Unit: "items",
+ })
+ require.NoError(t, err)
+
+ // Set different values for different pools
+ primary := gauge.WithLabels(map[string]string{"pool": "primary"})
+ require.NoError(t, primary.Set(context.Background(), 50))
+
+ replica := gauge.WithLabels(map[string]string{"pool": "replica"})
+ require.NoError(t, replica.Set(context.Background(), 30))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "e2e_gauge")
+ require.NotNil(t, m)
+ assert.Equal(t, "End to end gauge", m.Description)
+ assert.Equal(t, "items", m.Unit)
+
+ dps := gaugeDataPoints(t, m)
+ assert.Len(t, dps, 2, "two attribute sets must produce two data points")
+
+ for _, dp := range dps {
+ if hasAttribute(dp.Attributes, "pool", "primary") {
+ assert.Equal(t, int64(50), dp.Value)
+ } else if hasAttribute(dp.Attributes, "pool", "replica") {
+ assert.Equal(t, int64(30), dp.Value)
+ } else {
+ t.Fatal("unexpected data point without pool attribute")
+ }
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 19. Noop provider compatibility (existing tests, upgraded)
+// ---------------------------------------------------------------------------
+
+func TestNoop_FactoryCreation(t *testing.T) {
+ meter := noop.NewMeterProvider().Meter("noop-test")
+ factory, err := NewMetricsFactory(meter, &log.NopLogger{})
+ require.NoError(t, err)
+ assert.NotNil(t, factory)
+}
+
+func TestNoop_AllHelpers(t *testing.T) {
+ meter := noop.NewMeterProvider().Meter("noop-test")
+ factory, err := NewMetricsFactory(meter, &log.NopLogger{})
+ require.NoError(t, err)
+
+ require.NoError(t, factory.RecordAccountCreated(context.Background(), attribute.String("result", "ok")))
+ require.NoError(t, factory.RecordTransactionProcessed(context.Background(), attribute.String("result", "ok")))
+ require.NoError(t, factory.RecordOperationRouteCreated(context.Background(), attribute.String("result", "ok")))
+ require.NoError(t, factory.RecordTransactionRouteCreated(context.Background(), attribute.String("result", "ok")))
+}
+
+// ---------------------------------------------------------------------------
+// 20. Histogram bucket count verification
+// ---------------------------------------------------------------------------
+
+func TestHistogram_BucketCountDistribution(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{
+ Name: "bucket_test",
+ Buckets: []float64{10, 50, 100},
+ })
+ require.NoError(t, err)
+
+ // Record values that fall into specific buckets:
+ // Bucket [0, 10): values 1, 5 => count=2
+ // Bucket [10, 50): values 15, 30 => count=2
+ // Bucket [50, 100): values 60 => count=1
+ // Bucket [100, +Inf): values 200 => count=1
+ for _, v := range []int64{1, 5, 15, 30, 60, 200} {
+ require.NoError(t, hist.Record(context.Background(), v))
+ }
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "bucket_test")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, uint64(6), dps[0].Count)
+ assert.Equal(t, int64(1+5+15+30+60+200), dps[0].Sum)
+
+ // BucketCounts: [<=10, <=50, <=100, +Inf]
+ // Expected: [2, 4, 5, 6] (cumulative in OTel SDK)
+ // Note: OTel SDK uses cumulative bucket counts
+ require.Len(t, dps[0].BucketCounts, 4, "3 boundaries => 4 bucket counts")
+}
+
+// ---------------------------------------------------------------------------
+// 21. Multiple metrics on same factory
+// ---------------------------------------------------------------------------
+
+func TestFactory_MultipleMetricTypes(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ // Create one of each type
+ counter, err := factory.Counter(Metric{Name: "multi_counter"})
+ require.NoError(t, err)
+
+ gauge, err := factory.Gauge(Metric{Name: "multi_gauge"})
+ require.NoError(t, err)
+
+ hist, err := factory.Histogram(Metric{Name: "multi_hist", Buckets: []float64{10, 100}})
+ require.NoError(t, err)
+
+ // Record values
+ require.NoError(t, counter.Add(context.Background(), 7))
+ require.NoError(t, gauge.Set(context.Background(), 42))
+ require.NoError(t, hist.Record(context.Background(), 55))
+
+ // Verify all
+ rm := collectMetrics(t, reader)
+
+ ctrMet := findMetricByName(rm, "multi_counter")
+ require.NotNil(t, ctrMet)
+ ctrDps := sumDataPoints(t, ctrMet)
+ require.Len(t, ctrDps, 1)
+ assert.Equal(t, int64(7), ctrDps[0].Value)
+
+ gaugeMet := findMetricByName(rm, "multi_gauge")
+ require.NotNil(t, gaugeMet)
+ gaugeDps := gaugeDataPoints(t, gaugeMet)
+ require.Len(t, gaugeDps, 1)
+ assert.Equal(t, int64(42), gaugeDps[0].Value)
+
+ histMet := findMetricByName(rm, "multi_hist")
+ require.NotNil(t, histMet)
+ histDps := histDataPoints(t, histMet)
+ require.Len(t, histDps, 1)
+ assert.Equal(t, uint64(1), histDps[0].Count)
+ assert.Equal(t, int64(55), histDps[0].Sum)
+}
+
+// ---------------------------------------------------------------------------
+// 22. Context propagation
+// ---------------------------------------------------------------------------
+
+func TestCounter_RespectsContext(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "ctx_counter"})
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ require.NoError(t, counter.AddOne(ctx))
+ cancel() // Cancel after recording
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "ctx_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(1), dps[0].Value, "value recorded before cancel must persist")
+}
+
+// ---------------------------------------------------------------------------
+// 23. Large value handling
+// ---------------------------------------------------------------------------
+
+func TestCounter_LargeValues(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "big_counter"})
+ require.NoError(t, err)
+
+ // Financial services can have very large transaction counts
+ largeVal := int64(1_000_000_000)
+ require.NoError(t, counter.Add(context.Background(), largeVal))
+ require.NoError(t, counter.Add(context.Background(), largeVal))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "big_counter")
+ require.NotNil(t, m)
+
+ dps := sumDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(2_000_000_000), dps[0].Value)
+}
+
+func TestHistogram_LargeValues(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ hist, err := factory.Histogram(Metric{
+ Name: "big_hist",
+ Buckets: []float64{1000, 10_000, 100_000, 1_000_000},
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, hist.Record(context.Background(), 5_000_000))
+
+ rm := collectMetrics(t, reader)
+ m := findMetricByName(rm, "big_hist")
+ require.NotNil(t, m)
+
+ dps := histDataPoints(t, m)
+ require.Len(t, dps, 1)
+ assert.Equal(t, int64(5_000_000), dps[0].Sum)
+}
+
+// ---------------------------------------------------------------------------
+// 24. Multiple collects
+// ---------------------------------------------------------------------------
+
+func TestCounter_MultipleCollects(t *testing.T) {
+ factory, reader := newTestFactory(t)
+
+ counter, err := factory.Counter(Metric{Name: "multi_collect_counter"})
+ require.NoError(t, err)
+
+ require.NoError(t, counter.Add(context.Background(), 5))
+
+ // First collect
+ rm1 := collectMetrics(t, reader)
+ m1 := findMetricByName(rm1, "multi_collect_counter")
+ require.NotNil(t, m1)
+ dps1 := sumDataPoints(t, m1)
+ require.Len(t, dps1, 1)
+ assert.Equal(t, int64(5), dps1[0].Value)
+
+ // Record more
+ require.NoError(t, counter.Add(context.Background(), 3))
+
+ // Second collect -- cumulative counter should show total
+ rm2 := collectMetrics(t, reader)
+ m2 := findMetricByName(rm2, "multi_collect_counter")
+ require.NotNil(t, m2)
+ dps2 := sumDataPoints(t, m2)
+ require.Len(t, dps2, 1)
+ assert.Equal(t, int64(8), dps2[0].Value, "cumulative counter should show 5+3=8")
+}
diff --git a/commons/opentelemetry/obfuscation.go b/commons/opentelemetry/obfuscation.go
index 4368b26f..1f97020c 100644
--- a/commons/opentelemetry/obfuscation.go
+++ b/commons/opentelemetry/obfuscation.go
@@ -1,98 +1,221 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package opentelemetry
import (
+ "bytes"
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/hex"
"encoding/json"
- "strings"
+ "fmt"
+ "regexp"
+
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/safe"
+ "github.com/LerianStudio/lib-commons/v4/commons/security"
+)
+
+// RedactionAction defines how sensitive values are transformed.
+type RedactionAction string
- cn "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/LerianStudio/lib-commons/v2/commons/security"
+const (
+ // RedactionMask replaces a sensitive value with the configured mask.
+ RedactionMask RedactionAction = "mask"
+ // RedactionHash replaces a sensitive value with an HMAC-SHA256 hash.
+ RedactionHash RedactionAction = "hash"
+ // RedactionDrop removes a sensitive field from the output.
+ RedactionDrop RedactionAction = "drop"
)
-// FieldObfuscator defines the interface for obfuscating sensitive fields in structs.
-// Implementations can provide custom logic for determining which fields to obfuscate
-// and how to obfuscate them.
-type FieldObfuscator interface {
- // ShouldObfuscate returns true if the given field name should be obfuscated
- ShouldObfuscate(fieldName string) bool
- // GetObfuscatedValue returns the value to use for obfuscated fields
- GetObfuscatedValue() string
+// RedactionRule matches fields/paths and applies a redaction action.
+type RedactionRule struct {
+ FieldPattern string
+ PathPattern string
+ Action RedactionAction
+
+ fieldRegex *regexp.Regexp
+ pathRegex *regexp.Regexp
}
-// DefaultObfuscator provides a simple implementation that obfuscates
-// common sensitive field names using the security package's word-boundary matching.
-type DefaultObfuscator struct {
- obfuscatedValue string
+// hmacKeySize is the byte length of the HMAC key generated for each Redactor.
+const hmacKeySize = 32
+
+// Redactor applies ordered redaction rules to structured payloads.
+type Redactor struct {
+ rules []RedactionRule
+ maskValue string
+ hmacKey []byte // per-instance key used by HMAC-SHA256 hashing
}
-// NewDefaultObfuscator creates a new DefaultObfuscator with common sensitive field names.
-// Uses the shared sensitive fields list from the security package to ensure consistency
-// across HTTP logging, OpenTelemetry spans, and other components.
-func NewDefaultObfuscator() *DefaultObfuscator {
- return &DefaultObfuscator{
- obfuscatedValue: cn.ObfuscatedValue,
+// NewDefaultRedactor builds a mask-based redactor from default sensitive fields.
+func NewDefaultRedactor() *Redactor {
+ fields := security.DefaultSensitiveFields()
+
+ rules := make([]RedactionRule, 0, len(fields))
+ for _, field := range fields {
+ rules = append(rules, RedactionRule{FieldPattern: `(?i)^` + regexp.QuoteMeta(field) + `$`, Action: RedactionMask})
+ }
+
+ r, err := NewRedactor(rules, cn.ObfuscatedValue)
+ if err != nil {
+ // Rule compilation failed unexpectedly. Return a conservative always-mask
+ // redactor rather than a no-rules redactor that would leak everything.
+ return NewAlwaysMaskRedactor()
}
-}
-// ShouldObfuscate returns true if the field name is in the sensitive fields list.
-// Delegates to security.IsSensitiveField for consistent word-boundary matching
-// across all components (HTTP logging, OpenTelemetry spans, URL sanitization).
-func (o *DefaultObfuscator) ShouldObfuscate(fieldName string) bool {
- return security.IsSensitiveField(fieldName)
+ return r
}
-// CustomObfuscator provides an implementation that obfuscates
-// only the specific field names provided during creation.
-type CustomObfuscator struct {
- sensitiveFields map[string]bool
- obfuscatedValue string
+// NewAlwaysMaskRedactor returns a conservative redactor that treats ALL fields as sensitive.
+// This is used as a safe fallback when normal redactor construction fails, ensuring
+// no data leaks through in fail-open scenarios.
+func NewAlwaysMaskRedactor() *Redactor {
+ return &Redactor{
+ rules: []RedactionRule{
+ {
+ // Match every field name
+ FieldPattern: ".*",
+ fieldRegex: regexp.MustCompile(".*"),
+ Action: RedactionMask,
+ },
+ },
+ maskValue: cn.ObfuscatedValue,
+ }
}
-// NewCustomObfuscator creates a new CustomObfuscator with custom sensitive field names.
-// Uses simple case-insensitive matching against the provided fields only.
-func NewCustomObfuscator(sensitiveFields []string) *CustomObfuscator {
- fieldMap := make(map[string]bool, len(sensitiveFields))
- for _, field := range sensitiveFields {
- fieldMap[strings.ToLower(field)] = true
+// NewRedactor compiles rules and returns a configured redactor.
+func NewRedactor(rules []RedactionRule, maskValue string) (*Redactor, error) {
+ if maskValue == "" {
+ maskValue = cn.ObfuscatedValue
}
- return &CustomObfuscator{
- sensitiveFields: fieldMap,
- obfuscatedValue: cn.ObfuscatedValue,
+ compiled := make([]RedactionRule, 0, len(rules))
+ for i := range rules {
+ rule := rules[i]
+ if rule.Action == "" {
+ rule.Action = RedactionMask
+ }
+
+ if rule.FieldPattern != "" {
+ re, err := safe.Compile(rule.FieldPattern)
+ if err != nil {
+ return nil, fmt.Errorf("invalid redaction field pattern at index %d: %w", i, err)
+ }
+
+ rule.fieldRegex = re
+ }
+
+ if rule.PathPattern != "" {
+ re, err := safe.Compile(rule.PathPattern)
+ if err != nil {
+ return nil, fmt.Errorf("invalid redaction path pattern at index %d: %w", i, err)
+ }
+
+ rule.pathRegex = re
+ }
+
+ compiled = append(compiled, rule)
}
+
+ key := make([]byte, hmacKeySize)
+ if _, err := rand.Read(key); err != nil {
+ return nil, fmt.Errorf("failed to generate HMAC key: %w", err)
+ }
+
+ return &Redactor{rules: compiled, maskValue: maskValue, hmacKey: key}, nil
}
-// ShouldObfuscate returns true if the field name matches one of the custom sensitive fields.
-// Uses simple case-insensitive matching (not word-boundary matching).
-func (o *CustomObfuscator) ShouldObfuscate(fieldName string) bool {
- return o.sensitiveFields[strings.ToLower(fieldName)]
+func (r *Redactor) actionFor(path, fieldName string) (RedactionAction, bool) {
+ if r == nil {
+ return "", false
+ }
+
+ for i := range r.rules {
+ rule := r.rules[i]
+ pathMatch := true
+
+ var fieldMatch bool
+ if rule.fieldRegex != nil {
+ fieldMatch = rule.fieldRegex.MatchString(fieldName)
+ } else {
+ fieldMatch = security.IsSensitiveField(fieldName)
+ }
+
+ if rule.pathRegex != nil {
+ pathMatch = rule.pathRegex.MatchString(path)
+ }
+
+ if fieldMatch && pathMatch {
+ return rule.Action, true
+ }
+ }
+
+ return "", false
}
-// GetObfuscatedValue returns the obfuscated value.
-func (o *CustomObfuscator) GetObfuscatedValue() string {
- return o.obfuscatedValue
+// redactValue applies the first matching redaction rule to a field value.
+// It returns the (possibly transformed) value, whether the field should be dropped,
+// and whether any redaction rule was applied (so the caller can skip expensive comparison).
+func (r *Redactor) redactValue(path, fieldName string, value any) (redacted any, drop bool, applied bool) {
+ action, ok := r.actionFor(path, fieldName)
+ if !ok {
+ return value, false, false
+ }
+
+ switch action {
+ case RedactionDrop:
+ return nil, true, true
+ case RedactionHash:
+ return r.hashString(fmt.Sprint(value)), false, true
+ case RedactionMask:
+ fallthrough
+ default:
+ return r.maskValue, false, true
+ }
}
-// GetObfuscatedValue returns the obfuscated value.
-func (o *DefaultObfuscator) GetObfuscatedValue() string {
- return o.obfuscatedValue
+// hashString computes an HMAC-SHA256 of v using the Redactor's per-instance key.
+// The result is a hex-encoded string prefixed with "sha256:" for identification.
+// Using HMAC prevents rainbow-table attacks against low-entropy PII.
+func (r *Redactor) hashString(v string) string {
+ if len(r.hmacKey) > 0 {
+ mac := hmac.New(sha256.New, r.hmacKey)
+ mac.Write([]byte(v))
+
+ return "sha256:" + hex.EncodeToString(mac.Sum(nil))
+ }
+
+ // Fallback for zero-key edge case (should not happen with proper construction).
+ h := sha256.Sum256([]byte(v))
+
+ return fmt.Sprintf("sha256:%x", h)
}
// obfuscateStructFields recursively obfuscates sensitive fields in a struct or map.
-func obfuscateStructFields(data any, obfuscator FieldObfuscator) any {
+func obfuscateStructFields(data any, path string, redactor *Redactor) any {
switch v := data.(type) {
case map[string]any:
result := make(map[string]any, len(v))
for key, value := range v {
- if obfuscator.ShouldObfuscate(key) {
- result[key] = obfuscator.GetObfuscatedValue()
- } else {
- result[key] = obfuscateStructFields(value, obfuscator)
+ childPath := key
+ if path != "" {
+ childPath = path + "." + key
+ }
+
+ if redactor != nil {
+ redacted, drop, applied := redactor.redactValue(childPath, key, value)
+ if drop {
+ continue
+ }
+
+ if applied {
+ result[key] = redacted
+ continue
+ }
}
+
+ result[key] = obfuscateStructFields(value, childPath, redactor)
}
return result
@@ -101,7 +224,8 @@ func obfuscateStructFields(data any, obfuscator FieldObfuscator) any {
result := make([]any, len(v))
for i, item := range v {
- result[i] = obfuscateStructFields(item, obfuscator)
+ childPath := fmt.Sprintf("%s[%d]", path, i)
+ result[i] = obfuscateStructFields(item, childPath, redactor)
}
return result
@@ -113,21 +237,29 @@ func obfuscateStructFields(data any, obfuscator FieldObfuscator) any {
// ObfuscateStruct applies obfuscation to a struct and returns the obfuscated data.
// This is a utility function that can be used independently of OpenTelemetry spans.
-func ObfuscateStruct(valueStruct any, obfuscator FieldObfuscator) (any, error) {
- if obfuscator == nil {
+func ObfuscateStruct(valueStruct any, redactor *Redactor) (any, error) {
+ if valueStruct == nil || redactor == nil {
return valueStruct, nil
}
- // Convert to JSON and back to get a map[string]any representation
+ // Convert to JSON and back to get a generic representation.
+ // Using any (not map[string]any) to handle arrays, primitives, and objects.
jsonBytes, err := json.Marshal(valueStruct)
if err != nil {
return nil, err
}
- var structData map[string]any
- if err := json.Unmarshal(jsonBytes, &structData); err != nil {
+ var data any
+
+ decoder := json.NewDecoder(bytes.NewReader(jsonBytes))
+ decoder.UseNumber()
+
+ if err := decoder.Decode(&data); err != nil {
return nil, err
}
- return obfuscateStructFields(structData, obfuscator), nil
+ // Zero the intermediate buffer to minimize sensitive data lifetime in memory
+ clear(jsonBytes)
+
+ return obfuscateStructFields(data, "", redactor), nil
}
diff --git a/commons/opentelemetry/obfuscation_example_test.go b/commons/opentelemetry/obfuscation_example_test.go
new file mode 100644
index 00000000..fac3c9db
--- /dev/null
+++ b/commons/opentelemetry/obfuscation_example_test.go
@@ -0,0 +1,49 @@
+//go:build unit
+
+package opentelemetry_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+)
+
+func ExampleObfuscateStruct_customRules() {
+ redactor, err := opentelemetry.NewRedactor([]opentelemetry.RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: opentelemetry.RedactionMask},
+ {FieldPattern: `(?i)^email$`, Action: opentelemetry.RedactionHash},
+ }, "***")
+ if err != nil {
+ fmt.Println("invalid rules")
+ return
+ }
+
+ masked, err := opentelemetry.ObfuscateStruct(map[string]any{
+ "name": "alice",
+ "email": "a@b.com",
+ "password": "secret",
+ }, redactor)
+ if err != nil {
+ fmt.Println("obfuscation failed")
+ return
+ }
+
+ m := masked.(map[string]any)
+
+ // password is masked, name is unchanged, email is HMAC-hashed with sha256: prefix
+ fmt.Println("name:", m["name"])
+ fmt.Println("password:", m["password"])
+ fmt.Println("email_prefix:", strings.HasPrefix(m["email"].(string), "sha256:"))
+
+ // Verify the JSON round-trips cleanly
+ b, _ := json.Marshal(masked)
+ fmt.Println("json_ok:", len(b) > 0)
+
+ // Output:
+ // name: alice
+ // password: ***
+ // email_prefix: true
+ // json_ok: true
+}
diff --git a/commons/opentelemetry/obfuscation_test.go b/commons/opentelemetry/obfuscation_test.go
index c2782d84..7fada087 100644
--- a/commons/opentelemetry/obfuscation_test.go
+++ b/commons/opentelemetry/obfuscation_test.go
@@ -1,597 +1,1303 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package opentelemetry
import (
- "context"
+ "encoding/json"
+ "fmt"
+ "strconv"
"strings"
"testing"
- cn "github.com/LerianStudio/lib-commons/v2/commons/constants"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "go.opentelemetry.io/otel/trace/noop"
+ "go.opentelemetry.io/otel/attribute"
)
-// TestStruct represents a test struct with sensitive and non-sensitive fields
-type TestStruct struct {
- Username string `json:"username"`
- Password string `json:"password"`
- Email string `json:"email"`
- Token string `json:"token"`
- PublicData string `json:"publicData"`
- Credentials struct {
- APIKey string `json:"apikey"`
- SecretKey string `json:"secret"`
- } `json:"credentials"`
- Metadata map[string]any `json:"metadata"`
+// ---------------------------------------------------------------------------
+// Helpers
+// ---------------------------------------------------------------------------
+
+// mustRedactor builds a Redactor or fails the test.
+func mustRedactor(t *testing.T, rules []RedactionRule, mask string) *Redactor {
+ t.Helper()
+
+ r, err := NewRedactor(rules, mask)
+ require.NoError(t, err)
+
+ return r
}
-// NestedTestStruct represents a more complex nested structure
-type NestedTestStruct struct {
- User TestStruct `json:"user"`
- Settings struct {
- Theme string `json:"theme"`
- PrivateKey string `json:"private_key"`
- Preferences []string `json:"preferences"`
- } `json:"settings"`
- Tokens []string `json:"tokens"`
+// hashVia returns the HMAC-SHA256 hash of v using the given redactor's key.
+// This replaces the old sha256Hex helper because hashing is now keyed per-Redactor.
+func hashVia(r *Redactor, v string) string {
+ return r.hashString(v)
}
-func TestNewDefaultObfuscator(t *testing.T) {
- obfuscator := NewDefaultObfuscator()
+// ===========================================================================
+// 1. Redactor construction
+// ===========================================================================
- assert.NotNil(t, obfuscator)
- assert.Equal(t, cn.ObfuscatedValue, obfuscator.GetObfuscatedValue())
+func TestNewRedactor_EmptyRules(t *testing.T) {
+ t.Parallel()
- // Test common sensitive fields (should match security.DefaultSensitiveFields)
- sensitiveFields := []string{
- "password", "token", "secret", "key", "authorization",
- "auth", "credential", "credentials", "apikey", "api_key",
- "access_token", "refresh_token", "private_key", "privatekey",
- }
+ r, err := NewRedactor(nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, r)
+ assert.Empty(t, r.rules)
+ assert.Equal(t, cn.ObfuscatedValue, r.maskValue, "empty mask should fall back to constant")
+}
- for _, field := range sensitiveFields {
- assert.True(t, obfuscator.ShouldObfuscate(field), "Field %s should be obfuscated", field)
- assert.True(t, obfuscator.ShouldObfuscate(strings.ToUpper(field)), "Field %s (uppercase) should be obfuscated", field)
- }
+func TestNewRedactor_CustomMaskValue(t *testing.T) {
+ t.Parallel()
+
+ r, err := NewRedactor(nil, "REDACTED")
+ require.NoError(t, err)
+ assert.Equal(t, "REDACTED", r.maskValue)
+}
+
+func TestNewRedactor_DefaultActionIsMask(t *testing.T) {
+ t.Parallel()
+
+ r, err := NewRedactor([]RedactionRule{
+ {FieldPattern: `^foo$`},
+ }, "")
+ require.NoError(t, err)
+ require.Len(t, r.rules, 1)
+ assert.Equal(t, RedactionMask, r.rules[0].Action, "blank Action should default to mask")
+}
+
+func TestNewRedactor_InvalidFieldPattern(t *testing.T) {
+ t.Parallel()
+
+ _, err := NewRedactor([]RedactionRule{
+ {FieldPattern: `[invalid`},
+ }, "")
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid redaction field pattern at index 0")
+}
+
+func TestNewRedactor_InvalidPathPattern(t *testing.T) {
+ t.Parallel()
+
+ _, err := NewRedactor([]RedactionRule{
+ {PathPattern: `[broken`},
+ }, "")
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid redaction path pattern at index 0")
+}
+
+func TestNewRedactor_MultipleRulesCompileCorrectly(t *testing.T) {
+ t.Parallel()
+
+ r, err := NewRedactor([]RedactionRule{
+ {FieldPattern: `^password$`, Action: RedactionMask},
+ {FieldPattern: `^email$`, Action: RedactionHash},
+ {PathPattern: `^session\.token$`, FieldPattern: `^token$`, Action: RedactionDrop},
+ }, "***")
+ require.NoError(t, err)
+ require.Len(t, r.rules, 3)
+ assert.NotNil(t, r.rules[0].fieldRegex)
+ assert.Nil(t, r.rules[0].pathRegex)
+ assert.NotNil(t, r.rules[2].fieldRegex)
+ assert.NotNil(t, r.rules[2].pathRegex)
+}
+
+// ===========================================================================
+// 2. NewDefaultRedactor
+// ===========================================================================
+
+func TestNewDefaultRedactor_IsNotNil(t *testing.T) {
+ t.Parallel()
- // Test non-sensitive fields
- nonSensitiveFields := []string{
- "username", "email", "name", "id", "status", "created_at", "updated_at",
+ r := NewDefaultRedactor()
+ require.NotNil(t, r)
+ assert.NotEmpty(t, r.rules, "default redactor should have rules from DefaultSensitiveFields")
+ assert.Equal(t, cn.ObfuscatedValue, r.maskValue)
+}
+
+func TestNewDefaultRedactor_MatchesSensitiveFields(t *testing.T) {
+ t.Parallel()
+
+ r := NewDefaultRedactor()
+
+ // These are all in the default sensitive list
+ for _, field := range []string{"password", "token", "secret", "authorization", "apikey", "cvv", "ssn"} {
+ action, matched := r.actionFor("", field)
+ assert.True(t, matched, "field %q should match default rules", field)
+ assert.Equal(t, RedactionMask, action)
}
+}
+
+func TestNewDefaultRedactor_CaseInsensitive(t *testing.T) {
+ t.Parallel()
+
+ r := NewDefaultRedactor()
- for _, field := range nonSensitiveFields {
- assert.False(t, obfuscator.ShouldObfuscate(field), "Field %s should not be obfuscated", field)
+ for _, field := range []string{"Password", "PASSWORD", "pAsSwOrD"} {
+ _, matched := r.actionFor("", field)
+ assert.True(t, matched, "field %q should match case-insensitively", field)
}
}
-func TestNewCustomObfuscator(t *testing.T) {
- customFields := []string{"customSecret", "internalToken", "SENSITIVE_DATA"}
- obfuscator := NewCustomObfuscator(customFields)
+func TestNewDefaultRedactor_NonSensitiveFieldUnchanged(t *testing.T) {
+ t.Parallel()
+
+ r := NewDefaultRedactor()
+
+ _, matched := r.actionFor("", "username")
+ assert.False(t, matched)
+}
+
+// ===========================================================================
+// 3. actionFor (field and path matching)
+// ===========================================================================
+
+func TestActionFor_NilRedactor(t *testing.T) {
+ t.Parallel()
+
+ var r *Redactor
- assert.NotNil(t, obfuscator)
- assert.Equal(t, cn.ObfuscatedValue, obfuscator.GetObfuscatedValue())
+ action, matched := r.actionFor("any.path", "any")
+ assert.False(t, matched)
+ assert.Equal(t, RedactionAction(""), action)
+}
+
+func TestActionFor_ExactFieldMatch(t *testing.T) {
+ t.Parallel()
- // Test custom sensitive fields (case insensitive)
- assert.True(t, obfuscator.ShouldObfuscate("customSecret"))
- assert.True(t, obfuscator.ShouldObfuscate("CUSTOMSECRET"))
- assert.True(t, obfuscator.ShouldObfuscate("internalToken"))
- assert.True(t, obfuscator.ShouldObfuscate("sensitive_data"))
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `^email$`, Action: RedactionHash},
+ }, "")
- // Test that default fields are not included
- assert.False(t, obfuscator.ShouldObfuscate("password"))
- assert.False(t, obfuscator.ShouldObfuscate("token"))
+ action, ok := r.actionFor("user.email", "email")
+ assert.True(t, ok)
+ assert.Equal(t, RedactionHash, action)
}
-func TestObfuscateStructFields(t *testing.T) {
- obfuscator := NewDefaultObfuscator()
+func TestActionFor_RegexFieldPattern(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i).*password.*`, Action: RedactionMask},
+ }, "")
tests := []struct {
- name string
- input any
- expected any
+ field string
+ matched bool
}{
- {
- name: "simple map with sensitive fields",
- input: map[string]any{
- "username": "john_doe",
- "password": "secret123",
- "email": "john@example.com",
- "token": "abc123xyz",
- },
- expected: map[string]any{
- "username": "john_doe",
- "password": cn.ObfuscatedValue,
- "email": "john@example.com",
- "token": cn.ObfuscatedValue,
- },
+ {"password", true},
+ {"user_password", true},
+ {"password_hash", true},
+ {"newPassword", true},
+ {"username", false},
+ }
+
+ for _, tt := range tests {
+ action, ok := r.actionFor("", tt.field)
+ assert.Equal(t, tt.matched, ok, "field=%q", tt.field)
+
+ if tt.matched {
+ assert.Equal(t, RedactionMask, action)
+ }
+ }
+}
+
+func TestActionFor_PathPatternOnly(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {PathPattern: `^config\.db\.password$`, Action: RedactionDrop},
+ }, "")
+
+ // "password" is in the default sensitive fields, so IsSensitiveField("password") returns true.
+ // With pathPattern matching the exact path, the rule should match.
+ _, ok := r.actionFor("config.db.password", "password")
+ assert.True(t, ok, "path+field should match")
+
+ // Non-matching path but sensitive field: the pathRegex will fail the pathMatch.
+ _, ok = r.actionFor("user.password", "password")
+ assert.False(t, ok, "path should NOT match different prefix")
+}
+
+func TestActionFor_CombinedFieldAndPathPattern(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `^token$`, PathPattern: `^session\.`, Action: RedactionDrop},
+ }, "")
+
+ _, ok := r.actionFor("session.token", "token")
+ assert.True(t, ok)
+
+ // Same field, different path
+ _, ok = r.actionFor("auth.token", "token")
+ assert.False(t, ok)
+}
+
+func TestActionFor_NoMatch(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `^secret$`, Action: RedactionMask},
+ }, "")
+
+ _, ok := r.actionFor("", "name")
+ assert.False(t, ok)
+}
+
+func TestActionFor_FirstMatchWins(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionHash},
+ {FieldPattern: `(?i)^password$`, Action: RedactionDrop},
+ }, "")
+
+ action, ok := r.actionFor("", "password")
+ assert.True(t, ok)
+ assert.Equal(t, RedactionHash, action, "first matching rule should win")
+}
+
+// ===========================================================================
+// 4. redactValue (mask / hash / drop)
+// ===========================================================================
+
+func TestRedactValue_Mask(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ val, drop, applied := r.redactValue("", "password", "secret123")
+ assert.False(t, drop)
+ assert.True(t, applied)
+ assert.Equal(t, "***", val)
+}
+
+func TestRedactValue_Hash(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^email$`, Action: RedactionHash},
+ }, "")
+
+ val, drop, applied := r.redactValue("", "email", "alice@example.com")
+ assert.False(t, drop)
+ assert.True(t, applied)
+ assert.Equal(t, hashVia(r, "alice@example.com"), val)
+}
+
+func TestRedactValue_Hash_Deterministic(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^document$`, Action: RedactionHash},
+ }, "")
+
+ val1, _, _ := r.redactValue("", "document", "12345")
+ val2, _, _ := r.redactValue("", "document", "12345")
+ assert.Equal(t, val1, val2, "hashing the same value must be deterministic")
+}
+
+func TestRedactValue_Hash_DifferentInputs(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^document$`, Action: RedactionHash},
+ }, "")
+
+ val1, _, _ := r.redactValue("", "document", "abc")
+ val2, _, _ := r.redactValue("", "document", "def")
+ assert.NotEqual(t, val1, val2, "different inputs must produce different hashes")
+}
+
+func TestRedactValue_Drop(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ }, "")
+
+ val, drop, applied := r.redactValue("", "token", "tok_abc")
+ assert.True(t, drop)
+ assert.True(t, applied)
+ assert.Nil(t, val)
+}
+
+func TestRedactValue_NoMatch(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `^secret$`, Action: RedactionMask},
+ }, "")
+
+ val, drop, applied := r.redactValue("", "name", "Alice")
+ assert.False(t, drop)
+ assert.False(t, applied)
+ assert.Equal(t, "Alice", val)
+}
+
+func TestRedactValue_NilRedactor(t *testing.T) {
+ t.Parallel()
+
+ var r *Redactor
+
+ val, drop, applied := r.redactValue("", "password", "secret")
+ assert.False(t, drop)
+ assert.False(t, applied)
+ assert.Equal(t, "secret", val)
+}
+
+// ===========================================================================
+// 5. hashString
+// ===========================================================================
+
+func TestHashString_Deterministic(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, nil, "")
+
+ h1 := r.hashString("hello")
+ h2 := r.hashString("hello")
+ assert.Equal(t, h1, h2)
+}
+
+func TestHashString_DifferentInputs(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, nil, "")
+
+ h1 := r.hashString("foo")
+ h2 := r.hashString("bar")
+ assert.NotEqual(t, h1, h2)
+}
+
+func TestHashString_Empty(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, nil, "")
+
+ h := r.hashString("")
+ assert.NotEmpty(t, h, "hash of empty string should produce a non-empty output")
+ assert.True(t, strings.HasPrefix(h, "sha256:"), "hash should have sha256: prefix")
+}
+
+func TestHashString_DifferentRedactorsProduceDifferentHashes(t *testing.T) {
+ t.Parallel()
+
+ r1 := mustRedactor(t, nil, "")
+ r2 := mustRedactor(t, nil, "")
+
+ h1 := r1.hashString("same-input")
+ h2 := r2.hashString("same-input")
+ assert.NotEqual(t, h1, h2, "different Redactors use different HMAC keys and should produce different hashes")
+}
+
+// ===========================================================================
+// 6. obfuscateStructFields -- flat maps
+// ===========================================================================
+
+func TestObfuscateStructFields_FlatMap(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^email$`, Action: RedactionHash},
+ }, "***")
+
+ input := map[string]any{
+ "name": "alice",
+ "email": "alice@example.com",
+ "password": "secret",
+ }
+
+ result := obfuscateStructFields(input, "", r)
+ m, ok := result.(map[string]any)
+ require.True(t, ok)
+ assert.Equal(t, "alice", m["name"])
+ assert.Equal(t, "***", m["password"])
+ assert.Equal(t, hashVia(r, "alice@example.com"), m["email"])
+}
+
+func TestObfuscateStructFields_EmptyMap(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `^password$`, Action: RedactionMask},
+ }, "***")
+
+ result := obfuscateStructFields(map[string]any{}, "", r)
+ m, ok := result.(map[string]any)
+ require.True(t, ok)
+ assert.Empty(t, m)
+}
+
+func TestObfuscateStructFields_NilRedactor(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]any{
+ "password": "secret",
+ }
+
+ result := obfuscateStructFields(input, "", nil)
+ m := result.(map[string]any)
+ assert.Equal(t, "secret", m["password"], "nil redactor should pass values through")
+}
+
+// ===========================================================================
+// 7. obfuscateStructFields -- nested maps
+// ===========================================================================
+
+func TestObfuscateStructFields_NestedTwoLevels(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ input := map[string]any{
+ "user": map[string]any{
+ "name": "bob",
+ "password": "topsecret",
},
- {
- name: "nested map with sensitive fields",
- input: map[string]any{
- "user": map[string]any{
- "name": "John",
- "password": "secret",
- },
- "config": map[string]any{
- "theme": "dark",
- "api_key": "key123",
- },
- },
- expected: map[string]any{
- "user": map[string]any{
- "name": "John",
- "password": cn.ObfuscatedValue,
- },
- "config": map[string]any{
- "theme": "dark",
- "api_key": cn.ObfuscatedValue,
- },
+ }
+
+ result := obfuscateStructFields(input, "", r).(map[string]any)
+ user := result["user"].(map[string]any)
+ assert.Equal(t, "bob", user["name"])
+ assert.Equal(t, "***", user["password"])
+}
+
+func TestObfuscateStructFields_NestedThreeLevels(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^secret$`, Action: RedactionDrop},
+ }, "")
+
+ input := map[string]any{
+ "level1": map[string]any{
+ "level2": map[string]any{
+ "secret": "deep-value",
+ "visible": "ok",
},
},
- {
- name: "array with sensitive data",
- input: []any{
- map[string]any{
- "id": 1,
- "password": "secret1",
- },
- map[string]any{
- "id": 2,
- "password": "secret2",
- },
- },
- expected: []any{
- map[string]any{
- "id": 1,
- "password": cn.ObfuscatedValue,
- },
- map[string]any{
- "id": 2,
- "password": cn.ObfuscatedValue,
- },
+ }
+
+ result := obfuscateStructFields(input, "", r).(map[string]any)
+ l2 := result["level1"].(map[string]any)["level2"].(map[string]any)
+ assert.NotContains(t, l2, "secret")
+ assert.Equal(t, "ok", l2["visible"])
+}
+
+func TestObfuscateStructFields_NestedPathPattern(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {PathPattern: `^config\.database\.password$`, FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "HIDDEN")
+
+ input := map[string]any{
+ "config": map[string]any{
+ "database": map[string]any{
+ "password": "pg_pass",
+ "host": "localhost",
},
},
- {
- name: "primitive value unchanged",
- input: "simple string",
- expected: "simple string",
- },
- {
- name: "number unchanged",
- input: 42,
- expected: 42,
- },
+ "password": "top-level-pass", // same field name, different path
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := obfuscateStructFields(tt.input, obfuscator)
- assert.Equal(t, tt.expected, result)
- })
- }
-}
-
-func TestObfuscateStruct(t *testing.T) {
- testStruct := TestStruct{
- Username: "john_doe",
- Password: "secret123",
- Email: "john@example.com",
- Token: "abc123xyz",
- PublicData: "public info",
- Credentials: struct {
- APIKey string `json:"apikey"`
- SecretKey string `json:"secret"`
- }{
- APIKey: "key123",
- SecretKey: "secret456",
- },
- Metadata: map[string]any{
- "theme": "dark",
- "private_key": "private123",
- },
+ result := obfuscateStructFields(input, "", r).(map[string]any)
+
+ dbCfg := result["config"].(map[string]any)["database"].(map[string]any)
+ assert.Equal(t, "HIDDEN", dbCfg["password"])
+ assert.Equal(t, "localhost", dbCfg["host"])
+
+ // Top-level password: no path match for the explicit path rule.
+ // However, IsSensitiveField("password") returns true, so it depends on
+ // actionFor logic. With a fieldRegex present, the match is
+ // fieldRegex.MatchString AND pathRegex.MatchString. pathRegex won't match
+ // "password" (it expects "config.database.password").
+ assert.NotEqual(t, "HIDDEN", result["password"], "top-level password should NOT match path-scoped rule")
+}
+
+// ===========================================================================
+// 8. obfuscateStructFields -- arrays
+// ===========================================================================
+
+func TestObfuscateStructFields_ArrayOfObjects(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ }, "")
+
+ input := []any{
+ map[string]any{"id": "1", "token": "tok_a"},
+ map[string]any{"id": "2", "token": "tok_b"},
}
- tests := []struct {
- name string
- obfuscator FieldObfuscator
- wantError bool
- }{
- {
- name: "with default obfuscator",
- obfuscator: NewDefaultObfuscator(),
- wantError: false,
- },
- {
- name: "with custom obfuscator",
- obfuscator: NewCustomObfuscator([]string{"username", "email"}),
- wantError: false,
- },
- {
- name: "without obfuscator (nil)",
- obfuscator: nil,
- wantError: false,
+ result := obfuscateStructFields(input, "", r).([]any)
+ require.Len(t, result, 2)
+
+ for i, item := range result {
+ m := item.(map[string]any)
+ assert.Equal(t, strconv.Itoa(i+1), m["id"])
+ assert.NotContains(t, m, "token")
+ }
+}
+
+func TestObfuscateStructFields_NestedArray(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ input := map[string]any{
+ "users": []any{
+ map[string]any{"name": "alice", "password": "s1"},
+ map[string]any{"name": "bob", "password": "s2"},
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := ObfuscateStruct(testStruct, tt.obfuscator)
-
- if tt.wantError {
- assert.Error(t, err)
- assert.Nil(t, result)
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, result)
- }
- })
+ result := obfuscateStructFields(input, "", r).(map[string]any)
+ users := result["users"].([]any)
+ require.Len(t, users, 2)
+
+ assert.Equal(t, "***", users[0].(map[string]any)["password"])
+ assert.Equal(t, "***", users[1].(map[string]any)["password"])
+ assert.Equal(t, "alice", users[0].(map[string]any)["name"])
+ assert.Equal(t, "bob", users[1].(map[string]any)["name"])
+}
+
+func TestObfuscateStructFields_EmptyArray(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, nil, "")
+
+ result := obfuscateStructFields([]any{}, "", r).([]any)
+ assert.Empty(t, result)
+}
+
+// ===========================================================================
+// 9. obfuscateStructFields -- mixed types
+// ===========================================================================
+
+func TestObfuscateStructFields_MixedTypes(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^secret$`, Action: RedactionMask},
+ }, "***")
+
+ input := map[string]any{
+ "count": float64(42),
+ "active": true,
+ "secret": "classified",
+ "nothing": nil,
+ "name": "test",
}
+
+ result := obfuscateStructFields(input, "", r).(map[string]any)
+ assert.Equal(t, float64(42), result["count"])
+ assert.Equal(t, true, result["active"])
+ assert.Equal(t, "***", result["secret"])
+ assert.Nil(t, result["nothing"])
+ assert.Equal(t, "test", result["name"])
}
-func TestObfuscateStruct_InvalidJSON(t *testing.T) {
- // Create a struct that cannot be marshaled to JSON (contains a channel)
- invalidStruct := struct {
- Name string
- Channel chan int
- }{
- Name: "test",
- Channel: make(chan int),
+func TestObfuscateStructFields_NilValue(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ input := map[string]any{
+ "password": nil,
+ }
+
+ // nil value, but the field matches -- mask replaces with mask value
+ result := obfuscateStructFields(input, "", r).(map[string]any)
+ assert.Equal(t, "***", result["password"])
+}
+
+func TestObfuscateStructFields_EmptyStringValue(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ input := map[string]any{
+ "password": "",
}
- obfuscator := NewDefaultObfuscator()
- result, err := ObfuscateStruct(invalidStruct, obfuscator)
+ result := obfuscateStructFields(input, "", r).(map[string]any)
+ assert.Equal(t, "***", result["password"])
+}
+
+func TestObfuscateStructFields_NonMapNonArray(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, nil, "")
+
+ // Primitives are returned as-is
+ assert.Equal(t, "hello", obfuscateStructFields("hello", "", r))
+ assert.Equal(t, float64(42), obfuscateStructFields(float64(42), "", r))
+ assert.Equal(t, true, obfuscateStructFields(true, "", r))
+ assert.Nil(t, obfuscateStructFields(nil, "", r))
+}
+
+// ===========================================================================
+// 10. ObfuscateStruct (public API)
+// ===========================================================================
+
+func TestObfuscateStruct_NilInput(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, nil, "")
- assert.Error(t, err)
+ result, err := ObfuscateStruct(nil, r)
+ require.NoError(t, err)
+ assert.Nil(t, result)
+}
+
+func TestObfuscateStruct_NilRedactor(t *testing.T) {
+ t.Parallel()
+
+ input := map[string]any{"password": "secret"}
+
+ result, err := ObfuscateStruct(input, nil)
+ require.NoError(t, err)
+ assert.Equal(t, input, result, "nil redactor returns input unchanged")
+}
+
+func TestObfuscateStruct_BothNil(t *testing.T) {
+ t.Parallel()
+
+ result, err := ObfuscateStruct(nil, nil)
+ require.NoError(t, err)
assert.Nil(t, result)
}
-func TestSetSpanAttributesFromStructWithObfuscation_Default(t *testing.T) {
- // Create a no-op tracer for testing
- tracer := noop.NewTracerProvider().Tracer("test")
- _, span := tracer.Start(context.TODO(), "test-span")
+func TestObfuscateStruct_FlatMap(t *testing.T) {
+ t.Parallel()
- testStruct := TestStruct{
- Username: "john_doe",
- Password: "secret123",
- Email: "john@example.com",
- Token: "abc123xyz",
- PublicData: "public info",
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ input := map[string]any{
+ "user": "alice",
+ "password": "s3cr3t",
}
- err := SetSpanAttributesFromStructWithObfuscation(&span, "test_data", testStruct)
+ result, err := ObfuscateStruct(input, r)
require.NoError(t, err)
- // The span should contain the obfuscated data (noop span doesn't store attributes)
-}
-
-func TestSetSpanAttributesFromStructWithObfuscation(t *testing.T) {
- // Create a no-op tracer for testing
- tracer := noop.NewTracerProvider().Tracer("test")
- _, span := tracer.Start(context.TODO(), "test-span")
-
- testStruct := TestStruct{
- Username: "john_doe",
- Password: "secret123",
- Email: "john@example.com",
- Token: "abc123xyz",
- PublicData: "public info",
- Credentials: struct {
- APIKey string `json:"apikey"`
- SecretKey string `json:"secret"`
- }{
- APIKey: "key123",
- SecretKey: "secret456",
- },
- Metadata: map[string]any{
- "theme": "dark",
- "private_key": "private123",
- },
+ m := result.(map[string]any)
+ assert.Equal(t, "alice", m["user"])
+ assert.Equal(t, "***", m["password"])
+}
+
+func TestObfuscateStruct_NestedStruct(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^email$`, Action: RedactionHash},
+ }, "***")
+
+ type Inner struct {
+ Password string `json:"password"`
+ Email string `json:"email"`
+ Name string `json:"name"`
+ }
+ type Outer struct {
+ ID string `json:"id"`
+ User Inner `json:"user"`
}
- tests := []struct {
- name string
- obfuscator FieldObfuscator
- wantError bool
- }{
- {
- name: "with default obfuscator",
- obfuscator: NewDefaultObfuscator(),
- wantError: false,
- },
- {
- name: "with custom obfuscator",
- obfuscator: NewCustomObfuscator([]string{"username", "email"}),
- wantError: false,
- },
- {
- name: "without obfuscator (nil)",
- obfuscator: nil,
- wantError: false,
+ input := Outer{
+ ID: "u1",
+ User: Inner{
+ Password: "secret",
+ Email: "alice@example.com",
+ Name: "Alice",
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- var err error
- if tt.obfuscator == nil || tt.name == "with default obfuscator" {
- err = SetSpanAttributesFromStructWithObfuscation(&span, "test_data", testStruct)
- } else {
- err = SetSpanAttributesFromStructWithCustomObfuscation(&span, "test_data", testStruct, tt.obfuscator)
- }
+ result, err := ObfuscateStruct(input, r)
+ require.NoError(t, err)
- if tt.wantError {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- }
- })
+ m := result.(map[string]any)
+ assert.Equal(t, "u1", m["id"])
+
+ user := m["user"].(map[string]any)
+ assert.Equal(t, "***", user["password"])
+ assert.Equal(t, hashVia(r, "alice@example.com"), user["email"])
+ assert.Equal(t, "Alice", user["name"])
+}
+
+func TestObfuscateStruct_ArrayOfStructs(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ }, "")
+
+ type Item struct {
+ ID string `json:"id"`
+ Token string `json:"token"`
+ }
+
+ input := []Item{
+ {ID: "1", Token: "tok_a"},
+ {ID: "2", Token: "tok_b"},
+ }
+
+ result, err := ObfuscateStruct(input, r)
+ require.NoError(t, err)
+
+ arr := result.([]any)
+ require.Len(t, arr, 2)
+
+ for i, item := range arr {
+ m := item.(map[string]any)
+ assert.Equal(t, strconv.Itoa(i+1), m["id"])
+ assert.NotContains(t, m, "token")
}
}
-func TestSetSpanAttributesFromStructWithObfuscation_NestedStruct(t *testing.T) {
- // Create a no-op tracer for testing
- tracer := noop.NewTracerProvider().Tracer("test")
- _, span := tracer.Start(context.TODO(), "test-span")
+func TestObfuscateStruct_UnmarshalableInput(t *testing.T) {
+ t.Parallel()
- nestedStruct := NestedTestStruct{
- User: TestStruct{
- Username: "john_doe",
- Password: "secret123",
- Token: "token456",
- },
- Settings: struct {
- Theme string `json:"theme"`
- PrivateKey string `json:"private_key"`
- Preferences []string `json:"preferences"`
- }{
- Theme: "dark",
- PrivateKey: "private789",
- Preferences: []string{"notifications", "dark_mode"},
+ r := mustRedactor(t, nil, "")
+
+ // channels cannot be marshaled to JSON
+ _, err := ObfuscateStruct(make(chan int), r)
+ require.Error(t, err)
+}
+
+// ===========================================================================
+// 11. JSON round-trip tests
+// ===========================================================================
+
+func TestJSONRoundTrip_SimplePayload(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^email$`, Action: RedactionHash},
+ }, "***")
+
+ jsonInput := `{"name":"alice","email":"alice@b.com","password":"pass"}`
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(jsonInput), &parsed))
+
+ result, err := ObfuscateStruct(parsed, r)
+ require.NoError(t, err)
+
+ b, err := json.Marshal(result)
+ require.NoError(t, err)
+
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(b, &decoded))
+
+ assert.Equal(t, "alice", decoded["name"])
+ assert.Equal(t, "***", decoded["password"])
+ assert.Contains(t, decoded["email"].(string), "sha256:")
+}
+
+func TestJSONRoundTrip_NestedPayload(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^secret$`, Action: RedactionDrop},
+ }, "")
+
+ jsonInput := `{
+ "config": {
+ "database": {
+ "host": "localhost",
+ "secret": "db_pass"
+ }
},
- Tokens: []string{"token1", "token2"},
+ "app": "myservice"
+ }`
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(jsonInput), &parsed))
+
+ result, err := ObfuscateStruct(parsed, r)
+ require.NoError(t, err)
+
+ b, err := json.Marshal(result)
+ require.NoError(t, err)
+
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(b, &decoded))
+
+ assert.Equal(t, "myservice", decoded["app"])
+ db := decoded["config"].(map[string]any)["database"].(map[string]any)
+ assert.Equal(t, "localhost", db["host"])
+ assert.NotContains(t, db, "secret")
+}
+
+func TestJSONRoundTrip_EmptyJSON(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, nil, "")
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal([]byte(`{}`), &parsed))
+
+ result, err := ObfuscateStruct(parsed, r)
+ require.NoError(t, err)
+
+ b, err := json.Marshal(result)
+ require.NoError(t, err)
+ assert.Equal(t, "{}", string(b))
+}
+
+func TestJSONRoundTrip_LargePayload(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ // Build a payload with many entries
+ payload := make(map[string]any, 200)
+ for i := range 200 {
+ key := fmt.Sprintf("field_%d", i)
+ payload[key] = fmt.Sprintf("value_%d", i)
}
+ payload["password"] = "should_be_masked"
- err := SetSpanAttributesFromStructWithObfuscation(&span, "nested_data", nestedStruct)
+ result, err := ObfuscateStruct(payload, r)
+ require.NoError(t, err)
- assert.NoError(t, err)
+ m := result.(map[string]any)
+ assert.Equal(t, "***", m["password"])
+ assert.Equal(t, "value_0", m["field_0"])
+ assert.Equal(t, "value_199", m["field_199"])
}
-func TestSetSpanAttributesFromStructWithObfuscation_InvalidJSON(t *testing.T) {
- // Create a no-op tracer for testing
- tracer := noop.NewTracerProvider().Tracer("test")
- _, span := tracer.Start(context.TODO(), "test-span")
+func TestJSONRoundTrip_ArrayPayload(t *testing.T) {
+ t.Parallel()
- // Create a struct that cannot be marshaled to JSON (contains a channel)
- invalidStruct := struct {
- Name string
- Channel chan int
- }{
- Name: "test",
- Channel: make(chan int),
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^ssn$`, Action: RedactionMask},
+ }, "REDACTED")
+
+ jsonInput := `[
+ {"name": "Alice", "ssn": "123-45-6789"},
+ {"name": "Bob", "ssn": "987-65-4321"}
+ ]`
+
+ var parsed []any
+ require.NoError(t, json.Unmarshal([]byte(jsonInput), &parsed))
+
+ result, err := ObfuscateStruct(parsed, r)
+ require.NoError(t, err)
+
+ arr := result.([]any)
+ require.Len(t, arr, 2)
+
+ assert.Equal(t, "REDACTED", arr[0].(map[string]any)["ssn"])
+ assert.Equal(t, "REDACTED", arr[1].(map[string]any)["ssn"])
+ assert.Equal(t, "Alice", arr[0].(map[string]any)["name"])
+ assert.Equal(t, "Bob", arr[1].(map[string]any)["name"])
+}
+
+// ===========================================================================
+// 12. All three actions end-to-end through ObfuscateStruct
+// ===========================================================================
+
+func TestObfuscateStruct_AllActionsEndToEnd(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^document$`, Action: RedactionHash},
+ {FieldPattern: `(?i)^token$`, PathPattern: `^session\.token$`, Action: RedactionDrop},
+ }, "***")
+
+ input := map[string]any{
+ "password": "secret",
+ "document": "123456789",
+ "session": map[string]any{"token": "tok_abc"},
+ "name": "visible",
}
- err := SetSpanAttributesFromStructWithObfuscation(&span, "invalid_data", invalidStruct)
+ result, err := ObfuscateStruct(input, r)
+ require.NoError(t, err)
+
+ m := result.(map[string]any)
+
+ // Mask
+ assert.Equal(t, "***", m["password"])
+
+ // Hash
+ hashed, ok := m["document"].(string)
+ require.True(t, ok)
+ assert.True(t, strings.HasPrefix(hashed, "sha256:"))
+ assert.NotEqual(t, "123456789", hashed)
- assert.Error(t, err)
+ // Drop
+ session, ok := m["session"].(map[string]any)
+ require.True(t, ok)
+ assert.NotContains(t, session, "token")
+
+ // Pass-through
+ assert.Equal(t, "visible", m["name"])
}
-// MockObfuscator is a custom obfuscator for testing
-type MockObfuscator struct {
- shouldObfuscateFunc func(string) bool
- obfuscatedValue string
+// ===========================================================================
+// 13. Edge cases
+// ===========================================================================
+
+func TestObfuscateStruct_FieldWithDotsInKey(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ // JSON keys with dots are just keys -- dots form paths only via nesting
+ input := map[string]any{
+ "db.password": "val", // this is a single key, NOT nested
+ }
+
+ result, err := ObfuscateStruct(input, r)
+ require.NoError(t, err)
+
+ m := result.(map[string]any)
+ // The field name is "db.password", not "password", so the direct field regex
+ // `^password$` does NOT match. When a rule has a compiled fieldRegex,
+ // IsSensitiveField is NOT used as fallback. Therefore the value passes through.
+ val, ok := m["db.password"]
+ require.True(t, ok, "key 'db.password' must exist in result map")
+ assert.Equal(t, "val", val, "dotted key should not be matched by ^password$ regex")
}
-func (m *MockObfuscator) ShouldObfuscate(fieldName string) bool {
- if m.shouldObfuscateFunc != nil {
- return m.shouldObfuscateFunc(fieldName)
+func TestObfuscateStruct_DeeplyNestedArrayOfObjects(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^secret$`, Action: RedactionMask},
+ }, "HIDDEN")
+
+ input := map[string]any{
+ "data": []any{
+ map[string]any{
+ "nested": []any{
+ map[string]any{
+ "secret": "deep_secret",
+ "visible": "ok",
+ },
+ },
+ },
+ },
}
- return false
+
+ result, err := ObfuscateStruct(input, r)
+ require.NoError(t, err)
+
+ data := result.(map[string]any)["data"].([]any)
+ nested := data[0].(map[string]any)["nested"].([]any)
+ item := nested[0].(map[string]any)
+ assert.Equal(t, "HIDDEN", item["secret"])
+ assert.Equal(t, "ok", item["visible"])
}
-func (m *MockObfuscator) GetObfuscatedValue() string {
- return m.obfuscatedValue
+func TestObfuscateStruct_NumericValues(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^pin$`, Action: RedactionMask},
+ }, "***")
+
+ // When marshaled via JSON with UseNumber(), numeric values become json.Number
+ input := map[string]any{
+ "pin": float64(1234),
+ "count": float64(10),
+ }
+
+ result, err := ObfuscateStruct(input, r)
+ require.NoError(t, err)
+
+ m := result.(map[string]any)
+ assert.Equal(t, "***", m["pin"])
+ assert.Equal(t, json.Number("10"), m["count"])
}
-func TestCustomObfuscatorInterface(t *testing.T) {
- // Create a no-op tracer for testing
- tracer := noop.NewTracerProvider().Tracer("test")
- _, span := tracer.Start(context.TODO(), "test-span")
+func TestObfuscateStruct_BooleanSensitiveField(t *testing.T) {
+ t.Parallel()
- testStruct := map[string]any{
- "public": "visible",
- "private": "hidden",
- "secret": "classified",
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^secret$`, Action: RedactionHash},
+ }, "")
+
+ input := map[string]any{
+ "secret": true,
}
- mockObfuscator := &MockObfuscator{
- shouldObfuscateFunc: func(fieldName string) bool {
- return fieldName == "private" || fieldName == "secret"
- },
- obfuscatedValue: "[REDACTED]",
+ result, err := ObfuscateStruct(input, r)
+ require.NoError(t, err)
+
+ m := result.(map[string]any)
+ hashed, ok := m["secret"].(string)
+ require.True(t, ok)
+ assert.True(t, strings.HasPrefix(hashed, "sha256:"))
+}
+
+// ===========================================================================
+// 14. Processor: redactAttributesByKey
+// ===========================================================================
+
+// These tests are in the internal package to test redactAttributesByKey directly.
+// The main processor_test.go already covers the basic flow; here we add edge cases.
+
+func TestRedactAttributesByKey_NilRedactor(t *testing.T) {
+ t.Parallel()
+
+ attrs := []attribute.KeyValue{
+ attribute.String("foo", "bar"),
}
- err := SetSpanAttributesFromStructWithCustomObfuscation(&span, "test_data", testStruct, mockObfuscator)
- assert.NoError(t, err)
+ result := redactAttributesByKey(attrs, nil)
+ assert.Equal(t, attrs, result, "nil redactor returns attributes unchanged")
}
-func TestObfuscatedValueConstant(t *testing.T) {
- assert.Equal(t, "********", cn.ObfuscatedValue)
+func TestRedactAttributesByKey_EmptyAttrs(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `^password$`, Action: RedactionMask},
+ }, "***")
+
+ result := redactAttributesByKey(nil, r)
+ assert.Empty(t, result)
}
-// TestSanitizeUTF8String tests the UTF-8 sanitization helper function
-func TestSanitizeUTF8String(t *testing.T) {
- tests := []struct {
- name string
- input string
- expected string
- }{
- {
- name: "valid UTF-8 string",
- input: "valid UTF-8 string",
- expected: "valid UTF-8 string",
- },
- {
- name: "invalid UTF-8 sequence",
- input: "invalid\x80string", // Invalid UTF-8 sequence
- expected: "invalid�string", // Replaced with Unicode replacement character
- },
- {
- name: "multiple invalid UTF-8 sequences",
- input: "test\xFFvalue\x80end", // Multiple invalid sequences
- expected: "test�value�end", // Each invalid byte replaced with Unicode replacement character
- },
- {
- name: "empty string",
- input: "",
- expected: "",
- },
- {
- name: "unicode characters (valid)",
- input: "测试字符串", // Chinese characters
- expected: "测试字符串",
- },
- {
- name: "mixed valid and invalid UTF-8",
- input: "测试\x80test字符\xFF", // Valid Chinese + invalid + valid Chinese + invalid
- expected: "测试�test字符�",
- },
- {
- name: "only invalid UTF-8",
- input: "\x80\xFF\xFE", // Consecutive invalid bytes
- expected: "�", // Consecutive invalid bytes become single replacement character
- },
- {
- name: "ASCII with invalid UTF-8",
- input: "Hello\x80World", // ASCII + invalid
- expected: "Hello�World",
- },
+func TestRedactAttributesByKey_DottedKeyExtractsFieldName(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ attrs := []attribute.KeyValue{
+ attribute.String("user.password", "secret"),
+ attribute.String("user.name", "alice"),
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := sanitizeUTF8String(tt.input)
- assert.Equal(t, tt.expected, result)
- })
+ result := redactAttributesByKey(attrs, r)
+ values := make(map[string]string, len(result))
+ for _, a := range result {
+ values[string(a.Key)] = a.Value.AsString()
}
+
+ assert.Equal(t, "***", values["user.password"])
+ assert.Equal(t, "alice", values["user.name"])
}
-// TestSetSpanAttributesWithUTF8Sanitization tests the integration of UTF-8 sanitization
-// with the span attribute setting functions
-func TestSetSpanAttributesWithUTF8Sanitization(t *testing.T) {
- // Create a no-op tracer for testing
- tracer := noop.NewTracerProvider().Tracer("test")
- _, span := tracer.Start(context.TODO(), "test-span")
+func TestRedactAttributesByKey_DropRemovesAttribute(t *testing.T) {
+ t.Parallel()
- tests := []struct {
- name string
- key string
- valueStruct any
- expectError bool
- }{
- {
- name: "struct with invalid UTF-8 in JSON output",
- key: "test\x80key", // Invalid UTF-8 in key
- valueStruct: struct {
- Name string `json:"name"`
- }{
- Name: "test\xFFvalue", // This will be in the JSON, but JSON marshaling handles UTF-8
- },
- expectError: false,
- },
- {
- name: "valid UTF-8 struct",
- key: "valid_key",
- valueStruct: TestStruct{
- Username: "测试用户", // Chinese characters
- Password: "secret123",
- Email: "test@example.com",
- },
- expectError: false,
- },
- {
- name: "struct that cannot be marshaled",
- key: "invalid_struct",
- valueStruct: struct {
- Channel chan int
- }{
- Channel: make(chan int),
- },
- expectError: true,
- },
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ }, "")
+
+ attrs := []attribute.KeyValue{
+ attribute.String("auth.token", "tok_123"),
+ attribute.String("auth.type", "bearer"),
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := SetSpanAttributesFromStructWithObfuscation(&span, tt.key, tt.valueStruct)
+ result := redactAttributesByKey(attrs, r)
+ require.Len(t, result, 1)
+ assert.Equal(t, "auth.type", string(result[0].Key))
+}
- if tt.expectError {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- }
- })
+func TestRedactAttributesByKey_HashProducesConsistentOutput(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^document$`, Action: RedactionHash},
+ }, "")
+
+ attrs := []attribute.KeyValue{
+ attribute.String("customer.document", "123456789"),
}
+
+ result1 := redactAttributesByKey(attrs, r)
+ result2 := redactAttributesByKey(attrs, r)
+
+ require.Len(t, result1, 1)
+ require.Len(t, result2, 1)
+ assert.Equal(t, result1[0].Value.AsString(), result2[0].Value.AsString())
+ assert.True(t, strings.HasPrefix(result1[0].Value.AsString(), "sha256:"))
}
-// TestUTF8SanitizationWithCustomObfuscator tests UTF-8 sanitization with custom obfuscator
-func TestUTF8SanitizationWithCustomObfuscator(t *testing.T) {
- // Create a no-op tracer for testing
- tracer := noop.NewTracerProvider().Tracer("test")
- _, span := tracer.Start(context.TODO(), "test-span")
+func TestRedactAttributesByKey_MultipleAttributes(t *testing.T) {
+ t.Parallel()
- // Create a struct with UTF-8 content
- testStruct := struct {
- Name string `json:"name"`
- Password string `json:"password"`
- City string `json:"city"`
- }{
- Name: "测试用户", // Chinese characters
- Password: "秘密123", // Chinese + ASCII
- City: "北京", // Chinese characters
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ {FieldPattern: `(?i)^document$`, Action: RedactionHash},
+ }, "***")
+
+ attrs := []attribute.KeyValue{
+ attribute.String("user.id", "u1"),
+ attribute.String("user.password", "secret"),
+ attribute.String("auth.token", "tok_123"),
+ attribute.String("customer.document", "123456789"),
+ attribute.Int64("request.count", 5),
}
- // Test with custom obfuscator
- customObfuscator := NewCustomObfuscator([]string{"password"})
- err := SetSpanAttributesFromStructWithCustomObfuscation(&span, "user\x80data", testStruct, customObfuscator)
+ result := redactAttributesByKey(attrs, r)
- // Should not error even with invalid UTF-8 in key
- assert.NoError(t, err)
+ values := make(map[string]string, len(result))
+ for _, a := range result {
+ values[string(a.Key)] = a.Value.Emit()
+ }
+
+ assert.Equal(t, "u1", values["user.id"])
+ assert.Equal(t, "***", values["user.password"])
+ assert.NotContains(t, values, "auth.token")
+ assert.True(t, strings.HasPrefix(values["customer.document"], "sha256:"))
+ assert.Equal(t, "5", values["request.count"])
}
-// BenchmarkSanitizeUTF8String benchmarks the UTF-8 sanitization function
-func BenchmarkSanitizeUTF8String(b *testing.B) {
- tests := []struct {
- name string
- input string
- }{
- {
- name: "valid UTF-8",
- input: "valid UTF-8 string with unicode: 测试",
- },
- {
- name: "invalid UTF-8",
- input: "invalid\x80string\xFFwith\xFEmultiple",
- },
- {
- name: "short valid string",
- input: "test",
- },
- {
- name: "short invalid string",
- input: "\x80",
- },
+func TestRedactAttributesByKey_KeyWithoutDot(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ attrs := []attribute.KeyValue{
+ attribute.String("password", "secret"),
}
- for _, tt := range tests {
- b.Run(tt.name, func(b *testing.B) {
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- _ = sanitizeUTF8String(tt.input)
+ result := redactAttributesByKey(attrs, r)
+ require.Len(t, result, 1)
+ assert.Equal(t, "***", result[0].Value.AsString())
+}
+
+// ===========================================================================
+// 15. Processor interface compliance
+// ===========================================================================
+
+func TestAttrBagSpanProcessor_NoOpMethods(t *testing.T) {
+ t.Parallel()
+
+ p := AttrBagSpanProcessor{}
+ assert.NoError(t, p.Shutdown(nil))
+ assert.NoError(t, p.ForceFlush(nil))
+}
+
+func TestRedactingAttrBagSpanProcessor_NoOpMethods(t *testing.T) {
+ t.Parallel()
+
+ p := RedactingAttrBagSpanProcessor{}
+ assert.NoError(t, p.Shutdown(nil))
+ assert.NoError(t, p.ForceFlush(nil))
+}
+
+// ===========================================================================
+// 16. RedactionAction constants
+// ===========================================================================
+
+func TestRedactionActionConstants(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, RedactionAction("mask"), RedactionMask)
+ assert.Equal(t, RedactionAction("hash"), RedactionHash)
+ assert.Equal(t, RedactionAction("drop"), RedactionDrop)
+}
+
+// ===========================================================================
+// 17. Concurrency safety
+// ===========================================================================
+
+func TestObfuscateStruct_ConcurrentSafety(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^email$`, Action: RedactionHash},
+ {FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ }, "***")
+
+ // We rely on -race flag to detect data races. Here we just exercise
+ // concurrent calls to ensure no panics.
+ done := make(chan struct{}, 50)
+ for i := range 50 {
+ go func(idx int) {
+ defer func() { done <- struct{}{} }()
+
+ payload := map[string]any{
+ "id": fmt.Sprintf("user_%d", idx),
+ "password": "secret",
+ "email": fmt.Sprintf("user%d@test.com", idx),
+ "token": "tok_concurrent",
+ "data": map[string]any{
+ "password": "nested_secret",
+ },
+ }
+
+ result, err := ObfuscateStruct(payload, r)
+ if err != nil {
+ t.Errorf("concurrent ObfuscateStruct failed: %v", err)
+ return
+ }
+
+ m := result.(map[string]any)
+ if m["password"] != "***" {
+ t.Errorf("expected masked password, got %v", m["password"])
}
- })
+ }(i)
+ }
+
+ for range 50 {
+ <-done
+ }
+}
+
+func TestRedactAttributesByKey_ConcurrentSafety(t *testing.T) {
+ t.Parallel()
+
+ r := mustRedactor(t, []RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ }, "***")
+
+ attrs := []attribute.KeyValue{
+ attribute.String("user.password", "secret"),
+ attribute.String("user.name", "alice"),
+ }
+
+ done := make(chan struct{}, 50)
+ for range 50 {
+ go func() {
+ defer func() { done <- struct{}{} }()
+ result := redactAttributesByKey(attrs, r)
+ if len(result) != 2 {
+ t.Errorf("expected 2 attributes, got %d", len(result))
+ }
+ }()
+ }
+
+ for range 50 {
+ <-done
}
}
diff --git a/commons/opentelemetry/otel.go b/commons/opentelemetry/otel.go
index 1734e3a3..fcb7980f 100644
--- a/commons/opentelemetry/otel.go
+++ b/commons/opentelemetry/otel.go
@@ -1,24 +1,25 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package opentelemetry
import (
+ "bytes"
"context"
"encoding/json"
"errors"
"fmt"
- stdlog "log"
"maps"
"net/http"
+ "os"
+ "reflect"
+ "strconv"
"strings"
"unicode/utf8"
- "github.com/LerianStudio/lib-commons/v2/commons"
- constant "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry/metrics"
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+ "github.com/LerianStudio/lib-commons/v4/commons/security"
"github.com/gofiber/fiber/v2"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
@@ -28,6 +29,7 @@ import (
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/log/global"
+ "go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/propagation"
sdklog "go.opentelemetry.io/otel/sdk/log"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
@@ -38,13 +40,27 @@ import (
"google.golang.org/grpc/metadata"
)
+const (
+ maxSpanAttributeStringLength = 4096
+ maxAttributeDepth = 32
+ maxAttributeCount = 128
+ defaultAttrPrefix = "value"
+)
+
var (
- // ErrNilTelemetryConfig indicates that nil config was provided to InitializeTelemetryWithError
- ErrNilTelemetryConfig = errors.New("telemetry config cannot be nil")
- // ErrNilTelemetryLogger indicates that config.Logger is nil
+ // ErrNilTelemetryLogger is returned when telemetry config has no logger.
ErrNilTelemetryLogger = errors.New("telemetry config logger cannot be nil")
+ // ErrEmptyEndpoint is returned when telemetry is enabled without exporter endpoint.
+ ErrEmptyEndpoint = errors.New("collector exporter endpoint cannot be empty when telemetry is enabled")
+ // ErrNilTelemetry is returned when a telemetry method receives a nil receiver.
+ ErrNilTelemetry = errors.New("telemetry instance is nil")
+ // ErrNilShutdown is returned when telemetry shutdown handlers are unavailable.
+ ErrNilShutdown = errors.New("telemetry shutdown function is nil")
+ // ErrNilProvider is returned when ApplyGlobals is called with nil providers.
+ ErrNilProvider = errors.New("telemetry providers must not be nil when applying globals")
)
+// TelemetryConfig configures tracing, metrics, logging, and propagation behavior.
type TelemetryConfig struct {
LibraryName string
ServiceName string
@@ -52,22 +68,308 @@ type TelemetryConfig struct {
DeploymentEnv string
CollectorExporterEndpoint string
EnableTelemetry bool
+ InsecureExporter bool
Logger log.Logger
+ Propagator propagation.TextMapPropagator
+ Redactor *Redactor
}
+// Telemetry holds configured OpenTelemetry providers and lifecycle handlers.
type Telemetry struct {
TelemetryConfig
TracerProvider *sdktrace.TracerProvider
- MetricProvider *sdkmetric.MeterProvider
+ MeterProvider *sdkmetric.MeterProvider
LoggerProvider *sdklog.LoggerProvider
MetricsFactory *metrics.MetricsFactory
shutdown func()
+ shutdownCtx func(context.Context) error
+}
+
+// NewTelemetry builds telemetry providers and exporters from configuration.
+func NewTelemetry(cfg TelemetryConfig) (*Telemetry, error) {
+ if cfg.Logger == nil {
+ return nil, ErrNilTelemetryLogger
+ }
+
+ if cfg.Propagator == nil {
+ cfg.Propagator = propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})
+ }
+
+ if cfg.Redactor == nil {
+ cfg.Redactor = NewDefaultRedactor()
+ }
+
+ normalizeEndpoint(&cfg)
+ normalizeEndpointEnvVars()
+
+ if cfg.EnableTelemetry && strings.TrimSpace(cfg.CollectorExporterEndpoint) == "" {
+ return handleEmptyEndpoint(cfg)
+ }
+
+ ctx := context.Background()
+
+ if !cfg.EnableTelemetry {
+ cfg.Logger.Log(ctx, log.LevelWarn, "Telemetry disabled")
+
+ return newNoopTelemetry(cfg)
+ }
+
+ if cfg.InsecureExporter && cfg.DeploymentEnv != "" &&
+ cfg.DeploymentEnv != "development" && cfg.DeploymentEnv != "local" {
+ cfg.Logger.Log(ctx, log.LevelWarn,
+ "InsecureExporter is enabled in non-development environment",
+ log.String("environment", cfg.DeploymentEnv))
+ }
+
+ return initExporters(ctx, cfg)
+}
+
+// normalizeEndpoint strips URL scheme from the collector endpoint and infers security mode.
+// gRPC WithEndpoint() expects host:port, not a full URL.
+// Consumers commonly pass OTEL_EXPORTER_OTLP_ENDPOINT as "http://host:4317".
+func normalizeEndpoint(cfg *TelemetryConfig) {
+ ep := strings.TrimSpace(cfg.CollectorExporterEndpoint)
+ if ep == "" {
+ return
+ }
+
+ switch {
+ case strings.HasPrefix(ep, "http://"):
+ cfg.CollectorExporterEndpoint = strings.TrimPrefix(ep, "http://")
+ cfg.InsecureExporter = true
+ case strings.HasPrefix(ep, "https://"):
+ cfg.CollectorExporterEndpoint = strings.TrimPrefix(ep, "https://")
+ default:
+ // No scheme — assume insecure (common in k8s internal comms).
+ cfg.InsecureExporter = true
+ }
+}
+
+// normalizeEndpointEnvVars ensures OTEL exporter endpoint environment variables
+// contain a URL scheme. The OTEL SDK's envconfig reads these via url.Parse(),
+// which fails on bare "host:port" values. Adding "http://" prevents noisy
+// "parse url" errors from the SDK's internal logger.
+func normalizeEndpointEnvVars() {
+ for _, key := range []string{
+ "OTEL_EXPORTER_OTLP_ENDPOINT",
+ "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT",
+ "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT",
+ } {
+ v := strings.TrimSpace(os.Getenv(key))
+ if v == "" || strings.HasPrefix(v, "http://") || strings.HasPrefix(v, "https://") {
+ continue
+ }
+
+ _ = os.Setenv(key, "http://"+v)
+ }
+}
+
+// handleEmptyEndpoint handles the case where telemetry is enabled but the collector
+// endpoint is empty, returning noop providers installed as globals.
+func handleEmptyEndpoint(cfg TelemetryConfig) (*Telemetry, error) {
+ cfg.Logger.Log(context.Background(), log.LevelWarn,
+ "Telemetry enabled but collector endpoint is empty; falling back to noop providers")
+
+ tl, noopErr := newNoopTelemetry(cfg)
+ if noopErr != nil {
+ return nil, noopErr
+ }
+
+ // Set noop providers as globals so downstream libraries (e.g. otelfiber)
+ // do not create real gRPC exporters that leak background goroutines.
+ _ = tl.ApplyGlobals()
+
+ return tl, ErrEmptyEndpoint
+}
+
+// initExporters creates OTLP exporters, providers, and a metrics factory,
+// rolling back partial allocations on failure.
+func initExporters(ctx context.Context, cfg TelemetryConfig) (*Telemetry, error) {
+ r := cfg.newResource()
+
+ // Track all allocated resources for rollback if a later step fails.
+ var cleanups []shutdownable
+
+ tExp, err := cfg.newTracerExporter(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("can't initialize tracer exporter: %w", err)
+ }
+
+ cleanups = append(cleanups, tExp)
+
+ mExp, err := cfg.newMetricExporter(ctx)
+ if err != nil {
+ shutdownAll(ctx, cleanups)
+
+ return nil, fmt.Errorf("can't initialize metric exporter: %w", err)
+ }
+
+ cleanups = append(cleanups, mExp)
+
+ lExp, err := cfg.newLoggerExporter(ctx)
+ if err != nil {
+ shutdownAll(ctx, cleanups)
+
+ return nil, fmt.Errorf("can't initialize logger exporter: %w", err)
+ }
+
+ cleanups = append(cleanups, lExp)
+
+ mp := cfg.newMeterProvider(r, mExp)
+ cleanups = append(cleanups, mp)
+
+ tp := cfg.newTracerProvider(r, tExp)
+ cleanups = append(cleanups, tp)
+
+ lp := cfg.newLoggerProvider(r, lExp)
+ cleanups = append(cleanups, lp)
+
+ metricsFactory, err := metrics.NewMetricsFactory(mp.Meter(cfg.LibraryName), cfg.Logger)
+ if err != nil {
+ shutdownAll(ctx, cleanups)
+
+ return nil, err
+ }
+
+ shutdown, shutdownCtx := buildShutdownHandlers(cfg.Logger, mp, tp, lp, tExp, mExp, lExp)
+
+ return &Telemetry{
+ TelemetryConfig: cfg,
+ TracerProvider: tp,
+ MeterProvider: mp,
+ LoggerProvider: lp,
+ MetricsFactory: metricsFactory,
+ shutdown: shutdown,
+ shutdownCtx: shutdownCtx,
+ }, nil
+}
+
+// newNoopTelemetry creates a Telemetry instance with no-op providers (no exporters).
+// This is used when telemetry is disabled or when the collector endpoint is empty,
+// ensuring global OTEL providers are safe no-ops that do not leak goroutines.
+func newNoopTelemetry(cfg TelemetryConfig) (*Telemetry, error) {
+ mp := sdkmetric.NewMeterProvider()
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(RedactingAttrBagSpanProcessor{Redactor: cfg.Redactor}))
+ lp := sdklog.NewLoggerProvider()
+
+ metricsFactory, err := metrics.NewMetricsFactory(mp.Meter(cfg.LibraryName), cfg.Logger)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Telemetry{
+ TelemetryConfig: cfg,
+ TracerProvider: tp,
+ MeterProvider: mp,
+ LoggerProvider: lp,
+ MetricsFactory: metricsFactory,
+ shutdown: func() {},
+ shutdownCtx: func(context.Context) error { return nil },
+ }, nil
+}
+
+// shutdownAll performs best-effort shutdown of all allocated components.
+// Used during NewTelemetry to roll back partial allocations on failure.
+func shutdownAll(ctx context.Context, components []shutdownable) {
+ for _, c := range components {
+ if isNilShutdownable(c) {
+ continue
+ }
+
+ _ = c.Shutdown(ctx)
+ }
+}
+
+// ApplyGlobals sets this instance as the process-global OTEL providers/propagator.
+// Returns an error if any required provider is nil.
+func (tl *Telemetry) ApplyGlobals() error {
+ if tl == nil {
+ return ErrNilTelemetry
+ }
+
+ if tl.TracerProvider == nil || tl.MeterProvider == nil || tl.Propagator == nil {
+ return ErrNilProvider
+ }
+
+ otel.SetTracerProvider(tl.TracerProvider)
+ otel.SetMeterProvider(tl.MeterProvider)
+
+ if tl.LoggerProvider != nil {
+ global.SetLoggerProvider(tl.LoggerProvider)
+ }
+
+ otel.SetTextMapPropagator(tl.Propagator)
+
+ return nil
+}
+
+// Tracer returns a tracer from this telemetry instance.
+func (tl *Telemetry) Tracer(name string) (trace.Tracer, error) {
+ if tl == nil || tl.TracerProvider == nil {
+ // Logger is intentionally nil: nil/incomplete Telemetry means no reliable logger available.
+ asserter := assert.New(context.Background(), nil, "opentelemetry", "Tracer")
+ _ = asserter.NoError(context.Background(), ErrNilTelemetry, "telemetry tracer provider is nil")
+
+ return nil, ErrNilTelemetry
+ }
+
+ return tl.TracerProvider.Tracer(name), nil
+}
+
+// Meter returns a meter from this telemetry instance.
+func (tl *Telemetry) Meter(name string) (metric.Meter, error) {
+ if tl == nil || tl.MeterProvider == nil {
+ // Logger is intentionally nil: nil/incomplete Telemetry means no reliable logger available.
+ asserter := assert.New(context.Background(), nil, "opentelemetry", "Meter")
+ _ = asserter.NoError(context.Background(), ErrNilTelemetry, "telemetry meter provider is nil")
+
+ return nil, ErrNilTelemetry
+ }
+
+ return tl.MeterProvider.Meter(name), nil
+}
+
+// ShutdownTelemetry shuts down telemetry components using background context.
+func (tl *Telemetry) ShutdownTelemetry() {
+ if tl == nil {
+ return
+ }
+
+ if err := tl.ShutdownTelemetryWithContext(context.Background()); err != nil {
+ asserter := assert.New(context.Background(), tl.Logger, "opentelemetry", "ShutdownTelemetry")
+ _ = asserter.NoError(context.Background(), err, "telemetry shutdown failed")
+
+ return
+ }
+}
+
+// ShutdownTelemetryWithContext shuts down telemetry components with caller context.
+func (tl *Telemetry) ShutdownTelemetryWithContext(ctx context.Context) error {
+ if tl == nil {
+ // Logger is intentionally nil: nil receiver means no Telemetry instance to extract logger from.
+ asserter := assert.New(context.Background(), nil, "opentelemetry", "ShutdownTelemetryWithContext")
+ _ = asserter.NoError(context.Background(), ErrNilTelemetry, "cannot shutdown nil telemetry")
+
+ return ErrNilTelemetry
+ }
+
+ if tl.shutdownCtx != nil {
+ return tl.shutdownCtx(ctx)
+ }
+
+ if tl.shutdown != nil {
+ tl.shutdown()
+ return nil
+ }
+
+ asserter := assert.New(context.Background(), tl.Logger, "opentelemetry", "ShutdownTelemetryWithContext")
+ _ = asserter.NoError(context.Background(), ErrNilShutdown, "cannot shutdown telemetry without configured shutdown function")
+
+ return ErrNilShutdown
}
-// NewResource creates a new resource with custom attributes.
func (tl *TelemetryConfig) newResource() *sdkresource.Resource {
- // Create a resource with only our custom attributes to avoid schema URL conflicts
- r := sdkresource.NewWithAttributes(
+ return sdkresource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceName(tl.ServiceName),
semconv.ServiceVersion(tl.ServiceVersion),
@@ -75,379 +377,456 @@ func (tl *TelemetryConfig) newResource() *sdkresource.Resource {
semconv.TelemetrySDKName(constant.TelemetrySDKName),
semconv.TelemetrySDKLanguageGo,
)
-
- return r
}
-// NewLoggerExporter creates a new logger exporter that writes to stdout.
func (tl *TelemetryConfig) newLoggerExporter(ctx context.Context) (*otlploggrpc.Exporter, error) {
- exporter, err := otlploggrpc.New(ctx, otlploggrpc.WithEndpoint(tl.CollectorExporterEndpoint), otlploggrpc.WithInsecure())
- if err != nil {
- return nil, err
+ opts := []otlploggrpc.Option{otlploggrpc.WithEndpoint(tl.CollectorExporterEndpoint)}
+ if tl.InsecureExporter {
+ opts = append(opts, otlploggrpc.WithInsecure())
}
- return exporter, nil
+ return otlploggrpc.New(ctx, opts...)
}
-// newMetricExporter creates a new metric exporter that writes to stdout.
func (tl *TelemetryConfig) newMetricExporter(ctx context.Context) (*otlpmetricgrpc.Exporter, error) {
- exp, err := otlpmetricgrpc.New(ctx, otlpmetricgrpc.WithEndpoint(tl.CollectorExporterEndpoint), otlpmetricgrpc.WithInsecure())
- if err != nil {
- return nil, err
+ opts := []otlpmetricgrpc.Option{otlpmetricgrpc.WithEndpoint(tl.CollectorExporterEndpoint)}
+ if tl.InsecureExporter {
+ opts = append(opts, otlpmetricgrpc.WithInsecure())
}
- return exp, nil
+ return otlpmetricgrpc.New(ctx, opts...)
}
-// newTracerExporter creates a new tracer exporter that writes to stdout.
func (tl *TelemetryConfig) newTracerExporter(ctx context.Context) (*otlptrace.Exporter, error) {
- exporter, err := otlptracegrpc.New(ctx, otlptracegrpc.WithEndpoint(tl.CollectorExporterEndpoint), otlptracegrpc.WithInsecure())
- if err != nil {
- return nil, err
+ opts := []otlptracegrpc.Option{otlptracegrpc.WithEndpoint(tl.CollectorExporterEndpoint)}
+ if tl.InsecureExporter {
+ opts = append(opts, otlptracegrpc.WithInsecure())
}
- return exporter, nil
+ return otlptracegrpc.New(ctx, opts...)
}
-// newLoggerProvider creates a new logger provider with stdout exporter and default resource.
func (tl *TelemetryConfig) newLoggerProvider(rsc *sdkresource.Resource, exp *otlploggrpc.Exporter) *sdklog.LoggerProvider {
bp := sdklog.NewBatchProcessor(exp)
- lp := sdklog.NewLoggerProvider(sdklog.WithResource(rsc), sdklog.WithProcessor(bp))
-
- return lp
+ return sdklog.NewLoggerProvider(sdklog.WithResource(rsc), sdklog.WithProcessor(bp))
}
-// newMeterProvider creates a new meter provider with stdout exporter and default resource.
func (tl *TelemetryConfig) newMeterProvider(res *sdkresource.Resource, exp *otlpmetricgrpc.Exporter) *sdkmetric.MeterProvider {
- mp := sdkmetric.NewMeterProvider(
+ return sdkmetric.NewMeterProvider(
sdkmetric.WithResource(res),
sdkmetric.WithReader(sdkmetric.NewPeriodicReader(exp)),
)
-
- return mp
}
-// newTracerProvider creates a new tracer provider with stdout exporter and default resource.
func (tl *TelemetryConfig) newTracerProvider(rsc *sdkresource.Resource, exp *otlptrace.Exporter) *sdktrace.TracerProvider {
- tp := sdktrace.NewTracerProvider(
- sdktrace.WithBatcher(exp),
+ return sdktrace.NewTracerProvider(
sdktrace.WithResource(rsc),
- sdktrace.WithSpanProcessor(AttrBagSpanProcessor{}),
+ sdktrace.WithSpanProcessor(RedactingAttrBagSpanProcessor{Redactor: tl.Redactor}),
+ sdktrace.WithBatcher(exp),
)
+}
- return tp
+type shutdownable interface {
+ Shutdown(ctx context.Context) error
}
-// ShutdownTelemetry shuts down the telemetry providers and exporters.
-func (tl *Telemetry) ShutdownTelemetry() {
- tl.shutdown()
+// isNilShutdownable checks for both untyped nil and interface-wrapped typed nil
+// (e.g., a concrete pointer that is nil but stored in a shutdownable interface).
+func isNilShutdownable(s shutdownable) bool {
+ if s == nil {
+ return true
+ }
+
+ v := reflect.ValueOf(s)
+
+ return v.Kind() == reflect.Ptr && v.IsNil()
}
-// InitializeTelemetryWithError initializes the telemetry providers and sets them globally.
-// Returns an error instead of calling Fatalf on failure.
-func InitializeTelemetryWithError(cfg *TelemetryConfig) (*Telemetry, error) {
- if cfg == nil {
- return nil, ErrNilTelemetryConfig
+func buildShutdownHandlers(l log.Logger, components ...shutdownable) (func(), func(context.Context) error) {
+ shutdown := func() {
+ ctx := context.Background()
+
+ for _, c := range components {
+ if isNilShutdownable(c) {
+ continue
+ }
+
+ if err := c.Shutdown(ctx); err != nil {
+ l.Log(ctx, log.LevelError, "telemetry shutdown error", log.Err(err))
+ }
+ }
}
- if cfg.Logger == nil {
- return nil, ErrNilTelemetryLogger
+ shutdownCtx := func(ctx context.Context) error {
+ var errs []error
+
+ for _, c := range components {
+ if isNilShutdownable(c) {
+ continue
+ }
+
+ if err := c.Shutdown(ctx); err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ return errors.Join(errs...)
}
- ctx := context.Background()
- l := cfg.Logger
+ return shutdown, shutdownCtx
+}
- if !cfg.EnableTelemetry {
- l.Warn("Telemetry turned off ⚠️ ")
+// isNilSpan checks for both untyped nil and interface-wrapped typed nil values.
+// trace.Span is an interface, so a concrete pointer that is nil but stored in
+// a trace.Span variable would pass a simple `span == nil` check.
+func isNilSpan(span trace.Span) bool {
+ if span == nil {
+ return true
+ }
+
+ v := reflect.ValueOf(span)
+
+ return v.Kind() == reflect.Ptr && v.IsNil()
+}
+
+// maxSpanErrorLength is the maximum length for error messages written to span status/events.
+const maxSpanErrorLength = 1024
+
+// sanitizeSpanMessage sanitizes an error message for span output:
+// - Truncates to a safe maximum length
+// - Strips common sensitive-looking patterns (bearer tokens, passwords in URLs)
+func sanitizeSpanMessage(msg string) string {
+ // Strip common sensitive patterns
+ for _, pattern := range []struct{ prefix, replacement string }{
+ {"Bearer ", "Bearer [REDACTED]"},
+ {"Basic ", "Basic [REDACTED]"},
+ } {
+ if idx := strings.Index(msg, pattern.prefix); idx >= 0 {
+ end := idx + len(pattern.prefix)
+ // Find the end of the token (next space or end of string)
+ tokenEnd := strings.IndexByte(msg[end:], ' ')
+ if tokenEnd < 0 {
+ msg = msg[:idx] + pattern.replacement
+ } else {
+ msg = msg[:idx] + pattern.replacement + msg[end+tokenEnd:]
+ }
+ }
+ }
- mp := sdkmetric.NewMeterProvider()
- tp := sdktrace.NewTracerProvider()
- lp := sdklog.NewLoggerProvider()
+ if len(msg) > maxSpanErrorLength {
+ msg = msg[:maxSpanErrorLength]
+ // Ensure valid UTF-8 after truncation
+ if !utf8.ValidString(msg) {
+ msg = strings.ToValidUTF8(msg, "")
+ }
+ }
- metricsFactory := metrics.NewMetricsFactory(mp.Meter(cfg.LibraryName), l)
+ return msg
+}
- return &Telemetry{
- TelemetryConfig: *cfg,
- TracerProvider: tp,
- MetricProvider: mp,
- LoggerProvider: lp,
- MetricsFactory: metricsFactory,
- shutdown: func() {},
- }, nil
+// HandleSpanBusinessErrorEvent records a business-error event on a span.
+func HandleSpanBusinessErrorEvent(span trace.Span, eventName string, err error) {
+ if isNilSpan(span) || err == nil {
+ return
}
- l.Infof("Initializing telemetry...")
+ span.AddEvent(eventName, trace.WithAttributes(attribute.String("error", sanitizeSpanMessage(err.Error()))))
+}
- r := cfg.newResource()
+// HandleSpanEvent records a generic event with optional attributes on a span.
+func HandleSpanEvent(span trace.Span, eventName string, attributes ...attribute.KeyValue) {
+ if isNilSpan(span) {
+ return
+ }
- tExp, err := cfg.newTracerExporter(ctx)
- if err != nil {
- return nil, fmt.Errorf("can't initialize tracer exporter: %w", err)
+ span.AddEvent(eventName, trace.WithAttributes(attributes...))
+}
+
+// HandleSpanError marks a span as failed and records the error.
+func HandleSpanError(span trace.Span, message string, err error) {
+ if isNilSpan(span) || err == nil {
+ return
}
- mExp, err := cfg.newMetricExporter(ctx)
- if err != nil {
- return nil, fmt.Errorf("can't initialize metric exporter: %w", err)
+ // Build status message: avoid malformed ": " when message is empty
+ statusMsg := sanitizeSpanMessage(err.Error())
+ if message != "" {
+ statusMsg = message + ": " + statusMsg
}
- lExp, err := cfg.newLoggerExporter(ctx)
+ span.SetStatus(codes.Error, statusMsg)
+ span.RecordError(err)
+}
+
+// SetSpanAttributesFromValue flattens a value and sets resulting attributes on a span.
+func SetSpanAttributesFromValue(span trace.Span, prefix string, value any, redactor *Redactor) error {
+ if isNilSpan(span) {
+ return nil
+ }
+
+ attrs, err := BuildAttributesFromValue(prefix, value, redactor)
if err != nil {
- return nil, fmt.Errorf("can't initialize logger exporter: %w", err)
+ return err
}
- mp := cfg.newMeterProvider(r, mExp)
- otel.SetMeterProvider(mp)
+ if len(attrs) > 0 {
+ span.SetAttributes(attrs...)
+ }
- meter := mp.Meter(cfg.LibraryName)
- metricsFactory := metrics.NewMetricsFactory(meter, l)
+ return nil
+}
- tp := cfg.newTracerProvider(r, tExp)
- otel.SetTracerProvider(tp)
+// BuildAttributesFromValue flattens a value into OTEL attributes with optional redaction.
+func BuildAttributesFromValue(prefix string, value any, redactor *Redactor) ([]attribute.KeyValue, error) {
+ if value == nil {
+ return nil, nil
+ }
- lp := cfg.newLoggerProvider(r, lExp)
- global.SetLoggerProvider(lp)
+ processed := value
- shutdownHandler := func() {
- err := mp.Shutdown(ctx)
- if err != nil {
- l.Errorf("can't shutdown metric provider: %v", err)
- }
+ if redactor != nil {
+ var err error
- err = tp.Shutdown(ctx)
+ processed, err = ObfuscateStruct(value, redactor)
if err != nil {
- l.Errorf("can't shutdown tracer provider: %v", err)
+ return nil, err
}
+ }
- err = lp.Shutdown(ctx)
- if err != nil {
- l.Errorf("can't shutdown logger provider: %v", err)
- }
+ b, err := json.Marshal(processed)
+ if err != nil {
+ return nil, err
+ }
- err = tExp.Shutdown(ctx)
- if err != nil {
- l.Errorf("can't shutdown tracer exporter: %v", err)
- }
+ // Use json.NewDecoder with UseNumber() to preserve numeric precision.
+ // This avoids float64 rounding for large integers (e.g., financial amounts).
+ var decoded any
- err = mExp.Shutdown(ctx)
- if err != nil {
- l.Errorf("can't shutdown metric exporter: %v", err)
- }
+ dec := json.NewDecoder(bytes.NewReader(b))
+ dec.UseNumber()
- err = lExp.Shutdown(ctx)
- if err != nil {
- l.Errorf("can't shutdown logger exporter: %v", err)
- }
+ if err := dec.Decode(&decoded); err != nil {
+ return nil, err
}
- otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
+ // Use fallback prefix for top-level scalars/slices to avoid empty keys.
+ effectivePrefix := sanitizeUTF8String(prefix)
+ if effectivePrefix == "" {
+ switch decoded.(type) {
+ case map[string]any:
+ // Maps expand their own keys; empty prefix is fine.
+ case []any:
+ effectivePrefix = "item"
+ default:
+ effectivePrefix = defaultAttrPrefix
+ }
+ }
- l.Infof("Telemetry initialized ✅ ")
+ attrs := make([]attribute.KeyValue, 0, 16)
+ flattenAttributes(&attrs, effectivePrefix, decoded, 0)
- return &Telemetry{
- TelemetryConfig: TelemetryConfig{
- LibraryName: cfg.LibraryName,
- ServiceName: cfg.ServiceName,
- ServiceVersion: cfg.ServiceVersion,
- DeploymentEnv: cfg.DeploymentEnv,
- CollectorExporterEndpoint: cfg.CollectorExporterEndpoint,
- EnableTelemetry: cfg.EnableTelemetry,
- Logger: l,
- },
- TracerProvider: tp,
- MetricProvider: mp,
- LoggerProvider: lp,
- MetricsFactory: metricsFactory,
- shutdown: shutdownHandler,
- }, nil
+ return attrs, nil
}
-// Deprecated: Use InitializeTelemetryWithError for proper error handling.
-// InitializeTelemetry initializes the telemetry providers and sets them globally.
-func InitializeTelemetry(cfg *TelemetryConfig) *Telemetry {
- telemetry, err := InitializeTelemetryWithError(cfg)
- if err != nil {
- if cfg == nil || cfg.Logger == nil || errors.Is(err, ErrNilTelemetryConfig) || errors.Is(err, ErrNilTelemetryLogger) {
- stdlog.Fatalf("%v", err)
- }
+func flattenAttributes(attrs *[]attribute.KeyValue, prefix string, value any, depth int) {
+ if depth >= maxAttributeDepth {
+ return
+ }
- cfg.Logger.Fatalf("%v", err)
+ if len(*attrs) >= maxAttributeCount {
+ return
}
- return telemetry
+ switch v := value.(type) {
+ case map[string]any:
+ flattenMap(attrs, prefix, v, depth)
+ case []any:
+ flattenSlice(attrs, prefix, v, depth)
+ case string:
+ s := truncateUTF8(sanitizeUTF8String(v), maxSpanAttributeStringLength)
+ *attrs = append(*attrs, attribute.String(resolveKey(prefix, defaultAttrPrefix), s))
+ case float64:
+ *attrs = append(*attrs, attribute.Float64(resolveKey(prefix, defaultAttrPrefix), v))
+ case bool:
+ *attrs = append(*attrs, attribute.Bool(resolveKey(prefix, defaultAttrPrefix), v))
+ case json.Number:
+ flattenJSONNumber(attrs, prefix, v)
+ case nil:
+ return
+ default:
+ *attrs = append(*attrs, attribute.String(resolveKey(prefix, defaultAttrPrefix), sanitizeUTF8String(fmt.Sprint(v))))
+ }
}
-// SetSpanAttributesFromStruct converts a struct to a JSON string and sets it as an attribute on the span.
-func SetSpanAttributesFromStruct(span *trace.Span, key string, valueStruct any) error {
- jsonByte, err := json.Marshal(valueStruct)
- if err != nil {
- return err
+// resolveKey returns prefix if non-empty, otherwise falls back to fallback.
+func resolveKey(prefix, fallback string) string {
+ if prefix == "" {
+ return fallback
}
- vStr := string(jsonByte)
+ return prefix
+}
- (*span).SetAttributes(attribute.KeyValue{
- Key: attribute.Key(key),
- Value: attribute.StringValue(vStr),
- })
+func flattenMap(attrs *[]attribute.KeyValue, prefix string, m map[string]any, depth int) {
+ for key, child := range m {
+ next := sanitizeUTF8String(key)
+ if prefix != "" {
+ next = prefix + "." + next
+ }
- return nil
+ flattenAttributes(attrs, next, child, depth+1)
+ }
}
-// Deprecated: Use SetSpanAttributesFromStruct instead.
-//
-// SetSpanAttributesFromStructWithObfuscation converts a struct to a JSON string,
-// obfuscates sensitive fields using the default obfuscator, and sets it as an attribute on the span.
-func SetSpanAttributesFromStructWithObfuscation(span *trace.Span, key string, valueStruct any) error {
- return SetSpanAttributesFromStructWithCustomObfuscation(span, key, valueStruct, NewDefaultObfuscator())
+func flattenSlice(attrs *[]attribute.KeyValue, prefix string, s []any, depth int) {
+ idxKey := resolveKey(prefix, "item")
+ for i, child := range s {
+ next := idxKey + "." + strconv.Itoa(i)
+ flattenAttributes(attrs, next, child, depth+1)
+ }
}
-// Deprecated: Use SetSpanAttributesFromStruct instead.
-//
-// SetSpanAttributesFromStructWithCustomObfuscation converts a struct to a JSON string,
-// obfuscates sensitive fields using the custom obfuscator provided, and sets it as an attribute on the span.
-func SetSpanAttributesFromStructWithCustomObfuscation(span *trace.Span, key string, valueStruct any, obfuscator FieldObfuscator) error {
- processedStruct, err := ObfuscateStruct(valueStruct, obfuscator)
- if err != nil {
- return err
+func flattenJSONNumber(attrs *[]attribute.KeyValue, prefix string, v json.Number) {
+ key := resolveKey(prefix, defaultAttrPrefix)
+
+ // Try Int64 first for precision, fall back to Float64
+ if i, err := v.Int64(); err == nil {
+ *attrs = append(*attrs, attribute.Int64(key, i))
+ } else if f, err := v.Float64(); err == nil {
+ *attrs = append(*attrs, attribute.Float64(key, f))
+ } else {
+ *attrs = append(*attrs, attribute.String(key, string(v)))
}
+}
- jsonByte, err := json.Marshal(processedStruct)
- if err != nil {
- return err
+// truncateUTF8 truncates a string to at most maxBytes, ensuring the result is valid UTF-8.
+// If the byte-slice cut lands in the middle of a multi-byte rune, incomplete trailing bytes
+// are trimmed so the result is always valid.
+func truncateUTF8(s string, maxBytes int) string {
+ if len(s) <= maxBytes {
+ return s
}
- (*span).SetAttributes(attribute.KeyValue{
- Key: attribute.Key(sanitizeUTF8String(key)),
- Value: attribute.StringValue(sanitizeUTF8String(string(jsonByte))),
- })
+ s = s[:maxBytes]
- return nil
+ // If the truncation produced invalid UTF-8, trim the trailing incomplete rune
+ for len(s) > 0 && !utf8.ValidString(s) {
+ s = s[:len(s)-1]
+ }
+
+ return s
}
-// SetSpanAttributeForParam sets a span attribute for a Fiber request parameter with consistent naming
-// entityName is a snake_case string used to identify id name, for example the "organization" entity name will result in "app.request.organization_id"
-// otherwise the path parameter "id" in a Fiber request for example "/v1/organizations/:id" will be parsed as "app.request.id"
+// SetSpanAttributeForParam adds a request parameter attribute to the current context bag.
+// Sensitive parameter names (as determined by security.IsSensitiveField) are masked.
func SetSpanAttributeForParam(c *fiber.Ctx, param, value, entityName string) {
- spanAttrKey := "app.request." + param
+ if c == nil {
+ return
+ }
+ spanAttrKey := "app.request." + param
if entityName != "" && param == "id" {
spanAttrKey = "app.request." + entityName + "_id"
}
- c.SetUserContext(commons.ContextWithSpanAttributes(c.UserContext(), attribute.String(spanAttrKey, value)))
+ // Mask value if the parameter name is considered sensitive
+ attrValue := value
+ if security.IsSensitiveField(param) {
+ attrValue = "[REDACTED]"
+ }
+
+ c.SetUserContext(commons.ContextWithSpanAttributes(c.UserContext(), attribute.String(spanAttrKey, attrValue)))
}
-// HandleSpanBusinessErrorEvent adds a business error event to the span.
-func HandleSpanBusinessErrorEvent(span *trace.Span, eventName string, err error) {
- if span != nil && err != nil {
- (*span).AddEvent(eventName, trace.WithAttributes(attribute.String("error", err.Error())))
+// InjectTraceContext injects trace context into a generic text map carrier.
+func InjectTraceContext(ctx context.Context, carrier propagation.TextMapCarrier) {
+ if carrier == nil {
+ return
}
+
+ otel.GetTextMapPropagator().Inject(ctx, carrier)
}
-// HandleSpanEvent adds an event to the span.
-func HandleSpanEvent(span *trace.Span, eventName string, attributes ...attribute.KeyValue) {
- if span != nil {
- (*span).AddEvent(eventName, trace.WithAttributes(attributes...))
+// ExtractTraceContext extracts trace context from a generic text map carrier.
+func ExtractTraceContext(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
+ if carrier == nil {
+ return ctx
}
+
+ return otel.GetTextMapPropagator().Extract(ctx, carrier)
}
-// HandleSpanError sets the status of the span to error and records the error.
-func HandleSpanError(span *trace.Span, message string, err error) {
- if span != nil && err != nil {
- (*span).SetStatus(codes.Error, message+": "+err.Error())
- (*span).RecordError(err)
+// InjectHTTPContext injects trace headers into HTTP headers.
+func InjectHTTPContext(ctx context.Context, headers http.Header) {
+ if headers == nil {
+ return
}
-}
-// InjectHTTPContext modifies HTTP headers for trace propagation in outgoing client requests
-func InjectHTTPContext(headers *http.Header, ctx context.Context) {
- carrier := propagation.HeaderCarrier{}
- otel.GetTextMapPropagator().Inject(ctx, carrier)
+ InjectTraceContext(ctx, propagation.HeaderCarrier(headers))
+}
- for k, v := range carrier {
- if len(v) > 0 {
- headers.Set(k, v[0])
- }
+// ExtractHTTPContext extracts trace headers from a Fiber request.
+func ExtractHTTPContext(ctx context.Context, c *fiber.Ctx) context.Context {
+ if c == nil {
+ return ctx
}
-}
-// ExtractHTTPContext extracts OpenTelemetry trace context from incoming HTTP headers
-// and injects it into the context. It works with Fiber's HTTP context.
-func ExtractHTTPContext(c *fiber.Ctx) context.Context {
- // Create a carrier from the HTTP headers
carrier := propagation.HeaderCarrier{}
-
- // Extract headers that might contain trace information
for key, value := range c.Request().Header.All() {
carrier.Set(string(key), string(value))
}
- // Extract the trace context
- return otel.GetTextMapPropagator().Extract(c.UserContext(), carrier)
+ return ExtractTraceContext(ctx, carrier)
}
-// InjectGRPCContext injects OpenTelemetry trace context into outgoing gRPC metadata.
-// It normalizes W3C trace headers to lowercase for gRPC compatibility.
-func InjectGRPCContext(ctx context.Context) context.Context {
- md, _ := metadata.FromOutgoingContext(ctx)
+// InjectGRPCContext injects trace context into gRPC metadata.
+func InjectGRPCContext(ctx context.Context, md metadata.MD) metadata.MD {
if md == nil {
md = metadata.New(nil)
}
- // Returns the canonical format of the MIME header key s.
- // The canonicalization converts the first letter and any letter
- // following a hyphen to upper case; the rest are converted to lowercase.
- // For example, the canonical key for "accept-encoding" is "Accept-Encoding".
- // MIME header keys are assumed to be ASCII only.
- // If s contains a space or invalid header field bytes, it is
- // returned without modifications.
- otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(md))
+ InjectTraceContext(ctx, propagation.HeaderCarrier(md))
- if traceparentValues, exists := md["Traceparent"]; exists && len(traceparentValues) > 0 {
+ if traceparentValues, exists := md[constant.HeaderTraceparentPascal]; exists && len(traceparentValues) > 0 {
md[constant.MetadataTraceparent] = traceparentValues
- delete(md, "Traceparent")
+ delete(md, constant.HeaderTraceparentPascal)
}
- if tracestateValues, exists := md["Tracestate"]; exists && len(tracestateValues) > 0 {
+ if tracestateValues, exists := md[constant.HeaderTracestatePascal]; exists && len(tracestateValues) > 0 {
md[constant.MetadataTracestate] = tracestateValues
- delete(md, "Tracestate")
+ delete(md, constant.HeaderTracestatePascal)
}
- return metadata.NewOutgoingContext(ctx, md)
+ return md
}
-// ExtractGRPCContext extracts OpenTelemetry trace context from incoming gRPC metadata
-// and injects it into the context. It handles case normalization for W3C trace headers.
-func ExtractGRPCContext(ctx context.Context) context.Context {
- md, ok := metadata.FromIncomingContext(ctx)
- if !ok || md == nil {
+// ExtractGRPCContext extracts trace context from gRPC metadata.
+func ExtractGRPCContext(ctx context.Context, md metadata.MD) context.Context {
+ if md == nil {
return ctx
}
mdCopy := md.Copy()
if traceparentValues, exists := mdCopy[constant.MetadataTraceparent]; exists && len(traceparentValues) > 0 {
- mdCopy["Traceparent"] = traceparentValues
+ mdCopy[constant.HeaderTraceparentPascal] = traceparentValues
delete(mdCopy, constant.MetadataTraceparent)
}
if tracestateValues, exists := mdCopy[constant.MetadataTracestate]; exists && len(tracestateValues) > 0 {
- mdCopy["Tracestate"] = tracestateValues
+ mdCopy[constant.HeaderTracestatePascal] = tracestateValues
delete(mdCopy, constant.MetadataTracestate)
}
- return otel.GetTextMapPropagator().Extract(ctx, propagation.HeaderCarrier(mdCopy))
+ return ExtractTraceContext(ctx, propagation.HeaderCarrier(mdCopy))
}
-// InjectQueueTraceContext injects OpenTelemetry trace context into RabbitMQ headers
-// for distributed tracing across queue messages. Returns a map of headers to be
-// added to the RabbitMQ message headers.
+// InjectQueueTraceContext serializes trace context to string headers for queues.
func InjectQueueTraceContext(ctx context.Context) map[string]string {
carrier := propagation.HeaderCarrier{}
- otel.GetTextMapPropagator().Inject(ctx, carrier)
-
- headers := make(map[string]string)
+ InjectTraceContext(ctx, carrier)
+ headers := make(map[string]string, len(carrier))
for k, v := range carrier {
if len(v) > 0 {
headers[k] = v[0]
@@ -457,9 +836,7 @@ func InjectQueueTraceContext(ctx context.Context) map[string]string {
return headers
}
-// ExtractQueueTraceContext extracts OpenTelemetry trace context from RabbitMQ headers
-// and returns a new context with the extracted trace information. This enables
-// distributed tracing continuity across queue message boundaries.
+// ExtractQueueTraceContext extracts trace context from queue string headers.
func ExtractQueueTraceContext(ctx context.Context, headers map[string]string) context.Context {
if headers == nil {
return ctx
@@ -470,52 +847,14 @@ func ExtractQueueTraceContext(ctx context.Context, headers map[string]string) co
carrier.Set(k, v)
}
- return otel.GetTextMapPropagator().Extract(ctx, carrier)
-}
-
-// GetTraceIDFromContext extracts the trace ID from the current span context
-// Returns empty string if no active span or trace ID is found
-func GetTraceIDFromContext(ctx context.Context) string {
- span := trace.SpanFromContext(ctx)
- if span == nil {
- return ""
- }
-
- spanContext := span.SpanContext()
-
- if !spanContext.IsValid() {
- return ""
- }
-
- return spanContext.TraceID().String()
-}
-
-// GetTraceStateFromContext extracts the trace state from the current span context
-// Returns empty string if no active span or trace state is found
-func GetTraceStateFromContext(ctx context.Context) string {
- span := trace.SpanFromContext(ctx)
- if span == nil {
- return ""
- }
-
- spanContext := span.SpanContext()
-
- if !spanContext.IsValid() {
- return ""
- }
-
- return spanContext.TraceState().String()
+ return ExtractTraceContext(ctx, carrier)
}
-// PrepareQueueHeaders prepares RabbitMQ headers with trace context injection
-// following W3C trace context standards. Returns a map suitable for amqp.Table.
+// PrepareQueueHeaders merges base headers with propagated trace headers.
func PrepareQueueHeaders(ctx context.Context, baseHeaders map[string]any) map[string]any {
headers := make(map[string]any)
-
- // Copy base headers first
maps.Copy(headers, baseHeaders)
- // Inject trace context using W3C standards
traceHeaders := InjectQueueTraceContext(ctx)
for k, v := range traceHeaders {
headers[k] = v
@@ -524,28 +863,28 @@ func PrepareQueueHeaders(ctx context.Context, baseHeaders map[string]any) map[st
return headers
}
-// InjectTraceHeadersIntoQueue adds OpenTelemetry trace headers to existing RabbitMQ headers
-// following W3C trace context standards. Modifies the headers map in place.
+// InjectTraceHeadersIntoQueue injects propagated trace headers into a mutable map.
func InjectTraceHeadersIntoQueue(ctx context.Context, headers *map[string]any) {
if headers == nil {
return
}
- // Inject trace context using W3C standards
+ if *headers == nil {
+ *headers = make(map[string]any)
+ }
+
traceHeaders := InjectQueueTraceContext(ctx)
for k, v := range traceHeaders {
(*headers)[k] = v
}
}
-// ExtractTraceContextFromQueueHeaders extracts OpenTelemetry trace context from RabbitMQ amqp.Table headers
-// and returns a new context with the extracted trace information. Handles type conversion automatically.
+// ExtractTraceContextFromQueueHeaders extracts trace context from AMQP-style headers.
func ExtractTraceContextFromQueueHeaders(baseCtx context.Context, amqpHeaders map[string]any) context.Context {
if len(amqpHeaders) == 0 {
return baseCtx
}
- // Convert amqp.Table headers to map[string]string for trace extraction
traceHeaders := make(map[string]string)
for k, v := range amqpHeaders {
@@ -558,16 +897,33 @@ func ExtractTraceContextFromQueueHeaders(baseCtx context.Context, amqpHeaders ma
return baseCtx
}
- // Extract trace context using existing function
return ExtractQueueTraceContext(baseCtx, traceHeaders)
}
-func (tl *Telemetry) EndTracingSpans(ctx context.Context) {
- trace.SpanFromContext(ctx).End()
+// GetTraceIDFromContext returns the current span trace ID, or empty if unavailable.
+func GetTraceIDFromContext(ctx context.Context) string {
+ span := trace.SpanFromContext(ctx)
+
+ sc := span.SpanContext()
+ if !sc.IsValid() {
+ return ""
+ }
+
+ return sc.TraceID().String()
+}
+
+// GetTraceStateFromContext returns the current span tracestate, or empty if unavailable.
+func GetTraceStateFromContext(ctx context.Context) string {
+ span := trace.SpanFromContext(ctx)
+
+ sc := span.SpanContext()
+ if !sc.IsValid() {
+ return ""
+ }
+
+ return sc.TraceState().String()
}
-// sanitizeUTF8String validates and sanitizes UTF-8 string.
-// If the string contains invalid UTF-8 characters, they are replaced with the Unicode replacement character (�).
func sanitizeUTF8String(s string) string {
if !utf8.ValidString(s) {
return strings.ToValidUTF8(s, "�")
diff --git a/commons/opentelemetry/otel_example_test.go b/commons/opentelemetry/otel_example_test.go
new file mode 100644
index 00000000..f3fcc6bc
--- /dev/null
+++ b/commons/opentelemetry/otel_example_test.go
@@ -0,0 +1,36 @@
+//go:build unit
+
+package opentelemetry_test
+
+import (
+ "fmt"
+ "sort"
+ "strings"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+)
+
+func ExampleBuildAttributesFromValue() {
+ type payload struct {
+ ID string `json:"id"`
+ RiskScore int `json:"risk_score"`
+ }
+
+ attrs, err := opentelemetry.BuildAttributesFromValue("customer", payload{
+ ID: "cst_123",
+ RiskScore: 8,
+ }, nil)
+
+ keys := make([]string, 0, len(attrs))
+ for _, kv := range attrs {
+ keys = append(keys, string(kv.Key))
+ }
+ sort.Strings(keys)
+
+ fmt.Println(err == nil)
+ fmt.Println(strings.Join(keys, ","))
+
+ // Output:
+ // true
+ // customer.id,customer.risk_score
+}
diff --git a/commons/opentelemetry/otel_test.go b/commons/opentelemetry/otel_test.go
index f0027740..ae970a28 100644
--- a/commons/opentelemetry/otel_test.go
+++ b/commons/opentelemetry/otel_test.go
@@ -1,107 +1,1278 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package opentelemetry
import (
- "errors"
+ "context"
+ "os"
+ "strings"
"testing"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/propagation"
+ sdktrace "go.opentelemetry.io/otel/sdk/trace"
+ "go.opentelemetry.io/otel/sdk/trace/tracetest"
+ "go.opentelemetry.io/otel/trace"
+ "google.golang.org/grpc/metadata"
)
-func TestInitializeTelemetryWithError_TelemetryDisabled(t *testing.T) {
- cfg := &TelemetryConfig{
- LibraryName: "test-lib",
- ServiceName: "test-service",
- ServiceVersion: "1.0.0",
- DeploymentEnv: "test",
+// ===========================================================================
+// 1. NewTelemetry validation
+// ===========================================================================
+
+func TestNewTelemetry_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ tl, err := NewTelemetry(TelemetryConfig{
EnableTelemetry: false,
- Logger: &log.NoneLogger{},
- }
+ })
+ require.ErrorIs(t, err, ErrNilTelemetryLogger)
+ assert.Nil(t, tl)
+}
- telemetry, err := InitializeTelemetryWithError(cfg)
+func TestNewTelemetry_EnabledEmptyEndpoint(t *testing.T) {
+ t.Parallel()
- assert.NoError(t, err)
- assert.NotNil(t, telemetry)
- assert.NotNil(t, telemetry.TracerProvider)
- assert.NotNil(t, telemetry.MetricProvider)
- assert.NotNil(t, telemetry.LoggerProvider)
- assert.NotNil(t, telemetry.MetricsFactory)
+ tl, err := NewTelemetry(TelemetryConfig{
+ EnableTelemetry: true,
+ LibraryName: "test-lib",
+ Logger: log.NewNop(),
+ })
+ require.ErrorIs(t, err, ErrEmptyEndpoint)
+ require.NotNil(t, tl, "must return noop Telemetry to prevent goroutine leaks")
+ assert.NotNil(t, tl.TracerProvider)
+ assert.NotNil(t, tl.MeterProvider)
+ assert.NotNil(t, tl.LoggerProvider)
+ assert.NotNil(t, tl.MetricsFactory)
}
-func TestInitializeTelemetry_TelemetryDisabled(t *testing.T) {
- cfg := &TelemetryConfig{
+func TestNewTelemetry_EnabledWhitespaceEndpoint(t *testing.T) {
+ t.Parallel()
+
+ tl, err := NewTelemetry(TelemetryConfig{
+ EnableTelemetry: true,
+ CollectorExporterEndpoint: " ",
+ LibraryName: "test-lib",
+ Logger: log.NewNop(),
+ })
+ require.ErrorIs(t, err, ErrEmptyEndpoint)
+ require.NotNil(t, tl, "must return noop Telemetry to prevent goroutine leaks")
+ assert.NotNil(t, tl.TracerProvider)
+ assert.NotNil(t, tl.MeterProvider)
+ assert.NotNil(t, tl.LoggerProvider)
+ assert.NotNil(t, tl.MetricsFactory)
+}
+
+func TestNewTelemetry_EnabledEmptyEndpoint_SetsGlobalNoopProviders(t *testing.T) {
+ // Not parallel: mutates global OTEL providers.
+ prevTP := otel.GetTracerProvider()
+ prevMP := otel.GetMeterProvider()
+ t.Cleanup(func() {
+ otel.SetTracerProvider(prevTP)
+ otel.SetMeterProvider(prevMP)
+ })
+
+ tl, err := NewTelemetry(TelemetryConfig{
+ EnableTelemetry: true,
LibraryName: "test-lib",
- ServiceName: "test-service",
- ServiceVersion: "1.0.0",
+ Logger: log.NewNop(),
+ })
+ require.ErrorIs(t, err, ErrEmptyEndpoint)
+ require.NotNil(t, tl)
+
+ // Verify that noop providers were installed as globals, preventing
+ // downstream libraries from spawning real gRPC exporters.
+ assert.Same(t, tl.TracerProvider, otel.GetTracerProvider(),
+ "global tracer provider must be the noop instance")
+ assert.Same(t, tl.MeterProvider, otel.GetMeterProvider(),
+ "global meter provider must be the noop instance")
+}
+
+func TestNewTelemetry_DisabledReturnsNoopProviders(t *testing.T) {
+ t.Parallel()
+
+ tl, err := NewTelemetry(TelemetryConfig{
+ LibraryName: "test-lib",
+ ServiceName: "test-svc",
+ ServiceVersion: "0.1.0",
DeploymentEnv: "test",
EnableTelemetry: false,
- Logger: &log.NoneLogger{},
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err)
+ require.NotNil(t, tl)
+ assert.NotNil(t, tl.TracerProvider)
+ assert.NotNil(t, tl.MeterProvider)
+ assert.NotNil(t, tl.LoggerProvider)
+ assert.NotNil(t, tl.MetricsFactory)
+ assert.NotNil(t, tl.Redactor)
+ assert.NotNil(t, tl.Propagator)
+}
+
+func TestNewTelemetry_DefaultPropagatorAndRedactor(t *testing.T) {
+ t.Parallel()
+
+ tl, err := NewTelemetry(TelemetryConfig{
+ LibraryName: "test-lib",
+ EnableTelemetry: false,
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err)
+ assert.NotNil(t, tl.Propagator, "default propagator should be set")
+ assert.NotNil(t, tl.Redactor, "default redactor should be set")
+}
+
+// ===========================================================================
+// 1b. Endpoint normalization
+// ===========================================================================
+
+func TestNewTelemetry_EndpointNormalization(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ endpoint string
+ wantEndpoint string
+ wantInsecure bool
+ insecureOverride bool // initial InsecureExporter value
+ }{
+ {
+ name: "http scheme stripped and insecure inferred",
+ endpoint: "http://otel-collector:4317",
+ wantEndpoint: "otel-collector:4317",
+ wantInsecure: true,
+ },
+ {
+ name: "https scheme stripped and insecure stays false",
+ endpoint: "https://otel-collector:4317",
+ wantEndpoint: "otel-collector:4317",
+ wantInsecure: false,
+ },
+ {
+ name: "no scheme defaults to insecure",
+ endpoint: "otel-collector:4317",
+ wantEndpoint: "otel-collector:4317",
+ wantInsecure: true,
+ },
+ {
+ name: "https with explicit insecure override preserved",
+ endpoint: "https://otel-collector:4317",
+ insecureOverride: true,
+ wantEndpoint: "otel-collector:4317",
+ wantInsecure: true,
+ },
+ {
+ name: "http with trailing slash",
+ endpoint: "http://otel-collector:4317/",
+ wantEndpoint: "otel-collector:4317/",
+ wantInsecure: true,
+ },
}
- telemetry := InitializeTelemetry(cfg)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
- assert.NotNil(t, telemetry)
- assert.NotNil(t, telemetry.TracerProvider)
- assert.NotNil(t, telemetry.MetricProvider)
- assert.NotNil(t, telemetry.LoggerProvider)
+ // Use telemetry disabled so we don't need a real collector.
+ tl, err := NewTelemetry(TelemetryConfig{
+ LibraryName: "test-lib",
+ EnableTelemetry: false,
+ CollectorExporterEndpoint: tt.endpoint,
+ InsecureExporter: tt.insecureOverride,
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err)
+ require.NotNil(t, tl)
+ assert.Equal(t, tt.wantEndpoint, tl.CollectorExporterEndpoint,
+ "endpoint should be normalized")
+ assert.Equal(t, tt.wantInsecure, tl.InsecureExporter,
+ "InsecureExporter should be inferred from scheme")
+ })
+ }
}
-func TestInitializeTelemetryWithError_NilConfig(t *testing.T) {
- telemetry, err := InitializeTelemetryWithError(nil)
+// ===========================================================================
+// 1c. Endpoint environment variable normalization
+// ===========================================================================
- assert.Nil(t, telemetry)
- assert.Error(t, err)
- assert.True(t, errors.Is(err, ErrNilTelemetryConfig))
+func TestNormalizeEndpointEnvVars(t *testing.T) {
+ envKeys := []string{
+ "OTEL_EXPORTER_OTLP_ENDPOINT",
+ "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT",
+ "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT",
+ }
+
+ tests := []struct {
+ name string
+ value string
+ set bool
+ expected string
+ }{
+ {
+ name: "bare host:port gets http scheme",
+ value: "10.10.0.202:4317",
+ set: true,
+ expected: "http://10.10.0.202:4317",
+ },
+ {
+ name: "hostname:port gets http scheme",
+ value: "otel-collector:4317",
+ set: true,
+ expected: "http://otel-collector:4317",
+ },
+ {
+ name: "http scheme preserved",
+ value: "http://otel-collector:4317",
+ set: true,
+ expected: "http://otel-collector:4317",
+ },
+ {
+ name: "https scheme preserved",
+ value: "https://otel-collector:4317",
+ set: true,
+ expected: "https://otel-collector:4317",
+ },
+ {
+ name: "whitespace trimmed before adding scheme",
+ value: " 10.10.0.202:4317 ",
+ set: true,
+ expected: "http://10.10.0.202:4317",
+ },
+ {
+ name: "empty value skipped",
+ value: "",
+ set: true,
+ expected: "",
+ },
+ {
+ name: "whitespace-only value skipped",
+ value: " ",
+ set: true,
+ expected: " ",
+ },
+ {
+ name: "unset env var skipped",
+ value: "",
+ set: false,
+ expected: "",
+ },
+ }
+
+ for _, tt := range tests {
+ for _, key := range envKeys {
+ t.Run(tt.name+"/"+key, func(t *testing.T) {
+ if tt.set {
+ t.Setenv(key, tt.value)
+ }
+
+ normalizeEndpointEnvVars()
+
+ got := os.Getenv(key)
+ assert.Equal(t, tt.expected, got)
+ })
+ }
+ }
}
-func TestInitializeTelemetryWithError_NilLogger(t *testing.T) {
- cfg := &TelemetryConfig{
+// ===========================================================================
+// 2. Telemetry methods on nil receiver
+// ===========================================================================
+
+func TestTelemetry_ApplyGlobals_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var tl *Telemetry
+ err := tl.ApplyGlobals()
+ require.ErrorIs(t, err, ErrNilTelemetry)
+}
+
+func TestTelemetry_Tracer_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var tl *Telemetry
+ tr, err := tl.Tracer("test")
+ require.ErrorIs(t, err, ErrNilTelemetry)
+ assert.Nil(t, tr)
+}
+
+func TestTelemetry_Meter_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var tl *Telemetry
+ m, err := tl.Meter("test")
+ require.ErrorIs(t, err, ErrNilTelemetry)
+ assert.Nil(t, m)
+}
+
+func TestTelemetry_ShutdownTelemetry_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var tl *Telemetry
+ assert.NotPanics(t, func() { tl.ShutdownTelemetry() })
+}
+
+func TestTelemetry_ShutdownTelemetryWithContext_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var tl *Telemetry
+ err := tl.ShutdownTelemetryWithContext(context.Background())
+ require.ErrorIs(t, err, ErrNilTelemetry)
+}
+
+// ===========================================================================
+// 3. Telemetry with disabled telemetry — provider access
+// ===========================================================================
+
+func newDisabledTelemetry(t *testing.T) *Telemetry {
+ t.Helper()
+
+ tl, err := NewTelemetry(TelemetryConfig{
LibraryName: "test-lib",
- ServiceName: "test-service",
- ServiceVersion: "1.0.0",
- DeploymentEnv: "test",
+ ServiceName: "test-svc",
+ ServiceVersion: "0.1.0",
EnableTelemetry: false,
- Logger: nil,
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err)
+
+ return tl
+}
+
+func TestTelemetry_Disabled_Tracer(t *testing.T) {
+ t.Parallel()
+
+ tl := newDisabledTelemetry(t)
+ tr, err := tl.Tracer("test-tracer")
+ require.NoError(t, err)
+ assert.NotNil(t, tr)
+}
+
+func TestTelemetry_Disabled_Meter(t *testing.T) {
+ t.Parallel()
+
+ tl := newDisabledTelemetry(t)
+ m, err := tl.Meter("test-meter")
+ require.NoError(t, err)
+ assert.NotNil(t, m)
+}
+
+func TestTelemetry_Disabled_ShutdownWithContext(t *testing.T) {
+ t.Parallel()
+
+ tl := newDisabledTelemetry(t)
+ err := tl.ShutdownTelemetryWithContext(context.Background())
+ require.NoError(t, err)
+}
+
+func TestTelemetry_Disabled_ShutdownTelemetry(t *testing.T) {
+ t.Parallel()
+
+ tl := newDisabledTelemetry(t)
+ assert.NotPanics(t, func() { tl.ShutdownTelemetry() })
+}
+
+func TestTelemetry_Disabled_ApplyGlobals(t *testing.T) {
+ prevTP := otel.GetTracerProvider()
+ prevMP := otel.GetMeterProvider()
+ t.Cleanup(func() {
+ otel.SetTracerProvider(prevTP)
+ otel.SetMeterProvider(prevMP)
+ })
+
+ tl := newDisabledTelemetry(t)
+ require.NoError(t, tl.ApplyGlobals())
+ assert.Same(t, tl.TracerProvider, otel.GetTracerProvider())
+ assert.Same(t, tl.MeterProvider, otel.GetMeterProvider())
+}
+
+// ===========================================================================
+// 4. ShutdownTelemetryWithContext — nil shutdown functions
+// ===========================================================================
+
+func TestTelemetry_ShutdownWithContext_NilShutdownFuncs(t *testing.T) {
+ t.Parallel()
+
+ tl := &Telemetry{
+ TelemetryConfig: TelemetryConfig{Logger: log.NewNop()},
+ shutdown: nil,
+ shutdownCtx: nil,
+ }
+
+ err := tl.ShutdownTelemetryWithContext(context.Background())
+ require.ErrorIs(t, err, ErrNilShutdown)
+}
+
+func TestTelemetry_ShutdownWithContext_FallbackToShutdown(t *testing.T) {
+ t.Parallel()
+
+ called := false
+ tl := &Telemetry{
+ TelemetryConfig: TelemetryConfig{Logger: log.NewNop()},
+ shutdown: func() { called = true },
+ shutdownCtx: nil,
}
- telemetry, err := InitializeTelemetryWithError(cfg)
+ err := tl.ShutdownTelemetryWithContext(context.Background())
+ require.NoError(t, err)
+ assert.True(t, called, "fallback shutdown should have been invoked")
+}
- assert.Nil(t, telemetry)
- assert.Error(t, err)
- assert.True(t, errors.Is(err, ErrNilTelemetryLogger))
+// ===========================================================================
+// 5. Context propagation helpers — nil/empty edge cases
+// ===========================================================================
+
+func TestInjectTraceContext_NilCarrier(t *testing.T) {
+ t.Parallel()
+ assert.NotPanics(t, func() { InjectTraceContext(context.Background(), nil) })
}
-func TestInitializeTelemetryWithError_EnabledWithLazyConnection(t *testing.T) {
- // Note: gRPC uses lazy connection, so the exporter creation succeeds initially.
- // The actual connection error would happen when trying to export data.
- // This test verifies that InitializeTelemetryWithError handles valid configuration
- // without panicking and returns a functional Telemetry instance.
- cfg := &TelemetryConfig{
- LibraryName: "test-lib",
- ServiceName: "test-service",
- ServiceVersion: "1.0.0",
- DeploymentEnv: "test",
- CollectorExporterEndpoint: "localhost:4317",
- EnableTelemetry: true,
- Logger: &log.NoneLogger{},
+func TestExtractTraceContext_NilCarrier(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ result := ExtractTraceContext(ctx, nil)
+ assert.Equal(t, ctx, result)
+}
+
+func TestInjectHTTPContext_NilHeaders(t *testing.T) {
+ t.Parallel()
+ assert.NotPanics(t, func() { InjectHTTPContext(context.Background(), nil) })
+}
+
+func TestInjectGRPCContext_NilMD(t *testing.T) {
+ t.Parallel()
+
+ md := InjectGRPCContext(context.Background(), nil)
+ require.NotNil(t, md, "nil md should produce a new metadata.MD")
+}
+
+func TestExtractGRPCContext_NilMD(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ result := ExtractGRPCContext(ctx, nil)
+ assert.Equal(t, ctx, result)
+}
+
+func TestExtractGRPCContext_WithTraceparentKey(t *testing.T) {
+ t.Parallel()
+
+ md := metadata.MD{
+ "traceparent": {"00-00112233445566778899aabbccddeeff-0123456789abcdef-01"},
}
+ ctx := ExtractGRPCContext(context.Background(), md)
+ assert.NotNil(t, ctx)
+
+ span := trace.SpanFromContext(ctx)
+ assert.Equal(t, "00112233445566778899aabbccddeeff", span.SpanContext().TraceID().String())
+}
+
+func TestInjectQueueTraceContext_ReturnsMap(t *testing.T) {
+ t.Parallel()
+
+ headers := InjectQueueTraceContext(context.Background())
+ require.NotNil(t, headers)
+}
+
+func TestExtractQueueTraceContext_NilHeaders(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ result := ExtractQueueTraceContext(ctx, nil)
+ assert.Equal(t, ctx, result)
+}
+
+func TestPrepareQueueHeaders_MergesHeaders(t *testing.T) {
+ t.Parallel()
+
+ base := map[string]any{"routing_key": "my.queue"}
+ result := PrepareQueueHeaders(context.Background(), base)
+ require.NotNil(t, result)
+ assert.Equal(t, "my.queue", result["routing_key"])
+}
+
+func TestPrepareQueueHeaders_DoesNotMutateBase(t *testing.T) {
+ t.Parallel()
+
+ base := map[string]any{"key": "val"}
+ result := PrepareQueueHeaders(context.Background(), base)
+ assert.Len(t, base, 1)
+ assert.NotSame(t, &base, &result)
+}
+
+func TestInjectTraceHeadersIntoQueue_NilPointer(t *testing.T) {
+ t.Parallel()
+ assert.NotPanics(t, func() { InjectTraceHeadersIntoQueue(context.Background(), nil) })
+}
+
+func TestInjectTraceHeadersIntoQueue_NilMap(t *testing.T) {
+ t.Parallel()
+
+ var headers map[string]any
+ InjectTraceHeadersIntoQueue(context.Background(), &headers)
+ require.NotNil(t, headers, "nil *map should be initialized")
+}
+
+func TestInjectTraceHeadersIntoQueue_ValidMap(t *testing.T) {
+ t.Parallel()
+
+ headers := map[string]any{"existing": "value"}
+ InjectTraceHeadersIntoQueue(context.Background(), &headers)
+ assert.Equal(t, "value", headers["existing"])
+}
+
+func TestExtractTraceContextFromQueueHeaders_EmptyHeaders(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ result := ExtractTraceContextFromQueueHeaders(ctx, nil)
+ assert.Equal(t, ctx, result)
+
+ result = ExtractTraceContextFromQueueHeaders(ctx, map[string]any{})
+ assert.Equal(t, ctx, result)
+}
- telemetry, err := InitializeTelemetryWithError(cfg)
+func TestExtractTraceContextFromQueueHeaders_NonStringValues(t *testing.T) {
+ t.Parallel()
- // With gRPC lazy connection, this should succeed
+ ctx := context.Background()
+ headers := map[string]any{
+ "traceparent": 12345,
+ "other": true,
+ }
+ result := ExtractTraceContextFromQueueHeaders(ctx, headers)
+ assert.Equal(t, ctx, result, "non-string values should be skipped, returning original ctx")
+}
+
+func TestExtractTraceContextFromQueueHeaders_ValidHeaders(t *testing.T) {
+ prev := otel.GetTextMapPropagator()
+ t.Cleanup(func() { otel.SetTextMapPropagator(prev) })
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+
+ headers := map[string]any{
+ "traceparent": "00-00112233445566778899aabbccddeeff-0123456789abcdef-01",
+ }
+ ctx := ExtractTraceContextFromQueueHeaders(context.Background(), headers)
+ span := trace.SpanFromContext(ctx)
+ assert.Equal(t, "00112233445566778899aabbccddeeff", span.SpanContext().TraceID().String())
+}
+
+// ===========================================================================
+// 6. GetTraceIDFromContext / GetTraceStateFromContext
+// ===========================================================================
+
+func TestGetTraceIDFromContext_NoActiveSpan(t *testing.T) {
+ t.Parallel()
+ assert.Empty(t, GetTraceIDFromContext(context.Background()))
+}
+
+func TestGetTraceStateFromContext_NoActiveSpan(t *testing.T) {
+ t.Parallel()
+ assert.Empty(t, GetTraceStateFromContext(context.Background()))
+}
+
+func TestGetTraceIDFromContext_WithSpan(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider()
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+
+ traceID := GetTraceIDFromContext(ctx)
+ assert.NotEmpty(t, traceID)
+ assert.Len(t, traceID, 32) // hex-encoded 16-byte trace ID
+}
+
+func TestGetTraceStateFromContext_WithSpan(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider()
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+
+ // SDK-created spans have empty tracestate by default, which is valid.
+ state := GetTraceStateFromContext(ctx)
+ assert.NotNil(t, state) // zero-value string is fine
+}
+
+// ===========================================================================
+// 7. flattenAttributes via BodyToSpanAttributes / BuildAttributesFromValue
+// ===========================================================================
+
+func TestFlattenAttributes_NestedMap(t *testing.T) {
+ t.Parallel()
+
+ attrs, err := BuildAttributesFromValue("root", map[string]any{
+ "user": map[string]any{
+ "name": "alice",
+ "age": float64(30),
+ },
+ "active": true,
+ }, nil)
require.NoError(t, err)
- require.NotNil(t, telemetry)
- assert.NotNil(t, telemetry.TracerProvider)
- assert.NotNil(t, telemetry.MetricProvider)
- assert.NotNil(t, telemetry.LoggerProvider)
- // Clean up
+ m := attrsToMap(attrs)
+ assert.Equal(t, "alice", m["root.user.name"])
+ assert.Contains(t, m, "root.user.age")
+ assert.Contains(t, m, "root.active")
+}
+
+func TestFlattenAttributes_Array(t *testing.T) {
+ t.Parallel()
+
+ attrs, err := BuildAttributesFromValue("items", map[string]any{
+ "list": []any{"a", "b"},
+ }, nil)
+ require.NoError(t, err)
+
+ m := attrsToMap(attrs)
+ assert.Equal(t, "a", m["items.list.0"])
+ assert.Equal(t, "b", m["items.list.1"])
+}
+
+func TestFlattenAttributes_NilValue(t *testing.T) {
+ t.Parallel()
+
+ attrs, err := BuildAttributesFromValue("prefix", nil, nil)
+ require.NoError(t, err)
+ assert.Nil(t, attrs)
+}
+
+func TestFlattenAttributes_StringTruncation(t *testing.T) {
+ t.Parallel()
+
+ longStr := strings.Repeat("x", maxSpanAttributeStringLength+500)
+ attrs, err := BuildAttributesFromValue("k", map[string]any{"v": longStr}, nil)
+ require.NoError(t, err)
+ require.Len(t, attrs, 1)
+ assert.Len(t, attrs[0].Value.AsString(), maxSpanAttributeStringLength)
+}
+
+func TestFlattenAttributes_DepthLimit(t *testing.T) {
+ t.Parallel()
+
+ // Build a deeply nested map exceeding maxAttributeDepth
+ nested := map[string]any{"leaf": "value"}
+ for i := 0; i < maxAttributeDepth+5; i++ {
+ nested = map[string]any{"level": nested}
+ }
+
+ var attrs []attribute.KeyValue
+ flattenAttributes(&attrs, "root", nested, 0)
+
+ // The leaf should never appear because depth is exceeded
+ for _, a := range attrs {
+ assert.NotContains(t, string(a.Key), "leaf")
+ }
+}
+
+func TestFlattenAttributes_CountLimit(t *testing.T) {
+ t.Parallel()
+
+ // Build a flat map with more than maxAttributeCount entries
+ wide := make(map[string]any, maxAttributeCount+50)
+ for i := 0; i < maxAttributeCount+50; i++ {
+ wide[strings.Repeat("k", 3)+strings.Repeat("0", 4)+string(rune('a'+i%26))+strings.Repeat("0", 3)] = "v"
+ }
+
+ var attrs []attribute.KeyValue
+ flattenAttributes(&attrs, "root", wide, 0)
+
+ assert.LessOrEqual(t, len(attrs), maxAttributeCount)
+}
+
+func TestFlattenAttributes_JsonNumber(t *testing.T) {
+ t.Parallel()
+
+ // json.Number is produced when using a Decoder with UseNumber()
+ attrs, err := BuildAttributesFromValue("n", map[string]any{
+ "count": float64(42),
+ }, nil)
+ require.NoError(t, err)
+
+ m := attrsToMap(attrs)
+ assert.Contains(t, m, "n.count")
+}
+
+func TestFlattenAttributes_BoolValues(t *testing.T) {
+ t.Parallel()
+
+ attrs, err := BuildAttributesFromValue("cfg", map[string]any{
+ "enabled": true,
+ "debug": false,
+ }, nil)
+ require.NoError(t, err)
+ assert.Len(t, attrs, 2)
+}
+
+// ===========================================================================
+// 8. sanitizeUTF8String
+// ===========================================================================
+
+func TestSanitizeUTF8String_ValidString(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, "hello world", sanitizeUTF8String("hello world"))
+}
+
+func TestSanitizeUTF8String_InvalidUTF8(t *testing.T) {
+ t.Parallel()
+
+ invalid := "hello\x80world"
+ result := sanitizeUTF8String(invalid)
+ assert.NotContains(t, result, "\x80")
+ assert.Contains(t, result, "hello")
+ assert.Contains(t, result, "world")
+}
+
+func TestSanitizeUTF8String_EmptyString(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, "", sanitizeUTF8String(""))
+}
+
+func TestSanitizeUTF8String_Unicode(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, "日本語テスト", sanitizeUTF8String("日本語テスト"))
+}
+
+// ===========================================================================
+// 9. HandleSpan helpers
+// ===========================================================================
+
+func TestHandleSpanBusinessErrorEvent_NilSpan(t *testing.T) {
+ t.Parallel()
+ assert.NotPanics(t, func() { HandleSpanBusinessErrorEvent(nil, "evt", assert.AnError) })
+}
+
+func TestHandleSpanBusinessErrorEvent_NilError(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider()
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ _, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+
+ assert.NotPanics(t, func() { HandleSpanBusinessErrorEvent(span, "evt", nil) })
+}
+
+func TestHandleSpanEvent_NilSpan(t *testing.T) {
+ t.Parallel()
+ assert.NotPanics(t, func() { HandleSpanEvent(nil, "evt") })
+}
+
+func TestHandleSpanError_NilSpan(t *testing.T) {
+ t.Parallel()
+ assert.NotPanics(t, func() { HandleSpanError(nil, "msg", assert.AnError) })
+}
+
+func TestHandleSpanError_NilError(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider()
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ _, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+
+ assert.NotPanics(t, func() { HandleSpanError(span, "msg", nil) })
+}
+
+// ===========================================================================
+// 10. SetSpanAttributesFromValue
+// ===========================================================================
+
+func TestSetSpanAttributesFromValue_NilSpan(t *testing.T) {
+ t.Parallel()
+ err := SetSpanAttributesFromValue(nil, "prefix", map[string]any{"k": "v"}, nil)
+ assert.NoError(t, err)
+}
+
+func TestSetSpanAttributesFromValue_NilValue(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider()
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ _, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+
+ err := SetSpanAttributesFromValue(span, "prefix", nil, nil)
+ assert.NoError(t, err)
+}
+
+// ===========================================================================
+// 11. BuildAttributesFromValue with redactor
+// ===========================================================================
+
+func TestBuildAttributesFromValue_WithRedactor(t *testing.T) {
+ t.Parallel()
+
+ r := NewDefaultRedactor()
+ attrs, err := BuildAttributesFromValue("req", map[string]any{
+ "username": "alice",
+ "password": "secret123",
+ }, r)
+ require.NoError(t, err)
+
+ m := attrsToMap(attrs)
+ assert.Equal(t, "alice", m["req.username"])
+ assert.NotEqual(t, "secret123", m["req.password"], "password should be redacted")
+}
+
+func TestBuildAttributesFromValue_StructInput(t *testing.T) {
+ t.Parallel()
+
+ type payload struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ }
+
+ attrs, err := BuildAttributesFromValue("obj", payload{ID: "123", Name: "test"}, nil)
+ require.NoError(t, err)
+
+ m := attrsToMap(attrs)
+ assert.Equal(t, "123", m["obj.id"])
+ assert.Equal(t, "test", m["obj.name"])
+}
+
+// ===========================================================================
+// 12. isNilShutdownable
+// ===========================================================================
+
+func TestIsNilShutdownable_UntypedNil(t *testing.T) {
+ t.Parallel()
+ assert.True(t, isNilShutdownable(nil))
+}
+
+func TestIsNilShutdownable_TypedNil(t *testing.T) {
+ t.Parallel()
+
+ var tp *sdktrace.TracerProvider
+ assert.True(t, isNilShutdownable(tp))
+}
+
+func TestIsNilShutdownable_ValidValue(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider()
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ assert.False(t, isNilShutdownable(tp))
+}
+
+// ===========================================================================
+// 13. InjectGRPCContext key normalization
+// ===========================================================================
+
+func TestInjectGRPCContext_TraceparentKeyNormalization(t *testing.T) {
+ prev := otel.GetTextMapPropagator()
+ t.Cleanup(func() { otel.SetTextMapPropagator(prev) })
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+
+ tp := sdktrace.NewTracerProvider()
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+
+ md := InjectGRPCContext(ctx, nil)
+ // The function should normalize "Traceparent" -> "traceparent"
+ assert.NotEmpty(t, md.Get("traceparent"), "traceparent key should be lowercase")
+}
+
+// ===========================================================================
+// 14. Propagation round-trip
+// ===========================================================================
+
+func TestQueuePropagation_RoundTrip(t *testing.T) {
+ prev := otel.GetTextMapPropagator()
+ prevTP := otel.GetTracerProvider()
+ t.Cleanup(func() {
+ otel.SetTextMapPropagator(prev)
+ otel.SetTracerProvider(prevTP)
+ })
+
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+ tp := sdktrace.NewTracerProvider()
+ otel.SetTracerProvider(tp)
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "producer")
+ defer span.End()
+
+ originalTraceID := span.SpanContext().TraceID().String()
+
+ // Inject into queue headers
+ queueHeaders := InjectQueueTraceContext(ctx)
+ assert.NotEmpty(t, queueHeaders)
+
+ // Extract on consumer side
+ consumerCtx := ExtractQueueTraceContext(context.Background(), queueHeaders)
+ extractedTraceID := GetTraceIDFromContext(consumerCtx)
+ assert.Equal(t, originalTraceID, extractedTraceID)
+
+ _ = tp.Shutdown(context.Background())
+}
+
+func TestHTTPPropagation_InjectAndVerify(t *testing.T) {
+ prev := otel.GetTextMapPropagator()
+ prevTP := otel.GetTracerProvider()
t.Cleanup(func() {
- telemetry.ShutdownTelemetry()
+ otel.SetTextMapPropagator(prev)
+ otel.SetTracerProvider(prevTP)
})
+
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+ tp := sdktrace.NewTracerProvider()
+ otel.SetTracerProvider(tp)
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "http-req")
+ defer span.End()
+
+ headers := make(map[string][]string)
+ InjectHTTPContext(ctx, headers)
+ assert.NotEmpty(t, headers["Traceparent"])
+
+ _ = tp.Shutdown(context.Background())
+}
+
+// ===========================================================================
+// 15. buildShutdownHandlers
+// ===========================================================================
+
+func TestBuildShutdownHandlers_NoComponents(t *testing.T) {
+ t.Parallel()
+
+ shutdown, shutdownCtx := buildShutdownHandlers(log.NewNop())
+ assert.NotPanics(t, func() { shutdown() })
+
+ err := shutdownCtx(context.Background())
+ assert.NoError(t, err)
+}
+
+func TestBuildShutdownHandlers_WithProviders(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider()
+ shutdown, shutdownCtx := buildShutdownHandlers(log.NewNop(), tp)
+
+ err := shutdownCtx(context.Background())
+ assert.NoError(t, err)
+
+ // Second shutdown may error (already shut down), but should not panic
+ assert.NotPanics(t, func() { shutdown() })
+}
+
+func TestBuildShutdownHandlers_NilComponents(t *testing.T) {
+ t.Parallel()
+
+ shutdown, shutdownCtx := buildShutdownHandlers(log.NewNop(), nil)
+ assert.NotPanics(t, func() { shutdown() })
+
+ err := shutdownCtx(context.Background())
+ assert.NoError(t, err)
+}
+
+func TestBuildShutdownHandlers_TypedNilProvider(t *testing.T) {
+ t.Parallel()
+
+ var tp *sdktrace.TracerProvider
+ shutdown, shutdownCtx := buildShutdownHandlers(log.NewNop(), tp)
+ assert.NotPanics(t, func() { shutdown() })
+
+ err := shutdownCtx(context.Background())
+ assert.NoError(t, err)
+}
+
+// ===========================================================================
+// 16. HandleSpan helpers with real spans
+// ===========================================================================
+
+func TestHandleSpanBusinessErrorEvent_WithSpan(t *testing.T) {
+ t.Parallel()
+
+ exporter := tracetest.NewInMemoryExporter()
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter))
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ _, span := tp.Tracer("test").Start(context.Background(), "op")
+
+ HandleSpanBusinessErrorEvent(span, "business_error", assert.AnError)
+ span.End()
+
+ spans := exporter.GetSpans()
+ require.Len(t, spans, 1)
+ require.NotEmpty(t, spans[0].Events, "business error event must be recorded")
+ assert.Equal(t, "business_error", spans[0].Events[0].Name)
+ // Status should remain OK (business errors don't set ERROR status)
+ assert.Equal(t, codes.Unset, spans[0].Status.Code, "business error must not set ERROR status")
+}
+
+func TestHandleSpanEvent_WithSpan(t *testing.T) {
+ t.Parallel()
+
+ exporter := tracetest.NewInMemoryExporter()
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter))
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ _, span := tp.Tracer("test").Start(context.Background(), "op")
+
+ HandleSpanEvent(span, "my_event", attribute.String("key", "value"))
+ span.End()
+
+ spans := exporter.GetSpans()
+ require.Len(t, spans, 1)
+ require.NotEmpty(t, spans[0].Events, "event must be recorded on span")
+ assert.Equal(t, "my_event", spans[0].Events[0].Name)
+}
+
+func TestHandleSpanError_WithSpan(t *testing.T) {
+ t.Parallel()
+
+ exporter := tracetest.NewInMemoryExporter()
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter))
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ _, span := tp.Tracer("test").Start(context.Background(), "op")
+
+ HandleSpanError(span, "something failed", assert.AnError)
+ span.End()
+
+ spans := exporter.GetSpans()
+ require.Len(t, spans, 1)
+ assert.Equal(t, codes.Error, spans[0].Status.Code, "HandleSpanError must set ERROR status")
+ assert.Contains(t, spans[0].Status.Description, "something failed")
+}
+
+func TestHandleSpanError_EmptyMessage(t *testing.T) {
+ t.Parallel()
+
+ exporter := tracetest.NewInMemoryExporter()
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter))
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+ _, span := tp.Tracer("test").Start(context.Background(), "op")
+
+ HandleSpanError(span, "", assert.AnError)
+ span.End()
+
+ spans := exporter.GetSpans()
+ require.Len(t, spans, 1)
+ assert.Equal(t, codes.Error, spans[0].Status.Code)
+ // With empty message, status should be just the error text (no leading ": ")
+ assert.False(t, strings.HasPrefix(spans[0].Status.Description, ": "),
+ "empty message must not produce leading ': ' in status description")
+}
+
+// ===========================================================================
+// 17. ShutdownTelemetry (non-nil) exercises error branch
+// ===========================================================================
+
+func TestTelemetry_ShutdownTelemetry_NonNil(t *testing.T) {
+ t.Parallel()
+
+ tl := newDisabledTelemetry(t)
+ assert.NotPanics(t, func() { tl.ShutdownTelemetry() })
+}
+
+// ===========================================================================
+// 18. InjectGRPCContext / ExtractGRPCContext tracestate normalization
+// ===========================================================================
+
+func TestInjectGRPCContext_TracestateNormalization(t *testing.T) {
+ prev := otel.GetTextMapPropagator()
+ t.Cleanup(func() { otel.SetTextMapPropagator(prev) })
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+
+ traceID, _ := trace.TraceIDFromHex("00112233445566778899aabbccddeeff")
+ spanID, _ := trace.SpanIDFromHex("0123456789abcdef")
+ ts := trace.TraceState{}
+ ts, _ = ts.Insert("vendor", "val")
+
+ sc := trace.NewSpanContext(trace.SpanContextConfig{
+ TraceID: traceID,
+ SpanID: spanID,
+ TraceFlags: trace.FlagsSampled,
+ TraceState: ts,
+ Remote: true,
+ })
+ ctx := trace.ContextWithSpanContext(context.Background(), sc)
+
+ md := InjectGRPCContext(ctx, nil)
+ assert.NotEmpty(t, md.Get("traceparent"))
+ assert.NotEmpty(t, md.Get("tracestate"))
+ // Verify PascalCase keys are removed
+ _, hasPascal := md["Traceparent"]
+ assert.False(t, hasPascal)
+}
+
+func TestExtractGRPCContext_TracestateNormalization(t *testing.T) {
+ prev := otel.GetTextMapPropagator()
+ t.Cleanup(func() { otel.SetTextMapPropagator(prev) })
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+
+ md := metadata.MD{
+ "traceparent": {"00-00112233445566778899aabbccddeeff-0123456789abcdef-01"},
+ "tracestate": {"vendor=val"},
+ }
+ ctx := ExtractGRPCContext(context.Background(), md)
+ span := trace.SpanFromContext(ctx)
+ assert.Equal(t, "00112233445566778899aabbccddeeff", span.SpanContext().TraceID().String())
+}
+
+// ===========================================================================
+// 19. Processor OnStart/OnEnd via tracer pipeline
+// ===========================================================================
+
+func TestAttrBagSpanProcessor_OnStartOnEnd_WithTracer(t *testing.T) {
+ t.Parallel()
+
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(AttrBagSpanProcessor{}))
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+ assert.NotNil(t, ctx)
+}
+
+func TestRedactingAttrBagSpanProcessor_OnStartOnEnd_WithTracer(t *testing.T) {
+ t.Parallel()
+
+ p := RedactingAttrBagSpanProcessor{Redactor: NewDefaultRedactor()}
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(p))
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+ assert.NotNil(t, ctx)
+}
+
+func TestAttrBagSpanProcessor_OnStart_WithContextAttributes(t *testing.T) {
+ t.Parallel()
+
+ exporter := tracetest.NewInMemoryExporter()
+ tp := sdktrace.NewTracerProvider(
+ sdktrace.WithSpanProcessor(AttrBagSpanProcessor{}),
+ sdktrace.WithSyncer(exporter),
+ )
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx := commons.ContextWithSpanAttributes(context.Background(), attribute.String("app.request.id", "r1"))
+ _, span := tp.Tracer("test").Start(ctx, "op")
+ span.End()
+
+ spans := exporter.GetSpans()
+ require.Len(t, spans, 1)
+ // Verify the context attribute was applied to the span
+ found := false
+ for _, a := range spans[0].Attributes {
+ if a.Key == "app.request.id" && a.Value.AsString() == "r1" {
+ found = true
+ }
+ }
+ assert.True(t, found, "span must contain app.request.id=r1 from context bag")
+}
+
+func TestRedactingAttrBagSpanProcessor_OnStart_WithContextAttributes(t *testing.T) {
+ t.Parallel()
+
+ p := RedactingAttrBagSpanProcessor{Redactor: NewDefaultRedactor()}
+ exporter := tracetest.NewInMemoryExporter()
+ tp := sdktrace.NewTracerProvider(
+ sdktrace.WithSpanProcessor(p),
+ sdktrace.WithSyncer(exporter),
+ )
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx := commons.ContextWithSpanAttributes(context.Background(),
+ attribute.String("app.request.id", "r1"),
+ attribute.String("user.password", "secret"),
+ )
+ _, span := tp.Tracer("test").Start(ctx, "op")
+ span.End()
+
+ spans := exporter.GetSpans()
+ require.Len(t, spans, 1)
+ // Verify the request ID is present and password is redacted
+ for _, a := range spans[0].Attributes {
+ if a.Key == "app.request.id" {
+ assert.Equal(t, "r1", a.Value.AsString(), "non-sensitive field should pass through")
+ }
+ if a.Key == "user.password" {
+ assert.NotEqual(t, "secret", a.Value.AsString(), "sensitive field should be redacted")
+ }
+ }
+}
+
+func TestRedactingAttrBagSpanProcessor_OnStart_NilRedactor(t *testing.T) {
+ t.Parallel()
+
+ p := RedactingAttrBagSpanProcessor{Redactor: nil}
+ tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(p))
+ t.Cleanup(func() { _ = tp.Shutdown(context.Background()) })
+
+ ctx, span := tp.Tracer("test").Start(context.Background(), "op")
+ defer span.End()
+ assert.NotNil(t, ctx)
+}
+
+// ===========================================================================
+// 20. flattenAttributes edge case: default branch (non-primitive type)
+// ===========================================================================
+
+func TestFlattenAttributes_DefaultBranch(t *testing.T) {
+ t.Parallel()
+
+ // After JSON round-trip, custom types become primitives. Test directly
+ // with a type that isn't map/slice/string/float64/bool/json.Number/nil.
+ type custom struct{ X int }
+ var attrs []attribute.KeyValue
+ flattenAttributes(&attrs, "key", custom{X: 42}, 0)
+ require.Len(t, attrs, 1)
+ assert.Equal(t, "key", string(attrs[0].Key))
+ assert.Contains(t, attrs[0].Value.AsString(), "42")
+}
+
+// ===========================================================================
+// 21. newResource coverage
+// ===========================================================================
+
+func TestNewResource(t *testing.T) {
+ t.Parallel()
+
+ cfg := &TelemetryConfig{
+ ServiceName: "svc",
+ ServiceVersion: "1.0",
+ DeploymentEnv: "test",
+ }
+ r := cfg.newResource()
+ assert.NotNil(t, r)
+}
+
+// ===========================================================================
+// 22. BuildAttributesFromValue error path
+// ===========================================================================
+
+func TestBuildAttributesFromValue_UnmarshalableValue(t *testing.T) {
+ t.Parallel()
+
+ // A channel cannot be JSON-marshaled
+ ch := make(chan int)
+ attrs, err := BuildAttributesFromValue("prefix", ch, nil)
+ assert.Error(t, err)
+ assert.Nil(t, attrs)
+}
+
+// ===========================================================================
+// helpers
+// ===========================================================================
+
+func attrsToMap(attrs []attribute.KeyValue) map[string]string {
+ m := make(map[string]string, len(attrs))
+ for _, a := range attrs {
+ m[string(a.Key)] = a.Value.Emit()
+ }
+
+ return m
}
diff --git a/commons/opentelemetry/processor.go b/commons/opentelemetry/processor.go
index 6cb1ca8b..5e7edbc1 100644
--- a/commons/opentelemetry/processor.go
+++ b/commons/opentelemetry/processor.go
@@ -1,13 +1,11 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package opentelemetry
import (
"context"
+ "strings"
- "github.com/LerianStudio/lib-commons/v2/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons"
+ "go.opentelemetry.io/otel/attribute"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)
@@ -16,14 +14,79 @@ import (
// AttrBagSpanProcessor copies request-scoped attributes from context into every span at start.
type AttrBagSpanProcessor struct{}
+// RedactingAttrBagSpanProcessor copies request attributes and applies redaction rules by key.
+type RedactingAttrBagSpanProcessor struct {
+ Redactor *Redactor
+}
+
+// OnStart applies request-scoped context attributes to newly started spans.
func (AttrBagSpanProcessor) OnStart(ctx context.Context, s sdktrace.ReadWriteSpan) {
if kv := commons.AttributesFromContext(ctx); len(kv) > 0 {
s.SetAttributes(kv...)
}
}
-func (AttrBagSpanProcessor) OnEnd(s sdktrace.ReadOnlySpan) {}
+// OnStart applies request-scoped attributes and redacts sensitive values before writing to span.
+func (p RedactingAttrBagSpanProcessor) OnStart(ctx context.Context, s sdktrace.ReadWriteSpan) {
+ kv := commons.AttributesFromContext(ctx)
+ if len(kv) == 0 {
+ return
+ }
+
+ if p.Redactor != nil {
+ kv = redactAttributesByKey(kv, p.Redactor)
+ }
+
+ s.SetAttributes(kv...)
+}
+
+// OnEnd is a no-op for this processor.
+func (AttrBagSpanProcessor) OnEnd(sdktrace.ReadOnlySpan) {}
+
+// OnEnd is a no-op for this processor.
+func (RedactingAttrBagSpanProcessor) OnEnd(sdktrace.ReadOnlySpan) {}
+
+// Shutdown is a no-op and always returns nil.
+func (AttrBagSpanProcessor) Shutdown(context.Context) error { return nil }
-func (AttrBagSpanProcessor) Shutdown(ctx context.Context) error { return nil }
+// Shutdown is a no-op and always returns nil.
+func (RedactingAttrBagSpanProcessor) Shutdown(context.Context) error { return nil }
-func (AttrBagSpanProcessor) ForceFlush(ctx context.Context) error { return nil }
+// ForceFlush is a no-op and always returns nil.
+func (AttrBagSpanProcessor) ForceFlush(context.Context) error { return nil }
+
+// ForceFlush is a no-op and always returns nil.
+func (RedactingAttrBagSpanProcessor) ForceFlush(context.Context) error { return nil }
+
+func redactAttributesByKey(attrs []attribute.KeyValue, redactor *Redactor) []attribute.KeyValue {
+ if redactor == nil {
+ return attrs
+ }
+
+ redacted := make([]attribute.KeyValue, 0, len(attrs))
+ for _, attr := range attrs {
+ key := string(attr.Key)
+
+ fieldName := key
+ if idx := strings.LastIndex(key, "."); idx >= 0 && idx+1 < len(key) {
+ fieldName = key[idx+1:]
+ }
+
+ action, ok := redactor.actionFor(key, fieldName)
+ if !ok {
+ redacted = append(redacted, attr)
+ continue
+ }
+
+ switch action {
+ case RedactionDrop:
+ continue
+ case RedactionHash:
+ redacted = append(redacted, attribute.String(string(attr.Key), redactor.hashString(attr.Value.Emit())))
+ default:
+ redacted = append(redacted, attribute.String(string(attr.Key), redactor.maskValue))
+ }
+ }
+
+ return redacted
+}
diff --git a/commons/opentelemetry/processor_test.go b/commons/opentelemetry/processor_test.go
new file mode 100644
index 00000000..17f051c0
--- /dev/null
+++ b/commons/opentelemetry/processor_test.go
@@ -0,0 +1,42 @@
+//go:build unit
+
+package opentelemetry
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/attribute"
+)
+
+func TestRedactAttributesByKey(t *testing.T) {
+ t.Parallel()
+
+ redactor, err := NewRedactor([]RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ {FieldPattern: `(?i)^document$`, Action: RedactionHash},
+ }, "***")
+ require.NoError(t, err)
+
+ attrs := []attribute.KeyValue{
+ attribute.String("user.id", "u1"),
+ attribute.String("user.password", "secret"),
+ attribute.String("auth.token", "tok_123"),
+ attribute.String("customer.document", "123456789"),
+ }
+
+ redacted := redactAttributesByKey(attrs, redactor)
+
+ values := make(map[string]string, len(redacted))
+ for _, attr := range redacted {
+ values[string(attr.Key)] = attr.Value.AsString()
+ }
+
+ assert.Equal(t, "u1", values["user.id"])
+ assert.Equal(t, "***", values["user.password"])
+ assert.NotContains(t, values, "auth.token")
+ assert.Contains(t, values["customer.document"], "sha256:")
+ assert.NotEqual(t, "123456789", values["customer.document"])
+}
diff --git a/commons/opentelemetry/queue_trace_example_test.go b/commons/opentelemetry/queue_trace_example_test.go
new file mode 100644
index 00000000..52c7da12
--- /dev/null
+++ b/commons/opentelemetry/queue_trace_example_test.go
@@ -0,0 +1,46 @@
+//go:build unit
+
+package opentelemetry_test
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/propagation"
+ "go.opentelemetry.io/otel/trace"
+)
+
+func ExamplePrepareQueueHeaders() {
+ prev := otel.GetTextMapPropagator()
+ otel.SetTextMapPropagator(propagation.TraceContext{})
+ defer otel.SetTextMapPropagator(prev)
+
+ traceID, _ := trace.TraceIDFromHex("00112233445566778899aabbccddeeff")
+ spanID, _ := trace.SpanIDFromHex("0123456789abcdef")
+
+ ctx := trace.ContextWithSpanContext(context.Background(), trace.NewSpanContext(trace.SpanContextConfig{
+ TraceID: traceID,
+ SpanID: spanID,
+ TraceFlags: trace.FlagsSampled,
+ Remote: true,
+ }))
+
+ headers := opentelemetry.PrepareQueueHeaders(ctx, map[string]any{"message_type": "transaction.created"})
+ traceParent, ok := headers["traceparent"]
+ if !ok {
+ traceParent = headers["Traceparent"]
+ }
+
+ extracted := opentelemetry.ExtractTraceContextFromQueueHeaders(context.Background(), headers)
+
+ fmt.Println(headers["message_type"])
+ fmt.Println(traceParent)
+ fmt.Println(opentelemetry.GetTraceIDFromContext(extracted))
+
+ // Output:
+ // transaction.created
+ // 00-00112233445566778899aabbccddeeff-0123456789abcdef-01
+ // 00112233445566778899aabbccddeeff
+}
diff --git a/commons/opentelemetry/queue_trace_test.go b/commons/opentelemetry/queue_trace_test.go
deleted file mode 100644
index 6ec237ba..00000000
--- a/commons/opentelemetry/queue_trace_test.go
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
-package opentelemetry
-
-import (
- "context"
- "testing"
-
- "go.opentelemetry.io/otel"
- "go.opentelemetry.io/otel/propagation"
- "go.opentelemetry.io/otel/sdk/trace"
- "go.opentelemetry.io/otel/trace/noop"
-)
-
-func TestQueueTraceContextPropagation(t *testing.T) {
- // Setup OpenTelemetry with proper propagator and real tracer
- tp := trace.NewTracerProvider()
- otel.SetTracerProvider(tp)
- otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
- propagation.TraceContext{},
- propagation.Baggage{},
- ))
- tracer := tp.Tracer("queue-trace-test")
-
- // Create a root span to simulate an HTTP request
- rootCtx, rootSpan := tracer.Start(context.Background(), "http-request")
- defer rootSpan.End()
-
- // Test injection
- headers := InjectQueueTraceContext(rootCtx)
- t.Logf("Injected headers: %+v", headers)
- if len(headers) == 0 {
- t.Error("Expected trace headers to be injected, got empty map")
- return
- }
-
- // Verify Traceparent header exists (OpenTelemetry uses canonical case)
- if _, exists := headers["Traceparent"]; !exists {
- t.Errorf("Expected 'Traceparent' header to be present in injected headers. Available headers: %v", headers)
- return
- }
-
- // Test extraction
- extractedCtx := ExtractQueueTraceContext(context.Background(), headers)
- if extractedCtx == nil {
- t.Error("Expected extracted context to be non-nil")
- }
-
- // Verify trace ID propagation
- originalTraceID := GetTraceIDFromContext(rootCtx)
- extractedTraceID := GetTraceIDFromContext(extractedCtx)
-
- if originalTraceID == "" {
- t.Error("Expected original trace ID to be non-empty")
- }
-
- if extractedTraceID == "" {
- t.Error("Expected extracted trace ID to be non-empty")
- }
-
- if originalTraceID != extractedTraceID {
- t.Errorf("Expected trace IDs to match: original=%s, extracted=%s", originalTraceID, extractedTraceID)
- }
-}
-
-func TestQueueTraceContextWithNilHeaders(t *testing.T) {
- ctx := context.Background()
-
- // Test extraction with nil headers
- extractedCtx := ExtractQueueTraceContext(ctx, nil)
- if extractedCtx != ctx {
- t.Error("Expected extracted context to be the same as input when headers are nil")
- }
-
- // Test extraction with empty headers
- extractedCtx = ExtractQueueTraceContext(ctx, map[string]string{})
- if extractedCtx == nil {
- t.Error("Expected extracted context to be non-nil even with empty headers")
- }
-}
-
-func TestGetTraceIDAndStateFromContext(t *testing.T) {
- // Test with empty context
- emptyCtx := context.Background()
- traceID := GetTraceIDFromContext(emptyCtx)
- traceState := GetTraceStateFromContext(emptyCtx)
-
- if traceID != "" {
- t.Errorf("Expected empty trace ID for empty context, got: %s", traceID)
- }
-
- if traceState != "" {
- t.Errorf("Expected empty trace state for empty context, got: %s", traceState)
- }
-
- // Test with span context
- tp := noop.NewTracerProvider()
- tracer := tp.Tracer("test")
- ctx, span := tracer.Start(context.Background(), "test-span")
- defer span.End()
-
- traceID = GetTraceIDFromContext(ctx)
- traceState = GetTraceStateFromContext(ctx)
-
- // Note: With noop tracer, these might still be empty, but functions should not panic
- if traceID == "" {
- t.Log("Trace ID is empty with noop tracer (expected)")
- }
-
- if traceState == "" {
- t.Log("Trace state is empty with noop tracer (expected)")
- }
-}
diff --git a/commons/opentelemetry/v2_test.go b/commons/opentelemetry/v2_test.go
new file mode 100644
index 00000000..856f5252
--- /dev/null
+++ b/commons/opentelemetry/v2_test.go
@@ -0,0 +1,176 @@
+//go:build unit
+
+package opentelemetry
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/log/global"
+ "go.opentelemetry.io/otel/propagation"
+ "go.opentelemetry.io/otel/sdk/trace"
+ "go.opentelemetry.io/otel/sdk/trace/tracetest"
+ oteltrace "go.opentelemetry.io/otel/trace"
+ "google.golang.org/grpc/metadata"
+)
+
+func TestNewTelemetry_Disabled(t *testing.T) {
+ tl, err := NewTelemetry(TelemetryConfig{
+ LibraryName: "test-lib",
+ ServiceName: "test-service",
+ ServiceVersion: "1.0.0",
+ DeploymentEnv: "test",
+ EnableTelemetry: false,
+ Logger: &log.NopLogger{},
+ })
+ require.NoError(t, err)
+ require.NotNil(t, tl)
+ assert.NotNil(t, tl.TracerProvider)
+ assert.NotNil(t, tl.MeterProvider)
+ assert.NotNil(t, tl.MetricsFactory)
+}
+
+func TestSetSpanAttributesFromValue_FlattensAndRedacts(t *testing.T) {
+ exporter := tracetest.NewInMemoryExporter()
+ tp := trace.NewTracerProvider(trace.WithSyncer(exporter))
+ tracer := tp.Tracer("test")
+
+ _, span := tracer.Start(context.Background(), "test-span")
+ err := SetSpanAttributesFromValue(span, "request", map[string]any{
+ "user": map[string]any{
+ "id": "u1",
+ "password": "top-secret",
+ },
+ "amount": 12.3,
+ }, NewDefaultRedactor())
+ require.NoError(t, err)
+ span.End()
+
+ spans := exporter.GetSpans()
+ require.Len(t, spans, 1)
+
+ attrs := spans[0].Attributes
+ find := func(key string) string {
+ for _, a := range attrs {
+ if string(a.Key) == key {
+ return a.Value.AsString()
+ }
+ }
+ return ""
+ }
+
+ assert.Equal(t, "u1", find("request.user.id"))
+ assert.NotEmpty(t, find("request.user.password"))
+ assert.NotEqual(t, "top-secret", find("request.user.password"))
+
+ if err := tp.Shutdown(context.Background()); err != nil {
+ t.Errorf("tp.Shutdown failed: %v", err)
+ }
+}
+
+func TestPropagation_HTTP_GRPC_Queue(t *testing.T) {
+ prevPropagator := otel.GetTextMapPropagator()
+ prevTracerProvider := otel.GetTracerProvider()
+ t.Cleanup(func() {
+ otel.SetTextMapPropagator(prevPropagator)
+ otel.SetTracerProvider(prevTracerProvider)
+ })
+
+ otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
+
+ tp := trace.NewTracerProvider()
+ otel.SetTracerProvider(tp)
+ tracer := tp.Tracer("test")
+
+ ctx, span := tracer.Start(context.Background(), "root")
+ defer span.End()
+
+ headers := map[string][]string{}
+ InjectHTTPContext(ctx, headers)
+ assert.NotEmpty(t, headers["Traceparent"])
+
+ md := InjectGRPCContext(ctx, nil)
+ assert.NotEmpty(t, md.Get("traceparent"))
+
+ queueHeaders := InjectQueueTraceContext(ctx)
+ extracted := ExtractQueueTraceContext(context.Background(), queueHeaders)
+ assert.Equal(t, span.SpanContext().TraceID().String(), oteltrace.SpanFromContext(extracted).SpanContext().TraceID().String())
+
+ if err := tp.Shutdown(context.Background()); err != nil {
+ t.Errorf("tp.Shutdown failed: %v", err)
+ }
+}
+
+func TestApplyGlobalsRestoresProviders(t *testing.T) {
+ prevPropagator := otel.GetTextMapPropagator()
+ prevTracerProvider := otel.GetTracerProvider()
+ prevMeterProvider := otel.GetMeterProvider()
+ prevLoggerProvider := global.GetLoggerProvider()
+ t.Cleanup(func() {
+ otel.SetTextMapPropagator(prevPropagator)
+ otel.SetTracerProvider(prevTracerProvider)
+ otel.SetMeterProvider(prevMeterProvider)
+ global.SetLoggerProvider(prevLoggerProvider)
+ })
+
+ tl, err := NewTelemetry(TelemetryConfig{
+ LibraryName: "test-lib",
+ ServiceName: "test-service",
+ ServiceVersion: "1.0.0",
+ DeploymentEnv: "test",
+ EnableTelemetry: false,
+ Logger: &log.NopLogger{},
+ })
+ require.NoError(t, err)
+
+ require.NoError(t, tl.ApplyGlobals())
+
+ assert.Same(t, tl.TracerProvider, otel.GetTracerProvider())
+ assert.Same(t, tl.MeterProvider, otel.GetMeterProvider())
+ assert.Same(t, tl.LoggerProvider, global.GetLoggerProvider())
+}
+
+func TestObfuscateStruct_Actions(t *testing.T) {
+ redactor, err := NewRedactor([]RedactionRule{
+ {FieldPattern: `(?i)^password$`, Action: RedactionMask},
+ {FieldPattern: `(?i)^document$`, Action: RedactionHash},
+ {PathPattern: `(?i)^session\.token$`, FieldPattern: `(?i)^token$`, Action: RedactionDrop},
+ }, "***")
+ require.NoError(t, err)
+
+ payload := map[string]any{
+ "password": "secret",
+ "document": "123456789",
+ "session": map[string]any{"token": "tok_abc"},
+ }
+
+ obfuscated, err := ObfuscateStruct(payload, redactor)
+ require.NoError(t, err)
+
+ b, err := json.Marshal(obfuscated)
+ require.NoError(t, err)
+
+ var decoded map[string]any
+ require.NoError(t, json.Unmarshal(b, &decoded))
+ assert.Equal(t, "***", decoded["password"])
+ assert.Contains(t, decoded["document"], "sha256:")
+ assert.NotContains(t, decoded["session"], "token")
+}
+
+func TestHandleSpanHelpers_NoPanicsOnNil(t *testing.T) {
+ var span oteltrace.Span
+ assert.NotPanics(t, func() {
+ HandleSpanEvent(span, "event", attribute.String("k", "v"))
+ HandleSpanBusinessErrorEvent(span, "event", assert.AnError)
+ HandleSpanError(span, "msg", assert.AnError)
+ })
+ assert.NotPanics(t, func() {
+ _ = ExtractGRPCContext(context.Background(), metadata.MD{})
+ })
+}
diff --git a/commons/os.go b/commons/os.go
index dff699e1..31ab52bb 100644
--- a/commons/os.go
+++ b/commons/os.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package commons
import (
@@ -29,30 +25,42 @@ func GetenvOrDefault(key string, defaultValue string) string {
return str
}
-// GetenvBoolOrDefault returns the value of os.Getenv(key string) value as bool or defaultValue if error
-// Is the environment variable (key) is not defined, it returns the given defaultValue
-// If the environment variable (key) is not a valid bool format, it returns the given defaultValue
+// GetenvBoolOrDefault returns the value of os.Getenv(key string) value as bool or defaultValue if error.
+// If the environment variable (key) is not defined, it returns the given defaultValue.
+// If the environment variable (key) is not a valid bool format, it returns the given defaultValue.
// If any error occurring during bool parse, it returns the given defaultValue.
+// A warning is printed to stderr when a non-empty value fails to parse, providing
+// visibility into misconfigured environment variables.
func GetenvBoolOrDefault(key string, defaultValue bool) bool {
str := os.Getenv(key)
val, err := strconv.ParseBool(str)
if err != nil {
+ if str != "" {
+ fmt.Fprintf(os.Stderr, "WARN: env var %s=%q is not a valid bool, using default %v\n", key, str, defaultValue)
+ }
+
return defaultValue
}
return val
}
-// GetenvIntOrDefault returns the value of os.Getenv(key string) value as int or defaultValue if error
-// If the environment variable (key) is not defined, it returns the given defaultValue
-// If the environment variable (key) is not a valid int format, it returns the given defaultValue
+// GetenvIntOrDefault returns the value of os.Getenv(key string) value as int or defaultValue if error.
+// If the environment variable (key) is not defined, it returns the given defaultValue.
+// If the environment variable (key) is not a valid int format, it returns the given defaultValue.
// If any error occurring during int parse, it returns the given defaultValue.
+// A warning is printed to stderr when a non-empty value fails to parse, providing
+// visibility into misconfigured environment variables.
func GetenvIntOrDefault(key string, defaultValue int64) int64 {
str := os.Getenv(key)
val, err := strconv.ParseInt(str, 10, 64)
if err != nil {
+ if str != "" {
+ fmt.Fprintf(os.Stderr, "WARN: env var %s=%q is not a valid int, using default %v\n", key, str, defaultValue)
+ }
+
return defaultValue
}
@@ -70,19 +78,20 @@ var (
localEnvConfigOnce sync.Once
)
-// InitLocalEnvConfig load a .env file to set up local environment vars
+// InitLocalEnvConfig load a .env file to set up local environment vars.
// It's called once per application process.
+// Version and environment are always logged in a plain startup banner format.
func InitLocalEnvConfig() *LocalEnvConfig {
version := GetenvOrDefault("VERSION", "NO-VERSION")
envName := GetenvOrDefault("ENV_NAME", "local")
- fmt.Printf("VERSION: \u001B[31m%s\u001B[0m\n", version)
- fmt.Printf("ENVIRONMENT NAME: \u001B[31m(%s)\u001B[0m\n", envName)
+ fmt.Printf("VERSION: %s\n\n", version)
+ fmt.Printf("ENVIRONMENT NAME: %s\n\n", envName)
if envName == "local" {
localEnvConfigOnce.Do(func() {
if err := godotenv.Load(); err != nil {
- fmt.Println("Skipping \u001B[31m.env\u001B[0m file, using env", envName)
+ fmt.Printf("Skipping .env file; using environment: %s\n", envName)
localEnvConfig = &LocalEnvConfig{
Initialized: false,
@@ -97,13 +106,29 @@ func InitLocalEnvConfig() *LocalEnvConfig {
})
}
+ // Always return a non-nil config with safe defaults so callers never
+ // need to nil-check. Non-local environments get Initialized=false.
+ if localEnvConfig == nil {
+ return &LocalEnvConfig{Initialized: false}
+ }
+
return localEnvConfig
}
-// SetConfigFromEnvVars builds a struct by setting it fields values using the "var" tag
-// Constraints: s any - must be an initialized pointer
+// ErrNilConfig indicates that a nil configuration value was passed to SetConfigFromEnvVars.
+var ErrNilConfig = errors.New("config must not be nil")
+
+// ErrNotStruct indicates that the pointer target is not a struct.
+var ErrNotStruct = errors.New("pointer must reference a struct")
+
+// SetConfigFromEnvVars builds a struct by setting its field values using the "env" tag.
+// Constraints: s must be a non-nil pointer to an initialized struct.
// Supported types: String, Boolean, Int, Int8, Int16, Int32 and Int64.
func SetConfigFromEnvVars(s any) error {
+ if s == nil {
+ return ErrNilConfig
+ }
+
v := reflect.ValueOf(s)
t := v.Type()
@@ -111,8 +136,18 @@ func SetConfigFromEnvVars(s any) error {
return ErrNotPointer
}
+ // Guard against typed-nil pointers (e.g. (*MyStruct)(nil)).
+ if v.IsNil() {
+ return ErrNilConfig
+ }
+
+ // The pointer must reference a struct.
+ if t.Elem().Kind() != reflect.Struct {
+ return ErrNotStruct
+ }
+
e := t.Elem()
- for i := 0; i < e.NumField(); i++ {
+ for i := range e.NumField() {
f := e.Field(i)
if tag, ok := f.Tag.Lookup("env"); ok {
values := strings.Split(tag, ",")
@@ -134,13 +169,3 @@ func SetConfigFromEnvVars(s any) error {
return nil
}
-
-// Deprecated: Use SetConfigFromEnvVars instead for proper error handling.
-// EnsureConfigFromEnvVars panics on error. Prefer SetConfigFromEnvVars for graceful error handling.
-func EnsureConfigFromEnvVars(s any) any {
- if err := SetConfigFromEnvVars(s); err != nil {
- panic(err)
- }
-
- return s
-}
diff --git a/commons/os_test.go b/commons/os_test.go
index c3a2e992..780ca28f 100644
--- a/commons/os_test.go
+++ b/commons/os_test.go
@@ -1,14 +1,17 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package commons
import (
+ "bytes"
+ "io"
"os"
+ "strings"
+ "sync"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestGetenvOrDefault_WithValue(t *testing.T) {
@@ -186,28 +189,91 @@ func TestSetConfigFromEnvVars_MissingEnvVars(t *testing.T) {
assert.Empty(t, config.Field, "missing env var should result in zero value")
}
-func TestEnsureConfigFromEnvVars_Success(t *testing.T) {
+func TestSetConfigFromEnvVars_NilInterface(t *testing.T) {
+ err := SetConfigFromEnvVars(nil)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilConfig)
+}
+
+func TestSetConfigFromEnvVars_TypedNilPointer(t *testing.T) {
type Config struct {
- Field string `env:"TEST_ENSURE_FIELD"`
+ Field string `env:"TEST_FIELD"`
}
- t.Setenv("TEST_ENSURE_FIELD", "value")
+ var config *Config // typed nil
- config := &Config{}
- result := EnsureConfigFromEnvVars(config)
+ err := SetConfigFromEnvVars(config)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilConfig)
+}
+
+func TestSetConfigFromEnvVars_PointerToNonStruct(t *testing.T) {
+ s := "not a struct"
+
+ err := SetConfigFromEnvVars(&s)
- assert.NotNil(t, result)
- assert.Equal(t, "value", config.Field)
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrNotStruct)
}
-func TestEnsureConfigFromEnvVars_PanicOnNonPointer(t *testing.T) {
- type Config struct {
- Field string `env:"TEST_FIELD"`
+func TestInitLocalEnvConfig_NonLocalReturnsNonNil(t *testing.T) {
+ t.Setenv("VERSION", "1.0.0")
+ t.Setenv("ENV_NAME", "production")
+
+ // Reset the once guard so we can test fresh.
+ localEnvConfig = nil
+ localEnvConfigOnce = sync.Once{}
+
+ result := InitLocalEnvConfig()
+
+ require.NotNil(t, result, "InitLocalEnvConfig must return non-nil even for non-local env")
+ assert.False(t, result.Initialized)
+}
+
+func TestInitLocalEnvConfigPrintsVersionAndEnvironment(t *testing.T) {
+ t.Setenv("VERSION", "NO-VERSION")
+ t.Setenv("ENV_NAME", "development")
+
+ localEnvConfig = nil
+ localEnvConfigOnce = sync.Once{}
+
+ stdout := os.Stdout
+ reader, writer, err := os.Pipe()
+ if err != nil {
+ t.Fatalf("create pipe: %v", err)
}
- config := Config{}
+ os.Stdout = writer
+
+ var output bytes.Buffer
+ copyDone := make(chan struct{})
+ copyErrCh := make(chan error, 1)
+ go func() {
+ _, copyErr := io.Copy(&output, reader)
+ copyErrCh <- copyErr
+ close(copyDone)
+ }()
- assert.Panics(t, func() {
- EnsureConfigFromEnvVars(config)
- }, "EnsureConfigFromEnvVars should panic on non-pointer")
+ defer func() {
+ require.NoError(t, reader.Close())
+ os.Stdout = stdout
+ }()
+
+ InitLocalEnvConfig()
+
+ if err := writer.Close(); err != nil {
+ t.Fatalf("close pipe writer: %v", err)
+ }
+
+ <-copyDone
+ require.NoError(t, <-copyErrCh)
+
+ result := output.String()
+
+ want := "VERSION: NO-VERSION\n\nENVIRONMENT NAME: development\n\n"
+ if !strings.Contains(result, want) {
+ t.Fatalf("unexpected output. got: %q", result)
+ }
}
diff --git a/commons/outbox/classifier.go b/commons/outbox/classifier.go
new file mode 100644
index 00000000..f29e33f1
--- /dev/null
+++ b/commons/outbox/classifier.go
@@ -0,0 +1,16 @@
+package outbox
+
+// RetryClassifier determines whether an error should not be retried.
+type RetryClassifier interface {
+ IsNonRetryable(err error) bool
+}
+
+type RetryClassifierFunc func(err error) bool
+
+func (fn RetryClassifierFunc) IsNonRetryable(err error) bool {
+ if fn == nil {
+ return false
+ }
+
+ return fn(err)
+}
diff --git a/commons/outbox/config.go b/commons/outbox/config.go
new file mode 100644
index 00000000..9b47fab5
--- /dev/null
+++ b/commons/outbox/config.go
@@ -0,0 +1,300 @@
+package outbox
+
+import (
+ "strings"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ "go.opentelemetry.io/otel/metric"
+)
+
+const (
+ defaultDispatchInterval = 2 * time.Second
+ defaultBatchSize = 50
+ defaultPublishMaxAttempts = 3
+ defaultPublishBackoff = 200 * time.Millisecond
+ defaultListPendingFailureThreshold = 3
+ defaultRetryWindow = 5 * time.Minute
+ defaultMaxDispatchAttempts = 10
+ defaultProcessingTimeout = 10 * time.Minute
+ defaultPriorityBudget = 10
+ defaultMaxFailedPerBatch = 25
+ defaultMaxTenantMetricDimensions = 1000
+ defaultMaxTrackedFailureTenants = 4096
+ defaultTenantFailureCounterFallback = "_default"
+)
+
+// DispatcherConfig controls dispatcher polling, retry, and metric behavior.
+type DispatcherConfig struct {
+ // DispatchInterval is the periodic interval between dispatch cycles.
+ DispatchInterval time.Duration
+ // BatchSize is the max number of events processed per cycle.
+ BatchSize int
+ // PublishMaxAttempts is the max publish attempts for one event.
+ PublishMaxAttempts int
+ // PublishBackoff is the base backoff between publish retries.
+ PublishBackoff time.Duration
+ // ListPendingFailureThreshold emits an error log once repeated list failures reach this count.
+ ListPendingFailureThreshold int
+ // RetryWindow is the minimum age for failed events to become retry-eligible.
+ RetryWindow time.Duration
+ // MaxDispatchAttempts is the max total dispatch attempts before invalidation.
+ MaxDispatchAttempts int
+ // ProcessingTimeout is the age threshold for reclaiming stuck processing events.
+ ProcessingTimeout time.Duration
+ // PriorityBudget limits how many events can be selected via priority lists per cycle.
+ PriorityBudget int
+ // MaxFailedPerBatch limits how many failed events are reclaimed in one cycle.
+ MaxFailedPerBatch int
+ // PriorityEventTypes defines ordered event types to pull first each cycle.
+ PriorityEventTypes []string
+ // IncludeTenantMetrics enables tenant metric attributes and can increase cardinality.
+ IncludeTenantMetrics bool
+ // MaxTenantMetricDimensions caps unique tenant labels before falling back to an overflow label.
+ MaxTenantMetricDimensions int
+ // MaxTrackedListPendingFailureTenants caps in-memory tenant counters for ListPending failures.
+ MaxTrackedListPendingFailureTenants int
+ // MeterProvider overrides the default global meter provider when set.
+ MeterProvider metric.MeterProvider
+}
+
+// DefaultDispatcherConfig returns the baseline dispatcher configuration.
+func DefaultDispatcherConfig() DispatcherConfig {
+ return DispatcherConfig{
+ DispatchInterval: defaultDispatchInterval,
+ BatchSize: defaultBatchSize,
+ PublishMaxAttempts: defaultPublishMaxAttempts,
+ PublishBackoff: defaultPublishBackoff,
+ ListPendingFailureThreshold: defaultListPendingFailureThreshold,
+ RetryWindow: defaultRetryWindow,
+ MaxDispatchAttempts: defaultMaxDispatchAttempts,
+ ProcessingTimeout: defaultProcessingTimeout,
+ PriorityBudget: defaultPriorityBudget,
+ MaxFailedPerBatch: defaultMaxFailedPerBatch,
+ PriorityEventTypes: nil,
+ IncludeTenantMetrics: false,
+ MaxTenantMetricDimensions: defaultMaxTenantMetricDimensions,
+ MaxTrackedListPendingFailureTenants: defaultMaxTrackedFailureTenants,
+ MeterProvider: nil,
+ }
+}
+
+func (cfg *DispatcherConfig) normalize() {
+ defaults := DefaultDispatcherConfig()
+
+ if cfg.DispatchInterval <= 0 {
+ cfg.DispatchInterval = defaults.DispatchInterval
+ }
+
+ if cfg.BatchSize <= 0 {
+ cfg.BatchSize = defaults.BatchSize
+ }
+
+ if cfg.PublishMaxAttempts <= 0 {
+ cfg.PublishMaxAttempts = defaults.PublishMaxAttempts
+ }
+
+ if cfg.PublishBackoff <= 0 {
+ cfg.PublishBackoff = defaults.PublishBackoff
+ }
+
+ if cfg.ListPendingFailureThreshold <= 0 {
+ cfg.ListPendingFailureThreshold = defaults.ListPendingFailureThreshold
+ }
+
+ if cfg.RetryWindow <= 0 {
+ cfg.RetryWindow = defaults.RetryWindow
+ }
+
+ if cfg.MaxDispatchAttempts <= 0 {
+ cfg.MaxDispatchAttempts = defaults.MaxDispatchAttempts
+ }
+
+ if cfg.ProcessingTimeout <= 0 {
+ cfg.ProcessingTimeout = defaults.ProcessingTimeout
+ }
+
+ if cfg.PriorityBudget <= 0 {
+ cfg.PriorityBudget = defaults.PriorityBudget
+ }
+
+ if cfg.MaxFailedPerBatch <= 0 {
+ cfg.MaxFailedPerBatch = defaults.MaxFailedPerBatch
+ }
+
+ if cfg.MaxTenantMetricDimensions <= 0 {
+ cfg.MaxTenantMetricDimensions = defaults.MaxTenantMetricDimensions
+ }
+
+ if cfg.MaxTrackedListPendingFailureTenants <= 0 {
+ cfg.MaxTrackedListPendingFailureTenants = defaults.MaxTrackedListPendingFailureTenants
+ }
+}
+
+// DispatcherOption mutates dispatcher configuration at construction.
+type DispatcherOption func(*Dispatcher)
+
+// WithBatchSize sets the maximum events processed in one dispatch cycle.
+func WithBatchSize(size int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if size > 0 {
+ dispatcher.cfg.BatchSize = size
+ }
+ }
+}
+
+// WithDispatchInterval sets the dispatch polling interval.
+func WithDispatchInterval(interval time.Duration) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if interval > 0 {
+ dispatcher.cfg.DispatchInterval = interval
+ }
+ }
+}
+
+// WithPublishMaxAttempts sets max publish attempts per event.
+func WithPublishMaxAttempts(maxAttempts int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if maxAttempts > 0 {
+ dispatcher.cfg.PublishMaxAttempts = maxAttempts
+ }
+ }
+}
+
+// WithPublishBackoff sets base backoff for publish retry attempts.
+func WithPublishBackoff(backoff time.Duration) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if backoff > 0 {
+ dispatcher.cfg.PublishBackoff = backoff
+ }
+ }
+}
+
+// WithRetryWindow sets failed-event cooldown before retry reclamation.
+func WithRetryWindow(retryWindow time.Duration) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if retryWindow > 0 {
+ dispatcher.cfg.RetryWindow = retryWindow
+ }
+ }
+}
+
+// WithMaxDispatchAttempts sets max dispatch attempts before invalidation.
+func WithMaxDispatchAttempts(attempts int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if attempts > 0 {
+ dispatcher.cfg.MaxDispatchAttempts = attempts
+ }
+ }
+}
+
+// WithProcessingTimeout sets the timeout used to reclaim stuck processing events.
+func WithProcessingTimeout(timeout time.Duration) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if timeout > 0 {
+ dispatcher.cfg.ProcessingTimeout = timeout
+ }
+ }
+}
+
+// WithListPendingFailureThreshold sets the log threshold for repeated list failures.
+func WithListPendingFailureThreshold(threshold int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if threshold > 0 {
+ dispatcher.cfg.ListPendingFailureThreshold = threshold
+ }
+ }
+}
+
+// WithPriorityBudget sets the per-cycle priority selection budget.
+func WithPriorityBudget(budget int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if budget > 0 {
+ dispatcher.cfg.PriorityBudget = budget
+ }
+ }
+}
+
+// WithMaxFailedPerBatch sets max failed events reclaimed each cycle.
+func WithMaxFailedPerBatch(maxFailed int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if maxFailed > 0 {
+ dispatcher.cfg.MaxFailedPerBatch = maxFailed
+ }
+ }
+}
+
+// WithPriorityEventTypes sets the ordered event types selected before generic pending events.
+func WithPriorityEventTypes(eventTypes ...string) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ types := make([]string, 0, len(eventTypes))
+ for _, eventType := range eventTypes {
+ normalized := strings.TrimSpace(eventType)
+ if normalized == "" {
+ continue
+ }
+
+ types = append(types, normalized)
+ }
+
+ if len(types) == 0 {
+ dispatcher.cfg.PriorityEventTypes = nil
+
+ return
+ }
+
+ dispatcher.cfg.PriorityEventTypes = types
+ }
+}
+
+// WithRetryClassifier sets the non-retryable error classifier.
+func WithRetryClassifier(classifier RetryClassifier) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if nilcheck.Interface(classifier) {
+ dispatcher.retryClassifier = nil
+
+ return
+ }
+
+ dispatcher.retryClassifier = classifier
+ }
+}
+
+// WithTenantMetricAttributes toggles tenant attributes for dispatcher metrics.
+func WithTenantMetricAttributes(enabled bool) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ dispatcher.cfg.IncludeTenantMetrics = enabled
+ }
+}
+
+// WithMaxTenantMetricDimensions sets the maximum unique tenant labels used in metrics.
+func WithMaxTenantMetricDimensions(maxDimensions int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if maxDimensions > 0 {
+ dispatcher.cfg.MaxTenantMetricDimensions = maxDimensions
+ }
+ }
+}
+
+// WithMaxTrackedListPendingFailureTenants sets the in-memory cap for tenant-specific ListPending failure counters.
+func WithMaxTrackedListPendingFailureTenants(maxTenants int) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if maxTenants > 0 {
+ dispatcher.cfg.MaxTrackedListPendingFailureTenants = maxTenants
+ }
+ }
+}
+
+// WithMeterProvider injects a custom meter provider for dispatcher metrics.
+// Passing nil keeps the default global OpenTelemetry meter provider.
+func WithMeterProvider(provider metric.MeterProvider) DispatcherOption {
+ return func(dispatcher *Dispatcher) {
+ if nilcheck.Interface(provider) {
+ dispatcher.cfg.MeterProvider = nil
+
+ return
+ }
+
+ dispatcher.cfg.MeterProvider = provider
+ }
+}
diff --git a/commons/outbox/config_test.go b/commons/outbox/config_test.go
new file mode 100644
index 00000000..1189e8fb
--- /dev/null
+++ b/commons/outbox/config_test.go
@@ -0,0 +1,138 @@
+//go:build unit
+
+package outbox
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type pointerRetryClassifier struct{}
+
+func (*pointerRetryClassifier) IsNonRetryable(error) bool { return true }
+
+func TestDispatcherConfigNormalize_AppliesDefaults(t *testing.T) {
+ t.Parallel()
+
+ cfg := DispatcherConfig{
+ DispatchInterval: -1,
+ BatchSize: 0,
+ PublishMaxAttempts: -2,
+ PublishBackoff: 0,
+ ListPendingFailureThreshold: -1,
+ RetryWindow: 0,
+ MaxDispatchAttempts: 0,
+ ProcessingTimeout: -5,
+ PriorityBudget: 0,
+ MaxFailedPerBatch: -1,
+ }
+
+ cfg.normalize()
+
+ defaults := DefaultDispatcherConfig()
+ require.Equal(t, defaults.DispatchInterval, cfg.DispatchInterval)
+ require.Equal(t, defaults.BatchSize, cfg.BatchSize)
+ require.Equal(t, defaults.PublishMaxAttempts, cfg.PublishMaxAttempts)
+ require.Equal(t, defaults.PublishBackoff, cfg.PublishBackoff)
+ require.Equal(t, defaults.ListPendingFailureThreshold, cfg.ListPendingFailureThreshold)
+ require.Equal(t, defaults.RetryWindow, cfg.RetryWindow)
+ require.Equal(t, defaults.MaxDispatchAttempts, cfg.MaxDispatchAttempts)
+ require.Equal(t, defaults.ProcessingTimeout, cfg.ProcessingTimeout)
+ require.Equal(t, defaults.PriorityBudget, cfg.PriorityBudget)
+ require.Equal(t, defaults.MaxFailedPerBatch, cfg.MaxFailedPerBatch)
+ require.Equal(t, defaults.MaxTenantMetricDimensions, cfg.MaxTenantMetricDimensions)
+ require.Equal(t, defaults.MaxTrackedListPendingFailureTenants, cfg.MaxTrackedListPendingFailureTenants)
+ require.False(t, cfg.IncludeTenantMetrics)
+ require.Nil(t, cfg.MeterProvider)
+}
+
+func TestDispatcherConfigNormalize_PreservesValidValues(t *testing.T) {
+ t.Parallel()
+
+ cfg := DispatcherConfig{
+ DispatchInterval: 3 * time.Second,
+ BatchSize: 25,
+ PublishMaxAttempts: 7,
+ PublishBackoff: 120 * time.Millisecond,
+ ListPendingFailureThreshold: 8,
+ RetryWindow: 2 * time.Minute,
+ MaxDispatchAttempts: 9,
+ ProcessingTimeout: 4 * time.Minute,
+ PriorityBudget: 11,
+ MaxFailedPerBatch: 13,
+ IncludeTenantMetrics: true,
+ MaxTenantMetricDimensions: 55,
+ MaxTrackedListPendingFailureTenants: 99,
+ }
+
+ cfg.normalize()
+
+ require.Equal(t, 3*time.Second, cfg.DispatchInterval)
+ require.Equal(t, 25, cfg.BatchSize)
+ require.Equal(t, 7, cfg.PublishMaxAttempts)
+ require.Equal(t, 120*time.Millisecond, cfg.PublishBackoff)
+ require.Equal(t, 8, cfg.ListPendingFailureThreshold)
+ require.Equal(t, 2*time.Minute, cfg.RetryWindow)
+ require.Equal(t, 9, cfg.MaxDispatchAttempts)
+ require.Equal(t, 4*time.Minute, cfg.ProcessingTimeout)
+ require.Equal(t, 11, cfg.PriorityBudget)
+ require.Equal(t, 13, cfg.MaxFailedPerBatch)
+ require.True(t, cfg.IncludeTenantMetrics)
+ require.Equal(t, 55, cfg.MaxTenantMetricDimensions)
+ require.Equal(t, 99, cfg.MaxTrackedListPendingFailureTenants)
+}
+
+func TestWithRetryClassifier_IgnoresTypedNil(t *testing.T) {
+ t.Parallel()
+
+ dispatcher := &Dispatcher{}
+ var classifier *pointerRetryClassifier
+
+ WithRetryClassifier(classifier)(dispatcher)
+
+ require.Nil(t, dispatcher.retryClassifier)
+}
+
+func TestWithMaxTenantMetricDimensions(t *testing.T) {
+ t.Parallel()
+
+ dispatcher := &Dispatcher{cfg: DefaultDispatcherConfig()}
+
+ WithMaxTenantMetricDimensions(42)(dispatcher)
+ require.Equal(t, 42, dispatcher.cfg.MaxTenantMetricDimensions)
+
+ WithMaxTenantMetricDimensions(0)(dispatcher)
+ require.Equal(t, 42, dispatcher.cfg.MaxTenantMetricDimensions)
+}
+
+func TestWithPriorityEventTypes_EmptyInputKeepsNil(t *testing.T) {
+ t.Parallel()
+
+ dispatcher := &Dispatcher{cfg: DefaultDispatcherConfig()}
+
+ WithPriorityEventTypes("")(dispatcher)
+ require.Nil(t, dispatcher.cfg.PriorityEventTypes)
+}
+
+func TestWithPriorityEventTypes_TrimsWhitespaceAndDropsEmpty(t *testing.T) {
+ t.Parallel()
+
+ dispatcher := &Dispatcher{cfg: DefaultDispatcherConfig()}
+
+ WithPriorityEventTypes(" payments.created ", "\t", "payments.failed", " ")(dispatcher)
+ require.Equal(t, []string{"payments.created", "payments.failed"}, dispatcher.cfg.PriorityEventTypes)
+}
+
+func TestWithMaxTrackedListPendingFailureTenants(t *testing.T) {
+ t.Parallel()
+
+ dispatcher := &Dispatcher{cfg: DefaultDispatcherConfig()}
+
+ WithMaxTrackedListPendingFailureTenants(12)(dispatcher)
+ require.Equal(t, 12, dispatcher.cfg.MaxTrackedListPendingFailureTenants)
+
+ WithMaxTrackedListPendingFailureTenants(0)(dispatcher)
+ require.Equal(t, 12, dispatcher.cfg.MaxTrackedListPendingFailureTenants)
+}
diff --git a/commons/outbox/dispatcher.go b/commons/outbox/dispatcher.go
new file mode 100644
index 00000000..11aa3bef
--- /dev/null
+++ b/commons/outbox/dispatcher.go
@@ -0,0 +1,913 @@
+package outbox
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ "github.com/google/uuid"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/metric"
+ "go.opentelemetry.io/otel/trace"
+ "go.opentelemetry.io/otel/trace/noop"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/backoff"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+)
+
+const overflowTenantMetricLabel = "_other"
+
+type tenantRequirementReporter interface {
+ RequiresTenant() bool
+}
+
+// Dispatcher handles publishing outbox events through registered handlers.
+type Dispatcher struct {
+ repo OutboxRepository
+ handlers *HandlerRegistry
+ retryClassifier RetryClassifier
+ logger libLog.Logger
+ tracer trace.Tracer
+ cfg DispatcherConfig
+
+ listPendingFailureCounts map[string]int
+ failureCountsMu sync.Mutex
+ tenantMetricKeys map[string]struct{}
+ tenantMetricMu sync.Mutex
+
+ stop chan struct{}
+ stopOnce sync.Once
+ runStateMu sync.Mutex
+ running bool
+ cancelFunc context.CancelFunc
+ dispatchWg sync.WaitGroup
+ tenantTurn int
+
+ metrics dispatcherMetrics
+}
+
+var _ libCommons.App = (*Dispatcher)(nil)
+
+// DispatchResult captures one dispatch cycle outcome.
+type DispatchResult struct {
+ Processed int
+ Published int
+ Failed int
+ StateUpdateFailed int
+}
+
+// NewDispatcher creates a generic outbox dispatcher.
+func NewDispatcher(
+ repo OutboxRepository,
+ handlers *HandlerRegistry,
+ logger libLog.Logger,
+ tracer trace.Tracer,
+ opts ...DispatcherOption,
+) (*Dispatcher, error) {
+ if nilcheck.Interface(repo) {
+ return nil, ErrOutboxRepositoryRequired
+ }
+
+ if handlers == nil {
+ return nil, ErrHandlerRegistryRequired
+ }
+
+ if nilcheck.Interface(tracer) {
+ tracer = noop.NewTracerProvider().Tracer("commons.noop")
+ }
+
+ if nilcheck.Interface(logger) {
+ logger = libLog.NewNop()
+ }
+
+ dispatcher := &Dispatcher{
+ repo: repo,
+ handlers: handlers,
+ logger: logger,
+ tracer: tracer,
+ cfg: DefaultDispatcherConfig(),
+ listPendingFailureCounts: make(map[string]int),
+ tenantMetricKeys: make(map[string]struct{}),
+ stop: make(chan struct{}),
+ }
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(dispatcher)
+ }
+ }
+
+ dispatcher.cfg.normalize()
+ dispatcher.ensureFailureCounterFallback()
+
+ if dispatcher.cfg.IncludeTenantMetrics {
+ dispatcher.logger.Log(
+ context.Background(),
+ libLog.LevelWarn,
+ fmt.Sprintf(
+ "outbox tenant metric attributes enabled; cardinality capped at %d with overflow label %q",
+ dispatcher.cfg.MaxTenantMetricDimensions,
+ overflowTenantMetricLabel,
+ ),
+ )
+ }
+
+ metrics, err := newDispatcherMetrics(dispatcher.cfg.MeterProvider)
+ if err != nil {
+ return nil, fmt.Errorf("init outbox metrics: %w", err)
+ }
+
+ dispatcher.metrics = metrics
+
+ return dispatcher, nil
+}
+
+// Run starts the dispatcher loop until Stop is called.
+func (dispatcher *Dispatcher) Run(launcher *libCommons.Launcher) error {
+ return dispatcher.RunContext(context.Background(), launcher)
+}
+
+// RunContext starts the dispatcher loop until Stop is called or ctx is cancelled.
+func (dispatcher *Dispatcher) RunContext(parentCtx context.Context, launcher *libCommons.Launcher) error {
+ if dispatcher == nil {
+ return ErrOutboxDispatcherRequired
+ }
+
+ if dispatcher.repo == nil || dispatcher.handlers == nil {
+ return ErrOutboxDispatcherRequired
+ }
+
+ if parentCtx == nil {
+ parentCtx = context.Background()
+ }
+
+ ctx, cancel := context.WithCancel(parentCtx)
+ if !dispatcher.registerRun(cancel) {
+ cancel()
+
+ return ErrOutboxDispatcherRunning
+ }
+
+ defer dispatcher.clearRun()
+
+ if launcher != nil && launcher.Logger != nil {
+ launcher.Logger.Log(context.Background(), libLog.LevelInfo, "outbox dispatcher started")
+ defer launcher.Logger.Log(context.Background(), libLog.LevelInfo, "outbox dispatcher stopped")
+ }
+
+ defer runtime.RecoverAndLogWithContext(
+ ctx,
+ dispatcher.logger,
+ "outbox",
+ "dispatcher_run",
+ )
+
+ ticker := time.NewTicker(dispatcher.cfg.DispatchInterval)
+ defer ticker.Stop()
+
+ func() {
+ dispatcher.dispatchWg.Add(1)
+ defer dispatcher.dispatchWg.Done()
+
+ initCtx, span := dispatcher.tracer.Start(ctx, "outbox.dispatcher.initial_dispatch")
+ defer span.End()
+ defer runtime.RecoverAndLogWithContext(initCtx, dispatcher.logger, "outbox", "dispatcher_initial")
+
+ dispatcher.dispatchAcrossTenants(initCtx)
+ }()
+
+ for {
+ select {
+ case <-dispatcher.stop:
+ return nil
+ case <-ctx.Done():
+ return nil
+ case <-ticker.C:
+ select {
+ case <-dispatcher.stop:
+ return nil
+ case <-ctx.Done():
+ return nil
+ default:
+ }
+
+ func() {
+ dispatcher.dispatchWg.Add(1)
+ defer dispatcher.dispatchWg.Done()
+
+ tickCtx, span := dispatcher.tracer.Start(ctx, "outbox.dispatcher.dispatch_once")
+ defer span.End()
+ defer runtime.RecoverAndLogWithContext(tickCtx, dispatcher.logger, "outbox", "dispatcher_tick")
+
+ dispatcher.dispatchAcrossTenants(tickCtx)
+ }()
+ }
+ }
+}
+
+// Stop signals the dispatcher loop to stop.
+func (dispatcher *Dispatcher) Stop() {
+ if dispatcher == nil {
+ return
+ }
+
+ dispatcher.stopOnce.Do(func() {
+ dispatcher.runStateMu.Lock()
+ cancel := dispatcher.cancelFunc
+
+ stop := dispatcher.stop
+ if stop == nil {
+ stop = make(chan struct{})
+ dispatcher.stop = stop
+ }
+ dispatcher.runStateMu.Unlock()
+
+ if cancel != nil {
+ cancel()
+ }
+
+ close(stop)
+ })
+}
+
+// Shutdown waits for in-flight dispatch cycle completion.
+func (dispatcher *Dispatcher) Shutdown(ctx context.Context) error {
+ if dispatcher == nil {
+ return nil
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ dispatcher.Stop()
+
+ done := make(chan struct{})
+
+ runtime.SafeGo(dispatcher.logger, "outbox.dispatcher_shutdown_wait", runtime.KeepRunning, func() {
+ dispatcher.dispatchWg.Wait()
+ close(done)
+ })
+
+ select {
+ case <-done:
+ return nil
+ case <-ctx.Done():
+ return fmt.Errorf("dispatcher shutdown: %w", ctx.Err())
+ }
+}
+
+// DispatchOnce processes one tenant-scoped dispatch cycle.
+func (dispatcher *Dispatcher) DispatchOnce(ctx context.Context) int {
+ return dispatcher.DispatchOnceResult(ctx).Processed
+}
+
+// DispatchOnceResult processes one tenant-scoped dispatch cycle and returns counters.
+func (dispatcher *Dispatcher) DispatchOnceResult(ctx context.Context) DispatchResult {
+ if dispatcher == nil {
+ return DispatchResult{}
+ }
+
+ if dispatcher.repo == nil || dispatcher.handlers == nil {
+ return DispatchResult{}
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ logger := dispatcher.logger
+ if nilcheck.Interface(logger) {
+ logger = libLog.NewNop()
+ }
+
+ tracer := dispatcher.tracer
+ if nilcheck.Interface(tracer) {
+ tracer = noop.NewTracerProvider().Tracer("commons.noop")
+ }
+
+ start := time.Now().UTC()
+
+ ctx, span := tracer.Start(ctx, "outbox.dispatch")
+ defer span.End()
+
+ events := dispatcher.collectEvents(ctx, span)
+ processed := 0
+ published := 0
+ failed := 0
+ stateUpdateFailed := 0
+
+ tenantKey := tenantKeyFromContext(ctx)
+ dispatcher.recordQueueDepth(ctx, tenantKey, int64(len(events)))
+
+ // Delivery semantics are at-least-once: publish happens before MarkPublished.
+ // If state persistence fails after publish, consumers must remain idempotent.
+ for _, event := range events {
+ if ctx.Err() != nil {
+ break
+ }
+
+ if event == nil {
+ continue
+ }
+
+ processed++
+
+ if err := dispatcher.publishEventWithRetry(ctx, event); err != nil {
+ dispatcher.handlePublishError(ctx, logger, event, err)
+
+ failed++
+
+ continue
+ }
+
+ published++
+
+ if err := dispatcher.repo.MarkPublished(ctx, event.ID, time.Now().UTC()); err != nil {
+ logger.Log(
+ ctx,
+ libLog.LevelError,
+ "outbox event published to broker but failed to persist PUBLISHED state; event may be retried",
+ libLog.String("event_id", event.ID.String()),
+ libLog.String("error", sanitizeErrorForStorage(err)),
+ )
+ dispatcher.addStateUpdateFailure(ctx, tenantKey, 1)
+
+ stateUpdateFailed++
+
+ continue
+ }
+ }
+
+ dispatcher.addDispatchedEvents(ctx, tenantKey, int64(published))
+ dispatcher.addFailedEvents(ctx, tenantKey, int64(failed))
+ dispatcher.recordDispatchLatency(ctx, tenantKey, time.Since(start).Seconds())
+
+ return DispatchResult{
+ Processed: processed,
+ Published: published,
+ Failed: failed,
+ StateUpdateFailed: stateUpdateFailed,
+ }
+}
+
+func (dispatcher *Dispatcher) tenantMetricAttribute(tenantKey string) (attribute.KeyValue, bool) {
+ if !dispatcher.cfg.IncludeTenantMetrics {
+ return attribute.KeyValue{}, false
+ }
+
+ boundedTenant := dispatcher.boundedTenantMetricKey(tenantKey)
+
+ return attribute.String("tenant", boundedTenant), true
+}
+
+func (dispatcher *Dispatcher) boundedTenantMetricKey(tenantKey string) string {
+ if tenantKey == "" {
+ tenantKey = defaultTenantFailureCounterFallback
+ }
+
+ dispatcher.tenantMetricMu.Lock()
+ defer dispatcher.tenantMetricMu.Unlock()
+
+ if dispatcher.tenantMetricKeys == nil {
+ dispatcher.tenantMetricKeys = make(map[string]struct{})
+ }
+
+ if _, exists := dispatcher.tenantMetricKeys[tenantKey]; exists {
+ return tenantKey
+ }
+
+ if len(dispatcher.tenantMetricKeys) < dispatcher.cfg.MaxTenantMetricDimensions {
+ dispatcher.tenantMetricKeys[tenantKey] = struct{}{}
+
+ return tenantKey
+ }
+
+ return overflowTenantMetricLabel
+}
+
+func (dispatcher *Dispatcher) recordQueueDepth(ctx context.Context, tenantKey string, depth int64) {
+ if dispatcher.metrics.queueDepth == nil {
+ return
+ }
+
+ dispatcher.metrics.queueDepth.Record(ctx, depth, dispatcher.tenantRecordOptions(tenantKey)...)
+}
+
+func (dispatcher *Dispatcher) addDispatchedEvents(ctx context.Context, tenantKey string, count int64) {
+ if dispatcher.metrics.eventsDispatched == nil || count <= 0 {
+ return
+ }
+
+ dispatcher.metrics.eventsDispatched.Add(ctx, count, dispatcher.tenantAddOptions(tenantKey)...)
+}
+
+func (dispatcher *Dispatcher) addFailedEvents(ctx context.Context, tenantKey string, count int64) {
+ if dispatcher.metrics.eventsFailed == nil || count <= 0 {
+ return
+ }
+
+ dispatcher.metrics.eventsFailed.Add(ctx, count, dispatcher.tenantAddOptions(tenantKey)...)
+}
+
+func (dispatcher *Dispatcher) addStateUpdateFailure(ctx context.Context, tenantKey string, count int64) {
+ if dispatcher.metrics.eventsStateFailed == nil || count <= 0 {
+ return
+ }
+
+ dispatcher.metrics.eventsStateFailed.Add(ctx, count, dispatcher.tenantAddOptions(tenantKey)...)
+}
+
+func (dispatcher *Dispatcher) recordDispatchLatency(ctx context.Context, tenantKey string, latencySeconds float64) {
+ if dispatcher.metrics.dispatchLatency == nil {
+ return
+ }
+
+ dispatcher.metrics.dispatchLatency.Record(ctx, latencySeconds, dispatcher.tenantRecordOptions(tenantKey)...)
+}
+
+// dispatchAcrossTenants intentionally keeps tenant dispatch sequential for per-cycle
+// predictability, but rotates the starting tenant between cycles to reduce unfairness
+// when a single tenant is consistently slow.
+func (dispatcher *Dispatcher) dispatchAcrossTenants(ctx context.Context) {
+ if ctx.Err() != nil {
+ return
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ if nilcheck.Interface(logger) {
+ logger = dispatcher.logger
+ }
+
+ if nilcheck.Interface(tracer) {
+ tracer = dispatcher.tracer
+ }
+
+ if nilcheck.Interface(tracer) {
+ tracer = noop.NewTracerProvider().Tracer("commons.noop")
+ }
+
+ ctx, span := tracer.Start(ctx, "outbox.dispatcher.tenants")
+ defer span.End()
+
+ tenants, err := dispatcher.repo.ListTenants(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to list tenants", err)
+ libLog.SafeError(logger, ctx, "failed to list tenants", err, false)
+
+ return
+ }
+
+ orderedTenants := dispatcher.tenantDispatchOrder(nonEmptyTenants(tenants))
+ if len(orderedTenants) == 0 {
+ dispatcher.dispatchWithoutDiscoveredTenant(ctx, tracer)
+
+ return
+ }
+
+ for _, tenantID := range orderedTenants {
+ if ctx.Err() != nil {
+ break
+ }
+
+ tenantCtx := ContextWithTenantID(ctx, tenantID)
+ tenantCtx, tenantSpan := tracer.Start(tenantCtx, "outbox.dispatcher.tenant")
+ result := dispatcher.DispatchOnceResult(tenantCtx)
+ // Keep tenant trace correlation without exposing raw tenant identifiers.
+ tenantSpan.SetAttributes(
+ attribute.String("tenant.id_hash", hashTenantID(tenantID)),
+ attribute.Int("outbox.dispatch.processed", result.Processed),
+ attribute.Int("outbox.dispatch.published", result.Published),
+ attribute.Int("outbox.dispatch.failed", result.Failed),
+ attribute.Int("outbox.dispatch.state_update_failed", result.StateUpdateFailed),
+ )
+
+ tenantSpan.End()
+ }
+}
+
+func (dispatcher *Dispatcher) dispatchWithoutDiscoveredTenant(ctx context.Context, tracer trace.Tracer) {
+ tenantID, ok := TenantIDFromContext(ctx)
+ if ok && tenantID != "" {
+ dispatcher.DispatchOnceResult(ctx)
+
+ return
+ }
+
+ requiresTenant := true
+ if reporter, ok := dispatcher.repo.(tenantRequirementReporter); ok {
+ requiresTenant = reporter.RequiresTenant()
+ }
+
+ if requiresTenant {
+ dispatcher.logger.Log(
+ ctx,
+ libLog.LevelWarn,
+ "outbox tenant discovery returned no tenants; skipping dispatch because repository requires tenant context",
+ )
+
+ return
+ }
+
+ fallbackCtx, fallbackSpan := tracer.Start(ctx, "outbox.dispatcher.default_scope")
+ result := dispatcher.DispatchOnceResult(fallbackCtx)
+ fallbackSpan.SetAttributes(
+ attribute.Int("outbox.dispatch.processed", result.Processed),
+ attribute.Int("outbox.dispatch.published", result.Published),
+ attribute.Int("outbox.dispatch.failed", result.Failed),
+ attribute.Int("outbox.dispatch.state_update_failed", result.StateUpdateFailed),
+ )
+ fallbackSpan.End()
+}
+
+func nonEmptyTenants(tenants []string) []string {
+ if len(tenants) == 0 {
+ return nil
+ }
+
+ result := make([]string, 0, len(tenants))
+ for _, tenantID := range tenants {
+ tenantID = strings.TrimSpace(tenantID)
+
+ if tenantID == "" {
+ continue
+ }
+
+ result = append(result, tenantID)
+ }
+
+ return result
+}
+
+func (dispatcher *Dispatcher) tenantAddOptions(tenantKey string) []metric.AddOption {
+ if attr, ok := dispatcher.tenantMetricAttribute(tenantKey); ok {
+ return []metric.AddOption{metric.WithAttributes(attr)}
+ }
+
+ return nil
+}
+
+func (dispatcher *Dispatcher) tenantRecordOptions(tenantKey string) []metric.RecordOption {
+ if attr, ok := dispatcher.tenantMetricAttribute(tenantKey); ok {
+ return []metric.RecordOption{metric.WithAttributes(attr)}
+ }
+
+ return nil
+}
+
+func (dispatcher *Dispatcher) registerRun(cancel context.CancelFunc) bool {
+ dispatcher.runStateMu.Lock()
+ defer dispatcher.runStateMu.Unlock()
+
+ if dispatcher.running {
+ return false
+ }
+
+ if dispatcher.stop == nil || isClosedSignal(dispatcher.stop) {
+ dispatcher.stop = make(chan struct{})
+ dispatcher.stopOnce = sync.Once{}
+ }
+
+ dispatcher.running = true
+ dispatcher.cancelFunc = cancel
+
+ return true
+}
+
+func (dispatcher *Dispatcher) clearRun() {
+ dispatcher.runStateMu.Lock()
+ defer dispatcher.runStateMu.Unlock()
+
+ dispatcher.running = false
+ dispatcher.cancelFunc = nil
+}
+
+func (dispatcher *Dispatcher) tenantDispatchOrder(tenants []string) []string {
+ if len(tenants) <= 1 {
+ return append([]string(nil), tenants...)
+ }
+
+ dispatcher.runStateMu.Lock()
+ start := dispatcher.tenantTurn % len(tenants)
+ dispatcher.tenantTurn = (dispatcher.tenantTurn + 1) % len(tenants)
+ dispatcher.runStateMu.Unlock()
+
+ ordered := make([]string, 0, len(tenants))
+ ordered = append(ordered, tenants[start:]...)
+ ordered = append(ordered, tenants[:start]...)
+
+ return ordered
+}
+
+// collectEvents gathers events for a single dispatch cycle using a priority-layered
+// strategy. Events are collected in this order:
+//
+// 1. Priority events: pending events matching PriorityEventTypes (up to PriorityBudget)
+// 2. Stuck events: PROCESSING events older than ProcessingTimeout (reclaimed for retry)
+// 3. Failed events: FAILED events older than RetryWindow with remaining attempts
+// 4. Pending events: remaining PENDING events ordered by created_at ASC
+//
+// Within each layer, ordering follows the respective SQL query (typically ASC by
+// created_at or updated_at). The total batch is bounded by BatchSize. Duplicate
+// events (e.g., a priority event also in the pending set) are removed.
+func (dispatcher *Dispatcher) collectEvents(ctx context.Context, span trace.Span) []*OutboxEvent {
+ logger := dispatcher.logger
+ failedBefore := time.Now().UTC().Add(-dispatcher.cfg.RetryWindow)
+ processingBefore := time.Now().UTC().Add(-dispatcher.cfg.ProcessingTimeout)
+
+ priorityBudget := min(dispatcher.cfg.PriorityBudget, dispatcher.cfg.BatchSize)
+ priorityEvents := dispatcher.collectPriorityEvents(ctx, span, priorityBudget)
+ collected := len(priorityEvents)
+
+ stuckLimit := dispatcher.cfg.BatchSize - collected
+ if stuckLimit <= 0 {
+ return deduplicateEvents(priorityEvents)
+ }
+
+ stuckEvents, err := dispatcher.repo.ResetStuckProcessing(
+ ctx,
+ stuckLimit,
+ processingBefore,
+ dispatcher.cfg.MaxDispatchAttempts,
+ )
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to reset stuck events", err)
+ libLog.SafeError(logger, ctx, "failed to reset stuck events", err, false)
+ }
+
+ collected += len(stuckEvents)
+
+ failedLimit := min(dispatcher.cfg.BatchSize-collected, dispatcher.cfg.MaxFailedPerBatch)
+ if failedLimit <= 0 {
+ return deduplicateEvents(append(priorityEvents, stuckEvents...))
+ }
+
+ failedEvents, err := dispatcher.repo.ResetForRetry(
+ ctx,
+ failedLimit,
+ failedBefore,
+ dispatcher.cfg.MaxDispatchAttempts,
+ )
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to reset failed events for retry", err)
+ libLog.SafeError(logger, ctx, "failed to reset failed events for retry", err, false)
+ }
+
+ collected += len(failedEvents)
+
+ remaining := dispatcher.cfg.BatchSize - collected
+ if remaining <= 0 {
+ return deduplicateEvents(append(append(priorityEvents, stuckEvents...), failedEvents...))
+ }
+
+ pendingEvents, err := dispatcher.repo.ListPending(ctx, remaining)
+ if err != nil {
+ tenantKey := tenantKeyFromContext(ctx)
+ dispatcher.handleListPendingError(ctx, span, tenantKey, err)
+
+ return deduplicateEvents(append(append(priorityEvents, stuckEvents...), failedEvents...))
+ }
+
+ tenantKey := tenantKeyFromContext(ctx)
+ dispatcher.clearListPendingFailureCount(tenantKey)
+
+ all := make([]*OutboxEvent, 0, collected+len(pendingEvents))
+ all = append(all, priorityEvents...)
+ all = append(all, stuckEvents...)
+ all = append(all, failedEvents...)
+ all = append(all, pendingEvents...)
+
+ return deduplicateEvents(all)
+}
+
+func deduplicateEvents(events []*OutboxEvent) []*OutboxEvent {
+ if len(events) == 0 {
+ return events
+ }
+
+ seen := make(map[uuid.UUID]bool, len(events))
+ result := make([]*OutboxEvent, 0, len(events))
+
+ for _, event := range events {
+ if event == nil {
+ continue
+ }
+
+ if seen[event.ID] {
+ continue
+ }
+
+ seen[event.ID] = true
+ result = append(result, event)
+ }
+
+ return result
+}
+
+func (dispatcher *Dispatcher) collectPriorityEvents(
+ ctx context.Context,
+ span trace.Span,
+ budget int,
+) []*OutboxEvent {
+ if budget <= 0 || len(dispatcher.cfg.PriorityEventTypes) == 0 {
+ return nil
+ }
+
+ var result []*OutboxEvent
+
+ for _, eventType := range dispatcher.cfg.PriorityEventTypes {
+ remaining := budget - len(result)
+ if remaining <= 0 {
+ break
+ }
+
+ events, err := dispatcher.repo.ListPendingByType(ctx, eventType, remaining)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to list priority events", err)
+ libLog.SafeError(dispatcher.logger, ctx, "failed to list priority events", err, false)
+
+ continue
+ }
+
+ result = append(result, events...)
+ }
+
+ return result
+}
+
+func tenantKeyFromContext(ctx context.Context) string {
+ tenantID, ok := TenantIDFromContext(ctx)
+ if ok && tenantID != "" {
+ return tenantID
+ }
+
+ return defaultTenantFailureCounterFallback
+}
+
+func hashTenantID(tenantID string) string {
+ if tenantID == "" {
+ return ""
+ }
+
+ sum := sha256.Sum256([]byte(tenantID))
+
+ return hex.EncodeToString(sum[:8])
+}
+
+func isClosedSignal(signal <-chan struct{}) bool {
+ if signal == nil {
+ return false
+ }
+
+ select {
+ case <-signal:
+ return true
+ default:
+ return false
+ }
+}
+
+func (dispatcher *Dispatcher) ensureFailureCounterFallback() {
+ dispatcher.failureCountsMu.Lock()
+ defer dispatcher.failureCountsMu.Unlock()
+
+ if dispatcher.listPendingFailureCounts == nil {
+ dispatcher.listPendingFailureCounts = make(map[string]int)
+ }
+
+ dispatcher.listPendingFailureCounts[defaultTenantFailureCounterFallback] = 0
+}
+
+func (dispatcher *Dispatcher) handleListPendingError(ctx context.Context, span trace.Span, tenantKey string, err error) {
+ logger := dispatcher.logger
+
+ libOpentelemetry.HandleSpanError(span, "failed to list outbox events", err)
+ libLog.SafeError(logger, ctx, "failed to list outbox events", err, false)
+
+ counterTenantKey := tenantKey
+
+ dispatcher.failureCountsMu.Lock()
+
+ maxTracked := dispatcher.cfg.MaxTrackedListPendingFailureTenants
+ if maxTracked <= 1 {
+ counterTenantKey = defaultTenantFailureCounterFallback
+ } else if _, exists := dispatcher.listPendingFailureCounts[counterTenantKey]; !exists &&
+ len(dispatcher.listPendingFailureCounts) >= maxTracked {
+ counterTenantKey = defaultTenantFailureCounterFallback
+ }
+
+ dispatcher.listPendingFailureCounts[counterTenantKey]++
+ count := dispatcher.listPendingFailureCounts[counterTenantKey]
+ dispatcher.failureCountsMu.Unlock()
+
+ if count >= dispatcher.cfg.ListPendingFailureThreshold {
+ fields := []libLog.Field{libLog.Int("count", count)}
+ if counterTenantKey == "" || counterTenantKey == defaultTenantFailureCounterFallback {
+ fields = append(fields, libLog.String("tenant_bucket", defaultTenantFailureCounterFallback))
+ } else {
+ fields = append(fields, libLog.String("tenant_hash", hashTenantID(counterTenantKey)))
+ }
+
+ logger.Log(ctx, libLog.LevelError, "outbox list pending failures exceeded threshold", fields...)
+ }
+}
+
+func (dispatcher *Dispatcher) clearListPendingFailureCount(tenantKey string) {
+ dispatcher.failureCountsMu.Lock()
+ defer dispatcher.failureCountsMu.Unlock()
+
+ if tenantKey == "" || tenantKey == defaultTenantFailureCounterFallback {
+ dispatcher.listPendingFailureCounts[defaultTenantFailureCounterFallback] = 0
+ return
+ }
+
+ if _, exists := dispatcher.listPendingFailureCounts[tenantKey]; !exists {
+ // Untracked tenants are folded into fallback when cap is reached. Any
+ // successful list for such tenants should also clear fallback failures.
+ dispatcher.listPendingFailureCounts[defaultTenantFailureCounterFallback] = 0
+ return
+ }
+
+ delete(dispatcher.listPendingFailureCounts, tenantKey)
+}
+
+func (dispatcher *Dispatcher) publishEventWithRetry(ctx context.Context, event *OutboxEvent) error {
+ maxAttempts := dispatcher.cfg.PublishMaxAttempts
+ if maxAttempts <= 0 {
+ maxAttempts = defaultPublishMaxAttempts
+ }
+
+ publishBackoff := dispatcher.cfg.PublishBackoff
+ if publishBackoff <= 0 {
+ publishBackoff = defaultPublishBackoff
+ }
+
+ var lastErr error
+
+ for attempt := range maxAttempts {
+ err := dispatcher.publishEvent(ctx, event)
+ if err == nil {
+ return nil
+ }
+
+ lastErr = fmt.Errorf("publish attempt %d/%d failed: %w", attempt+1, maxAttempts, err)
+ if dispatcher.isNonRetryableError(err) || attempt == maxAttempts-1 {
+ break
+ }
+
+ delay := backoff.ExponentialWithJitter(publishBackoff, attempt)
+ if waitErr := backoff.WaitContext(ctx, delay); waitErr != nil {
+ lastErr = fmt.Errorf("publish retry wait interrupted: %w", waitErr)
+ break
+ }
+ }
+
+ return lastErr
+}
+
+func (dispatcher *Dispatcher) publishEvent(ctx context.Context, event *OutboxEvent) error {
+ if event == nil {
+ return ErrOutboxEventRequired
+ }
+
+ if len(event.Payload) == 0 {
+ return ErrOutboxEventPayloadRequired
+ }
+
+ return dispatcher.handlers.Handle(ctx, event)
+}
+
+func (dispatcher *Dispatcher) handlePublishError(
+ ctx context.Context,
+ logger libLog.Logger,
+ event *OutboxEvent,
+ err error,
+) {
+ if dispatcher.isNonRetryableError(err) {
+ if markErr := dispatcher.repo.MarkInvalid(ctx, event.ID, sanitizeErrorForStorage(err)); markErr != nil {
+ logger.Log(ctx, libLog.LevelError, "failed to mark outbox invalid", libLog.String("error", sanitizeErrorForStorage(markErr)))
+ }
+
+ return
+ }
+
+ if markErr := dispatcher.repo.MarkFailed(ctx, event.ID, sanitizeErrorForStorage(err), dispatcher.cfg.MaxDispatchAttempts); markErr != nil {
+ logger.Log(ctx, libLog.LevelError, "failed to mark outbox failed", libLog.String("error", sanitizeErrorForStorage(markErr)))
+ }
+}
+
+func (dispatcher *Dispatcher) isNonRetryableError(err error) bool {
+ if err == nil || nilcheck.Interface(dispatcher.retryClassifier) {
+ return false
+ }
+
+ return dispatcher.retryClassifier.IsNonRetryable(err)
+}
diff --git a/commons/outbox/dispatcher_test.go b/commons/outbox/dispatcher_test.go
new file mode 100644
index 00000000..20532bfa
--- /dev/null
+++ b/commons/outbox/dispatcher_test.go
@@ -0,0 +1,1156 @@
+//go:build unit
+
+package outbox
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/trace/noop"
+)
+
+type fakeRepo struct {
+ mu sync.Mutex
+ pending []*OutboxEvent
+ pendingByTenant map[string][]*OutboxEvent
+ pendingByType map[string][]*OutboxEvent
+ stuck []*OutboxEvent
+ failedForRetry []*OutboxEvent
+ markedPub []uuid.UUID
+ markPublishedCalls []uuid.UUID
+ markedFail []uuid.UUID
+ markedInv []uuid.UUID
+ tenants []string
+ tenantsErr error
+ listPendingErr error
+ listPendingTypeErr error
+ resetStuckErr error
+ resetForRetryErr error
+ markPublishedErr error
+ markFailedErr error
+ markInvalidErr error
+ listPendingBlocked <-chan struct{}
+ blockIgnoresCtx bool
+ listPendingCalls int32
+ listPendingTenants []string
+}
+
+type tenantAwareFakeRepo struct {
+ *fakeRepo
+ requiresTenant bool
+}
+
+func (repo *tenantAwareFakeRepo) RequiresTenant() bool {
+ if repo == nil {
+ return true
+ }
+
+ return repo.requiresTenant
+}
+
+func (repo *fakeRepo) Create(context.Context, *OutboxEvent) (*OutboxEvent, error) {
+ return nil, nil
+}
+
+func (repo *fakeRepo) CreateWithTx(context.Context, Tx, *OutboxEvent) (*OutboxEvent, error) {
+ return nil, nil
+}
+
+func (repo *fakeRepo) ListPending(ctx context.Context, _ int) ([]*OutboxEvent, error) {
+ atomic.AddInt32(&repo.listPendingCalls, 1)
+
+ if repo.listPendingBlocked != nil {
+ if repo.blockIgnoresCtx {
+ <-repo.listPendingBlocked
+ } else {
+ select {
+ case <-repo.listPendingBlocked:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ }
+ }
+
+ if repo.listPendingErr != nil {
+ return nil, repo.listPendingErr
+ }
+
+ if repo.pendingByTenant != nil {
+ tenantID, ok := TenantIDFromContext(ctx)
+ if ok {
+ repo.mu.Lock()
+ repo.listPendingTenants = append(repo.listPendingTenants, tenantID)
+ repo.mu.Unlock()
+
+ if tenantPending, exists := repo.pendingByTenant[tenantID]; exists {
+ return tenantPending, nil
+ }
+ }
+ }
+
+ return repo.pending, nil
+}
+
+func (repo *fakeRepo) ListPendingByType(_ context.Context, eventType string, _ int) ([]*OutboxEvent, error) {
+ if repo.listPendingTypeErr != nil {
+ return nil, repo.listPendingTypeErr
+ }
+
+ if repo.pendingByType != nil {
+ if events, exists := repo.pendingByType[eventType]; exists {
+ return events, nil
+ }
+
+ return nil, nil
+ }
+
+ result := make([]*OutboxEvent, 0)
+ for _, event := range repo.pending {
+ if event != nil && event.EventType == eventType {
+ result = append(result, event)
+ }
+ }
+
+ return result, nil
+}
+
+func (repo *fakeRepo) ListTenants(context.Context) ([]string, error) {
+ if repo.tenantsErr != nil {
+ return nil, repo.tenantsErr
+ }
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+
+ return append([]string(nil), repo.tenants...), nil
+}
+
+func (repo *fakeRepo) listPendingCallCount() int {
+ return int(atomic.LoadInt32(&repo.listPendingCalls))
+}
+
+func (repo *fakeRepo) GetByID(context.Context, uuid.UUID) (*OutboxEvent, error) { return nil, nil }
+
+func (repo *fakeRepo) MarkPublished(_ context.Context, id uuid.UUID, _ time.Time) error {
+ repo.mu.Lock()
+ repo.markPublishedCalls = append(repo.markPublishedCalls, id)
+ repo.mu.Unlock()
+
+ if repo.markPublishedErr != nil {
+ return repo.markPublishedErr
+ }
+
+ repo.mu.Lock()
+ repo.markedPub = append(repo.markedPub, id)
+ repo.mu.Unlock()
+
+ return nil
+}
+
+func (repo *fakeRepo) MarkFailed(_ context.Context, id uuid.UUID, _ string, _ int) error {
+ if repo.markFailedErr != nil {
+ return repo.markFailedErr
+ }
+
+ repo.mu.Lock()
+ repo.markedFail = append(repo.markedFail, id)
+ repo.mu.Unlock()
+
+ return nil
+}
+
+func (repo *fakeRepo) ListFailedForRetry(context.Context, int, time.Time, int) ([]*OutboxEvent, error) {
+ return nil, nil
+}
+
+func (repo *fakeRepo) ResetForRetry(context.Context, int, time.Time, int) ([]*OutboxEvent, error) {
+ if repo.resetForRetryErr != nil {
+ return nil, repo.resetForRetryErr
+ }
+
+ return repo.failedForRetry, nil
+}
+
+func (repo *fakeRepo) ResetStuckProcessing(context.Context, int, time.Time, int) ([]*OutboxEvent, error) {
+ if repo.resetStuckErr != nil {
+ return nil, repo.resetStuckErr
+ }
+
+ return repo.stuck, nil
+}
+
+func (repo *fakeRepo) MarkInvalid(_ context.Context, id uuid.UUID, _ string) error {
+ if repo.markInvalidErr != nil {
+ return repo.markInvalidErr
+ }
+
+ repo.mu.Lock()
+ repo.markedInv = append(repo.markedInv, id)
+ repo.mu.Unlock()
+
+ return nil
+}
+
+func (repo *fakeRepo) listPendingTenantOrder() []string {
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+
+ return append([]string(nil), repo.listPendingTenants...)
+}
+
+func TestDispatcher_DispatchOncePublishes(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+
+ eventID := uuid.New()
+ repo.pending = []*OutboxEvent{{ID: eventID, EventType: "payment.created", Payload: []byte("ok")}}
+
+ handled := false
+ require.NoError(t, handlers.Register("payment.created", func(_ context.Context, event *OutboxEvent) error {
+ handled = true
+ require.Equal(t, eventID, event.ID)
+
+ return nil
+ }))
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithPublishMaxAttempts(1),
+ )
+ require.NoError(t, err)
+
+ processed := dispatcher.DispatchOnce(context.Background())
+ require.Equal(t, 1, processed)
+ require.True(t, handled)
+ require.Len(t, repo.markedPub, 1)
+ require.Equal(t, eventID, repo.markedPub[0])
+}
+
+func TestDispatcher_DispatchOnceMarksInvalidOnNonRetryable(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+
+ eventID := uuid.New()
+ repo.pending = []*OutboxEvent{{ID: eventID, EventType: "payment.created", Payload: []byte("ok")}}
+
+ nonRetryable := errors.New("non-retryable")
+ require.NoError(t, handlers.Register("payment.created", func(context.Context, *OutboxEvent) error {
+ return nonRetryable
+ }))
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithPublishMaxAttempts(1),
+ WithRetryClassifier(RetryClassifierFunc(func(err error) bool {
+ return errors.Is(err, nonRetryable)
+ })),
+ )
+ require.NoError(t, err)
+
+ _ = dispatcher.DispatchOnce(context.Background())
+ require.Len(t, repo.markedInv, 1)
+ require.Equal(t, eventID, repo.markedInv[0])
+ require.Empty(t, repo.markedFail)
+}
+
+func TestDeduplicateEvents_FiltersNilAndDuplicates(t *testing.T) {
+ t.Parallel()
+
+ idA := uuid.New()
+ idB := uuid.New()
+
+ events := []*OutboxEvent{
+ nil,
+ {ID: idA},
+ {ID: idA},
+ nil,
+ {ID: idB},
+ }
+
+ result := deduplicateEvents(events)
+ require.Len(t, result, 2)
+ require.Equal(t, idA, result[0].ID)
+ require.Equal(t, idB, result[1].ID)
+}
+
+func TestDispatcher_DispatchOnceStopsOnContextCancellation(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ firstID := uuid.New()
+ secondID := uuid.New()
+ repo.pending = []*OutboxEvent{
+ {ID: firstID, EventType: "payment.created", Payload: []byte("1")},
+ {ID: secondID, EventType: "payment.created", Payload: []byte("2")},
+ }
+
+ handlers := NewHandlerRegistry()
+ ctx, cancel := context.WithCancel(context.Background())
+ handled := make([]uuid.UUID, 0, 2)
+
+ require.NoError(t, handlers.Register("payment.created", func(_ context.Context, event *OutboxEvent) error {
+ handled = append(handled, event.ID)
+ if event.ID == firstID {
+ cancel()
+ }
+
+ return nil
+ }))
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"), WithPublishMaxAttempts(1))
+ require.NoError(t, err)
+
+ processed := dispatcher.DispatchOnce(ctx)
+ require.Equal(t, 1, processed)
+ require.Equal(t, []uuid.UUID{firstID}, handled)
+ require.Equal(t, []uuid.UUID{firstID}, repo.markedPub)
+}
+
+func TestDispatcher_DispatchOnceMarkPublishedErrorDoesNotMarkFailedOrInvalid(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{markPublishedErr: errors.New("db write failed")}
+ eventID := uuid.New()
+ repo.pending = []*OutboxEvent{{ID: eventID, EventType: "payment.created", Payload: []byte("ok")}}
+
+ handlers := NewHandlerRegistry()
+ require.NoError(t, handlers.Register("payment.created", func(context.Context, *OutboxEvent) error {
+ return nil
+ }))
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"), WithPublishMaxAttempts(1))
+ require.NoError(t, err)
+
+ result := dispatcher.DispatchOnceResult(context.Background())
+ require.Equal(t, 1, result.Processed)
+ require.Equal(t, 1, result.Published)
+ require.Equal(t, 1, result.StateUpdateFailed)
+ require.Equal(t, 0, result.Failed)
+ require.Equal(t, []uuid.UUID{eventID}, repo.markPublishedCalls)
+ require.Empty(t, repo.markedPub)
+ require.Empty(t, repo.markedFail)
+ require.Empty(t, repo.markedInv)
+}
+
+func TestDispatcher_DispatchOnceRetryableErrorMarksFailed(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ eventID := uuid.New()
+ repo.pending = []*OutboxEvent{{ID: eventID, EventType: "payment.created", Payload: []byte("ok")}}
+
+ handlers := NewHandlerRegistry()
+ retryErr := errors.New("temporary broker outage")
+ require.NoError(t, handlers.Register("payment.created", func(context.Context, *OutboxEvent) error {
+ return retryErr
+ }))
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithPublishMaxAttempts(1),
+ )
+ require.NoError(t, err)
+
+ processed := dispatcher.DispatchOnce(context.Background())
+ require.Equal(t, 1, processed)
+ require.Equal(t, []uuid.UUID{eventID}, repo.markedFail)
+ require.Empty(t, repo.markedInv)
+ require.Empty(t, repo.markedPub)
+}
+
+func TestDispatcher_PublishEventWithRetry_SucceedsAfterTransientError(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+ event := &OutboxEvent{ID: uuid.New(), EventType: "payment.created", Payload: []byte("ok")}
+
+ attempts := 0
+ require.NoError(t, handlers.Register("payment.created", func(context.Context, *OutboxEvent) error {
+ attempts++
+ if attempts == 1 {
+ return errors.New("temporary failure")
+ }
+
+ return nil
+ }))
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithPublishMaxAttempts(3),
+ WithPublishBackoff(time.Millisecond),
+ )
+ require.NoError(t, err)
+
+ err = dispatcher.publishEventWithRetry(context.Background(), event)
+ require.NoError(t, err)
+ require.Equal(t, 2, attempts)
+}
+
+func TestDispatcher_PublishEventWithRetry_StopsOnNonRetryableError(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+ event := &OutboxEvent{ID: uuid.New(), EventType: "payment.created", Payload: []byte("ok")}
+
+ nonRetryable := errors.New("validation failed")
+ attempts := 0
+ require.NoError(t, handlers.Register("payment.created", func(context.Context, *OutboxEvent) error {
+ attempts++
+
+ return nonRetryable
+ }))
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithPublishMaxAttempts(5),
+ WithPublishBackoff(time.Millisecond),
+ WithRetryClassifier(RetryClassifierFunc(func(err error) bool {
+ return errors.Is(err, nonRetryable)
+ })),
+ )
+ require.NoError(t, err)
+
+ err = dispatcher.publishEventWithRetry(context.Background(), event)
+ require.Error(t, err)
+ require.Equal(t, 1, attempts)
+}
+
+func TestDispatcher_PublishEventWithRetry_StopsWhenContextCancelled(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+ event := &OutboxEvent{ID: uuid.New(), EventType: "payment.created", Payload: []byte("ok")}
+
+ require.NoError(t, handlers.Register("payment.created", func(context.Context, *OutboxEvent) error {
+ return errors.New("temporary failure")
+ }))
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithPublishMaxAttempts(5),
+ WithPublishBackoff(50*time.Millisecond),
+ )
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
+ defer cancel()
+
+ err = dispatcher.publishEventWithRetry(ctx, event)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "publish retry wait interrupted")
+}
+
+func TestNewDispatcher_ValidationErrors(t *testing.T) {
+ t.Parallel()
+
+ handlers := NewHandlerRegistry()
+
+ dispatcher, err := NewDispatcher(nil, handlers, nil, noop.NewTracerProvider().Tracer("test"))
+ require.Nil(t, dispatcher)
+ require.ErrorIs(t, err, ErrOutboxRepositoryRequired)
+
+ repo := &fakeRepo{}
+ dispatcher, err = NewDispatcher(repo, nil, nil, noop.NewTracerProvider().Tracer("test"))
+ require.Nil(t, dispatcher)
+ require.ErrorIs(t, err, ErrHandlerRegistryRequired)
+}
+
+func TestDeduplicateEvents_EmptyInput(t *testing.T) {
+ t.Parallel()
+
+ result := deduplicateEvents(nil)
+ require.Nil(t, result)
+}
+
+func TestDispatcher_DispatchOnceNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var dispatcher *Dispatcher
+
+ require.Equal(t, 0, dispatcher.DispatchOnce(context.Background()))
+}
+
+func TestDispatcher_DispatchOnceResultNilContext(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"))
+ require.NoError(t, err)
+
+ result := dispatcher.DispatchOnceResult(nil)
+ require.Equal(t, 0, result.Processed)
+ require.Equal(t, 0, result.Published)
+ require.Equal(t, 0, result.Failed)
+ require.Equal(t, 0, result.StateUpdateFailed)
+}
+
+func TestDispatcher_DispatchOnceResult_ZeroValueIsSafe(t *testing.T) {
+ t.Parallel()
+
+ dispatcher := &Dispatcher{}
+
+ require.NotPanics(t, func() {
+ result := dispatcher.DispatchOnceResult(context.Background())
+ require.Equal(t, DispatchResult{}, result)
+ })
+}
+
+func TestDispatcher_RunStopShutdownLifecycle(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{tenants: []string{"tenant-1"}}
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithDispatchInterval(5*time.Millisecond),
+ )
+ require.NoError(t, err)
+
+ runDone := make(chan error, 1)
+ go func() {
+ runDone <- dispatcher.Run(nil)
+ }()
+
+ require.Eventually(t, func() bool {
+ return repo.listPendingCallCount() > 0
+ }, time.Second, time.Millisecond)
+
+ require.NoError(t, dispatcher.Shutdown(context.Background()))
+
+ select {
+ case err := <-runDone:
+ require.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("dispatcher run did not stop")
+ }
+}
+
+func TestDispatcher_RunContext_CanRestartAfterShutdown(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{tenants: []string{"tenant-1"}}
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithDispatchInterval(5*time.Millisecond),
+ )
+ require.NoError(t, err)
+
+ runOnce := func() {
+ initialCalls := repo.listPendingCallCount()
+
+ runDone := make(chan error, 1)
+ go func() {
+ runDone <- dispatcher.Run(nil)
+ }()
+
+ require.Eventually(t, func() bool {
+ return repo.listPendingCallCount() > initialCalls
+ }, time.Second, time.Millisecond)
+
+ require.NoError(t, dispatcher.Shutdown(context.Background()))
+ require.NoError(t, <-runDone)
+ }
+
+ runOnce()
+ runOnce()
+}
+
+func TestDispatcher_RunContextStopsWhenParentCancelled(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{tenants: []string{"tenant-1"}}
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithDispatchInterval(5*time.Millisecond),
+ )
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ runDone := make(chan error, 1)
+ go func() {
+ runDone <- dispatcher.RunContext(ctx, nil)
+ }()
+
+ require.Eventually(t, func() bool {
+ return repo.listPendingCallCount() > 0
+ }, time.Second, time.Millisecond)
+
+ cancel()
+
+ select {
+ case err := <-runDone:
+ require.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("dispatcher run did not stop after parent context cancellation")
+ }
+}
+
+func TestDispatcher_RunContextRejectsConcurrentRun(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{tenants: []string{"tenant-1"}}
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithDispatchInterval(5*time.Millisecond),
+ )
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ runDone := make(chan error, 1)
+ go func() {
+ runDone <- dispatcher.RunContext(ctx, nil)
+ }()
+
+ require.Eventually(t, func() bool {
+ return repo.listPendingCallCount() > 0
+ }, time.Second, time.Millisecond)
+
+ err = dispatcher.RunContext(context.Background(), nil)
+ require.ErrorIs(t, err, ErrOutboxDispatcherRunning)
+
+ cancel()
+ require.NoError(t, <-runDone)
+}
+
+func TestDispatcher_ShutdownTimeoutWhenDispatchBlocked(t *testing.T) {
+ t.Parallel()
+
+ block := make(chan struct{})
+ repo := &fakeRepo{
+ tenants: []string{"tenant-1"},
+ listPendingBlocked: block,
+ blockIgnoresCtx: true,
+ }
+
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithDispatchInterval(5*time.Millisecond),
+ )
+ require.NoError(t, err)
+
+ runDone := make(chan error, 1)
+ go func() {
+ runDone <- dispatcher.Run(nil)
+ }()
+
+ require.Eventually(t, func() bool {
+ return repo.listPendingCallCount() > 0
+ }, time.Second, time.Millisecond)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
+ defer cancel()
+
+ err = dispatcher.Shutdown(ctx)
+ require.ErrorIs(t, err, context.DeadlineExceeded)
+ require.ErrorContains(t, err, "dispatcher shutdown")
+
+ close(block)
+
+ select {
+ case runErr := <-runDone:
+ require.NoError(t, runErr)
+ case <-time.After(time.Second):
+ t.Fatal("dispatcher run did not exit after unblock")
+ }
+}
+
+func TestDispatcher_CollectEventsPipelinePrioritizesAndDeduplicates(t *testing.T) {
+ t.Parallel()
+
+ priorityID := uuid.New()
+ stuckID := uuid.New()
+ failedID := uuid.New()
+
+ repo := &fakeRepo{
+ pendingByType: map[string][]*OutboxEvent{
+ "priority.payment": {{ID: priorityID, EventType: "priority.payment", Payload: []byte("p")}},
+ },
+ stuck: []*OutboxEvent{
+ {ID: priorityID, EventType: "priority.payment", Payload: []byte("dup")},
+ {ID: stuckID, EventType: "stuck.payment", Payload: []byte("s")},
+ },
+ failedForRetry: []*OutboxEvent{{ID: failedID, EventType: "failed.payment", Payload: []byte("f")}},
+ pending: []*OutboxEvent{{ID: uuid.New(), EventType: "pending.payment", Payload: []byte("x")}},
+ }
+
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithBatchSize(4),
+ WithPriorityBudget(2),
+ WithMaxFailedPerBatch(2),
+ WithPriorityEventTypes("priority.payment"),
+ )
+ require.NoError(t, err)
+
+ ctx, span := dispatcher.tracer.Start(context.Background(), "test.collect_events")
+ defer span.End()
+
+ collected := dispatcher.collectEvents(ctx, span)
+ require.Len(t, collected, 3)
+ require.Equal(t, priorityID, collected[0].ID)
+ require.Equal(t, stuckID, collected[1].ID)
+ require.Equal(t, failedID, collected[2].ID)
+}
+
+func TestDispatcher_CollectEvents_ContinuesWhenResetStuckProcessingFails(t *testing.T) {
+ t.Parallel()
+
+ failedID := uuid.New()
+ pendingID := uuid.New()
+
+ repo := &fakeRepo{
+ resetStuckErr: errors.New("reset stuck failed"),
+ failedForRetry: []*OutboxEvent{{ID: failedID, EventType: "failed.payment", Payload: []byte("f")}},
+ pending: []*OutboxEvent{{ID: pendingID, EventType: "pending.payment", Payload: []byte("p")}},
+ }
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ NewHandlerRegistry(),
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithBatchSize(4),
+ WithMaxFailedPerBatch(2),
+ )
+ require.NoError(t, err)
+
+ ctx, span := dispatcher.tracer.Start(context.Background(), "test.collect_events_reset_stuck_error")
+ defer span.End()
+
+ collected := dispatcher.collectEvents(ctx, span)
+ require.Len(t, collected, 2)
+ require.Equal(t, failedID, collected[0].ID)
+ require.Equal(t, pendingID, collected[1].ID)
+}
+
+func TestDispatcher_CollectEvents_ContinuesWhenResetForRetryFails(t *testing.T) {
+ t.Parallel()
+
+ stuckID := uuid.New()
+ pendingID := uuid.New()
+
+ repo := &fakeRepo{
+ stuck: []*OutboxEvent{{ID: stuckID, EventType: "stuck.payment", Payload: []byte("s")}},
+ resetForRetryErr: errors.New("reset retry failed"),
+ pending: []*OutboxEvent{{ID: pendingID, EventType: "pending.payment", Payload: []byte("p")}},
+ }
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ NewHandlerRegistry(),
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithBatchSize(4),
+ WithMaxFailedPerBatch(2),
+ )
+ require.NoError(t, err)
+
+ ctx, span := dispatcher.tracer.Start(context.Background(), "test.collect_events_reset_retry_error")
+ defer span.End()
+
+ collected := dispatcher.collectEvents(ctx, span)
+ require.Len(t, collected, 2)
+ require.Equal(t, stuckID, collected[0].ID)
+ require.Equal(t, pendingID, collected[1].ID)
+}
+
+func TestDispatcher_CollectEvents_ContinuesWhenListPendingByTypeFails(t *testing.T) {
+ t.Parallel()
+
+ stuckID := uuid.New()
+ failedID := uuid.New()
+ pendingID := uuid.New()
+
+ repo := &fakeRepo{
+ listPendingTypeErr: errors.New("list pending by type failed"),
+ stuck: []*OutboxEvent{{ID: stuckID, EventType: "stuck.payment", Payload: []byte("s")}},
+ failedForRetry: []*OutboxEvent{{ID: failedID, EventType: "failed.payment", Payload: []byte("f")}},
+ pending: []*OutboxEvent{{ID: pendingID, EventType: "pending.payment", Payload: []byte("p")}},
+ }
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ NewHandlerRegistry(),
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithBatchSize(4),
+ WithPriorityBudget(2),
+ WithMaxFailedPerBatch(2),
+ WithPriorityEventTypes("priority.payment"),
+ )
+ require.NoError(t, err)
+
+ ctx, span := dispatcher.tracer.Start(context.Background(), "test.collect_events_priority_error")
+ defer span.End()
+
+ collected := dispatcher.collectEvents(ctx, span)
+ require.Len(t, collected, 3)
+ require.Equal(t, stuckID, collected[0].ID)
+ require.Equal(t, failedID, collected[1].ID)
+ require.Equal(t, pendingID, collected[2].ID)
+}
+
+func TestDispatcher_DispatchAcrossTenantsProcessesEachTenant(t *testing.T) {
+ t.Parallel()
+
+ tenantA := "tenant-a"
+ tenantB := "tenant-b"
+ eventA := uuid.New()
+ eventB := uuid.New()
+
+ repo := &fakeRepo{
+ tenants: []string{tenantA, tenantB},
+ pendingByTenant: map[string][]*OutboxEvent{
+ tenantA: {{ID: eventA, EventType: "payment.created", Payload: []byte("a")}},
+ tenantB: {{ID: eventB, EventType: "payment.created", Payload: []byte("b")}},
+ },
+ }
+
+ handlers := NewHandlerRegistry()
+ handledTenants := make(map[string]bool)
+ require.NoError(t, handlers.Register("payment.created", func(ctx context.Context, _ *OutboxEvent) error {
+ tenantID, ok := TenantIDFromContext(ctx)
+ require.True(t, ok)
+ handledTenants[tenantID] = true
+
+ return nil
+ }))
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"), WithPublishMaxAttempts(1))
+ require.NoError(t, err)
+
+ dispatcher.dispatchAcrossTenants(context.Background())
+
+ require.True(t, handledTenants[tenantA])
+ require.True(t, handledTenants[tenantB])
+ require.ElementsMatch(t, []uuid.UUID{eventA, eventB}, repo.markedPub)
+}
+
+func TestDispatcher_DispatchAcrossTenantsRoundRobinStartingTenant(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{
+ tenants: []string{"tenant-a", "tenant-b", "tenant-c"},
+ pendingByTenant: map[string][]*OutboxEvent{
+ "tenant-a": {},
+ "tenant-b": {},
+ "tenant-c": {},
+ },
+ }
+
+ dispatcher, err := NewDispatcher(repo, NewHandlerRegistry(), nil, noop.NewTracerProvider().Tracer("test"))
+ require.NoError(t, err)
+
+ dispatcher.dispatchAcrossTenants(context.Background())
+ dispatcher.dispatchAcrossTenants(context.Background())
+
+ order := repo.listPendingTenantOrder()
+ require.Len(t, order, 6)
+ require.Equal(t, "tenant-a", order[0])
+ require.Equal(t, "tenant-b", order[3])
+}
+
+func TestDispatcher_DispatchAcrossTenants_StopsAfterContextCancelBetweenTenants(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ t.Cleanup(cancel)
+
+ repo := &fakeRepo{
+ tenants: []string{"tenant-a", "tenant-b"},
+ pendingByTenant: map[string][]*OutboxEvent{
+ "tenant-a": {{ID: uuid.New(), EventType: "payment.created", Payload: []byte("a")}},
+ "tenant-b": {{ID: uuid.New(), EventType: "payment.created", Payload: []byte("b")}},
+ },
+ }
+
+ handlers := NewHandlerRegistry()
+ handledTenants := make(map[string]bool)
+ require.NoError(t, handlers.Register("payment.created", func(handlerCtx context.Context, _ *OutboxEvent) error {
+ tenantID, ok := TenantIDFromContext(handlerCtx)
+ require.True(t, ok)
+ handledTenants[tenantID] = true
+ cancel()
+
+ return nil
+ }))
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"), WithPublishMaxAttempts(1))
+ require.NoError(t, err)
+
+ dispatcher.dispatchAcrossTenants(ctx)
+
+ require.True(t, handledTenants["tenant-a"])
+ require.False(t, handledTenants["tenant-b"])
+}
+
+func TestDispatcher_DispatchAcrossTenantsEmptyList(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{tenants: []string{}}
+ dispatcher, err := NewDispatcher(repo, NewHandlerRegistry(), nil, noop.NewTracerProvider().Tracer("test"))
+ require.NoError(t, err)
+
+ dispatcher.dispatchAcrossTenants(context.Background())
+
+ require.Equal(t, 0, repo.listPendingCallCount())
+}
+
+func TestDispatcher_DispatchAcrossTenantsEmptyListFallsBackWhenTenantNotRequired(t *testing.T) {
+ t.Parallel()
+
+ baseRepo := &fakeRepo{
+ tenants: []string{},
+ pending: []*OutboxEvent{{ID: uuid.New(), EventType: "payment.created", Payload: []byte("ok")}},
+ }
+ repo := &tenantAwareFakeRepo{fakeRepo: baseRepo, requiresTenant: false}
+
+ handlers := NewHandlerRegistry()
+ require.NoError(t, handlers.Register("payment.created", func(context.Context, *OutboxEvent) error {
+ return nil
+ }))
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"), WithPublishMaxAttempts(1))
+ require.NoError(t, err)
+
+ dispatcher.dispatchAcrossTenants(context.Background())
+
+ require.Equal(t, 1, baseRepo.listPendingCallCount())
+ require.Len(t, baseRepo.markedPub, 1)
+}
+
+func TestDispatcher_DispatchAcrossTenantsEmptyListSkipsWhenTenantRequired(t *testing.T) {
+ t.Parallel()
+
+ baseRepo := &fakeRepo{
+ tenants: []string{},
+ pending: []*OutboxEvent{{ID: uuid.New(), EventType: "payment.created", Payload: []byte("ok")}},
+ }
+ repo := &tenantAwareFakeRepo{fakeRepo: baseRepo, requiresTenant: true}
+
+ dispatcher, err := NewDispatcher(repo, NewHandlerRegistry(), nil, noop.NewTracerProvider().Tracer("test"))
+ require.NoError(t, err)
+
+ dispatcher.dispatchAcrossTenants(context.Background())
+
+ require.Equal(t, 0, baseRepo.listPendingCallCount())
+ require.Empty(t, baseRepo.markedPub)
+}
+
+func TestDispatcher_HandleListPendingErrorCapsTrackedTenants(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithListPendingFailureThreshold(100),
+ WithMaxTrackedListPendingFailureTenants(2),
+ )
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ _, span := dispatcher.tracer.Start(ctx, "test.list_pending_error")
+
+ errFailure := errors.New("list pending failed")
+ dispatcher.handleListPendingError(ctx, span, "tenant-1", errFailure)
+ dispatcher.handleListPendingError(ctx, span, "tenant-2", errFailure)
+ dispatcher.handleListPendingError(ctx, span, "tenant-3", errFailure)
+
+ span.End()
+
+ require.Len(t, dispatcher.listPendingFailureCounts, 2)
+ require.Equal(t, 2, dispatcher.listPendingFailureCounts[defaultTenantFailureCounterFallback])
+}
+
+func TestDispatcher_BoundedTenantMetricKeyUsesOverflowLabel(t *testing.T) {
+ t.Parallel()
+
+ dispatcher := &Dispatcher{
+ cfg: DispatcherConfig{
+ IncludeTenantMetrics: true,
+ MaxTenantMetricDimensions: 2,
+ },
+ tenantMetricKeys: make(map[string]struct{}),
+ }
+
+ require.Equal(t, "tenant-1", dispatcher.boundedTenantMetricKey("tenant-1"))
+ require.Equal(t, "tenant-2", dispatcher.boundedTenantMetricKey("tenant-2"))
+ require.Equal(t, overflowTenantMetricLabel, dispatcher.boundedTenantMetricKey("tenant-3"))
+ require.Equal(t, 2, len(dispatcher.tenantMetricKeys))
+}
+
+func TestDispatcher_HandlePublishError_LogsMarkInvalidFailure(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{markInvalidErr: errors.New("mark invalid failed")}
+ handlers := NewHandlerRegistry()
+ nonRetryable := errors.New("non-retryable")
+
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithRetryClassifier(RetryClassifierFunc(func(err error) bool {
+ return errors.Is(err, nonRetryable)
+ })),
+ )
+ require.NoError(t, err)
+
+ dispatcher.handlePublishError(
+ context.Background(),
+ dispatcher.logger,
+ &OutboxEvent{ID: uuid.New()},
+ nonRetryable,
+ )
+
+ require.Empty(t, repo.markedInv)
+}
+
+func TestDispatcher_HandlePublishError_LogsMarkFailedFailure(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{markFailedErr: errors.New("mark failed failed")}
+ handlers := NewHandlerRegistry()
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"))
+ require.NoError(t, err)
+
+ dispatcher.handlePublishError(
+ context.Background(),
+ dispatcher.logger,
+ &OutboxEvent{ID: uuid.New()},
+ errors.New("retryable"),
+ )
+
+ require.Empty(t, repo.markedFail)
+}
+
+func TestDispatcher_DispatchOnce_EmptyPayloadMarksFailed(t *testing.T) {
+ t.Parallel()
+
+ eventID := uuid.New()
+ repo := &fakeRepo{pending: []*OutboxEvent{{ID: eventID, EventType: "payment.created", Payload: nil}}}
+ handlers := NewHandlerRegistry()
+
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"), WithPublishMaxAttempts(1))
+ require.NoError(t, err)
+
+ result := dispatcher.DispatchOnceResult(context.Background())
+
+ require.Equal(t, 1, result.Processed)
+ require.Equal(t, 1, result.Failed)
+ require.Equal(t, []uuid.UUID{eventID}, repo.markedFail)
+ require.Empty(t, repo.markedPub)
+}
+
+func TestDispatcher_DispatchAcrossTenants_ListTenantsErrorDoesNotDispatch(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{tenantsErr: errors.New("list tenants failed")}
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(repo, handlers, nil, noop.NewTracerProvider().Tracer("test"))
+ require.NoError(t, err)
+
+ dispatcher.dispatchAcrossTenants(context.Background())
+
+ require.Equal(t, 0, repo.listPendingCallCount())
+ require.Empty(t, repo.markedPub)
+}
+
+func TestNonEmptyTenants_TrimWhitespaceEntries(t *testing.T) {
+ t.Parallel()
+
+ tenants := nonEmptyTenants([]string{"tenant-a", " ", "\ttenant-b\n", "", "tenant-c"})
+ require.Equal(t, []string{"tenant-a", "tenant-b", "tenant-c"}, tenants)
+}
+
+func TestDispatcher_ClearListPendingFailureCount_ResetsFallbackForOverflowTenant(t *testing.T) {
+ t.Parallel()
+
+ repo := &fakeRepo{}
+ handlers := NewHandlerRegistry()
+ dispatcher, err := NewDispatcher(
+ repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ WithMaxTrackedListPendingFailureTenants(2),
+ WithListPendingFailureThreshold(100),
+ )
+ require.NoError(t, err)
+
+ ctx := context.Background()
+ _, span := dispatcher.tracer.Start(ctx, "test.overflow_reset")
+ errList := errors.New("list pending failed")
+
+ dispatcher.handleListPendingError(ctx, span, "tenant-1", errList)
+ dispatcher.handleListPendingError(ctx, span, "tenant-2", errList)
+ dispatcher.handleListPendingError(ctx, span, "tenant-3", errList)
+ require.Equal(t, 2, dispatcher.listPendingFailureCounts[defaultTenantFailureCounterFallback])
+
+ dispatcher.clearListPendingFailureCount("tenant-3")
+ span.End()
+
+ require.Equal(t, 0, dispatcher.listPendingFailureCounts[defaultTenantFailureCounterFallback])
+}
diff --git a/commons/outbox/doc.go b/commons/outbox/doc.go
new file mode 100644
index 00000000..7642f181
--- /dev/null
+++ b/commons/outbox/doc.go
@@ -0,0 +1,5 @@
+// Package outbox provides transactional outbox primitives.
+//
+// It includes an event model, repository contracts, a generic dispatcher with
+// retry controls, and PostgreSQL adapters under the postgres subpackage.
+package outbox
diff --git a/commons/outbox/errors.go b/commons/outbox/errors.go
new file mode 100644
index 00000000..05e095e7
--- /dev/null
+++ b/commons/outbox/errors.go
@@ -0,0 +1,21 @@
+package outbox
+
+import "errors"
+
+var (
+ ErrOutboxEventRequired = errors.New("outbox event is required")
+ ErrOutboxRepositoryRequired = errors.New("outbox repository is required")
+ ErrOutboxDispatcherRequired = errors.New("outbox dispatcher is required")
+ ErrOutboxDispatcherRunning = errors.New("outbox dispatcher is already running")
+ ErrOutboxEventPayloadRequired = errors.New("outbox event payload is required")
+ ErrOutboxEventPayloadTooLarge = errors.New("outbox event payload exceeds maximum allowed size")
+ ErrOutboxEventPayloadNotJSON = errors.New("outbox event payload must be valid JSON (stored as JSONB)")
+ ErrHandlerRegistryRequired = errors.New("handler registry is required")
+ ErrEventTypeRequired = errors.New("event type is required")
+ ErrEventHandlerRequired = errors.New("event handler is required")
+ ErrHandlerAlreadyRegistered = errors.New("event handler already registered")
+ ErrHandlerNotRegistered = errors.New("event handler is not registered")
+ ErrTenantIDRequired = errors.New("tenant id is required")
+ ErrOutboxStatusInvalid = errors.New("invalid outbox status")
+ ErrOutboxTransitionInvalid = errors.New("invalid outbox status transition")
+)
diff --git a/commons/outbox/event.go b/commons/outbox/event.go
new file mode 100644
index 00000000..b8354412
--- /dev/null
+++ b/commons/outbox/event.go
@@ -0,0 +1,95 @@
+package outbox
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/google/uuid"
+)
+
+const (
+ OutboxStatusPending = "PENDING"
+ OutboxStatusProcessing = "PROCESSING"
+ OutboxStatusPublished = "PUBLISHED"
+ OutboxStatusFailed = "FAILED"
+ OutboxStatusInvalid = "INVALID"
+ DefaultMaxPayloadBytes = 1 << 20
+)
+
+// OutboxEvent is an event stored in the outbox for reliable delivery.
+type OutboxEvent struct {
+ ID uuid.UUID
+ EventType string
+ AggregateID uuid.UUID
+ Payload []byte
+ Status string
+ Attempts int
+ PublishedAt *time.Time
+ LastError string
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+// NewOutboxEvent creates a valid outbox event initialized as pending.
+func NewOutboxEvent(
+ ctx context.Context,
+ eventType string,
+ aggregateID uuid.UUID,
+ payload []byte,
+) (*OutboxEvent, error) {
+ return NewOutboxEventWithID(ctx, uuid.New(), eventType, aggregateID, payload)
+}
+
+// NewOutboxEventWithID creates a valid outbox event initialized as pending using a caller-provided ID.
+func NewOutboxEventWithID(
+ ctx context.Context,
+ eventID uuid.UUID,
+ eventType string,
+ aggregateID uuid.UUID,
+ payload []byte,
+) (*OutboxEvent, error) {
+ asserter := assert.New(ctx, nil, "outbox", "outbox.new_event")
+
+ if err := asserter.That(ctx, eventID != uuid.Nil, "event id is required"); err != nil {
+ return nil, fmt.Errorf("outbox event id: %w", err)
+ }
+
+ eventType = strings.TrimSpace(eventType)
+
+ if err := asserter.NotEmpty(ctx, eventType, "event type is required"); err != nil {
+ return nil, fmt.Errorf("outbox event type: %w", err)
+ }
+
+ if err := asserter.That(ctx, aggregateID != uuid.Nil, "aggregate id is required"); err != nil {
+ return nil, fmt.Errorf("outbox event aggregate id: %w", err)
+ }
+
+ if err := asserter.That(ctx, len(payload) > 0, "payload is required"); err != nil {
+ return nil, fmt.Errorf("outbox event payload: %w", err)
+ }
+
+ if err := asserter.That(ctx, len(payload) <= DefaultMaxPayloadBytes, "payload exceeds max size"); err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrOutboxEventPayloadTooLarge, err)
+ }
+
+ if err := asserter.That(ctx, json.Valid(payload), "payload must be valid JSON"); err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrOutboxEventPayloadNotJSON, err)
+ }
+
+ now := time.Now().UTC()
+
+ return &OutboxEvent{
+ ID: eventID,
+ EventType: eventType,
+ AggregateID: aggregateID,
+ Payload: payload,
+ Status: OutboxStatusPending,
+ Attempts: 0,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }, nil
+}
diff --git a/commons/outbox/event_test.go b/commons/outbox/event_test.go
new file mode 100644
index 00000000..8c0323b8
--- /dev/null
+++ b/commons/outbox/event_test.go
@@ -0,0 +1,88 @@
+//go:build unit
+
+package outbox
+
+import (
+ "context"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewOutboxEvent(t *testing.T) {
+ t.Parallel()
+
+ aggregateID := uuid.New()
+ payload := []byte(`{"key":"value"}`)
+
+ event, err := NewOutboxEvent(context.Background(), "event.type", aggregateID, payload)
+ require.NoError(t, err)
+ require.NotNil(t, event)
+ require.Equal(t, "event.type", event.EventType)
+ require.Equal(t, aggregateID, event.AggregateID)
+ require.Equal(t, payload, event.Payload)
+ require.Equal(t, OutboxStatusPending, event.Status)
+ require.Equal(t, 0, event.Attempts)
+ require.NotEqual(t, uuid.Nil, event.ID)
+ require.False(t, event.CreatedAt.IsZero())
+ require.False(t, event.UpdatedAt.IsZero())
+ require.Equal(t, event.CreatedAt, event.UpdatedAt)
+}
+
+func TestNewOutboxEventValidation(t *testing.T) {
+ t.Parallel()
+
+ event, err := NewOutboxEvent(context.Background(), "", uuid.New(), []byte(`{"k":"v"}`))
+ require.Error(t, err)
+ require.Nil(t, event)
+ require.Contains(t, err.Error(), "event type")
+
+ event, err = NewOutboxEvent(context.Background(), "type", uuid.Nil, []byte(`{"k":"v"}`))
+ require.Error(t, err)
+ require.Nil(t, event)
+ require.Contains(t, err.Error(), "aggregate id")
+
+ event, err = NewOutboxEvent(context.Background(), "type", uuid.New(), nil)
+ require.Error(t, err)
+ require.Nil(t, event)
+ require.Contains(t, err.Error(), "payload")
+
+ oversizedPayload := make([]byte, DefaultMaxPayloadBytes+1)
+ event, err = NewOutboxEvent(context.Background(), "type", uuid.New(), oversizedPayload)
+ require.Error(t, err)
+ require.Nil(t, event)
+ require.ErrorIs(t, err, ErrOutboxEventPayloadTooLarge)
+
+ event, err = NewOutboxEvent(context.Background(), "type", uuid.New(), []byte("not-json"))
+ require.Error(t, err)
+ require.Nil(t, event)
+ require.ErrorIs(t, err, ErrOutboxEventPayloadNotJSON)
+
+ event, err = NewOutboxEvent(context.Background(), " ", uuid.New(), []byte(`{"k":"v"}`))
+ require.Error(t, err)
+ require.Nil(t, event)
+ require.Contains(t, err.Error(), "event type")
+}
+
+func TestNewOutboxEventWithID(t *testing.T) {
+ t.Parallel()
+
+ eventID := uuid.New()
+ aggregateID := uuid.New()
+
+ event, err := NewOutboxEventWithID(context.Background(), eventID, "event.type", aggregateID, []byte(`{"key":"value"}`))
+ require.NoError(t, err)
+ require.NotNil(t, event)
+ require.Equal(t, eventID, event.ID)
+ require.Equal(t, OutboxStatusPending, event.Status)
+}
+
+func TestNewOutboxEventWithIDValidation(t *testing.T) {
+ t.Parallel()
+
+ event, err := NewOutboxEventWithID(context.Background(), uuid.Nil, "event.type", uuid.New(), []byte(`{"key":"value"}`))
+ require.Error(t, err)
+ require.Nil(t, event)
+ require.Contains(t, err.Error(), "event id")
+}
diff --git a/commons/outbox/handler.go b/commons/outbox/handler.go
new file mode 100644
index 00000000..547c8c7c
--- /dev/null
+++ b/commons/outbox/handler.go
@@ -0,0 +1,76 @@
+package outbox
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+)
+
+// EventHandler handles one outbox event.
+type EventHandler func(ctx context.Context, event *OutboxEvent) error
+
+// HandlerRegistry stores event handlers by event type.
+type HandlerRegistry struct {
+ mu sync.RWMutex
+ handlers map[string]EventHandler
+}
+
+func NewHandlerRegistry() *HandlerRegistry {
+ return &HandlerRegistry{handlers: map[string]EventHandler{}}
+}
+
+func (registry *HandlerRegistry) Register(eventType string, handler EventHandler) error {
+ if registry == nil {
+ return ErrHandlerRegistryRequired
+ }
+
+ normalizedType := strings.TrimSpace(eventType)
+ if normalizedType == "" {
+ return ErrEventTypeRequired
+ }
+
+ if handler == nil {
+ return ErrEventHandlerRequired
+ }
+
+ registry.mu.Lock()
+ defer registry.mu.Unlock()
+
+ if registry.handlers == nil {
+ registry.handlers = make(map[string]EventHandler)
+ }
+
+ if _, exists := registry.handlers[normalizedType]; exists {
+ return fmt.Errorf("%w: %s", ErrHandlerAlreadyRegistered, normalizedType)
+ }
+
+ registry.handlers[normalizedType] = handler
+
+ return nil
+}
+
+func (registry *HandlerRegistry) Handle(ctx context.Context, event *OutboxEvent) error {
+ if registry == nil {
+ return ErrHandlerRegistryRequired
+ }
+
+ if event == nil {
+ return ErrOutboxEventRequired
+ }
+
+ eventType := strings.TrimSpace(event.EventType)
+ if eventType == "" {
+ return ErrEventTypeRequired
+ }
+
+ registry.mu.RLock()
+ handler, ok := registry.handlers[eventType]
+ registry.mu.RUnlock()
+
+ if !ok {
+ return fmt.Errorf("%w: %s", ErrHandlerNotRegistered, eventType)
+ }
+
+ return handler(ctx, event)
+}
diff --git a/commons/outbox/handler_test.go b/commons/outbox/handler_test.go
new file mode 100644
index 00000000..6797c865
--- /dev/null
+++ b/commons/outbox/handler_test.go
@@ -0,0 +1,124 @@
+//go:build unit
+
+package outbox
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+func TestHandlerRegistry_RegisterAndHandle(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+ handled := false
+
+ err := registry.Register("payment.created", func(_ context.Context, event *OutboxEvent) error {
+ handled = true
+ require.Equal(t, "payment.created", event.EventType)
+ return nil
+ })
+ require.NoError(t, err)
+
+ event := &OutboxEvent{ID: uuid.New(), EventType: "payment.created", Payload: []byte(`{"ok":true}`)}
+ err = registry.Handle(context.Background(), event)
+ require.NoError(t, err)
+ require.True(t, handled)
+}
+
+func TestHandlerRegistry_RegisterDuplicate(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+ require.NoError(t, registry.Register("same", func(_ context.Context, _ *OutboxEvent) error { return nil }))
+
+ err := registry.Register("same", func(_ context.Context, _ *OutboxEvent) error { return nil })
+ require.ErrorIs(t, err, ErrHandlerAlreadyRegistered)
+}
+
+func TestHandlerRegistry_RegisterNormalizesEventType(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+ require.NoError(t, registry.Register(" payment.created ", func(_ context.Context, _ *OutboxEvent) error { return nil }))
+
+ err := registry.Handle(context.Background(), &OutboxEvent{ID: uuid.New(), EventType: "payment.created", Payload: []byte(`{"x":1}`)})
+ require.NoError(t, err)
+}
+
+func TestHandlerRegistry_HandleMissing(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+ err := registry.Handle(context.Background(), &OutboxEvent{ID: uuid.New(), EventType: "missing", Payload: []byte(`{"x":1}`)})
+ require.ErrorIs(t, err, ErrHandlerNotRegistered)
+}
+
+func TestHandlerRegistry_HandlePropagatesHandlerError(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+ handlerErr := errors.New("publish to broker failed")
+ require.NoError(t, registry.Register("payment.created", func(_ context.Context, _ *OutboxEvent) error {
+ return handlerErr
+ }))
+
+ event := &OutboxEvent{ID: uuid.New(), EventType: "payment.created", Payload: []byte(`{"ok":true}`)}
+ err := registry.Handle(context.Background(), event)
+ require.ErrorIs(t, err, handlerErr)
+}
+
+func TestRetryClassifierFunc_IsNonRetryable(t *testing.T) {
+ t.Parallel()
+
+ classifier := RetryClassifierFunc(func(err error) bool {
+ return errors.Is(err, ErrHandlerNotRegistered)
+ })
+
+ require.True(t, classifier.IsNonRetryable(ErrHandlerNotRegistered))
+ require.False(t, classifier.IsNonRetryable(errors.New("other")))
+}
+
+func TestHandlerRegistry_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var registry *HandlerRegistry
+
+ err := registry.Register("event", func(context.Context, *OutboxEvent) error { return nil })
+ require.ErrorIs(t, err, ErrHandlerRegistryRequired)
+
+ err = registry.Handle(context.Background(), &OutboxEvent{ID: uuid.New(), EventType: "event", Payload: []byte(`{"ok":true}`)})
+ require.ErrorIs(t, err, ErrHandlerRegistryRequired)
+}
+
+func TestHandlerRegistry_RegisterValidation(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+
+ err := registry.Register("", func(context.Context, *OutboxEvent) error { return nil })
+ require.ErrorIs(t, err, ErrEventTypeRequired)
+
+ err = registry.Register("payment.created", nil)
+ require.ErrorIs(t, err, ErrEventHandlerRequired)
+}
+
+func TestHandlerRegistry_HandleNilEvent(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+ err := registry.Handle(context.Background(), nil)
+ require.ErrorIs(t, err, ErrOutboxEventRequired)
+}
+
+func TestHandlerRegistry_HandleRejectsBlankEventType(t *testing.T) {
+ t.Parallel()
+
+ registry := NewHandlerRegistry()
+ err := registry.Handle(context.Background(), &OutboxEvent{ID: uuid.New(), EventType: " ", Payload: []byte(`{"ok":true}`)})
+ require.ErrorIs(t, err, ErrEventTypeRequired)
+}
diff --git a/commons/outbox/metrics.go b/commons/outbox/metrics.go
new file mode 100644
index 00000000..775420ab
--- /dev/null
+++ b/commons/outbox/metrics.go
@@ -0,0 +1,76 @@
+package outbox
+
+import (
+ "fmt"
+
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/metric"
+)
+
+type dispatcherMetrics struct {
+ eventsDispatched metric.Int64Counter
+ eventsFailed metric.Int64Counter
+ eventsStateFailed metric.Int64Counter
+ dispatchLatency metric.Float64Histogram
+ queueDepth metric.Int64Gauge
+}
+
+func newDispatcherMetrics(provider metric.MeterProvider) (dispatcherMetrics, error) {
+ if provider == nil {
+ provider = otel.GetMeterProvider()
+ }
+
+ meter := provider.Meter("commons.outbox.dispatcher")
+
+ var (
+ metrics dispatcherMetrics
+ err error
+ )
+
+ metrics.eventsDispatched, err = meter.Int64Counter(
+ "outbox.events.dispatched",
+ metric.WithDescription("Number of outbox events successfully published"),
+ metric.WithUnit("{event}"),
+ )
+ if err != nil {
+ return dispatcherMetrics{}, fmt.Errorf("create outbox.events.dispatched counter: %w", err)
+ }
+
+ metrics.eventsFailed, err = meter.Int64Counter(
+ "outbox.events.failed",
+ metric.WithDescription("Number of outbox events that failed to publish"),
+ metric.WithUnit("{event}"),
+ )
+ if err != nil {
+ return dispatcherMetrics{}, fmt.Errorf("create outbox.events.failed counter: %w", err)
+ }
+
+ metrics.eventsStateFailed, err = meter.Int64Counter(
+ "outbox.events.state_update_failed",
+ metric.WithDescription("Number of outbox events published but not persisted as published"),
+ metric.WithUnit("{event}"),
+ )
+ if err != nil {
+ return dispatcherMetrics{}, fmt.Errorf("create outbox.events.state_update_failed counter: %w", err)
+ }
+
+ metrics.dispatchLatency, err = meter.Float64Histogram(
+ "outbox.dispatch.latency",
+ metric.WithDescription("Time taken per dispatch cycle"),
+ metric.WithUnit("s"),
+ )
+ if err != nil {
+ return dispatcherMetrics{}, fmt.Errorf("create outbox.dispatch.latency histogram: %w", err)
+ }
+
+ metrics.queueDepth, err = meter.Int64Gauge(
+ "outbox.queue.depth",
+ metric.WithDescription("Number of outbox events selected in a dispatch cycle (pending and reclaimed)"),
+ metric.WithUnit("{event}"),
+ )
+ if err != nil {
+ return dispatcherMetrics{}, fmt.Errorf("create outbox.queue.depth gauge: %w", err)
+ }
+
+ return metrics, nil
+}
diff --git a/commons/outbox/metrics_test.go b/commons/outbox/metrics_test.go
new file mode 100644
index 00000000..d0c39edc
--- /dev/null
+++ b/commons/outbox/metrics_test.go
@@ -0,0 +1,98 @@
+//go:build unit
+
+package outbox
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/metric"
+ "go.opentelemetry.io/otel/metric/noop"
+)
+
+type testMeterProvider struct {
+ metric.MeterProvider
+ meter metric.Meter
+}
+
+func (provider testMeterProvider) Meter(_ string, _ ...metric.MeterOption) metric.Meter {
+ return provider.meter
+}
+
+type failingMeter struct {
+ metric.Meter
+ failOnName string
+ failErr error
+}
+
+func (meter failingMeter) Int64Counter(name string, options ...metric.Int64CounterOption) (metric.Int64Counter, error) {
+ if name == meter.failOnName {
+ return nil, meter.failErr
+ }
+
+ return meter.Meter.Int64Counter(name, options...)
+}
+
+func (meter failingMeter) Float64Histogram(name string, options ...metric.Float64HistogramOption) (metric.Float64Histogram, error) {
+ if name == meter.failOnName {
+ return nil, meter.failErr
+ }
+
+ return meter.Meter.Float64Histogram(name, options...)
+}
+
+func (meter failingMeter) Int64Gauge(name string, options ...metric.Int64GaugeOption) (metric.Int64Gauge, error) {
+ if name == meter.failOnName {
+ return nil, meter.failErr
+ }
+
+ return meter.Meter.Int64Gauge(name, options...)
+}
+
+func TestNewDispatcherMetrics_DefaultProvider(t *testing.T) {
+ t.Parallel()
+
+ metrics, err := newDispatcherMetrics(nil)
+ require.NoError(t, err)
+ require.NotNil(t, metrics.eventsDispatched)
+ require.NotNil(t, metrics.eventsFailed)
+ require.NotNil(t, metrics.eventsStateFailed)
+ require.NotNil(t, metrics.dispatchLatency)
+ require.NotNil(t, metrics.queueDepth)
+}
+
+func TestNewDispatcherMetrics_ErrorPaths(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ instrument string
+ errText string
+ }{
+ {name: "eventsDispatched counter", instrument: "outbox.events.dispatched", errText: "create outbox.events.dispatched counter"},
+ {name: "eventsFailed counter", instrument: "outbox.events.failed", errText: "create outbox.events.failed counter"},
+ {name: "eventsStateFailed counter", instrument: "outbox.events.state_update_failed", errText: "create outbox.events.state_update_failed counter"},
+ {name: "dispatchLatency histogram", instrument: "outbox.dispatch.latency", errText: "create outbox.dispatch.latency histogram"},
+ {name: "queueDepth gauge", instrument: "outbox.queue.depth", errText: "create outbox.queue.depth gauge"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ provider := testMeterProvider{
+ MeterProvider: noop.NewMeterProvider(),
+ meter: failingMeter{
+ Meter: noop.NewMeterProvider().Meter("test"),
+ failOnName: tt.instrument,
+ failErr: errors.New("instrument creation failed"),
+ },
+ }
+
+ _, err := newDispatcherMetrics(provider)
+ require.Error(t, err)
+ require.ErrorContains(t, err, tt.errText)
+ })
+ }
+}
diff --git a/commons/outbox/postgres/column_resolver.go b/commons/outbox/postgres/column_resolver.go
new file mode 100644
index 00000000..7e9eec3b
--- /dev/null
+++ b/commons/outbox/postgres/column_resolver.go
@@ -0,0 +1,240 @@
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/outbox"
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "golang.org/x/sync/singleflight"
+)
+
+// ColumnResolver supports column-per-tenant strategy.
+//
+// ApplyTenant is a no-op because tenant scoping is handled by SQL WHERE clauses
+// in Repository when tenantColumn is configured.
+type ColumnResolver struct {
+ client *libPostgres.Client
+ tableName string
+ tenantColumn string
+ tenantTTL time.Duration
+ cacheMu sync.RWMutex
+ cache []string
+ cacheSet bool
+ cacheUntil time.Time
+ sfGroup singleflight.Group
+}
+
+const defaultTenantDiscoveryTTL = 10 * time.Second
+
+// defaultTenantDiscoveryTimeout caps how long a singleflight tenant-discovery
+// query may run. Because context.WithoutCancel strips any parent deadline, an
+// explicit timeout prevents unbounded queries from blocking all coalesced callers.
+const defaultTenantDiscoveryTimeout = 5 * time.Second
+
+type ColumnResolverOption func(*ColumnResolver)
+
+func WithColumnResolverTableName(tableName string) ColumnResolverOption {
+ return func(resolver *ColumnResolver) {
+ resolver.tableName = tableName
+ }
+}
+
+func WithColumnResolverTenantColumn(tenantColumn string) ColumnResolverOption {
+ return func(resolver *ColumnResolver) {
+ resolver.tenantColumn = tenantColumn
+ }
+}
+
+func WithColumnResolverTenantDiscoveryTTL(ttl time.Duration) ColumnResolverOption {
+ return func(resolver *ColumnResolver) {
+ if ttl > 0 {
+ resolver.tenantTTL = ttl
+ }
+ }
+}
+
+func NewColumnResolver(client *libPostgres.Client, opts ...ColumnResolverOption) (*ColumnResolver, error) {
+ if client == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ resolver := &ColumnResolver{
+ client: client,
+ tableName: "outbox_events",
+ tenantColumn: "tenant_id",
+ tenantTTL: defaultTenantDiscoveryTTL,
+ }
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(resolver)
+ }
+ }
+
+ resolver.tableName = strings.TrimSpace(resolver.tableName)
+ resolver.tenantColumn = strings.TrimSpace(resolver.tenantColumn)
+
+ if resolver.tableName == "" {
+ resolver.tableName = "outbox_events"
+ }
+
+ if resolver.tenantColumn == "" {
+ resolver.tenantColumn = "tenant_id"
+ }
+
+ if err := validateIdentifierPath(resolver.tableName); err != nil {
+ return nil, fmt.Errorf("table name: %w", err)
+ }
+
+ if err := validateIdentifier(resolver.tenantColumn); err != nil {
+ return nil, fmt.Errorf("tenant column: %w", err)
+ }
+
+ return resolver, nil
+}
+
+func (resolver *ColumnResolver) ApplyTenant(_ context.Context, _ *sql.Tx, _ string) error {
+ return nil
+}
+
+func (resolver *ColumnResolver) DiscoverTenants(ctx context.Context) ([]string, error) {
+ if resolver == nil || resolver.client == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ if cached, ok := resolver.cachedTenants(time.Now().UTC()); ok {
+ return cached, nil
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ // Coalesce concurrent cache-miss queries via singleflight to prevent
+ // thundering herd on TTL expiry when multiple dispatchers poll tenants.
+ result, err, _ := resolver.sfGroup.Do("discover", func() (any, error) {
+ // Double-check cache inside singleflight — another caller may have
+ // already refreshed it while we were waiting for the flight leader.
+ if cached, ok := resolver.cachedTenants(time.Now().UTC()); ok {
+ return cached, nil
+ }
+
+ // Use a context that inherits values but not cancellation,
+ // so first caller's timeout doesn't cascade to coalesced callers.
+ // Apply an explicit timeout to prevent unbounded queries when the
+ // parent context's deadline was stripped by WithoutCancel.
+ sfCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), defaultTenantDiscoveryTimeout)
+ defer cancel()
+
+ return resolver.queryTenants(sfCtx)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ tenants, ok := result.([]string)
+ if !ok {
+ return nil, fmt.Errorf("unexpected type from singleflight: got %T, expected []string", result)
+ }
+
+ return tenants, nil
+}
+
+func (resolver *ColumnResolver) queryTenants(ctx context.Context) ([]string, error) {
+ db, err := resolver.primaryDB(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ table := quoteIdentifierPath(resolver.tableName)
+ column := quoteIdentifier(resolver.tenantColumn)
+
+ query := "SELECT DISTINCT " + column + " FROM " + table + // #nosec G202 -- table/column names validated at construction via validateIdentifier/validateIdentifierPath; quote functions escape identifiers
+ " WHERE status IN ($1, $2, $3) AND " + column + " IS NOT NULL ORDER BY " + column
+
+ rows, err := db.QueryContext(
+ ctx,
+ query,
+ outbox.OutboxStatusPending,
+ outbox.OutboxStatusFailed,
+ outbox.OutboxStatusProcessing,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("querying distinct tenant ids: %w", err)
+ }
+ defer rows.Close()
+
+ tenants := make([]string, 0)
+
+ for rows.Next() {
+ var tenant string
+ if scanErr := rows.Scan(&tenant); scanErr != nil {
+ return nil, fmt.Errorf("scanning tenant id: %w", scanErr)
+ }
+
+ tenant = strings.TrimSpace(tenant)
+
+ if tenant != "" {
+ tenants = append(tenants, tenant)
+ }
+ }
+
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterating tenant ids: %w", err)
+ }
+
+ resolver.storeCachedTenants(tenants, time.Now().UTC())
+
+ return tenants, nil
+}
+
+// RequiresTenant returns true because column-per-tenant strategy always requires
+// a tenant ID to scope queries via WHERE clauses.
+func (resolver *ColumnResolver) RequiresTenant() bool {
+ return true
+}
+
+func (resolver *ColumnResolver) TenantColumn() string {
+ if resolver == nil {
+ return ""
+ }
+
+ return resolver.tenantColumn
+}
+
+func (resolver *ColumnResolver) primaryDB(ctx context.Context) (*sql.DB, error) {
+ return resolvePrimaryDB(ctx, resolver.client)
+}
+
+func (resolver *ColumnResolver) cachedTenants(now time.Time) ([]string, bool) {
+ if resolver.tenantTTL <= 0 {
+ return nil, false
+ }
+
+ resolver.cacheMu.RLock()
+ defer resolver.cacheMu.RUnlock()
+
+ if !resolver.cacheSet || !now.Before(resolver.cacheUntil) {
+ return nil, false
+ }
+
+ return append([]string(nil), resolver.cache...), true
+}
+
+func (resolver *ColumnResolver) storeCachedTenants(tenants []string, now time.Time) {
+ if resolver.tenantTTL <= 0 {
+ return
+ }
+
+ resolver.cacheMu.Lock()
+ defer resolver.cacheMu.Unlock()
+
+ resolver.cache = append([]string(nil), tenants...)
+ resolver.cacheSet = true
+ resolver.cacheUntil = now.Add(resolver.tenantTTL)
+}
diff --git a/commons/outbox/postgres/column_resolver_test.go b/commons/outbox/postgres/column_resolver_test.go
new file mode 100644
index 00000000..422168ff
--- /dev/null
+++ b/commons/outbox/postgres/column_resolver_test.go
@@ -0,0 +1,91 @@
+//go:build unit
+
+package postgres
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewColumnResolver_NilClient(t *testing.T) {
+ t.Parallel()
+
+ resolver, err := NewColumnResolver(nil)
+ require.Nil(t, resolver)
+ require.ErrorIs(t, err, ErrConnectionRequired)
+}
+
+func TestNewColumnResolver_ValidatesIdentifiers(t *testing.T) {
+ t.Parallel()
+
+ client := &libPostgres.Client{}
+
+ _, err := NewColumnResolver(client, WithColumnResolverTableName(`public.outbox";drop`))
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrInvalidIdentifier)
+
+ _, err = NewColumnResolver(client, WithColumnResolverTenantColumn(`tenant-id`))
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrInvalidIdentifier)
+}
+
+func TestColumnResolver_DiscoverTenantsNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var resolver *ColumnResolver
+
+ tenants, err := resolver.DiscoverTenants(context.Background())
+ require.Nil(t, tenants)
+ require.ErrorIs(t, err, ErrConnectionRequired)
+}
+
+func TestNewColumnResolver_AppliesTenantDiscoveryTTLOption(t *testing.T) {
+ t.Parallel()
+
+ client := &libPostgres.Client{}
+
+ resolver, err := NewColumnResolver(client, WithColumnResolverTenantDiscoveryTTL(2*time.Minute))
+ require.NoError(t, err)
+ require.Equal(t, 2*time.Minute, resolver.tenantTTL)
+}
+
+func TestColumnResolver_DiscoverTenantsReturnsCachedSnapshot(t *testing.T) {
+ t.Parallel()
+
+ resolver := &ColumnResolver{
+ client: &libPostgres.Client{},
+ tenantTTL: time.Minute,
+ cache: []string{"tenant-a", "tenant-b"},
+ cacheSet: true,
+ cacheUntil: time.Now().UTC().Add(time.Minute),
+ }
+
+ tenants, err := resolver.DiscoverTenants(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, []string{"tenant-a", "tenant-b"}, tenants)
+
+ tenants[0] = "mutated"
+ again, err := resolver.DiscoverTenants(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, []string{"tenant-a", "tenant-b"}, again)
+}
+
+func TestColumnResolver_DiscoverTenantsReturnsCachedEmptySnapshot(t *testing.T) {
+ t.Parallel()
+
+ resolver := &ColumnResolver{
+ client: &libPostgres.Client{},
+ tenantTTL: time.Minute,
+ cache: []string{},
+ cacheSet: true,
+ cacheUntil: time.Now().UTC().Add(time.Minute),
+ }
+
+ tenants, err := resolver.DiscoverTenants(context.Background())
+ require.NoError(t, err)
+ require.Empty(t, tenants)
+}
diff --git a/commons/outbox/postgres/db.go b/commons/outbox/postgres/db.go
new file mode 100644
index 00000000..6d52f54a
--- /dev/null
+++ b/commons/outbox/postgres/db.go
@@ -0,0 +1,45 @@
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+
+ "github.com/bxcodec/dbresolver/v2"
+)
+
+type resolverProvider interface {
+ Resolver(ctx context.Context) (dbresolver.DB, error)
+}
+
+func resolvePrimaryDB(ctx context.Context, client resolverProvider) (*sql.DB, error) {
+ if client == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ value := reflect.ValueOf(client)
+ if value.Kind() == reflect.Pointer && value.IsNil() {
+ return nil, ErrConnectionRequired
+ }
+
+ resolved, err := client.Resolver(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get database connection: %w", err)
+ }
+
+ if resolved == nil {
+ return nil, ErrNoPrimaryDB
+ }
+
+ primaryDBs := resolved.PrimaryDBs()
+ if len(primaryDBs) == 0 {
+ return nil, ErrNoPrimaryDB
+ }
+
+ if primaryDBs[0] == nil {
+ return nil, ErrNoPrimaryDB
+ }
+
+ return primaryDBs[0], nil
+}
diff --git a/commons/outbox/postgres/db_test.go b/commons/outbox/postgres/db_test.go
new file mode 100644
index 00000000..9cf3acab
--- /dev/null
+++ b/commons/outbox/postgres/db_test.go
@@ -0,0 +1,153 @@
+//go:build unit
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "testing"
+ "time"
+
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "github.com/bxcodec/dbresolver/v2"
+ "github.com/stretchr/testify/require"
+)
+
+type resolverProviderFunc func(context.Context) (dbresolver.DB, error)
+
+func (fn resolverProviderFunc) Resolver(ctx context.Context) (dbresolver.DB, error) {
+ return fn(ctx)
+}
+
+type fakeDBResolver struct {
+ primary []*sql.DB
+}
+
+func (resolver fakeDBResolver) Begin() (dbresolver.Tx, error) { return nil, nil }
+
+func (resolver fakeDBResolver) BeginTx(context.Context, *sql.TxOptions) (dbresolver.Tx, error) {
+ return nil, nil
+}
+
+func (resolver fakeDBResolver) Close() error { return nil }
+
+func (resolver fakeDBResolver) Conn(context.Context) (dbresolver.Conn, error) { return nil, nil }
+
+func (resolver fakeDBResolver) Driver() driver.Driver { return nil }
+
+func (resolver fakeDBResolver) Exec(string, ...interface{}) (sql.Result, error) { return nil, nil }
+
+func (resolver fakeDBResolver) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) {
+ return nil, nil
+}
+
+func (resolver fakeDBResolver) Ping() error { return nil }
+
+func (resolver fakeDBResolver) PingContext(context.Context) error { return nil }
+
+func (resolver fakeDBResolver) Prepare(string) (dbresolver.Stmt, error) { return nil, nil }
+
+func (resolver fakeDBResolver) PrepareContext(context.Context, string) (dbresolver.Stmt, error) {
+ return nil, nil
+}
+
+func (resolver fakeDBResolver) Query(string, ...interface{}) (*sql.Rows, error) { return nil, nil }
+
+func (resolver fakeDBResolver) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) {
+ return nil, nil
+}
+
+func (resolver fakeDBResolver) QueryRow(string, ...interface{}) *sql.Row { return nil }
+
+func (resolver fakeDBResolver) QueryRowContext(context.Context, string, ...interface{}) *sql.Row {
+ return nil
+}
+
+func (resolver fakeDBResolver) SetConnMaxIdleTime(time.Duration) {}
+
+func (resolver fakeDBResolver) SetConnMaxLifetime(time.Duration) {}
+
+func (resolver fakeDBResolver) SetMaxIdleConns(int) {}
+
+func (resolver fakeDBResolver) SetMaxOpenConns(int) {}
+
+func (resolver fakeDBResolver) PrimaryDBs() []*sql.DB { return resolver.primary }
+
+func (resolver fakeDBResolver) ReplicaDBs() []*sql.DB { return nil }
+
+func (resolver fakeDBResolver) Stats() sql.DBStats { return sql.DBStats{} }
+
+func TestResolvePrimaryDB_NilClient(t *testing.T) {
+ t.Parallel()
+
+ db, err := resolvePrimaryDB(context.Background(), nil)
+ require.Nil(t, db)
+ require.ErrorIs(t, err, ErrConnectionRequired)
+}
+
+func TestResolvePrimaryDB_NilContext(t *testing.T) {
+ t.Parallel()
+
+ client, err := libPostgres.New(libPostgres.Config{
+ PrimaryDSN: "postgres://localhost:5432/postgres",
+ ReplicaDSN: "postgres://localhost:5432/postgres",
+ })
+ require.NoError(t, err)
+
+ db, err := resolvePrimaryDB(nil, client)
+ require.Nil(t, db)
+ require.Error(t, err)
+ require.ErrorContains(t, err, "failed to get database connection")
+ require.True(t, errors.Is(err, libPostgres.ErrNilContext))
+}
+
+func TestResolvePrimaryDB_ResolverFailure(t *testing.T) {
+ t.Parallel()
+
+ client, err := libPostgres.New(libPostgres.Config{
+ PrimaryDSN: "postgres://invalid:invalid@127.0.0.1:1/postgres",
+ ReplicaDSN: "postgres://invalid:invalid@127.0.0.1:1/postgres",
+ })
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancel()
+
+ db, err := resolvePrimaryDB(ctx, client)
+ require.Nil(t, db)
+ require.ErrorContains(t, err, "failed to get database connection")
+ require.NotErrorIs(t, err, ErrNoPrimaryDB)
+ require.NotErrorIs(t, err, ErrConnectionRequired)
+}
+
+func TestResolvePrimaryDB_NilResolvedDB(t *testing.T) {
+ t.Parallel()
+
+ db, err := resolvePrimaryDB(context.Background(), resolverProviderFunc(func(context.Context) (dbresolver.DB, error) {
+ return nil, nil
+ }))
+ require.Nil(t, db)
+ require.ErrorIs(t, err, ErrNoPrimaryDB)
+}
+
+func TestResolvePrimaryDB_EmptyPrimaryDBs(t *testing.T) {
+ t.Parallel()
+
+ db, err := resolvePrimaryDB(context.Background(), resolverProviderFunc(func(context.Context) (dbresolver.DB, error) {
+ return fakeDBResolver{primary: []*sql.DB{}}, nil
+ }))
+ require.Nil(t, db)
+ require.ErrorIs(t, err, ErrNoPrimaryDB)
+}
+
+func TestResolvePrimaryDB_NilPrimaryDBEntry(t *testing.T) {
+ t.Parallel()
+
+ db, err := resolvePrimaryDB(context.Background(), resolverProviderFunc(func(context.Context) (dbresolver.DB, error) {
+ return fakeDBResolver{primary: []*sql.DB{nil}}, nil
+ }))
+ require.Nil(t, db)
+ require.ErrorIs(t, err, ErrNoPrimaryDB)
+}
diff --git a/commons/outbox/postgres/doc.go b/commons/outbox/postgres/doc.go
new file mode 100644
index 00000000..382049d8
--- /dev/null
+++ b/commons/outbox/postgres/doc.go
@@ -0,0 +1,10 @@
+// Package postgres provides PostgreSQL adapters for outbox repository contracts.
+//
+// Migration files under migrations/ include two mutually exclusive tracks:
+// - schema-per-tenant in migrations/
+// - column-per-tenant in migrations/column/
+// Choose one strategy per deployment.
+//
+// SchemaResolver enforces non-empty tenant context by default. Use
+// WithAllowEmptyTenant only for explicit single-tenant/public-schema flows.
+package postgres
diff --git a/commons/outbox/postgres/migrations/000001_outbox_events_schema.down.sql b/commons/outbox/postgres/migrations/000001_outbox_events_schema.down.sql
new file mode 100644
index 00000000..a94db5b9
--- /dev/null
+++ b/commons/outbox/postgres/migrations/000001_outbox_events_schema.down.sql
@@ -0,0 +1,2 @@
+DROP TABLE IF EXISTS outbox_events;
+DROP TYPE IF EXISTS outbox_event_status;
diff --git a/commons/outbox/postgres/migrations/000001_outbox_events_schema.up.sql b/commons/outbox/postgres/migrations/000001_outbox_events_schema.up.sql
new file mode 100644
index 00000000..419b325b
--- /dev/null
+++ b/commons/outbox/postgres/migrations/000001_outbox_events_schema.up.sql
@@ -0,0 +1,37 @@
+-- Schema-per-tenant outbox_events table template.
+-- Apply this migration inside each tenant schema.
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_type t
+ INNER JOIN pg_namespace n ON n.oid = t.typnamespace
+ WHERE t.typname = 'outbox_event_status'
+ AND n.nspname = current_schema()
+ ) THEN
+ CREATE TYPE outbox_event_status AS ENUM ('PENDING', 'PROCESSING', 'PUBLISHED', 'FAILED', 'INVALID');
+ END IF;
+END $$;
+
+CREATE TABLE IF NOT EXISTS outbox_events (
+ id UUID PRIMARY KEY,
+ event_type VARCHAR(255) NOT NULL,
+ aggregate_id UUID NOT NULL,
+ payload JSONB NOT NULL,
+ status outbox_event_status NOT NULL DEFAULT 'PENDING',
+ attempts INT NOT NULL DEFAULT 0,
+ published_at TIMESTAMPTZ NULL,
+ last_error VARCHAR(512),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_outbox_events_status_created_at
+ ON outbox_events (status, created_at ASC);
+
+CREATE INDEX IF NOT EXISTS idx_outbox_events_status_updated_at
+ ON outbox_events (status, updated_at ASC);
+
+CREATE INDEX IF NOT EXISTS idx_outbox_events_event_type_status_created_at
+ ON outbox_events (event_type, status, created_at ASC);
diff --git a/commons/outbox/postgres/migrations/README.md b/commons/outbox/postgres/migrations/README.md
new file mode 100644
index 00000000..df2d753e
--- /dev/null
+++ b/commons/outbox/postgres/migrations/README.md
@@ -0,0 +1,13 @@
+# Outbox migrations
+
+This directory contains two alternative migration tracks for `outbox_events`:
+
+- `000001_outbox_events_schema.*.sql`: schema-per-tenant strategy (default track in this directory)
+- `column/000001_outbox_events_column.*.sql`: column-per-tenant strategy (`tenant_id`)
+
+Use exactly one track for a given deployment topology.
+
+- For schema-per-tenant deployments, point migrations to this directory.
+- For column-per-tenant deployments, point migrations to `migrations/column`.
+
+Column track note: primary key is `(tenant_id, id)` to avoid cross-tenant key coupling.
diff --git a/commons/outbox/postgres/migrations/column/000001_outbox_events_column.down.sql b/commons/outbox/postgres/migrations/column/000001_outbox_events_column.down.sql
new file mode 100644
index 00000000..a94db5b9
--- /dev/null
+++ b/commons/outbox/postgres/migrations/column/000001_outbox_events_column.down.sql
@@ -0,0 +1,2 @@
+DROP TABLE IF EXISTS outbox_events;
+DROP TYPE IF EXISTS outbox_event_status;
diff --git a/commons/outbox/postgres/migrations/column/000001_outbox_events_column.up.sql b/commons/outbox/postgres/migrations/column/000001_outbox_events_column.up.sql
new file mode 100644
index 00000000..5b84a437
--- /dev/null
+++ b/commons/outbox/postgres/migrations/column/000001_outbox_events_column.up.sql
@@ -0,0 +1,32 @@
+-- Column-per-tenant outbox_events table template.
+-- Apply this migration once in a shared schema.
+
+DO $$ BEGIN
+ CREATE TYPE outbox_event_status AS ENUM ('PENDING', 'PROCESSING', 'PUBLISHED', 'FAILED', 'INVALID');
+EXCEPTION
+ WHEN duplicate_object THEN null;
+END $$;
+
+CREATE TABLE IF NOT EXISTS outbox_events (
+ id UUID NOT NULL,
+ tenant_id TEXT NOT NULL,
+ event_type VARCHAR(255) NOT NULL,
+ aggregate_id UUID NOT NULL,
+ payload JSONB NOT NULL,
+ status outbox_event_status NOT NULL DEFAULT 'PENDING',
+ attempts INT NOT NULL DEFAULT 0,
+ published_at TIMESTAMPTZ NULL,
+ last_error VARCHAR(512),
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ PRIMARY KEY (tenant_id, id)
+);
+
+CREATE INDEX IF NOT EXISTS idx_outbox_events_tenant_status_created_at
+ ON outbox_events (tenant_id, status, created_at ASC);
+
+CREATE INDEX IF NOT EXISTS idx_outbox_events_tenant_status_updated_at
+ ON outbox_events (tenant_id, status, updated_at ASC);
+
+CREATE INDEX IF NOT EXISTS idx_outbox_events_tenant_event_type_status_created_at
+ ON outbox_events (tenant_id, event_type, status, created_at ASC);
diff --git a/commons/outbox/postgres/repository.go b/commons/outbox/postgres/repository.go
new file mode 100644
index 00000000..78ac87b8
--- /dev/null
+++ b/commons/outbox/postgres/repository.go
@@ -0,0 +1,1539 @@
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "regexp"
+ "strings"
+ "time"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/outbox"
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "github.com/google/uuid"
+)
+
+const maxSQLIdentifierLength = 63
+
+var (
+ ErrConnectionRequired = errors.New("postgres connection is required")
+ ErrTransactionRequired = errors.New("postgres transaction is required")
+ ErrStateTransitionConflict = errors.New("outbox event state transition conflict")
+ ErrRepositoryNotInitialized = errors.New("outbox repository not initialized")
+ ErrLimitMustBePositive = errors.New("limit must be greater than zero")
+ ErrIDRequired = errors.New("id is required")
+ ErrAggregateIDRequired = errors.New("aggregate id is required")
+ ErrMaxAttemptsMustBePositive = errors.New("maxAttempts must be greater than zero")
+ ErrEventTypeRequired = errors.New("event type is required")
+ ErrTenantResolverRequired = errors.New("tenant resolver is required")
+ ErrTenantDiscovererRequired = errors.New("tenant discoverer is required")
+ ErrNoPrimaryDB = errors.New("no primary database configured for tenant transaction")
+ ErrInvalidIdentifier = errors.New("invalid sql identifier")
+ identifierPattern = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
+ defaultTransactionTimeout = 30 * time.Second
+ outboxColumns = "id, event_type, aggregate_id, payload, status, attempts, published_at, last_error, created_at, updated_at"
+)
+
+type tenantColumnProvider interface {
+ TenantColumn() string
+}
+
+type tenantRequirementProvider interface {
+ RequiresTenant() bool
+}
+
+type Option func(*Repository)
+
+func WithLogger(logger libLog.Logger) Option {
+ return func(repo *Repository) {
+ if nilcheck.Interface(logger) {
+ return
+ }
+
+ repo.logger = logger
+ }
+}
+
+func WithTableName(tableName string) Option {
+ return func(repo *Repository) {
+ repo.tableName = tableName
+ }
+}
+
+func WithTenantColumn(tenantColumn string) Option {
+ return func(repo *Repository) {
+ repo.tenantColumn = tenantColumn
+ }
+}
+
+func WithTransactionTimeout(timeout time.Duration) Option {
+ return func(repo *Repository) {
+ if timeout > 0 {
+ repo.transactionTimeout = timeout
+ }
+ }
+}
+
+// Repository persists outbox events in PostgreSQL.
+type Repository struct {
+ client *libPostgres.Client
+ tenantResolver outbox.TenantResolver
+ tenantDiscoverer outbox.TenantDiscoverer
+ primaryDBLookup func(context.Context) (*sql.DB, error)
+ requireTenant bool
+ logger libLog.Logger
+ tableName string
+ tenantColumn string
+ transactionTimeout time.Duration
+}
+
+// NewRepository creates a PostgreSQL outbox repository.
+func NewRepository(
+ client *libPostgres.Client,
+ tenantResolver outbox.TenantResolver,
+ tenantDiscoverer outbox.TenantDiscoverer,
+ opts ...Option,
+) (*Repository, error) {
+ if client == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ if nilcheck.Interface(tenantResolver) {
+ return nil, ErrTenantResolverRequired
+ }
+
+ if nilcheck.Interface(tenantDiscoverer) {
+ return nil, ErrTenantDiscovererRequired
+ }
+
+ repo := &Repository{
+ client: client,
+ tenantResolver: tenantResolver,
+ tenantDiscoverer: tenantDiscoverer,
+ logger: libLog.NewNop(),
+ tableName: "outbox_events",
+ transactionTimeout: defaultTransactionTimeout,
+ }
+
+ if provider, ok := tenantResolver.(tenantColumnProvider); ok {
+ repo.tenantColumn = provider.TenantColumn()
+ }
+
+ if provider, ok := tenantResolver.(tenantRequirementProvider); ok {
+ repo.requireTenant = provider.RequiresTenant()
+ }
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(repo)
+ }
+ }
+
+ if nilcheck.Interface(repo.logger) {
+ repo.logger = libLog.NewNop()
+ }
+
+ repo.tableName = strings.TrimSpace(repo.tableName)
+ if repo.tableName == "" {
+ repo.tableName = "outbox_events"
+ }
+
+ repo.tenantColumn = strings.TrimSpace(repo.tenantColumn)
+
+ if err := validateIdentifierPath(repo.tableName); err != nil {
+ return nil, fmt.Errorf("table name: %w", err)
+ }
+
+ if repo.tenantColumn != "" {
+ if err := validateIdentifier(repo.tenantColumn); err != nil {
+ return nil, fmt.Errorf("tenant column: %w", err)
+ }
+ }
+
+ return repo, nil
+}
+
+// GetByID retrieves an outbox event by id.
+func (repo *Repository) GetByID(ctx context.Context, id uuid.UUID) (*outbox.OutboxEvent, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ if id == uuid.Nil {
+ return nil, ErrIDRequired
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.get_outbox_by_id")
+ defer span.End()
+
+ result, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) (*outbox.OutboxEvent, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "SELECT " + outboxColumns + " FROM " + table + " WHERE id = $1" // #nosec G202 -- table name validated at construction via validateIdentifierPath; quoteIdentifierPath escapes identifiers
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(2, tenantID)
+ if filterErr != nil {
+ return nil, filterErr
+ }
+
+ args := make([]any, 0, 1+len(filterArgs))
+ args = append(args, id)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ row := tx.QueryRowContext(ctx, query, args...)
+
+ return scanOutboxEvent(row)
+ })
+ if err != nil {
+ if !errors.Is(err, sql.ErrNoRows) {
+ libOpentelemetry.HandleSpanError(span, "failed to get outbox event", err)
+ logSanitizedError(logger, ctx, "failed to get outbox event", err)
+ }
+
+ return nil, fmt.Errorf("getting outbox event: %w", err)
+ }
+
+ return result, nil
+}
+
+// Create stores a new outbox event using a new transaction.
+func (repo *Repository) Create(ctx context.Context, event *outbox.OutboxEvent) (*outbox.OutboxEvent, error) {
+ return repo.create(ctx, nil, event)
+}
+
+// CreateWithTx stores a new outbox event using an existing transaction.
+func (repo *Repository) CreateWithTx(
+ ctx context.Context,
+ tx outbox.Tx,
+ event *outbox.OutboxEvent,
+) (*outbox.OutboxEvent, error) {
+ return repo.create(ctx, tx, event)
+}
+
+func (repo *Repository) create(
+ ctx context.Context,
+ tx *sql.Tx,
+ event *outbox.OutboxEvent,
+) (*outbox.OutboxEvent, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ if err := validateCreateEvent(event); err != nil {
+ return nil, err
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.create_outbox_event")
+ defer span.End()
+
+ result, err := withTenantTxOrExisting(repo, ctx, tx, func(execTx *sql.Tx) (*outbox.OutboxEvent, error) {
+ createValues := normalizedCreateValues(event, time.Now().UTC())
+ table := quoteIdentifierPath(repo.tableName)
+ query := "INSERT INTO " + table + // #nosec G202 -- table name validated at construction; quoteIdentifierPath escapes identifiers
+ " (id, event_type, aggregate_id, payload, status, attempts, published_at, last_error, created_at, updated_at"
+
+ args := []any{
+ createValues.id,
+ createValues.eventType,
+ createValues.aggregateID,
+ createValues.payload,
+ createValues.status,
+ createValues.attempts,
+ createValues.publishedAt,
+ createValues.lastError,
+ createValues.createdAt,
+ createValues.updatedAt,
+ }
+
+ if repo.tenantColumn != "" {
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ query += ", " + quoteIdentifier(repo.tenantColumn)
+
+ args = append(args, tenantID)
+ }
+
+ var placeholders strings.Builder
+
+ for i := range args {
+ if i > 0 {
+ placeholders.WriteString(", ")
+ }
+
+ fmt.Fprintf(&placeholders, "$%d", i+1)
+ }
+
+ query += ") VALUES (" + placeholders.String() + ") RETURNING " + outboxColumns
+
+ row := execTx.QueryRowContext(ctx, query, args...)
+
+ return scanOutboxEvent(row)
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to create outbox event", err)
+ logSanitizedError(logger, ctx, "failed to create outbox event", err)
+
+ return nil, fmt.Errorf("creating outbox event: %w", err)
+ }
+
+ return result, nil
+}
+
+// ListPending retrieves pending outbox events up to the given limit.
+func (repo *Repository) ListPending(ctx context.Context, limit int) ([]*outbox.OutboxEvent, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ if limit <= 0 {
+ return nil, ErrLimitMustBePositive
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.list_outbox_pending")
+ defer span.End()
+
+ result, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) ([]*outbox.OutboxEvent, error) {
+ events, err := repo.listPendingRows(ctx, tx, limit)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(events) == 0 {
+ return events, nil
+ }
+
+ ids := collectEventIDs(events)
+ if len(ids) == 0 {
+ return events, nil
+ }
+
+ now := time.Now().UTC()
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ if err := repo.markEventsProcessing(ctx, tx, now, ids, tenantID, outbox.OutboxStatusPending); err != nil {
+ return nil, err
+ }
+
+ applyProcessingState(events, now)
+
+ return events, nil
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to list outbox events", err)
+ logSanitizedError(logger, ctx, "failed to list outbox events", err)
+
+ return nil, fmt.Errorf("listing pending events: %w", err)
+ }
+
+ return result, nil
+}
+
+// ListPendingByType retrieves pending outbox events filtered by event type.
+func (repo *Repository) ListPendingByType(
+ ctx context.Context,
+ eventType string,
+ limit int,
+) ([]*outbox.OutboxEvent, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ if limit <= 0 {
+ return nil, ErrLimitMustBePositive
+ }
+
+ eventType = strings.TrimSpace(eventType)
+
+ if eventType == "" {
+ return nil, ErrEventTypeRequired
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.list_outbox_pending_by_type")
+ defer span.End()
+
+ result, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) ([]*outbox.OutboxEvent, error) {
+ events, err := repo.listPendingByTypeRows(ctx, tx, eventType, limit)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(events) == 0 {
+ return events, nil
+ }
+
+ ids := collectEventIDs(events)
+ if len(ids) == 0 {
+ return events, nil
+ }
+
+ now := time.Now().UTC()
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ if err := repo.markEventsProcessing(ctx, tx, now, ids, tenantID, outbox.OutboxStatusPending); err != nil {
+ return nil, err
+ }
+
+ applyProcessingState(events, now)
+
+ return events, nil
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to list outbox events by type", err)
+ logSanitizedError(logger, ctx, "failed to list outbox events by type", err)
+
+ return nil, fmt.Errorf("listing pending events by type: %w", err)
+ }
+
+ return result, nil
+}
+
+// ListTenants returns tenant IDs discovered by the configured discoverer.
+func (repo *Repository) ListTenants(ctx context.Context) ([]string, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.list_outbox_tenants")
+ defer span.End()
+
+ tenants, err := repo.tenantDiscoverer.DiscoverTenants(ctx)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to list tenant schemas", err)
+ logSanitizedError(logger, ctx, "failed to list tenant schemas", err)
+
+ return nil, fmt.Errorf("list tenant schemas: %w", err)
+ }
+
+ return tenants, nil
+}
+
+// MarkPublished marks an outbox event as published.
+func (repo *Repository) MarkPublished(ctx context.Context, id uuid.UUID, publishedAt time.Time) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return ErrRepositoryNotInitialized
+ }
+
+ if err := outbox.ValidateOutboxTransition(outbox.OutboxStatusProcessing, outbox.OutboxStatusPublished); err != nil {
+ return fmt.Errorf("mark published transition: %w", err)
+ }
+
+ if id == uuid.Nil {
+ return ErrIDRequired
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.mark_outbox_published")
+ defer span.End()
+
+ _, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) (struct{}, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "UPDATE " + table + " SET status = $1::outbox_event_status, published_at = $2, updated_at = $3 " + // #nosec G202 -- table name validated at construction; quoteIdentifierPath escapes identifiers
+ "WHERE id = $4 AND status = $5::outbox_event_status"
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return struct{}{}, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(6, tenantID)
+ if filterErr != nil {
+ return struct{}{}, filterErr
+ }
+
+ args := make([]any, 0, 5+len(filterArgs))
+ args = append(args, outbox.OutboxStatusPublished, publishedAt, time.Now().UTC(), id, outbox.OutboxStatusProcessing)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ result, execErr := tx.ExecContext(ctx, query, args...)
+ if execErr != nil {
+ return struct{}{}, fmt.Errorf("executing update: %w", execErr)
+ }
+
+ if err := ensureRowsAffected(result); err != nil {
+ return struct{}{}, err
+ }
+
+ return struct{}{}, nil
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to mark outbox published", err)
+ logSanitizedError(logger, ctx, "failed to mark outbox published", err)
+
+ return fmt.Errorf("marking published: %w", err)
+ }
+
+ return nil
+}
+
+// MarkFailed marks an outbox event as failed and may transition to invalid.
+func (repo *Repository) MarkFailed(ctx context.Context, id uuid.UUID, errMsg string, maxAttempts int) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return ErrRepositoryNotInitialized
+ }
+
+ if err := outbox.ValidateOutboxTransition(outbox.OutboxStatusProcessing, outbox.OutboxStatusFailed); err != nil {
+ return fmt.Errorf("mark failed transition: %w", err)
+ }
+
+ if err := outbox.ValidateOutboxTransition(outbox.OutboxStatusProcessing, outbox.OutboxStatusInvalid); err != nil {
+ return fmt.Errorf("mark failed->invalid transition: %w", err)
+ }
+
+ if id == uuid.Nil {
+ return ErrIDRequired
+ }
+
+ if maxAttempts <= 0 {
+ return ErrMaxAttemptsMustBePositive
+ }
+
+ errMsg = outbox.SanitizeErrorMessageForStorage(errMsg)
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.mark_outbox_failed")
+ defer span.End()
+
+ _, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) (struct{}, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "UPDATE " + table + " SET " + // #nosec G202 -- table name validated at construction; quoteIdentifierPath escapes identifiers
+ "status = CASE WHEN attempts + 1 >= $1 THEN $2 ELSE $3 END::outbox_event_status, " +
+ "attempts = attempts + 1, " +
+ "last_error = CASE WHEN attempts + 1 >= $1 THEN $4 ELSE $5 END, " +
+ "updated_at = $6 WHERE id = $7 AND status = $8::outbox_event_status"
+
+ args := []any{
+ maxAttempts,
+ outbox.OutboxStatusInvalid,
+ outbox.OutboxStatusFailed,
+ "max dispatch attempts exceeded",
+ errMsg,
+ time.Now().UTC(),
+ id,
+ outbox.OutboxStatusProcessing,
+ }
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return struct{}{}, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(9, tenantID)
+ if filterErr != nil {
+ return struct{}{}, filterErr
+ }
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ result, execErr := tx.ExecContext(ctx, query, args...)
+ if execErr != nil {
+ return struct{}{}, fmt.Errorf("executing update: %w", execErr)
+ }
+
+ if err := ensureRowsAffected(result); err != nil {
+ return struct{}{}, err
+ }
+
+ return struct{}{}, nil
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to mark outbox failed", err)
+ logSanitizedError(logger, ctx, "failed to mark outbox failed", err)
+
+ return fmt.Errorf("marking failed: %w", err)
+ }
+
+ return nil
+}
+
+// ListFailedForRetry lists failed events eligible for retry.
+func (repo *Repository) ListFailedForRetry(
+ ctx context.Context,
+ limit int,
+ failedBefore time.Time,
+ maxAttempts int,
+) ([]*outbox.OutboxEvent, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ if limit <= 0 {
+ return nil, ErrLimitMustBePositive
+ }
+
+ if maxAttempts <= 0 {
+ return nil, ErrMaxAttemptsMustBePositive
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.list_failed_for_retry")
+ defer span.End()
+
+ result, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) ([]*outbox.OutboxEvent, error) {
+ return repo.listFailedForRetryRows(ctx, tx, limit, failedBefore, maxAttempts, false)
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to list failed events for retry", err)
+ logSanitizedError(logger, ctx, "failed to list failed events for retry", err)
+
+ return nil, fmt.Errorf("listing failed events for retry: %w", err)
+ }
+
+ return result, nil
+}
+
+// ResetForRetry atomically selects and resets failed events to processing.
+func (repo *Repository) ResetForRetry(
+ ctx context.Context,
+ limit int,
+ failedBefore time.Time,
+ maxAttempts int,
+) ([]*outbox.OutboxEvent, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ if limit <= 0 {
+ return nil, ErrLimitMustBePositive
+ }
+
+ if maxAttempts <= 0 {
+ return nil, ErrMaxAttemptsMustBePositive
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.reset_for_retry")
+ defer span.End()
+
+ result, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) ([]*outbox.OutboxEvent, error) {
+ events, err := repo.listFailedForRetryRows(ctx, tx, limit, failedBefore, maxAttempts, true)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(events) == 0 {
+ return events, nil
+ }
+
+ ids := collectEventIDs(events)
+ if len(ids) == 0 {
+ return events, nil
+ }
+
+ now := time.Now().UTC()
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ if err := repo.markEventsProcessing(ctx, tx, now, ids, tenantID, outbox.OutboxStatusFailed); err != nil {
+ return nil, err
+ }
+
+ applyProcessingState(events, now)
+
+ return events, nil
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to reset events for retry", err)
+ logSanitizedError(logger, ctx, "failed to reset events for retry", err)
+
+ return nil, fmt.Errorf("resetting events for retry: %w", err)
+ }
+
+ return result, nil
+}
+
+// ResetStuckProcessing reclaims long-running processing events.
+func (repo *Repository) ResetStuckProcessing(
+ ctx context.Context,
+ limit int,
+ processingBefore time.Time,
+ maxAttempts int,
+) ([]*outbox.OutboxEvent, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return nil, ErrRepositoryNotInitialized
+ }
+
+ if limit <= 0 {
+ return nil, ErrLimitMustBePositive
+ }
+
+ if maxAttempts <= 0 {
+ return nil, ErrMaxAttemptsMustBePositive
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.reset_outbox_processing")
+ defer span.End()
+
+ result, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) ([]*outbox.OutboxEvent, error) {
+ events, err := repo.listStuckProcessingRows(ctx, tx, limit, processingBefore)
+ if err != nil {
+ return nil, err
+ }
+
+ if len(events) == 0 {
+ return events, nil
+ }
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ retryEvents, exhaustedIDs := splitStuckEvents(events, maxAttempts)
+ now := time.Now().UTC()
+
+ retryIDs := collectEventIDs(retryEvents)
+ if len(retryIDs) > 0 {
+ if err := repo.markStuckEventsReprocessing(ctx, tx, now, retryIDs, tenantID); err != nil {
+ return nil, err
+ }
+
+ applyStuckReprocessingState(retryEvents, now)
+ }
+
+ if len(exhaustedIDs) > 0 {
+ if err := repo.markStuckEventsInvalid(ctx, tx, now, exhaustedIDs, tenantID); err != nil {
+ return nil, err
+ }
+ }
+
+ return retryEvents, nil
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to reset stuck events", err)
+ logSanitizedError(logger, ctx, "failed to reset stuck events", err)
+
+ return nil, fmt.Errorf("reset stuck events: %w", err)
+ }
+
+ return result, nil
+}
+
+// MarkInvalid marks an outbox event as invalid.
+func (repo *Repository) MarkInvalid(ctx context.Context, id uuid.UUID, errMsg string) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if !repo.initialized() {
+ return ErrRepositoryNotInitialized
+ }
+
+ if err := outbox.ValidateOutboxTransition(outbox.OutboxStatusProcessing, outbox.OutboxStatusInvalid); err != nil {
+ return fmt.Errorf("mark invalid transition: %w", err)
+ }
+
+ if id == uuid.Nil {
+ return ErrIDRequired
+ }
+
+ errMsg = outbox.SanitizeErrorMessageForStorage(errMsg)
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "postgres.mark_outbox_invalid")
+ defer span.End()
+
+ _, err := withTenantTxOrExisting(repo, ctx, nil, func(tx *sql.Tx) (struct{}, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "UPDATE " + table + " SET status = $1::outbox_event_status, last_error = $2, updated_at = $3 " + // #nosec G202 -- table name validated at construction; quoteIdentifierPath escapes identifiers
+ "WHERE id = $4 AND status = $5::outbox_event_status"
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return struct{}{}, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(6, tenantID)
+ if filterErr != nil {
+ return struct{}{}, filterErr
+ }
+
+ args := make([]any, 0, 5+len(filterArgs))
+ args = append(args, outbox.OutboxStatusInvalid, errMsg, time.Now().UTC(), id, outbox.OutboxStatusProcessing)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ result, execErr := tx.ExecContext(ctx, query, args...)
+ if execErr != nil {
+ return struct{}{}, fmt.Errorf("executing update: %w", execErr)
+ }
+
+ if err := ensureRowsAffected(result); err != nil {
+ return struct{}{}, err
+ }
+
+ return struct{}{}, nil
+ })
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "failed to mark outbox invalid", err)
+ logSanitizedError(logger, ctx, "failed to mark outbox invalid", err)
+
+ return fmt.Errorf("marking invalid: %w", err)
+ }
+
+ return nil
+}
+
+func (repo *Repository) listPendingRows(ctx context.Context, tx *sql.Tx, limit int) ([]*outbox.OutboxEvent, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "SELECT " + outboxColumns + " FROM " + table + " WHERE status = $1"
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(2, tenantID)
+ if filterErr != nil {
+ return nil, filterErr
+ }
+
+ args := make([]any, 0, 1+len(filterArgs)+1)
+ args = append(args, outbox.OutboxStatusPending)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+ query += fmt.Sprintf(" ORDER BY created_at ASC LIMIT $%d FOR UPDATE SKIP LOCKED", len(args)+1)
+ args = append(args, limit)
+
+ return queryOutboxEvents(ctx, tx, query, args, limit, "querying pending events")
+}
+
+func (repo *Repository) listPendingByTypeRows(
+ ctx context.Context,
+ tx *sql.Tx,
+ eventType string,
+ limit int,
+) ([]*outbox.OutboxEvent, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "SELECT " + outboxColumns + " FROM " + table + " WHERE status = $1 AND event_type = $2"
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(3, tenantID)
+ if filterErr != nil {
+ return nil, filterErr
+ }
+
+ args := make([]any, 0, 2+len(filterArgs)+1)
+ args = append(args, outbox.OutboxStatusPending, eventType)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ query += fmt.Sprintf(" ORDER BY created_at ASC LIMIT $%d FOR UPDATE SKIP LOCKED", len(args)+1)
+ args = append(args, limit)
+
+ return queryOutboxEvents(ctx, tx, query, args, limit, "querying pending events by type")
+}
+
+func (repo *Repository) listFailedForRetryRows(
+ ctx context.Context,
+ tx *sql.Tx,
+ limit int,
+ failedBefore time.Time,
+ maxAttempts int,
+ forUpdate bool,
+) ([]*outbox.OutboxEvent, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "SELECT " + outboxColumns + " FROM " + table +
+ " WHERE status = $1 AND attempts < $2 AND updated_at <= $3"
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(4, tenantID)
+ if filterErr != nil {
+ return nil, filterErr
+ }
+
+ args := make([]any, 0, 3+len(filterArgs)+1)
+ args = append(args, outbox.OutboxStatusFailed, maxAttempts, failedBefore)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+ query += fmt.Sprintf(" ORDER BY updated_at ASC LIMIT $%d", len(args)+1)
+ args = append(args, limit)
+
+ if forUpdate {
+ query += " FOR UPDATE SKIP LOCKED"
+ }
+
+ return queryOutboxEvents(ctx, tx, query, args, limit, "querying failed events for retry")
+}
+
+func (repo *Repository) listStuckProcessingRows(
+ ctx context.Context,
+ tx *sql.Tx,
+ limit int,
+ processingBefore time.Time,
+) ([]*outbox.OutboxEvent, error) {
+ table := quoteIdentifierPath(repo.tableName)
+ query := "SELECT " + outboxColumns + " FROM " + table +
+ " WHERE status = $1 AND updated_at <= $2"
+
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return nil, tenantErr
+ }
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(3, tenantID)
+ if filterErr != nil {
+ return nil, filterErr
+ }
+
+ args := make([]any, 0, 2+len(filterArgs)+1)
+ args = append(args, outbox.OutboxStatusProcessing, processingBefore)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+ query += fmt.Sprintf(" ORDER BY updated_at ASC LIMIT $%d FOR UPDATE SKIP LOCKED", len(args)+1)
+ args = append(args, limit)
+
+ return queryOutboxEvents(ctx, tx, query, args, limit, "querying stuck events")
+}
+
+func (repo *Repository) markEventsProcessing(
+ ctx context.Context,
+ tx *sql.Tx,
+ now time.Time,
+ ids []uuid.UUID,
+ tenantID string,
+ fromStatus string,
+) error {
+ return repo.markEventsWithStatus(
+ ctx,
+ tx,
+ now,
+ outbox.OutboxStatusProcessing,
+ ids,
+ tenantID,
+ fromStatus,
+ )
+}
+
+func (repo *Repository) markEventsWithStatus(
+ ctx context.Context,
+ tx *sql.Tx,
+ now time.Time,
+ status string,
+ ids []uuid.UUID,
+ tenantID string,
+ fromStatus string,
+) error {
+ if err := outbox.ValidateOutboxTransition(fromStatus, status); err != nil {
+ return fmt.Errorf("status transition: %w", err)
+ }
+
+ table := quoteIdentifierPath(repo.tableName)
+ query := "UPDATE " + table + // #nosec G202 -- table name validated at construction; quoteIdentifierPath escapes identifiers
+ " SET status = $1::outbox_event_status, updated_at = $2 WHERE id = ANY($3::uuid[]) AND status = $4::outbox_event_status"
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(5, tenantID)
+ if filterErr != nil {
+ return filterErr
+ }
+
+ args := make([]any, 0, 4+len(filterArgs))
+ args = append(args, status, now, ids, fromStatus)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ result, err := tx.ExecContext(ctx, query, args...)
+ if err != nil {
+ return fmt.Errorf("updating status to %s: %w", status, err)
+ }
+
+ if err := ensureRowsAffectedExact(result, int64(len(ids))); err != nil {
+ return fmt.Errorf("updating status to %s: %w", status, err)
+ }
+
+ return nil
+}
+
+func (repo *Repository) markStuckEventsReprocessing(
+ ctx context.Context,
+ tx *sql.Tx,
+ now time.Time,
+ ids []uuid.UUID,
+ tenantID string,
+) error {
+ if err := outbox.ValidateOutboxTransition(outbox.OutboxStatusProcessing, outbox.OutboxStatusProcessing); err != nil {
+ return fmt.Errorf("stuck reprocessing transition: %w", err)
+ }
+
+ // Intentionally keep PROCESSING -> PROCESSING while incrementing attempts.
+ // If we flipped to PENDING before returning rows to the caller, another
+ // dispatcher could acquire and publish the same event immediately after this
+ // transaction commits. Keeping PROCESSING narrows duplicate publication windows
+ // to later stuck-recovery cycles.
+ table := quoteIdentifierPath(repo.tableName)
+ query := "UPDATE " + table + // #nosec G202 -- table name validated at construction; quoteIdentifierPath escapes identifiers
+ " SET status = $1::outbox_event_status, attempts = attempts + 1, updated_at = $2 " +
+ "WHERE id = ANY($3::uuid[]) AND status = $4::outbox_event_status"
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(5, tenantID)
+ if filterErr != nil {
+ return filterErr
+ }
+
+ args := make([]any, 0, 4+len(filterArgs))
+ args = append(args, outbox.OutboxStatusProcessing, now, ids, outbox.OutboxStatusProcessing)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ result, err := tx.ExecContext(ctx, query, args...)
+ if err != nil {
+ return fmt.Errorf("updating stuck events to processing: %w", err)
+ }
+
+ if err := ensureRowsAffectedExact(result, int64(len(ids))); err != nil {
+ return fmt.Errorf("updating stuck events to processing: %w", err)
+ }
+
+ return nil
+}
+
+func (repo *Repository) markStuckEventsInvalid(
+ ctx context.Context,
+ tx *sql.Tx,
+ now time.Time,
+ ids []uuid.UUID,
+ tenantID string,
+) error {
+ if err := outbox.ValidateOutboxTransition(outbox.OutboxStatusProcessing, outbox.OutboxStatusInvalid); err != nil {
+ return fmt.Errorf("stuck invalid transition: %w", err)
+ }
+
+ table := quoteIdentifierPath(repo.tableName)
+ query := "UPDATE " + table + // #nosec G202 -- table name validated at construction; quoteIdentifierPath escapes identifiers
+ " SET status = $1::outbox_event_status, attempts = attempts + 1, " +
+ "last_error = $2, updated_at = $3 WHERE id = ANY($4::uuid[]) AND status = $5::outbox_event_status"
+
+ filter, filterArgs, filterErr := repo.tenantFilterClause(6, tenantID)
+ if filterErr != nil {
+ return filterErr
+ }
+
+ args := make([]any, 0, 5+len(filterArgs))
+ args = append(args, outbox.OutboxStatusInvalid, "max dispatch attempts exceeded", now, ids, outbox.OutboxStatusProcessing)
+
+ query += filter
+
+ args = append(args, filterArgs...)
+
+ result, err := tx.ExecContext(ctx, query, args...)
+ if err != nil {
+ return fmt.Errorf("updating stuck events to invalid: %w", err)
+ }
+
+ if err := ensureRowsAffectedExact(result, int64(len(ids))); err != nil {
+ return fmt.Errorf("updating stuck events to invalid: %w", err)
+ }
+
+ return nil
+}
+
+func splitStuckEvents(events []*outbox.OutboxEvent, maxAttempts int) ([]*outbox.OutboxEvent, []uuid.UUID) {
+ retryEvents := make([]*outbox.OutboxEvent, 0, len(events))
+ exhaustedIDs := make([]uuid.UUID, 0)
+
+ for _, event := range events {
+ if event == nil || event.ID == uuid.Nil {
+ continue
+ }
+
+ if event.Attempts+1 >= maxAttempts {
+ exhaustedIDs = append(exhaustedIDs, event.ID)
+
+ continue
+ }
+
+ retryEvents = append(retryEvents, event)
+ }
+
+ return retryEvents, exhaustedIDs
+}
+
+func applyStuckReprocessingState(events []*outbox.OutboxEvent, now time.Time) {
+ for _, event := range events {
+ if event == nil {
+ continue
+ }
+
+ event.Attempts++
+ event.Status = outbox.OutboxStatusProcessing
+ event.UpdatedAt = now
+ }
+}
+
+func collectEventIDs(events []*outbox.OutboxEvent) []uuid.UUID {
+ ids := make([]uuid.UUID, 0, len(events))
+
+ for _, event := range events {
+ if event == nil || event.ID == uuid.Nil {
+ continue
+ }
+
+ ids = append(ids, event.ID)
+ }
+
+ return ids
+}
+
+func applyProcessingState(events []*outbox.OutboxEvent, now time.Time) {
+ for _, event := range events {
+ if event == nil {
+ continue
+ }
+
+ event.Status = outbox.OutboxStatusProcessing
+ event.UpdatedAt = now
+ }
+}
+
+func scanOutboxEvent(scanner interface{ Scan(dest ...any) error }) (*outbox.OutboxEvent, error) {
+ var event outbox.OutboxEvent
+
+ var lastError sql.NullString
+
+ if err := scanner.Scan(
+ &event.ID,
+ &event.EventType,
+ &event.AggregateID,
+ &event.Payload,
+ &event.Status,
+ &event.Attempts,
+ &event.PublishedAt,
+ &lastError,
+ &event.CreatedAt,
+ &event.UpdatedAt,
+ ); err != nil {
+ return nil, fmt.Errorf("scanning outbox event: %w", err)
+ }
+
+ if lastError.Valid {
+ event.LastError = lastError.String
+ }
+
+ return &event, nil
+}
+
+func withTenantTxOrExisting[T any](
+ repo *Repository,
+ ctx context.Context,
+ tx *sql.Tx,
+ fn func(*sql.Tx) (T, error),
+) (T, error) {
+ var zero T
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if tx != nil {
+ tenantID, tenantErr := repo.tenantIDFromContext(ctx)
+ if tenantErr != nil {
+ return zero, tenantErr
+ }
+
+ if err := repo.tenantResolver.ApplyTenant(ctx, tx, tenantID); err != nil {
+ return zero, fmt.Errorf("failed to apply tenant: %w", err)
+ }
+
+ return fn(tx)
+ }
+
+ primaryDB, err := repo.primaryDB(ctx)
+ if err != nil {
+ return zero, err
+ }
+
+ txCtx := ctx
+
+ if _, hasDeadline := ctx.Deadline(); !hasDeadline {
+ var cancel context.CancelFunc
+
+ txCtx, cancel = context.WithTimeout(ctx, repo.transactionTimeout)
+ defer cancel()
+ }
+
+ newTx, err := primaryDB.BeginTx(txCtx, nil)
+ if err != nil {
+ return zero, fmt.Errorf("failed to begin transaction: %w", err)
+ }
+
+ defer func() {
+ _ = newTx.Rollback()
+ }()
+
+ tenantID, tenantErr := repo.tenantIDFromContext(txCtx)
+ if tenantErr != nil {
+ return zero, tenantErr
+ }
+
+ if err := repo.tenantResolver.ApplyTenant(txCtx, newTx, tenantID); err != nil {
+ return zero, fmt.Errorf("failed to apply tenant: %w", err)
+ }
+
+ result, err := fn(newTx)
+ if err != nil {
+ return zero, err
+ }
+
+ if err := newTx.Commit(); err != nil {
+ return zero, fmt.Errorf("failed to commit transaction: %w", err)
+ }
+
+ return result, nil
+}
+
+func (repo *Repository) initialized() bool {
+ return repo != nil && repo.client != nil && !nilcheck.Interface(repo.tenantResolver) && !nilcheck.Interface(repo.tenantDiscoverer)
+}
+
+// RequiresTenant reports whether repository operations require a tenant ID.
+func (repo *Repository) RequiresTenant() bool {
+ if repo == nil {
+ return true
+ }
+
+ return repo.requireTenant || repo.tenantColumn != ""
+}
+
+func (repo *Repository) primaryDB(ctx context.Context) (*sql.DB, error) {
+ if repo == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ if repo.primaryDBLookup != nil {
+ return repo.primaryDBLookup(ctx)
+ }
+
+ return resolvePrimaryDB(ctx, repo.client)
+}
+
+func (repo *Repository) tenantIDFromContext(ctx context.Context) (string, error) {
+ tenantID, ok := outbox.TenantIDFromContext(ctx)
+ if (repo.tenantColumn != "" || repo.requireTenant) && (!ok || tenantID == "") {
+ return "", outbox.ErrTenantIDRequired
+ }
+
+ if !ok {
+ return "", nil
+ }
+
+ return tenantID, nil
+}
+
+func (repo *Repository) tenantFilterClause(index int, tenantID string) (string, []any, error) {
+ if repo.tenantColumn == "" {
+ return "", nil, nil
+ }
+
+ if tenantID == "" {
+ return "", nil, outbox.ErrTenantIDRequired
+ }
+
+ filter := fmt.Sprintf(" AND %s = $%d", quoteIdentifier(repo.tenantColumn), index)
+
+ return filter, []any{tenantID}, nil
+}
+
+func validateIdentifier(identifier string) error {
+ if len(identifier) > maxSQLIdentifierLength {
+ return ErrInvalidIdentifier
+ }
+
+ if !identifierPattern.MatchString(identifier) {
+ return ErrInvalidIdentifier
+ }
+
+ return nil
+}
+
+func validateIdentifierPath(path string) error {
+ parts := strings.Split(path, ".")
+ if len(parts) == 0 {
+ return ErrInvalidIdentifier
+ }
+
+ for _, part := range parts {
+ trimmed := strings.TrimSpace(part)
+ if err := validateIdentifier(trimmed); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func quoteIdentifierPath(path string) string {
+ parts := strings.Split(path, ".")
+ quoted := make([]string, 0, len(parts))
+
+ for _, part := range parts {
+ quoted = append(quoted, quoteIdentifier(strings.TrimSpace(part)))
+ }
+
+ return strings.Join(quoted, ".")
+}
+
+func quoteIdentifier(identifier string) string {
+ identifier = strings.ReplaceAll(identifier, "\x00", "")
+
+ return "\"" + strings.ReplaceAll(identifier, "\"", "\"\"") + "\""
+}
+
+func logSanitizedError(logger libLog.Logger, ctx context.Context, message string, err error) {
+ if nilcheck.Interface(logger) || err == nil {
+ return
+ }
+
+ logger.Log(ctx, libLog.LevelError, message, libLog.String("error", outbox.SanitizeErrorMessageForStorage(err.Error())))
+}
+
+func ensureRowsAffected(result sql.Result) error {
+ rows, err := rowsAffected(result)
+ if err != nil {
+ return err
+ }
+
+ if rows == 0 {
+ return ErrStateTransitionConflict
+ }
+
+ return nil
+}
+
+func ensureRowsAffectedExact(result sql.Result, expected int64) error {
+ rows, err := rowsAffected(result)
+ if err != nil {
+ return err
+ }
+
+ if rows != expected {
+ return ErrStateTransitionConflict
+ }
+
+ return nil
+}
+
+func rowsAffected(result sql.Result) (int64, error) {
+ if result == nil {
+ return 0, ErrStateTransitionConflict
+ }
+
+ rows, err := result.RowsAffected()
+ if err != nil {
+ return 0, fmt.Errorf("rows affected: %w", err)
+ }
+
+ return rows, nil
+}
+
+type createValues struct {
+ id uuid.UUID
+ eventType string
+ aggregateID uuid.UUID
+ payload []byte
+ status string
+ attempts int
+ publishedAt *time.Time
+ lastError string
+ createdAt time.Time
+ updatedAt time.Time
+}
+
+func normalizedCreateValues(event *outbox.OutboxEvent, now time.Time) createValues {
+ createdAt := event.CreatedAt
+ if createdAt.IsZero() {
+ createdAt = now
+ }
+
+ updatedAt := event.UpdatedAt
+ if updatedAt.IsZero() || updatedAt.Before(createdAt) {
+ updatedAt = createdAt
+ }
+
+ return createValues{
+ id: event.ID,
+ eventType: strings.TrimSpace(event.EventType),
+ aggregateID: event.AggregateID,
+ payload: event.Payload,
+ status: outbox.OutboxStatusPending,
+ attempts: 0,
+ publishedAt: nil,
+ lastError: "",
+ createdAt: createdAt,
+ updatedAt: updatedAt,
+ }
+}
+
+func validateCreateEvent(event *outbox.OutboxEvent) error {
+ if event == nil {
+ return outbox.ErrOutboxEventRequired
+ }
+
+ if event.ID == uuid.Nil {
+ return ErrIDRequired
+ }
+
+ if strings.TrimSpace(event.EventType) == "" {
+ return ErrEventTypeRequired
+ }
+
+ if event.AggregateID == uuid.Nil {
+ return ErrAggregateIDRequired
+ }
+
+ if len(event.Payload) == 0 {
+ return outbox.ErrOutboxEventPayloadRequired
+ }
+
+ if len(event.Payload) > outbox.DefaultMaxPayloadBytes {
+ return outbox.ErrOutboxEventPayloadTooLarge
+ }
+
+ if !json.Valid(event.Payload) {
+ return outbox.ErrOutboxEventPayloadNotJSON
+ }
+
+ return nil
+}
+
+func queryOutboxEvents(
+ ctx context.Context,
+ tx *sql.Tx,
+ query string,
+ args []any,
+ limit int,
+ errorPrefix string,
+) ([]*outbox.OutboxEvent, error) {
+ rows, err := tx.QueryContext(ctx, query, args...)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %w", errorPrefix, err)
+ }
+
+ defer rows.Close()
+
+ events := make([]*outbox.OutboxEvent, 0, limit)
+
+ for rows.Next() {
+ event, scanErr := scanOutboxEvent(rows)
+ if scanErr != nil {
+ return nil, fmt.Errorf("scanning outbox event: %w", scanErr)
+ }
+
+ events = append(events, event)
+ }
+
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterating rows: %w", err)
+ }
+
+ return events, nil
+}
diff --git a/commons/outbox/postgres/repository_integration_test.go b/commons/outbox/postgres/repository_integration_test.go
new file mode 100644
index 00000000..dbb92899
--- /dev/null
+++ b/commons/outbox/postgres/repository_integration_test.go
@@ -0,0 +1,494 @@
+//go:build integration
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/outbox"
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/trace/noop"
+)
+
+type integrationRepoFixture struct {
+ ctx context.Context
+ client *libPostgres.Client
+ primaryDB *sql.DB
+ repo *Repository
+ tableName string
+ tenantCtx context.Context
+}
+
+func newIntegrationRepoFixture(t *testing.T) *integrationRepoFixture {
+ t.Helper()
+
+ dsn := strings.TrimSpace(os.Getenv("OUTBOX_POSTGRES_DSN"))
+ if dsn == "" {
+ t.Skip("OUTBOX_POSTGRES_DSN not set")
+ }
+
+ ctx := context.Background()
+ client, err := libPostgres.New(libPostgres.Config{PrimaryDSN: dsn, ReplicaDSN: dsn})
+ require.NoError(t, err)
+
+ require.NoError(t, client.Connect(ctx))
+ t.Cleanup(func() {
+ if err := client.Close(); err != nil {
+ t.Errorf("cleanup: client close: %v", err)
+ }
+ })
+
+ primaryDB, err := client.Primary()
+ require.NoError(t, err)
+
+ tableName := "outbox_it_" + strings.ReplaceAll(uuid.NewString(), "-", "")[:16]
+
+ _, err = primaryDB.ExecContext(ctx, `
+DO $$
+BEGIN
+ IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'outbox_event_status') THEN
+ CREATE TYPE outbox_event_status AS ENUM ('PENDING','PROCESSING','PUBLISHED','FAILED','INVALID');
+ END IF;
+END
+$$;
+`)
+ require.NoError(t, err)
+
+ _, err = primaryDB.ExecContext(ctx, fmt.Sprintf(`
+CREATE TABLE %s (
+ id UUID NOT NULL,
+ event_type VARCHAR(255) NOT NULL,
+ aggregate_id UUID NOT NULL,
+ payload JSONB NOT NULL,
+ status outbox_event_status NOT NULL DEFAULT 'PENDING',
+ attempts INT NOT NULL DEFAULT 0,
+ published_at TIMESTAMPTZ,
+ last_error VARCHAR(512),
+ created_at TIMESTAMPTZ NOT NULL,
+ updated_at TIMESTAMPTZ NOT NULL,
+ tenant_id TEXT NOT NULL,
+ PRIMARY KEY (tenant_id, id)
+);
+`, quoteIdentifier(tableName)))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if _, err := primaryDB.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteIdentifier(tableName))); err != nil {
+ t.Errorf("cleanup: drop table %s: %v", tableName, err)
+ }
+ })
+
+ resolver, err := NewColumnResolver(
+ client,
+ WithColumnResolverTableName(tableName),
+ WithColumnResolverTenantColumn("tenant_id"),
+ )
+ require.NoError(t, err)
+
+ repo, err := NewRepository(
+ client,
+ resolver,
+ resolver,
+ WithTableName(tableName),
+ WithTenantColumn("tenant_id"),
+ )
+ require.NoError(t, err)
+
+ return &integrationRepoFixture{
+ ctx: ctx,
+ client: client,
+ primaryDB: primaryDB,
+ repo: repo,
+ tableName: tableName,
+ tenantCtx: outbox.ContextWithTenantID(ctx, "tenant-a"),
+ }
+}
+
+func createFixtureEvent(t *testing.T, fx *integrationRepoFixture, eventType string) *outbox.OutboxEvent {
+ t.Helper()
+
+ return createFixtureEventForTenant(t, fx, "tenant-a", eventType)
+}
+
+func createFixtureEventForTenant(
+ t *testing.T,
+ fx *integrationRepoFixture,
+ tenantID string,
+ eventType string,
+) *outbox.OutboxEvent {
+ t.Helper()
+
+ eventCtx := outbox.ContextWithTenantID(fx.ctx, tenantID)
+ event, err := outbox.NewOutboxEvent(eventCtx, eventType, uuid.New(), []byte(`{"ok":true}`))
+ require.NoError(t, err)
+
+ created, err := fx.repo.Create(eventCtx, event)
+ require.NoError(t, err)
+
+ return created
+}
+
+func updateFixtureEventStateForTenant(
+ t *testing.T,
+ fx *integrationRepoFixture,
+ id uuid.UUID,
+ tenantID string,
+ status string,
+ attempts int,
+ updatedAt time.Time,
+) {
+ t.Helper()
+
+ _, err := fx.primaryDB.ExecContext(
+ fx.ctx,
+ fmt.Sprintf(
+ "UPDATE %s SET status = $1::outbox_event_status, attempts = $2, updated_at = $3 WHERE id = $4 AND tenant_id = $5",
+ quoteIdentifier(fx.tableName),
+ ),
+ status,
+ attempts,
+ updatedAt,
+ id,
+ tenantID,
+ )
+ require.NoError(t, err)
+}
+
+func updateFixtureEventState(
+ t *testing.T,
+ fx *integrationRepoFixture,
+ id uuid.UUID,
+ status string,
+ attempts int,
+ updatedAt time.Time,
+) {
+ t.Helper()
+
+ updateFixtureEventStateForTenant(t, fx, id, "tenant-a", status, attempts, updatedAt)
+}
+
+func TestRepository_IntegrationCreateListAndMarkFailed(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ created := createFixtureEvent(t, fx, "payment.created")
+ require.NotNil(t, created)
+
+ pending, err := fx.repo.ListPending(fx.tenantCtx, 10)
+ require.NoError(t, err)
+ require.Len(t, pending, 1)
+ require.Equal(t, outbox.OutboxStatusProcessing, pending[0].Status)
+
+ require.NoError(t, fx.repo.MarkFailed(fx.tenantCtx, created.ID, "password=abc123", 5))
+
+ updated, err := fx.repo.GetByID(fx.tenantCtx, created.ID)
+ require.NoError(t, err)
+ require.Equal(t, outbox.OutboxStatusFailed, updated.Status)
+ require.NotContains(t, updated.LastError, "abc123")
+}
+
+func TestRepository_IntegrationMarkPublished(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ event := createFixtureEvent(t, fx, "payment.published")
+
+ now := time.Now().UTC()
+ updateFixtureEventState(t, fx, event.ID, outbox.OutboxStatusProcessing, 0, now)
+ require.NoError(t, fx.repo.MarkPublished(fx.tenantCtx, event.ID, now))
+
+ published, err := fx.repo.GetByID(fx.tenantCtx, event.ID)
+ require.NoError(t, err)
+ require.Equal(t, outbox.OutboxStatusPublished, published.Status)
+}
+
+func TestRepository_IntegrationMarkInvalidRedactsSensitiveData(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ event := createFixtureEvent(t, fx, "payment.invalid")
+
+ now := time.Now().UTC()
+ updateFixtureEventState(t, fx, event.ID, outbox.OutboxStatusProcessing, 0, now)
+ require.NoError(t, fx.repo.MarkInvalid(fx.tenantCtx, event.ID, "token=super-secret"))
+
+ invalid, err := fx.repo.GetByID(fx.tenantCtx, event.ID)
+ require.NoError(t, err)
+ require.Equal(t, outbox.OutboxStatusInvalid, invalid.Status)
+ require.NotContains(t, invalid.LastError, "super-secret")
+}
+
+func TestRepository_IntegrationListPendingByType(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ target := createFixtureEvent(t, fx, "payment.priority")
+ _ = createFixtureEvent(t, fx, "payment.non-priority")
+
+ priorityEvents, err := fx.repo.ListPendingByType(fx.tenantCtx, "payment.priority", 10)
+ require.NoError(t, err)
+ require.Len(t, priorityEvents, 1)
+ require.Equal(t, target.ID, priorityEvents[0].ID)
+ require.Equal(t, outbox.OutboxStatusProcessing, priorityEvents[0].Status)
+}
+
+func TestRepository_IntegrationResetForRetry(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ event := createFixtureEvent(t, fx, "payment.failed")
+
+ staleTime := time.Now().UTC().Add(-time.Hour)
+ updateFixtureEventState(t, fx, event.ID, outbox.OutboxStatusFailed, 1, staleTime)
+
+ retried, err := fx.repo.ResetForRetry(fx.tenantCtx, 10, time.Now().UTC(), 5)
+ require.NoError(t, err)
+ require.Len(t, retried, 1)
+ require.Equal(t, event.ID, retried[0].ID)
+ require.Equal(t, outbox.OutboxStatusProcessing, retried[0].Status)
+}
+
+func TestRepository_IntegrationResetStuckProcessing(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ retryEvent := createFixtureEvent(t, fx, "payment.stuck.retry")
+ exhaustedEvent := createFixtureEvent(t, fx, "payment.stuck.exhausted")
+
+ staleTime := time.Now().UTC().Add(-time.Hour)
+ updateFixtureEventState(t, fx, retryEvent.ID, outbox.OutboxStatusProcessing, 1, staleTime)
+ updateFixtureEventState(t, fx, exhaustedEvent.ID, outbox.OutboxStatusProcessing, 2, staleTime)
+
+ resetStuck, err := fx.repo.ResetStuckProcessing(fx.tenantCtx, 10, time.Now().UTC(), 3)
+ require.NoError(t, err)
+ require.Len(t, resetStuck, 1)
+ require.Equal(t, retryEvent.ID, resetStuck[0].ID)
+ require.Equal(t, outbox.OutboxStatusProcessing, resetStuck[0].Status)
+ require.Equal(t, 2, resetStuck[0].Attempts)
+
+ exhausted, err := fx.repo.GetByID(fx.tenantCtx, exhaustedEvent.ID)
+ require.NoError(t, err)
+ require.Equal(t, outbox.OutboxStatusInvalid, exhausted.Status)
+ require.Equal(t, 3, exhausted.Attempts)
+ require.Equal(t, "max dispatch attempts exceeded", exhausted.LastError)
+}
+
+func TestRepository_IntegrationCreateWithTx(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ tx, err := fx.primaryDB.BeginTx(fx.tenantCtx, nil)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) {
+ t.Errorf("cleanup: tx rollback: %v", err)
+ }
+ })
+
+ event, err := outbox.NewOutboxEvent(fx.tenantCtx, "payment.tx.create", uuid.New(), []byte(`{"ok":true}`))
+ require.NoError(t, err)
+
+ created, err := fx.repo.CreateWithTx(fx.tenantCtx, tx, event)
+ require.NoError(t, err)
+ require.NotNil(t, created)
+
+ require.NoError(t, tx.Commit())
+
+ stored, err := fx.repo.GetByID(fx.tenantCtx, created.ID)
+ require.NoError(t, err)
+ require.Equal(t, created.ID, stored.ID)
+}
+
+func TestRepository_IntegrationMarkPublishedRequiresProcessingState(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ event := createFixtureEvent(t, fx, "payment.state.guard")
+ err := fx.repo.MarkPublished(fx.tenantCtx, event.ID, time.Now().UTC())
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+}
+
+func TestRepository_IntegrationCreateForcesPendingLifecycleInvariants(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ now := time.Now().UTC()
+ publishedAt := now.Add(-time.Minute)
+
+ created, err := fx.repo.Create(
+ fx.tenantCtx,
+ &outbox.OutboxEvent{
+ ID: uuid.New(),
+ EventType: "payment.invariant.override",
+ AggregateID: uuid.New(),
+ Payload: []byte(`{"ok":true}`),
+ Status: outbox.OutboxStatusPublished,
+ Attempts: 9,
+ PublishedAt: &publishedAt,
+ LastError: "must not persist",
+ CreatedAt: now,
+ UpdatedAt: now,
+ },
+ )
+ require.NoError(t, err)
+ require.NotNil(t, created)
+ require.Equal(t, outbox.OutboxStatusPending, created.Status)
+ require.Equal(t, 0, created.Attempts)
+ require.Nil(t, created.PublishedAt)
+ require.Empty(t, created.LastError)
+}
+
+func TestRepository_IntegrationTenantIsolationBoundaries(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ tenantA := outbox.ContextWithTenantID(fx.ctx, "tenant-a")
+ tenantB := outbox.ContextWithTenantID(fx.ctx, "tenant-b")
+
+ eventA := createFixtureEventForTenant(t, fx, "tenant-a", "payment.isolation.a")
+ eventB := createFixtureEventForTenant(t, fx, "tenant-b", "payment.isolation.b")
+
+ pendingA, err := fx.repo.ListPending(tenantA, 10)
+ require.NoError(t, err)
+ require.Len(t, pendingA, 1)
+ require.Equal(t, eventA.ID, pendingA[0].ID)
+
+ pendingB, err := fx.repo.ListPending(tenantB, 10)
+ require.NoError(t, err)
+ require.Len(t, pendingB, 1)
+ require.Equal(t, eventB.ID, pendingB[0].ID)
+
+ _, err = fx.repo.GetByID(tenantA, eventB.ID)
+ require.Error(t, err)
+ require.ErrorIs(t, err, sql.ErrNoRows)
+
+ err = fx.repo.MarkPublished(tenantA, eventB.ID, time.Now().UTC())
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+
+ storedB, err := fx.repo.GetByID(tenantB, eventB.ID)
+ require.NoError(t, err)
+ require.Equal(t, outbox.OutboxStatusProcessing, storedB.Status)
+}
+
+func TestRepository_IntegrationMarkFailedAndInvalidRequireProcessingState(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ failedEvent := createFixtureEvent(t, fx, "payment.failed.guard")
+ err := fx.repo.MarkFailed(fx.tenantCtx, failedEvent.ID, "retry error", 3)
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+
+ invalidEvent := createFixtureEvent(t, fx, "payment.invalid.guard")
+ err = fx.repo.MarkInvalid(fx.tenantCtx, invalidEvent.ID, "non-retryable error")
+ require.Error(t, err)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+}
+
+func TestRepository_IntegrationDispatcherLifecyclePersistsPublishedState(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ created := createFixtureEvent(t, fx, "payment.dispatch.lifecycle")
+ require.NotNil(t, created)
+
+ handlers := outbox.NewHandlerRegistry()
+ var handled atomic.Bool
+
+ require.NoError(t, handlers.Register("payment.dispatch.lifecycle", func(_ context.Context, event *outbox.OutboxEvent) error {
+ require.NotNil(t, event)
+ require.Equal(t, created.ID, event.ID)
+ handled.Store(true)
+
+ return nil
+ }))
+
+ dispatcher, err := outbox.NewDispatcher(
+ fx.repo,
+ handlers,
+ nil,
+ noop.NewTracerProvider().Tracer("test"),
+ outbox.WithBatchSize(10),
+ outbox.WithPublishMaxAttempts(1),
+ )
+ require.NoError(t, err)
+
+ result := dispatcher.DispatchOnceResult(fx.tenantCtx)
+ require.Equal(t, 1, result.Processed)
+ require.Equal(t, 1, result.Published)
+ require.Equal(t, 0, result.Failed)
+ require.Equal(t, 0, result.StateUpdateFailed)
+ require.True(t, handled.Load())
+
+ stored, err := fx.repo.GetByID(fx.tenantCtx, created.ID)
+ require.NoError(t, err)
+ require.Equal(t, outbox.OutboxStatusPublished, stored.Status)
+ require.NotNil(t, stored.PublishedAt)
+ require.True(t, stored.UpdatedAt.After(created.UpdatedAt) || stored.UpdatedAt.Equal(created.UpdatedAt))
+}
+
+func TestColumnResolver_IntegrationDiscoverTenants(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+
+ _, err := fx.repo.Create(
+ outbox.ContextWithTenantID(fx.ctx, "tenant-b"),
+ &outbox.OutboxEvent{
+ ID: uuid.New(),
+ EventType: "payment.discover",
+ AggregateID: uuid.New(),
+ Payload: []byte(`{"ok":true}`),
+ Status: outbox.OutboxStatusPending,
+ Attempts: 0,
+ CreatedAt: time.Now().UTC(),
+ UpdatedAt: time.Now().UTC(),
+ },
+ )
+ require.NoError(t, err)
+
+ resolver, err := NewColumnResolver(
+ fx.client,
+ WithColumnResolverTableName(fx.tableName),
+ WithColumnResolverTenantColumn("tenant_id"),
+ )
+ require.NoError(t, err)
+
+ tenants, err := resolver.DiscoverTenants(fx.ctx)
+ require.NoError(t, err)
+ require.Contains(t, tenants, "tenant-b")
+}
+
+func TestSchemaResolver_IntegrationApplyTenantAndDiscoverTenants(t *testing.T) {
+ fx := newIntegrationRepoFixture(t)
+ tenantSchema := uuid.NewString()
+ defaultTenant := uuid.NewString()
+
+ _, err := fx.primaryDB.ExecContext(fx.ctx, fmt.Sprintf("CREATE SCHEMA %s", quoteIdentifier(tenantSchema)))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if _, err := fx.primaryDB.ExecContext(fx.ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", quoteIdentifier(tenantSchema))); err != nil {
+ t.Errorf("cleanup: drop schema %s: %v", tenantSchema, err)
+ }
+ })
+
+ resolver, err := NewSchemaResolver(fx.client, WithDefaultTenantID(defaultTenant))
+ require.NoError(t, err)
+
+ tx, err := fx.primaryDB.BeginTx(fx.ctx, nil)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) {
+ t.Errorf("cleanup: tx rollback: %v", err)
+ }
+ })
+
+ require.NoError(t, resolver.ApplyTenant(fx.ctx, tx, tenantSchema))
+
+ var currentSchema string
+ require.NoError(t, tx.QueryRowContext(fx.ctx, "SELECT current_schema()").Scan(¤tSchema))
+ require.Equal(t, tenantSchema, currentSchema)
+ require.NoError(t, tx.Rollback())
+
+ tenants, err := resolver.DiscoverTenants(fx.ctx)
+ require.NoError(t, err)
+ require.Contains(t, tenants, tenantSchema)
+ require.Contains(t, tenants, defaultTenant)
+}
diff --git a/commons/outbox/postgres/repository_test.go b/commons/outbox/postgres/repository_test.go
new file mode 100644
index 00000000..89ad72ef
--- /dev/null
+++ b/commons/outbox/postgres/repository_test.go
@@ -0,0 +1,389 @@
+//go:build unit
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "testing"
+ "time"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/outbox"
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/require"
+)
+
+type noopTenantResolver struct{}
+
+func (noopTenantResolver) ApplyTenant(context.Context, *sql.Tx, string) error { return nil }
+
+type noopTenantDiscoverer struct{}
+
+func (noopTenantDiscoverer) DiscoverTenants(context.Context) ([]string, error) { return nil, nil }
+
+type requireTenantResolver struct{}
+
+func (requireTenantResolver) ApplyTenant(context.Context, *sql.Tx, string) error { return nil }
+
+func (requireTenantResolver) RequiresTenant() bool { return true }
+
+type panicLogger struct {
+ seen bool
+}
+
+func (logger *panicLogger) Log(context.Context, libLog.Level, string, ...libLog.Field) {
+ logger.seen = true
+}
+
+func (logger *panicLogger) With(...libLog.Field) libLog.Logger {
+ return logger
+}
+
+func (logger *panicLogger) WithGroup(string) libLog.Logger {
+ return logger
+}
+
+func (logger *panicLogger) Enabled(libLog.Level) bool {
+ return true
+}
+
+func (logger *panicLogger) Sync(context.Context) error {
+ return nil
+}
+
+func TestValidateIdentifier(t *testing.T) {
+ t.Parallel()
+
+ require.NoError(t, validateIdentifier("outbox_events"))
+ require.NoError(t, validateIdentifier("tenant_01"))
+
+ invalid := []string{
+ "",
+ "123table",
+ "outbox-events",
+ "public.outbox",
+ `outbox"; DROP TABLE users; --`,
+ "outbox events",
+ }
+
+ for _, candidate := range invalid {
+ require.Error(t, validateIdentifier(candidate), candidate)
+ }
+
+ tooLong := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ require.Len(t, tooLong, 64)
+ require.Error(t, validateIdentifier(tooLong))
+}
+
+func TestValidateIdentifierPath(t *testing.T) {
+ t.Parallel()
+
+ require.NoError(t, validateIdentifierPath("public.outbox_events"))
+ require.NoError(t, validateIdentifierPath("tenant_01.outbox_events"))
+
+ require.Error(t, validateIdentifierPath("public."))
+ require.Error(t, validateIdentifierPath(`public."outbox"`))
+ require.Error(t, validateIdentifierPath("public.outbox-events"))
+}
+
+func TestQuoteIdentifierFunctions(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, `"outbox_events"`, quoteIdentifier("outbox_events"))
+ require.Equal(t, `"a""b"`, quoteIdentifier(`a"b`))
+ require.Equal(t, `"public"."outbox_events"`, quoteIdentifierPath("public.outbox_events"))
+ require.Equal(t, `"public"."out""box"`, quoteIdentifierPath(`public.out"box`))
+}
+
+func TestSplitStuckEventsAndApplyState(t *testing.T) {
+ t.Parallel()
+
+ retryID := uuid.New()
+ exhaustedID := uuid.New()
+
+ events := []*outbox.OutboxEvent{
+ {ID: retryID, Attempts: 1, Status: outbox.OutboxStatusProcessing},
+ {ID: exhaustedID, Attempts: 2, Status: outbox.OutboxStatusProcessing},
+ nil,
+ }
+
+ retryEvents, exhaustedIDs := splitStuckEvents(events, 3)
+ require.Len(t, retryEvents, 1)
+ require.Equal(t, retryID, retryEvents[0].ID)
+ require.Equal(t, []uuid.UUID{exhaustedID}, exhaustedIDs)
+
+ now := time.Now().UTC()
+ applyStuckReprocessingState(retryEvents, now)
+ require.Equal(t, 2, retryEvents[0].Attempts)
+ require.Equal(t, outbox.OutboxStatusProcessing, retryEvents[0].Status)
+ require.Equal(t, now, retryEvents[0].UpdatedAt)
+}
+
+func TestNewRepository_Validation(t *testing.T) {
+ t.Parallel()
+
+ repo, err := NewRepository(nil, noopTenantResolver{}, noopTenantDiscoverer{})
+ require.Nil(t, repo)
+ require.ErrorIs(t, err, ErrConnectionRequired)
+
+ client := &libPostgres.Client{}
+
+ repo, err = NewRepository(client, nil, noopTenantDiscoverer{})
+ require.Nil(t, repo)
+ require.ErrorIs(t, err, ErrTenantResolverRequired)
+
+ repo, err = NewRepository(client, noopTenantResolver{}, nil)
+ require.Nil(t, repo)
+ require.ErrorIs(t, err, ErrTenantDiscovererRequired)
+
+ repo, err = NewRepository(client, noopTenantResolver{}, noopTenantDiscoverer{}, WithTableName("bad-table"))
+ require.Nil(t, repo)
+ require.ErrorIs(t, err, ErrInvalidIdentifier)
+
+ repo, err = NewRepository(client, noopTenantResolver{}, noopTenantDiscoverer{}, WithTenantColumn("tenant-id"))
+ require.Nil(t, repo)
+ require.ErrorIs(t, err, ErrInvalidIdentifier)
+}
+
+func TestQuoteIdentifier_StripsNullByte(t *testing.T) {
+ t.Parallel()
+
+ quoted := quoteIdentifier("tenant\x00_id")
+ require.Equal(t, `"tenant_id"`, quoted)
+}
+
+func TestRepository_MarkFailedValidation(t *testing.T) {
+ t.Parallel()
+
+ repo := &Repository{
+ client: &libPostgres.Client{},
+ tenantResolver: noopTenantResolver{},
+ tenantDiscoverer: noopTenantDiscoverer{},
+ tableName: "outbox_events",
+ transactionTimeout: time.Second,
+ }
+
+ err := repo.MarkFailed(context.Background(), uuid.Nil, "failed", 3)
+ require.ErrorIs(t, err, ErrIDRequired)
+
+ err = repo.MarkFailed(context.Background(), uuid.New(), "failed", 0)
+ require.ErrorIs(t, err, ErrMaxAttemptsMustBePositive)
+}
+
+func TestRepository_ListPendingByTypeValidation(t *testing.T) {
+ t.Parallel()
+
+ repo := &Repository{
+ client: &libPostgres.Client{},
+ tenantResolver: noopTenantResolver{},
+ tenantDiscoverer: noopTenantDiscoverer{},
+ tableName: "outbox_events",
+ transactionTimeout: time.Second,
+ }
+
+ _, err := repo.ListPendingByType(context.Background(), " ", 1)
+ require.ErrorIs(t, err, ErrEventTypeRequired)
+}
+
+type resultWithRows struct {
+ rows int64
+ err error
+}
+
+func (result resultWithRows) LastInsertId() (int64, error) {
+ return 0, nil
+}
+
+func (result resultWithRows) RowsAffected() (int64, error) {
+ if result.err != nil {
+ return 0, result.err
+ }
+
+ return result.rows, nil
+}
+
+func TestEnsureRowsAffected(t *testing.T) {
+ t.Parallel()
+
+ err := ensureRowsAffected(nil)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+
+ err = ensureRowsAffected(resultWithRows{err: errors.New("rows failure")})
+ require.ErrorContains(t, err, "rows affected")
+
+ err = ensureRowsAffected(resultWithRows{rows: 0})
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+
+ err = ensureRowsAffected(resultWithRows{rows: 1})
+ require.NoError(t, err)
+}
+
+func TestEnsureRowsAffectedExact(t *testing.T) {
+ t.Parallel()
+
+ err := ensureRowsAffectedExact(nil, 1)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+
+ err = ensureRowsAffectedExact(resultWithRows{err: errors.New("rows failure")}, 1)
+ require.ErrorContains(t, err, "rows affected")
+
+ err = ensureRowsAffectedExact(resultWithRows{rows: 0}, 1)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+
+ err = ensureRowsAffectedExact(resultWithRows{rows: 1}, 2)
+ require.ErrorIs(t, err, ErrStateTransitionConflict)
+
+ err = ensureRowsAffectedExact(resultWithRows{rows: 2}, 2)
+ require.NoError(t, err)
+}
+
+func TestValidateCreateEvent(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now().UTC()
+
+ valid := &outbox.OutboxEvent{
+ ID: uuid.New(),
+ EventType: "payment.created",
+ AggregateID: uuid.New(),
+ Payload: []byte(`{"ok":true}`),
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ require.NoError(t, validateCreateEvent(valid))
+
+ err := validateCreateEvent(nil)
+ require.ErrorIs(t, err, outbox.ErrOutboxEventRequired)
+
+ err = validateCreateEvent(&outbox.OutboxEvent{AggregateID: uuid.New(), EventType: "a", Payload: []byte(`{"ok":true}`)})
+ require.ErrorIs(t, err, ErrIDRequired)
+
+ err = validateCreateEvent(&outbox.OutboxEvent{ID: uuid.New(), AggregateID: uuid.New(), EventType: " ", Payload: []byte(`{"ok":true}`)})
+ require.ErrorIs(t, err, ErrEventTypeRequired)
+
+ err = validateCreateEvent(&outbox.OutboxEvent{ID: uuid.New(), EventType: "payment.created", Payload: []byte(`{"ok":true}`)})
+ require.ErrorIs(t, err, ErrAggregateIDRequired)
+
+ err = validateCreateEvent(&outbox.OutboxEvent{ID: uuid.New(), EventType: "payment.created", AggregateID: uuid.New()})
+ require.ErrorIs(t, err, outbox.ErrOutboxEventPayloadRequired)
+
+ err = validateCreateEvent(&outbox.OutboxEvent{ID: uuid.New(), EventType: "payment.created", AggregateID: uuid.New(), Payload: []byte("not-json")})
+ require.ErrorIs(t, err, outbox.ErrOutboxEventPayloadNotJSON)
+
+ oversizedPayload := make([]byte, outbox.DefaultMaxPayloadBytes+1)
+ err = validateCreateEvent(&outbox.OutboxEvent{ID: uuid.New(), EventType: "payment.created", AggregateID: uuid.New(), Payload: oversizedPayload})
+ require.ErrorIs(t, err, outbox.ErrOutboxEventPayloadTooLarge)
+}
+
+func TestRepository_TenantIDFromContext(t *testing.T) {
+ t.Parallel()
+
+ repo := &Repository{requireTenant: true}
+ tenantID, err := repo.tenantIDFromContext(context.Background())
+ require.Empty(t, tenantID)
+ require.ErrorIs(t, err, outbox.ErrTenantIDRequired)
+
+ repo.requireTenant = false
+ tenantID, err = repo.tenantIDFromContext(context.Background())
+ require.NoError(t, err)
+ require.Empty(t, tenantID)
+
+ ctx := outbox.ContextWithTenantID(context.Background(), "tenant-a")
+ tenantID, err = repo.tenantIDFromContext(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "tenant-a", tenantID)
+}
+
+func TestLogSanitizedError_TypedNilLoggerDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ var logger *panicLogger
+
+ require.NotPanics(t, func() {
+ logSanitizedError(logger, context.Background(), "msg", errors.New("boom"))
+ })
+}
+
+func TestNewRepository_WithTypedNilLoggerFallsBackToNop(t *testing.T) {
+ t.Parallel()
+
+ var logger *panicLogger
+
+ repo, err := NewRepository(
+ &libPostgres.Client{},
+ noopTenantResolver{},
+ noopTenantDiscoverer{},
+ WithLogger(logger),
+ )
+ require.NoError(t, err)
+ require.NotNil(t, repo)
+ require.NotNil(t, repo.logger)
+}
+
+func TestNewRepository_PropagatesResolverTenantRequirement(t *testing.T) {
+ t.Parallel()
+
+ repo, err := NewRepository(
+ &libPostgres.Client{},
+ requireTenantResolver{},
+ noopTenantDiscoverer{},
+ )
+ require.NoError(t, err)
+
+ tenantID, tenantErr := repo.tenantIDFromContext(context.Background())
+ require.Empty(t, tenantID)
+ require.ErrorIs(t, tenantErr, outbox.ErrTenantIDRequired)
+}
+
+func TestNormalizedCreateValues_EnforcesInitialLifecycleInvariants(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now().UTC()
+ publishedAt := now.Add(-time.Minute)
+
+ values := normalizedCreateValues(&outbox.OutboxEvent{
+ ID: uuid.New(),
+ EventType: "payment.created",
+ AggregateID: uuid.New(),
+ Payload: []byte(`{"ok":true}`),
+ Status: outbox.OutboxStatusPublished,
+ Attempts: 7,
+ PublishedAt: &publishedAt,
+ LastError: "internal details",
+ CreatedAt: now,
+ UpdatedAt: now.Add(-time.Hour),
+ }, now)
+
+ require.Equal(t, outbox.OutboxStatusPending, values.status)
+ require.Equal(t, 0, values.attempts)
+ require.Nil(t, values.publishedAt)
+ require.Empty(t, values.lastError)
+ require.Equal(t, now, values.createdAt)
+ require.Equal(t, now, values.updatedAt)
+}
+
+func TestNormalizedCreateValues_TrimsEventType(t *testing.T) {
+ t.Parallel()
+
+ values := normalizedCreateValues(&outbox.OutboxEvent{
+ ID: uuid.New(),
+ EventType: " payment.created ",
+ AggregateID: uuid.New(),
+ Payload: []byte(`{"ok":true}`),
+ }, time.Now().UTC())
+
+ require.Equal(t, "payment.created", values.eventType)
+}
+
+func TestRepository_RequiresTenant(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, (*Repository)(nil).RequiresTenant())
+ require.True(t, (&Repository{requireTenant: true}).RequiresTenant())
+ require.True(t, (&Repository{tenantColumn: "tenant_id"}).RequiresTenant())
+ require.False(t, (&Repository{}).RequiresTenant())
+}
diff --git a/commons/outbox/postgres/schema_resolver.go b/commons/outbox/postgres/schema_resolver.go
new file mode 100644
index 00000000..56488162
--- /dev/null
+++ b/commons/outbox/postgres/schema_resolver.go
@@ -0,0 +1,228 @@
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "strings"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/outbox"
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+)
+
+const uuidSchemaRegex = "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
+
+// defaultOutboxTableName is the default table name used by the outbox
+// pattern for event persistence and tenant schema discovery.
+const defaultOutboxTableName = "outbox_events"
+
+// defaultSchemaSearchPath is the schema used when ApplyTenant receives an
+// empty or default-tenant ID with AllowEmptyTenant enabled.
+const defaultSchemaSearchPath = "public"
+
+var ErrDefaultTenantIDInvalid = errors.New("default tenant id must be UUID when tenant is required")
+
+type SchemaResolverOption func(*SchemaResolver)
+
+func WithDefaultTenantID(tenantID string) SchemaResolverOption {
+ return func(resolver *SchemaResolver) {
+ resolver.defaultTenantID = tenantID
+ }
+}
+
+// WithRequireTenant enforces that every ApplyTenant call receives a non-empty
+// tenant ID. This is the default behavior.
+func WithRequireTenant() SchemaResolverOption {
+ return func(resolver *SchemaResolver) {
+ resolver.requireTenant = true
+ }
+}
+
+// WithAllowEmptyTenant permits ApplyTenant calls with empty tenant IDs.
+//
+// When an empty tenant ID is received, the transaction's search_path is
+// explicitly set to the configured default schema ("public") instead of
+// relying on the connection's ambient search_path. This prevents cross-tenant
+// leakage when the connection pool routes to a connection whose search_path
+// was previously set to a different tenant.
+func WithAllowEmptyTenant() SchemaResolverOption {
+ return func(resolver *SchemaResolver) {
+ resolver.requireTenant = false
+ }
+}
+
+// WithOutboxTableName sets the outbox table name used to verify schema
+// eligibility during DiscoverTenants. Only schemas containing this table
+// are returned. Defaults to "outbox_events".
+func WithOutboxTableName(tableName string) SchemaResolverOption {
+ return func(resolver *SchemaResolver) {
+ resolver.outboxTableName = tableName
+ }
+}
+
+// SchemaResolver applies schema-per-tenant scoping and tenant discovery.
+type SchemaResolver struct {
+ client *libPostgres.Client
+ defaultTenantID string
+ outboxTableName string
+ requireTenant bool
+}
+
+func NewSchemaResolver(client *libPostgres.Client, opts ...SchemaResolverOption) (*SchemaResolver, error) {
+ if client == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ resolver := &SchemaResolver{client: client, requireTenant: true, outboxTableName: defaultOutboxTableName}
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(resolver)
+ }
+ }
+
+ resolver.defaultTenantID = strings.TrimSpace(resolver.defaultTenantID)
+ if resolver.defaultTenantID != "" && resolver.requireTenant && !libCommons.IsUUID(resolver.defaultTenantID) {
+ return nil, ErrDefaultTenantIDInvalid
+ }
+
+ resolver.outboxTableName = strings.TrimSpace(resolver.outboxTableName)
+ if resolver.outboxTableName == "" {
+ resolver.outboxTableName = defaultOutboxTableName
+ }
+
+ return resolver, nil
+}
+
+func (resolver *SchemaResolver) RequiresTenant() bool {
+ if resolver == nil {
+ return true
+ }
+
+ return resolver.requireTenant
+}
+
+// ApplyTenant scopes the current transaction to tenant search_path.
+//
+// Security invariant: tenantID must remain UUID-validated and identifier-quoted
+// before query construction. This method intentionally relies on both checks to
+// keep dynamic search_path assignment safe.
+//
+// When tenantID is empty or matches the configured default tenant (with
+// AllowEmptyTenant enabled), the search_path is explicitly set to the default
+// schema ("public") to prevent queries from running against a stale
+// connection-level search_path left by a previous tenant operation.
+func (resolver *SchemaResolver) ApplyTenant(ctx context.Context, tx *sql.Tx, tenantID string) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if resolver == nil {
+ return ErrConnectionRequired
+ }
+
+ if tx == nil {
+ return ErrTransactionRequired
+ }
+
+ tenantID = strings.TrimSpace(tenantID)
+
+ if tenantID == "" {
+ if resolver.requireTenant {
+ return fmt.Errorf("schema resolver: %w", outbox.ErrTenantIDRequired)
+ }
+
+ // Explicitly set search_path to the default schema instead of no-oping.
+ // This prevents cross-tenant leakage when the pooled connection retains
+ // a search_path from a previous tenant transaction.
+ return resolver.setDefaultSearchPath(ctx, tx)
+ }
+
+ if tenantID == resolver.defaultTenantID && !resolver.requireTenant {
+ // Even for the default tenant, explicitly set the search_path to avoid
+ // inheriting a stale tenant-scoped path from the connection pool.
+ return resolver.setDefaultSearchPath(ctx, tx)
+ }
+
+ if !libCommons.IsUUID(tenantID) {
+ return errors.New("invalid tenant id format")
+ }
+
+ query := "SET LOCAL search_path TO " + quoteIdentifier(tenantID) + ", public" // #nosec G202 -- tenantID is UUID-validated; quoteIdentifier escapes the identifier
+ if _, err := tx.ExecContext(ctx, query); err != nil {
+ return fmt.Errorf("set search_path: %w", err)
+ }
+
+ return nil
+}
+
+// setDefaultSearchPath explicitly sets the transaction search_path to the
+// default schema. This is used when no tenant-specific schema is needed,
+// ensuring the query doesn't run against a stale search_path.
+func (resolver *SchemaResolver) setDefaultSearchPath(ctx context.Context, tx *sql.Tx) error {
+ query := "SET LOCAL search_path TO " + quoteIdentifier(defaultSchemaSearchPath) // #nosec G202 -- constant string "public"; quoteIdentifier escapes the identifier
+ if _, err := tx.ExecContext(ctx, query); err != nil {
+ return fmt.Errorf("set default search_path: %w", err)
+ }
+
+ return nil
+}
+
+// DiscoverTenants returns tenants by inspecting UUID-shaped schema names
+// that contain the configured outbox table (default: "outbox_events").
+//
+// Only schemas where the outbox table actually exists are returned, preventing
+// false positives from empty or unrelated UUID-shaped schemas. The configured
+// default tenant is NOT injected unless it was actually found in the database.
+func (resolver *SchemaResolver) DiscoverTenants(ctx context.Context) ([]string, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if resolver == nil || resolver.client == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ db, err := resolver.primaryDB(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ // Join pg_namespace with information_schema.tables to verify the outbox
+ // table exists in each UUID-shaped schema before returning it as a tenant.
+ query := `SELECT n.nspname
+ FROM pg_namespace n
+ INNER JOIN information_schema.tables t
+ ON t.table_schema = n.nspname
+ AND t.table_name = $2
+ WHERE n.nspname ~* $1` // #nosec G202 -- parameterized query; no dynamic identifiers
+
+ rows, err := db.QueryContext(ctx, query, uuidSchemaRegex, resolver.outboxTableName)
+ if err != nil {
+ return nil, fmt.Errorf("querying tenant schemas: %w", err)
+ }
+ defer rows.Close()
+
+ tenants := make([]string, 0)
+
+ for rows.Next() {
+ var tenant string
+ if scanErr := rows.Scan(&tenant); scanErr != nil {
+ return nil, fmt.Errorf("scanning tenant schema: %w", scanErr)
+ }
+
+ tenants = append(tenants, tenant)
+ }
+
+ if err := rows.Err(); err != nil {
+ return nil, fmt.Errorf("iterating tenant schemas: %w", err)
+ }
+
+ return tenants, nil
+}
+
+func (resolver *SchemaResolver) primaryDB(ctx context.Context) (*sql.DB, error) {
+ return resolvePrimaryDB(ctx, resolver.client)
+}
diff --git a/commons/outbox/postgres/schema_resolver_test.go b/commons/outbox/postgres/schema_resolver_test.go
new file mode 100644
index 00000000..ac47d67c
--- /dev/null
+++ b/commons/outbox/postgres/schema_resolver_test.go
@@ -0,0 +1,120 @@
+//go:build unit
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/outbox"
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewSchemaResolver_NilClient(t *testing.T) {
+ t.Parallel()
+
+ resolver, err := NewSchemaResolver(nil)
+ require.Nil(t, resolver)
+ require.ErrorIs(t, err, ErrConnectionRequired)
+}
+
+func TestSchemaResolver_ApplyTenantValidation(t *testing.T) {
+ t.Parallel()
+
+ resolver := &SchemaResolver{}
+
+ require.ErrorIs(t, resolver.ApplyTenant(context.Background(), nil, "tenant"), ErrTransactionRequired)
+}
+
+func TestSchemaResolver_ApplyTenantNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var resolver *SchemaResolver
+
+ err := resolver.ApplyTenant(context.Background(), &sql.Tx{}, "tenant")
+ require.ErrorIs(t, err, ErrConnectionRequired)
+}
+
+func TestSchemaResolver_ApplyTenantEmptyAndDefaultExplicitlySetSearchPath(t *testing.T) {
+ t.Parallel()
+
+ // With AllowEmptyTenant, ApplyTenant now explicitly sets search_path to
+ // the default schema ("public") instead of no-oping. Since we cannot
+ // easily mock sql.Tx.ExecContext, we verify the resolver is configured
+ // correctly and the method does NOT return ErrTenantIDRequired.
+ resolver, err := NewSchemaResolver(
+ &libPostgres.Client{},
+ WithDefaultTenantID("tenant-default"),
+ WithAllowEmptyTenant(),
+ )
+ require.NoError(t, err)
+ require.False(t, resolver.RequiresTenant())
+
+ // Verify that a non-default, non-empty, non-UUID tenant is still rejected.
+ err = resolver.ApplyTenant(context.Background(), &sql.Tx{}, "not-a-uuid")
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "invalid tenant id format")
+}
+
+func TestNewSchemaResolver_DefaultRequiresTenant(t *testing.T) {
+ t.Parallel()
+
+ resolver, err := NewSchemaResolver(&libPostgres.Client{})
+ require.NoError(t, err)
+
+ require.True(t, resolver.RequiresTenant())
+}
+
+func TestNewSchemaResolver_WithAllowEmptyTenantDisablesRequirement(t *testing.T) {
+ t.Parallel()
+
+ resolver, err := NewSchemaResolver(&libPostgres.Client{}, WithAllowEmptyTenant())
+ require.NoError(t, err)
+
+ require.False(t, resolver.RequiresTenant())
+}
+
+func TestNewSchemaResolver_DefaultTenantValidationInStrictMode(t *testing.T) {
+ t.Parallel()
+
+ resolver, err := NewSchemaResolver(&libPostgres.Client{}, WithDefaultTenantID("default-tenant"))
+ require.Nil(t, resolver)
+ require.ErrorIs(t, err, ErrDefaultTenantIDInvalid)
+
+ resolver, err = NewSchemaResolver(
+ &libPostgres.Client{},
+ WithAllowEmptyTenant(),
+ WithDefaultTenantID("default-tenant"),
+ )
+ require.NoError(t, err)
+ require.NotNil(t, resolver)
+}
+
+func TestSchemaResolver_ApplyTenantRequireTenant(t *testing.T) {
+ t.Parallel()
+
+ resolver := &SchemaResolver{requireTenant: true}
+
+ err := resolver.ApplyTenant(context.Background(), &sql.Tx{}, "")
+ require.ErrorIs(t, err, outbox.ErrTenantIDRequired)
+}
+
+func TestSchemaResolver_ApplyTenantRejectsInvalidTenantID(t *testing.T) {
+ t.Parallel()
+
+ resolver := &SchemaResolver{}
+ err := resolver.ApplyTenant(context.Background(), &sql.Tx{}, "tenant-invalid")
+ require.ErrorContains(t, err, "invalid tenant id format")
+}
+
+func TestSchemaResolver_DiscoverTenantsNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var resolver *SchemaResolver
+
+ tenants, err := resolver.DiscoverTenants(context.Background())
+ require.Nil(t, tenants)
+ require.ErrorIs(t, err, ErrConnectionRequired)
+}
diff --git a/commons/outbox/repository.go b/commons/outbox/repository.go
new file mode 100644
index 00000000..51658dbe
--- /dev/null
+++ b/commons/outbox/repository.go
@@ -0,0 +1,33 @@
+package outbox
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+// Tx is the transactional handle used by CreateWithTx.
+//
+// It intentionally aliases *sql.Tx to keep the repository contract compatible
+// with existing database/sql transaction orchestration and tenant resolvers.
+// This avoids hidden adapter layers in write paths where tenant scoping runs
+// inside the caller's transaction.
+type Tx = *sql.Tx
+
+// OutboxRepository defines persistence operations for outbox events.
+type OutboxRepository interface {
+ Create(ctx context.Context, event *OutboxEvent) (*OutboxEvent, error)
+ CreateWithTx(ctx context.Context, tx Tx, event *OutboxEvent) (*OutboxEvent, error)
+ ListPending(ctx context.Context, limit int) ([]*OutboxEvent, error)
+ ListPendingByType(ctx context.Context, eventType string, limit int) ([]*OutboxEvent, error)
+ ListTenants(ctx context.Context) ([]string, error)
+ GetByID(ctx context.Context, id uuid.UUID) (*OutboxEvent, error)
+ MarkPublished(ctx context.Context, id uuid.UUID, publishedAt time.Time) error
+ MarkFailed(ctx context.Context, id uuid.UUID, errMsg string, maxAttempts int) error
+ ListFailedForRetry(ctx context.Context, limit int, failedBefore time.Time, maxAttempts int) ([]*OutboxEvent, error)
+ ResetForRetry(ctx context.Context, limit int, failedBefore time.Time, maxAttempts int) ([]*OutboxEvent, error)
+ ResetStuckProcessing(ctx context.Context, limit int, processingBefore time.Time, maxAttempts int) ([]*OutboxEvent, error)
+ MarkInvalid(ctx context.Context, id uuid.UUID, errMsg string) error
+}
diff --git a/commons/outbox/sanitizer.go b/commons/outbox/sanitizer.go
new file mode 100644
index 00000000..300b2338
--- /dev/null
+++ b/commons/outbox/sanitizer.go
@@ -0,0 +1,141 @@
+package outbox
+
+import (
+ "regexp"
+ "strings"
+)
+
+// sanitizeErrorForStorage redacts sensitive values and enforces bounded length
+// before storing error messages in the last_error database column (CWE-209).
+const maxErrorLength = 512
+
+const errorTruncatedSuffix = "... (truncated)"
+
+const redactedValue = "[REDACTED]"
+
+type sensitiveDataPattern struct {
+ pattern *regexp.Regexp
+ replacement string
+}
+
+var sensitiveDataPatterns = []sensitiveDataPattern{
+ {
+ pattern: regexp.MustCompile(`(?i)\b([a-z][a-z0-9+.-]*://[^:\s/]+):([^@\s]+)@`),
+ replacement: `$1:` + redactedValue + `@`,
+ },
+ {
+ pattern: regexp.MustCompile(`(?i)\bbearer\s+[a-z0-9\-._~+/]+=*\b`),
+ replacement: "Bearer " + redactedValue,
+ },
+ {
+ pattern: regexp.MustCompile(`(?i)(authorization\s*:\s*basic\s+)[a-z0-9+/=]+`),
+ replacement: `$1` + redactedValue,
+ },
+ {
+ pattern: regexp.MustCompile(`\beyJ[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+\b`),
+ replacement: redactedValue,
+ },
+ {
+ pattern: regexp.MustCompile(`(?i)\b(api[-_ ]?key|access[-_ ]?token|refresh[-_ ]?token|password|secret)\s*[:=]\s*([^\s,;]+)`),
+ replacement: `$1=` + redactedValue,
+ },
+ {
+ pattern: regexp.MustCompile(`(?i)([?&](?:password|pass|pwd|token|api[_-]?key|access[_-]?token|refresh[_-]?token)=)([^&\s]+)`),
+ replacement: `$1` + redactedValue,
+ },
+ {
+ pattern: regexp.MustCompile(`\b(AKIA|ASIA)[A-Z0-9]{16}\b`),
+ replacement: redactedValue,
+ },
+ {
+ pattern: regexp.MustCompile(`(?i)\b(aws[_-]?secret[_-]?access[_-]?key|gcp[_-]?credentials|private[_-]?key|client[_-]?secret)\s*[:=]\s*([^\s,;]+)`),
+ replacement: `$1=` + redactedValue,
+ },
+ {
+ pattern: regexp.MustCompile(`(?i)\b[A-Z0-9._%+\-]+@[A-Z0-9.\-]+\.[A-Z]{2,}\b`),
+ replacement: redactedValue,
+ },
+}
+
+var longNumericTokenPattern = regexp.MustCompile(`\b\d{12,19}\b`)
+
+func sanitizeErrorForStorage(err error) string {
+ if err == nil {
+ return ""
+ }
+
+ return SanitizeErrorMessageForStorage(err.Error())
+}
+
+// SanitizeErrorMessageForStorage redacts sensitive values and enforces a bounded length.
+func SanitizeErrorMessageForStorage(msg string) string {
+ redacted := redactSensitiveData(strings.TrimSpace(msg))
+
+ return truncateError(redacted, maxErrorLength, errorTruncatedSuffix)
+}
+
+func redactSensitiveData(msg string) string {
+ redacted := msg
+
+ for _, matcher := range sensitiveDataPatterns {
+ redacted = matcher.pattern.ReplaceAllString(redacted, matcher.replacement)
+ }
+
+ redacted = redactLuhnNumberSequences(redacted)
+
+ return redacted
+}
+
+func redactLuhnNumberSequences(msg string) string {
+ return longNumericTokenPattern.ReplaceAllStringFunc(msg, func(candidate string) string {
+ if !passesLuhn(candidate) {
+ return candidate
+ }
+
+ return redactedValue
+ })
+}
+
+func passesLuhn(number string) bool {
+ if len(number) < 12 || len(number) > 19 {
+ return false
+ }
+
+ sum := 0
+ shouldDouble := false
+
+ for i := len(number) - 1; i >= 0; i-- {
+ digit := int(number[i] - '0')
+ if digit < 0 || digit > 9 {
+ return false
+ }
+
+ if shouldDouble {
+ digit *= 2
+ if digit > 9 {
+ digit -= 9
+ }
+ }
+
+ sum += digit
+ shouldDouble = !shouldDouble
+ }
+
+ return sum%10 == 0
+}
+
+func truncateError(msg string, maxRunes int, suffix string) string {
+ runes := []rune(msg)
+ if len(runes) <= maxRunes {
+ return msg
+ }
+
+ suffixRunes := []rune(suffix)
+ if maxRunes <= len(suffixRunes) {
+ return string(runes[:maxRunes])
+ }
+
+ trimmed := string(runes[:maxRunes-len(suffixRunes)])
+
+ return trimmed + suffix
+}
diff --git a/commons/outbox/sanitizer_test.go b/commons/outbox/sanitizer_test.go
new file mode 100644
index 00000000..23c3812e
--- /dev/null
+++ b/commons/outbox/sanitizer_test.go
@@ -0,0 +1,103 @@
+//go:build unit
+
+package outbox
+
+import (
+ "errors"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSanitizeErrorForStorage_RedactsSecrets(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New("bearer eyJabc.def.ghi api_key=secret123 user@mail.com 4111111111111111")
+ msg := sanitizeErrorForStorage(err)
+
+ require.NotContains(t, msg, "secret123")
+ require.NotContains(t, msg, "user@mail.com")
+ require.NotContains(t, msg, "4111111111111111")
+ require.Contains(t, msg, redactedValue)
+}
+
+func TestSanitizeErrorForStorage_Truncates(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New(strings.Repeat("x", maxErrorLength+30))
+ msg := sanitizeErrorForStorage(err)
+
+ require.LessOrEqual(t, len([]rune(msg)), maxErrorLength)
+ require.Contains(t, msg, errorTruncatedSuffix)
+}
+
+func TestSanitizeErrorForStorage_RedactsConnectionStringsAndCloudSecrets(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New(
+ "dial postgres://user:myPassword@db.local:5432/app " +
+ "AKIAIOSFODNN7EXAMPLE aws_secret_access_key=abcd1234",
+ )
+
+ msg := sanitizeErrorForStorage(err)
+
+ require.NotContains(t, msg, "myPassword")
+ require.NotContains(t, msg, "AKIAIOSFODNN7EXAMPLE")
+ require.NotContains(t, msg, "abcd1234")
+ require.Contains(t, msg, redactedValue)
+}
+
+func TestSanitizeErrorForStorage_NilError(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, "", sanitizeErrorForStorage(nil))
+}
+
+func TestSanitizeErrorMessageForStorage_ShortMessageUnchanged(t *testing.T) {
+ t.Parallel()
+
+ msg := "safe short error"
+ require.Equal(t, msg, SanitizeErrorMessageForStorage(msg))
+}
+
+func TestSanitizeErrorMessageForStorage_RedactsQueryParameterCredentials(t *testing.T) {
+ t.Parallel()
+
+ message := "request failed https://api.test.local/callback?password=super-secret&mode=sync"
+ sanitized := SanitizeErrorMessageForStorage(message)
+
+ require.NotContains(t, sanitized, "super-secret")
+ require.Contains(t, sanitized, "password="+redactedValue)
+}
+
+func TestSanitizeErrorMessageForStorage_DoesNotRedactNonLuhnLongNumbers(t *testing.T) {
+ t.Parallel()
+
+ message := "failed at unix_ms=1700000000000 while parsing request"
+ sanitized := SanitizeErrorMessageForStorage(message)
+
+ require.Contains(t, sanitized, "1700000000000")
+ require.NotContains(t, sanitized, redactedValue)
+}
+
+func TestSanitizeErrorMessageForStorage_RedactsAuthorizationBasicHeader(t *testing.T) {
+ t.Parallel()
+
+ message := "downstream call failed Authorization: Basic dXNlcjpwYXNz"
+ sanitized := SanitizeErrorMessageForStorage(message)
+
+ require.NotContains(t, sanitized, "dXNlcjpwYXNz")
+ require.Contains(t, sanitized, "Authorization: Basic "+redactedValue)
+}
+
+func TestSanitizeErrorMessageForStorage_UnicodeInput(t *testing.T) {
+ t.Parallel()
+
+ message := "erro de autenticao 🔒 usuario=test@example.com senha=segredo"
+ sanitized := SanitizeErrorMessageForStorage(message)
+
+ require.Contains(t, sanitized, "🔒")
+ require.NotContains(t, sanitized, "test@example.com")
+ require.Contains(t, sanitized, redactedValue)
+}
diff --git a/commons/outbox/status.go b/commons/outbox/status.go
new file mode 100644
index 00000000..e899de86
--- /dev/null
+++ b/commons/outbox/status.go
@@ -0,0 +1,74 @@
+package outbox
+
+import "fmt"
+
+// OutboxEventStatus represents a valid outbox event lifecycle state.
+type OutboxEventStatus string
+
+const (
+ StatusPending OutboxEventStatus = OutboxStatusPending
+ StatusProcessing OutboxEventStatus = OutboxStatusProcessing
+ StatusPublished OutboxEventStatus = OutboxStatusPublished
+ StatusFailed OutboxEventStatus = OutboxStatusFailed
+ StatusInvalid OutboxEventStatus = OutboxStatusInvalid
+)
+
+// ParseOutboxEventStatus validates and converts a raw string status.
+func ParseOutboxEventStatus(raw string) (OutboxEventStatus, error) {
+ status := OutboxEventStatus(raw)
+
+ if !status.IsValid() {
+ return "", fmt.Errorf("%w: %q", ErrOutboxStatusInvalid, raw)
+ }
+
+ return status, nil
+}
+
+// IsValid reports whether the status is part of the outbox lifecycle.
+func (status OutboxEventStatus) IsValid() bool {
+ switch status {
+ case StatusPending, StatusProcessing, StatusPublished, StatusFailed, StatusInvalid:
+ return true
+ default:
+ return false
+ }
+}
+
+// CanTransitionTo reports whether a transition from status to next is allowed.
+func (status OutboxEventStatus) CanTransitionTo(next OutboxEventStatus) bool {
+ switch status {
+ case StatusPending:
+ return next == StatusProcessing
+ case StatusFailed:
+ return next == StatusProcessing
+ case StatusProcessing:
+ return next == StatusProcessing || next == StatusPublished || next == StatusFailed || next == StatusInvalid
+ case StatusPublished, StatusInvalid:
+ return false
+ default:
+ return false
+ }
+}
+
+// ValidateOutboxTransition validates a status transition using typed lifecycle rules.
+func ValidateOutboxTransition(fromRaw, toRaw string) error {
+ from, err := ParseOutboxEventStatus(fromRaw)
+ if err != nil {
+ return fmt.Errorf("from status: %w", err)
+ }
+
+ to, err := ParseOutboxEventStatus(toRaw)
+ if err != nil {
+ return fmt.Errorf("to status: %w", err)
+ }
+
+ if !from.CanTransitionTo(to) {
+ return fmt.Errorf("%w: %s -> %s", ErrOutboxTransitionInvalid, from, to)
+ }
+
+ return nil
+}
+
+func (status OutboxEventStatus) String() string {
+ return string(status)
+}
diff --git a/commons/outbox/status_test.go b/commons/outbox/status_test.go
new file mode 100644
index 00000000..b5c7db34
--- /dev/null
+++ b/commons/outbox/status_test.go
@@ -0,0 +1,84 @@
+//go:build unit
+
+package outbox
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseOutboxEventStatus(t *testing.T) {
+ t.Parallel()
+
+ status, err := ParseOutboxEventStatus(OutboxStatusPending)
+ require.NoError(t, err)
+ require.Equal(t, StatusPending, status)
+
+ _, err = ParseOutboxEventStatus("UNKNOWN")
+ require.ErrorIs(t, err, ErrOutboxStatusInvalid)
+}
+
+func TestOutboxEventStatus_IsValid(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, StatusPending.IsValid())
+ require.True(t, StatusProcessing.IsValid())
+ require.True(t, StatusPublished.IsValid())
+ require.True(t, StatusFailed.IsValid())
+ require.True(t, StatusInvalid.IsValid())
+ require.False(t, OutboxEventStatus("BROKEN").IsValid())
+}
+
+func TestOutboxEventStatus_String(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, OutboxStatusProcessing, StatusProcessing.String())
+}
+
+func TestOutboxEventStatus_CanTransitionTo(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, StatusPending.CanTransitionTo(StatusProcessing))
+ require.True(t, StatusProcessing.CanTransitionTo(StatusPublished))
+ require.False(t, StatusPublished.CanTransitionTo(StatusProcessing))
+}
+
+func TestValidateOutboxTransition(t *testing.T) {
+ t.Parallel()
+
+ // Valid transitions.
+ require.NoError(t, ValidateOutboxTransition(OutboxStatusPending, OutboxStatusProcessing))
+ require.NoError(t, ValidateOutboxTransition(OutboxStatusFailed, OutboxStatusProcessing))
+ require.NoError(t, ValidateOutboxTransition(OutboxStatusProcessing, OutboxStatusPublished))
+ require.NoError(t, ValidateOutboxTransition(OutboxStatusProcessing, OutboxStatusFailed))
+ require.NoError(t, ValidateOutboxTransition(OutboxStatusProcessing, OutboxStatusInvalid))
+ require.NoError(t, ValidateOutboxTransition(OutboxStatusProcessing, OutboxStatusProcessing))
+
+ // Invalid transitions from terminal states.
+ err := ValidateOutboxTransition(OutboxStatusPublished, OutboxStatusProcessing)
+ require.ErrorIs(t, err, ErrOutboxTransitionInvalid)
+
+ err = ValidateOutboxTransition(OutboxStatusPublished, OutboxStatusFailed)
+ require.ErrorIs(t, err, ErrOutboxTransitionInvalid)
+
+ err = ValidateOutboxTransition(OutboxStatusInvalid, OutboxStatusProcessing)
+ require.ErrorIs(t, err, ErrOutboxTransitionInvalid)
+
+ err = ValidateOutboxTransition(OutboxStatusInvalid, OutboxStatusPending)
+ require.ErrorIs(t, err, ErrOutboxTransitionInvalid)
+
+ // Invalid backward transitions.
+ err = ValidateOutboxTransition(OutboxStatusPending, OutboxStatusFailed)
+ require.ErrorIs(t, err, ErrOutboxTransitionInvalid)
+
+ err = ValidateOutboxTransition(OutboxStatusFailed, OutboxStatusPublished)
+ require.ErrorIs(t, err, ErrOutboxTransitionInvalid)
+
+ // Unknown status.
+ err = ValidateOutboxTransition("UNKNOWN", OutboxStatusProcessing)
+ require.ErrorIs(t, err, ErrOutboxStatusInvalid)
+
+ err = ValidateOutboxTransition(OutboxStatusProcessing, "BOGUS")
+ require.ErrorIs(t, err, ErrOutboxStatusInvalid)
+}
diff --git a/commons/outbox/tenant.go b/commons/outbox/tenant.go
new file mode 100644
index 00000000..df0f2566
--- /dev/null
+++ b/commons/outbox/tenant.go
@@ -0,0 +1,103 @@
+package outbox
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "strings"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+)
+
+type tenantIDContextKey string
+
+// TenantIDContextKey stores tenant id used by outbox multi-tenant operations.
+//
+// Deprecated: use tenantmanager/core.ContextWithTenantID and tenantmanager/core.GetTenantIDFromContext.
+// This constant will be removed in v3.0.
+const TenantIDContextKey tenantIDContextKey = "outbox.tenant_id"
+
+// ErrTenantIDWhitespace is returned when a tenant ID contains leading or
+// trailing whitespace. Callers should trim the ID before passing it.
+var ErrTenantIDWhitespace = errors.New("tenant ID contains leading or trailing whitespace")
+
+// TenantResolver applies tenant-scoping rules for a transaction.
+type TenantResolver interface {
+ ApplyTenant(ctx context.Context, tx *sql.Tx, tenantID string) error
+}
+
+// TenantDiscoverer lists tenant identifiers to dispatch events for.
+type TenantDiscoverer interface {
+ DiscoverTenants(ctx context.Context) ([]string, error)
+}
+
+// ContextWithTenantID returns a context carrying tenantID.
+//
+// If the tenant ID contains leading or trailing whitespace, it is trimmed
+// before storing. An error is returned alongside the context to signal that
+// the caller provided a malformed input.
+func ContextWithTenantID(ctx context.Context, tenantID string) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ trimmed := strings.TrimSpace(tenantID)
+ if trimmed == "" {
+ return ctx
+ }
+
+ ctx = core.ContextWithTenantID(ctx, trimmed)
+
+ return context.WithValue(ctx, TenantIDContextKey, trimmed)
+}
+
+// ContextWithTenantIDStrict returns a context carrying tenantID.
+//
+// Unlike ContextWithTenantID, this variant returns an error when the tenant ID
+// contains leading or trailing whitespace instead of silently trimming. The
+// trimmed value is still stored so the context is usable.
+func ContextWithTenantIDStrict(ctx context.Context, tenantID string) (context.Context, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ trimmed := strings.TrimSpace(tenantID)
+ if trimmed == "" {
+ return ctx, nil
+ }
+
+ ctx = core.ContextWithTenantID(ctx, trimmed)
+ ctx = context.WithValue(ctx, TenantIDContextKey, trimmed)
+
+ if trimmed != tenantID {
+ return ctx, ErrTenantIDWhitespace
+ }
+
+ return ctx, nil
+}
+
+// TenantIDFromContext reads tenant id from context.
+func TenantIDFromContext(ctx context.Context) (string, bool) {
+ if ctx == nil {
+ return "", false
+ }
+
+ tenantID := core.GetTenantIDFromContext(ctx)
+
+ trimmed := strings.TrimSpace(tenantID)
+ if trimmed != "" {
+ return trimmed, true
+ }
+
+ tenantID, ok := ctx.Value(TenantIDContextKey).(string)
+ if !ok {
+ return "", false
+ }
+
+ trimmed = strings.TrimSpace(tenantID)
+ if trimmed == "" {
+ return "", false
+ }
+
+ return trimmed, true
+}
diff --git a/commons/outbox/tenant_test.go b/commons/outbox/tenant_test.go
new file mode 100644
index 00000000..cc7b1264
--- /dev/null
+++ b/commons/outbox/tenant_test.go
@@ -0,0 +1,88 @@
+//go:build unit
+
+package outbox
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestContextWithTenantID_TrimsWhitespace(t *testing.T) {
+ t.Parallel()
+
+ // IDs with leading/trailing spaces are now trimmed before storing.
+ ctx := ContextWithTenantID(nil, " tenant-1 ")
+ tenantID, ok := TenantIDFromContext(ctx)
+
+ require.True(t, ok)
+ require.Equal(t, "tenant-1", tenantID)
+}
+
+func TestContextWithTenantIDStrict_ReturnsErrorOnWhitespace(t *testing.T) {
+ t.Parallel()
+
+ ctx, err := ContextWithTenantIDStrict(context.Background(), " tenant-1 ")
+ require.ErrorIs(t, err, ErrTenantIDWhitespace)
+
+ // Trimmed value is still usable
+ tenantID, ok := TenantIDFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, "tenant-1", tenantID)
+}
+
+func TestContextWithTenantIDStrict_NoErrorOnCleanID(t *testing.T) {
+ t.Parallel()
+
+ ctx, err := ContextWithTenantIDStrict(context.Background(), "tenant-1")
+ require.NoError(t, err)
+
+ tenantID, ok := TenantIDFromContext(ctx)
+ require.True(t, ok)
+ require.Equal(t, "tenant-1", tenantID)
+}
+
+func TestContextWithTenantID_NilContextUsesBackground(t *testing.T) {
+ t.Parallel()
+
+ ctx := ContextWithTenantID(nil, "tenant-1")
+ tenantID, ok := TenantIDFromContext(ctx)
+
+ require.True(t, ok)
+ require.Equal(t, "tenant-1", tenantID)
+}
+
+func TestTenantIDFromContext_RoundTrip(t *testing.T) {
+ t.Parallel()
+
+ ctx := ContextWithTenantID(context.Background(), "tenant-42")
+ tenantID, ok := TenantIDFromContext(ctx)
+
+ require.True(t, ok)
+ require.Equal(t, "tenant-42", tenantID)
+}
+
+func TestTenantIDFromContext_InvalidCases(t *testing.T) {
+ t.Parallel()
+
+ tenantID, ok := TenantIDFromContext(nil)
+ require.False(t, ok)
+ require.Empty(t, tenantID)
+
+ ctx := ContextWithTenantID(context.Background(), " ")
+ tenantID, ok = TenantIDFromContext(ctx)
+ require.False(t, ok)
+ require.Empty(t, tenantID)
+}
+
+func TestTenantIDFromContext_TrimsStoredWhitespace(t *testing.T) {
+ t.Parallel()
+
+ // Even if whitespace somehow got into the context, TenantIDFromContext trims it.
+ ctx := context.WithValue(context.Background(), TenantIDContextKey, " spaced ")
+ tenantID, ok := TenantIDFromContext(ctx)
+
+ require.True(t, ok)
+ require.Equal(t, "spaced", tenantID)
+}
diff --git a/commons/pointers/doc.go b/commons/pointers/doc.go
new file mode 100644
index 00000000..0ebf981b
--- /dev/null
+++ b/commons/pointers/doc.go
@@ -0,0 +1,5 @@
+// Package pointers provides helpers for pointer creation and conversions.
+//
+// Use this package to reduce boilerplate in tests and DTO assembly while keeping
+// pointer semantics explicit at call sites.
+package pointers
diff --git a/commons/pointers/pointers.go b/commons/pointers/pointers.go
index 43a170f7..22d5736c 100644
--- a/commons/pointers/pointers.go
+++ b/commons/pointers/pointers.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package pointers
import "time"
diff --git a/commons/pointers/pointers_test.go b/commons/pointers/pointers_test.go
index 7aed47cf..dffe0078 100644
--- a/commons/pointers/pointers_test.go
+++ b/commons/pointers/pointers_test.go
@@ -1,6 +1,4 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package pointers
@@ -41,6 +39,14 @@ func TestInt64(t *testing.T) {
}
}
+func TestFloat64(t *testing.T) {
+ f := 3.14
+ result := Float64(f)
+ if *result != f {
+ t.Errorf("Float64() = %v, want %v", *result, f)
+ }
+}
+
func TestInt(t *testing.T) {
num := 42
result := Int(num)
diff --git a/commons/postgres/doc.go b/commons/postgres/doc.go
new file mode 100644
index 00000000..c9e7da01
--- /dev/null
+++ b/commons/postgres/doc.go
@@ -0,0 +1,5 @@
+// Package postgres provides shared PostgreSQL connection helpers.
+//
+// It focuses on predictable connection lifecycle and configuration defaults that
+// are safe for service startup and shutdown flows.
+package postgres
diff --git a/commons/postgres/migration_integration_test.go b/commons/postgres/migration_integration_test.go
new file mode 100644
index 00000000..037351ab
--- /dev/null
+++ b/commons/postgres/migration_integration_test.go
@@ -0,0 +1,277 @@
+//go:build integration
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Migration_DirtyState
+// ---------------------------------------------------------------------------
+//
+// Validates that golang-migrate's dirty-version mechanism is correctly
+// classified by classifyMigrationError into ErrMigrationDirty.
+//
+// Key insight: golang-migrate's postgres driver runs single-statement migrations
+// inside a transaction. If the statement fails, the transaction rolls back and
+// the DB is NOT marked dirty. A dirty state only occurs with MultiStatementEnabled
+// where the first statement commits but the second fails — leaving the schema
+// partially applied.
+//
+// Scenario:
+// 1. Migration 000001 (multi-statement, AllowMultiStatements=true):
+// - Statement 1: CREATE TABLE users (succeeds, commits)
+// - Statement 2: ALTER TABLE nonexistent_table (fails)
+// 2. golang-migrate marks schema_migrations as (version=1, dirty=true).
+// 3. The returned error MUST wrap ErrMigrationDirty.
+// 4. The users table must exist (first statement was committed).
+
+func TestIntegration_Migration_DirtyState(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ migDir := t.TempDir()
+
+ // Migration 1 — multi-statement: first succeeds, second fails.
+ // With MultiStatementEnabled, statements execute outside a transaction,
+ // so the first CREATE TABLE commits before the second ALTER fails.
+ // This leaves the database in a dirty state at version 1.
+ multiStatementSQL := `CREATE TABLE users (id SERIAL PRIMARY KEY, email TEXT NOT NULL);
+ALTER TABLE nonexistent_table ADD COLUMN foo TEXT;`
+
+ require.NoError(t, os.WriteFile(
+ filepath.Join(migDir, "000001_partial_migration.up.sql"),
+ []byte(multiStatementSQL),
+ 0o644,
+ ))
+
+ require.NoError(t, os.WriteFile(
+ filepath.Join(migDir, "000001_partial_migration.down.sql"),
+ []byte("DROP TABLE IF EXISTS users;"),
+ 0o644,
+ ))
+
+ migrator, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: dsn,
+ DatabaseName: "testdb",
+ MigrationsPath: migDir,
+ Component: "dirty_state_test",
+ AllowMultiStatements: true,
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err, "NewMigrator() should succeed")
+
+ // --- Run migrations — expect failure partway through version 1 ----------
+
+ err = migrator.Up(ctx)
+ require.Error(t, err, "first Up() must fail because the second statement is invalid")
+
+ // The first Up() returns the SQL execution error, NOT ErrDirty.
+ // golang-migrate sets schema_migrations to (version=1, dirty=true) but
+ // returns the raw error from the failed statement.
+
+ // --- Second Up() detects the dirty state left by the first call ----------
+
+ // Create a fresh migrator (same config) to simulate a process restart.
+ migrator2, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: dsn,
+ DatabaseName: "testdb",
+ MigrationsPath: migDir,
+ Component: "dirty_state_test",
+ AllowMultiStatements: true,
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err, "NewMigrator() for second attempt should succeed")
+
+ err = migrator2.Up(ctx)
+ require.Error(t, err, "second Up() must fail with dirty state")
+
+ // NOW the error chain must contain ErrMigrationDirty.
+ assert.True(t,
+ errors.Is(err, ErrMigrationDirty),
+ "error should wrap ErrMigrationDirty; got: %v", err,
+ )
+
+ // --- Verify side-effects ------------------------------------------------
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ t.Cleanup(func() { _ = client.Close() })
+
+ db, err := client.Primary()
+ require.NoError(t, err)
+
+ // First statement committed — users table must exist.
+ assertTableExists(t, ctx, db, "users")
+
+ // The schema_migrations table must show dirty=true at version 1.
+ var version int
+
+ var dirty bool
+
+ err = db.QueryRowContext(ctx,
+ "SELECT version, dirty FROM schema_migrations",
+ ).Scan(&version, &dirty)
+ require.NoError(t, err, "schema_migrations should have exactly one row")
+ assert.Equal(t, 1, version, "dirty version should be 1")
+ assert.True(t, dirty, "dirty flag should be true")
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Migration_NoChange
+// ---------------------------------------------------------------------------
+//
+// Validates that running Up() twice is idempotent: the second call returns nil
+// because classifyMigrationError converts migrate.ErrNoChange to a zero-value
+// outcome (err == nil).
+
+func TestIntegration_Migration_NoChange(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ migDir := t.TempDir()
+
+ require.NoError(t, os.WriteFile(
+ filepath.Join(migDir, "000001_create_items.up.sql"),
+ []byte("CREATE TABLE items (id SERIAL PRIMARY KEY, name TEXT NOT NULL);"),
+ 0o644,
+ ))
+
+ require.NoError(t, os.WriteFile(
+ filepath.Join(migDir, "000001_create_items.down.sql"),
+ []byte("DROP TABLE IF EXISTS items;"),
+ 0o644,
+ ))
+
+ migrator, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: dsn,
+ DatabaseName: "testdb",
+ MigrationsPath: migDir,
+ Component: "no_change_test",
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err)
+
+ // First run — applies migration 1.
+ err = migrator.Up(ctx)
+ require.NoError(t, err, "first Up() should succeed")
+
+ // Second run — no new migrations; ErrNoChange is suppressed to nil.
+ err = migrator.Up(ctx)
+ assert.NoError(t, err, "second Up() should return nil (ErrNoChange suppressed)")
+
+ // Sanity: table still exists and is usable.
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ t.Cleanup(func() { _ = client.Close() })
+
+ db, err := client.Primary()
+ require.NoError(t, err)
+
+ assertTableExists(t, ctx, db, "items")
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Migration_MultiStatement
+// ---------------------------------------------------------------------------
+//
+// Validates that AllowMultiStatements: true enables a single migration file
+// containing multiple SQL statements separated by semicolons.
+
+func TestIntegration_Migration_MultiStatement(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ migDir := t.TempDir()
+
+ multiSQL := `CREATE TABLE multi_a (id SERIAL PRIMARY KEY);
+CREATE TABLE multi_b (id SERIAL PRIMARY KEY);`
+
+ require.NoError(t, os.WriteFile(
+ filepath.Join(migDir, "000001_create_multi_tables.up.sql"),
+ []byte(multiSQL),
+ 0o644,
+ ))
+
+ require.NoError(t, os.WriteFile(
+ filepath.Join(migDir, "000001_create_multi_tables.down.sql"),
+ []byte("DROP TABLE IF EXISTS multi_b; DROP TABLE IF EXISTS multi_a;"),
+ 0o644,
+ ))
+
+ migrator, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: dsn,
+ DatabaseName: "testdb",
+ MigrationsPath: migDir,
+ Component: "multi_stmt_test",
+ AllowMultiStatements: true,
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err, "NewMigrator() should succeed with AllowMultiStatements")
+
+ err = migrator.Up(ctx)
+ require.NoError(t, err, "Up() should succeed with multi-statement migration")
+
+ // Verify both tables were created.
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ t.Cleanup(func() { _ = client.Close() })
+
+ db, err := client.Primary()
+ require.NoError(t, err)
+
+ assertTableExists(t, ctx, db, "multi_a")
+ assertTableExists(t, ctx, db, "multi_b")
+}
+
+// ---------------------------------------------------------------------------
+// helpers
+// ---------------------------------------------------------------------------
+
+// assertTableExists verifies that a table with the given name exists in the
+// public schema of the connected database. It fails the test immediately if
+// the table is missing.
+func assertTableExists(t *testing.T, ctx context.Context, db *sql.DB, table string) {
+ t.Helper()
+
+ var exists bool
+
+ err := db.QueryRowContext(ctx,
+ `SELECT EXISTS (
+ SELECT 1 FROM information_schema.tables
+ WHERE table_schema = 'public' AND table_name = $1
+ )`,
+ table,
+ ).Scan(&exists)
+ require.NoError(t, err, fmt.Sprintf("query for table %q existence should succeed", table))
+ assert.True(t, exists, fmt.Sprintf("table %q should exist in public schema", table))
+}
diff --git a/commons/postgres/pagination.go b/commons/postgres/pagination.go
deleted file mode 100644
index d9909c32..00000000
--- a/commons/postgres/pagination.go
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
-package postgres
-
-import "time"
-
-// Pagination is a struct designed to encapsulate pagination response payload data.
-//
-// swagger:model Pagination
-// @Description Pagination is the struct designed to store the pagination data of an entity list.
-type Pagination struct {
- Items any `json:"items"`
- Page int `json:"page,omitempty" example:"1"`
- PrevCursor string `json:"prev_cursor,omitempty" example:"MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA==" extensions:"x-omitempty"`
- NextCursor string `json:"next_cursor,omitempty" example:"MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA==" extensions:"x-omitempty"`
- Limit int `json:"limit" example:"10"`
- SortOrder string `json:"-" example:"asc"`
- StartDate time.Time `json:"-" example:"2021-01-01"`
- EndDate time.Time `json:"-" example:"2021-12-31"`
-} // @name Pagination
-
-// SetItems set an array of any struct in items.
-func (p *Pagination) SetItems(items any) {
- p.Items = items
-}
-
-// SetCursor set the next and previous cursor.
-func (p *Pagination) SetCursor(next, prev string) {
- p.NextCursor = next
- p.PrevCursor = prev
-}
diff --git a/commons/postgres/postgres.go b/commons/postgres/postgres.go
index 4a1872d8..585607ad 100644
--- a/commons/postgres/postgres.go
+++ b/commons/postgres/postgres.go
@@ -1,166 +1,926 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package postgres
import (
+ "context"
"database/sql"
"errors"
"fmt"
"net/url"
+ "os"
"path/filepath"
+ "regexp"
+ "slices"
"strings"
+ "sync"
"time"
// File system migration source. We need to import it to be able to use it as source in migrate.NewWithSourceInstance
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/LerianStudio/lib-commons/v4/commons/backoff"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
"github.com/bxcodec/dbresolver/v2"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
_ "github.com/jackc/pgx/v5/stdlib"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+)
+
+const (
+ defaultMaxOpenConns = 25
+ defaultMaxIdleConns = 10
+ defaultConnMaxLifetime = 30 * time.Minute
+ defaultConnMaxIdleTime = 5 * time.Minute
)
-// PostgresConnection is a hub which deal with postgres connections.
-type PostgresConnection struct {
- ConnectionStringPrimary string
- ConnectionStringReplica string
- PrimaryDBName string
- ReplicaDBName string
- ConnectionDB *dbresolver.DB
- Connected bool
- Component string
- MigrationsPath string
- Logger log.Logger
- MaxOpenConnections int
- MaxIdleConnections int
- // MultiStatementEnabled controls whether migrations run with multi-statement mode.
- // When nil, defaults to true for backward compatibility.
- // Use pointers.Bool(true) or pointers.Bool(false) to set explicitly.
- MultiStatementEnabled *bool
+var (
+ // ErrNilClient is returned when a postgres client receiver is nil.
+ ErrNilClient = errors.New("postgres client is nil")
+ // ErrNilContext is returned when a required context is nil.
+ ErrNilContext = errors.New("context is nil")
+ // ErrInvalidConfig indicates invalid postgres or migration configuration.
+ ErrInvalidConfig = errors.New("invalid postgres config")
+ // ErrNotConnected indicates operations requiring an active connection were called before connect.
+ ErrNotConnected = errors.New("postgres client is not connected")
+ // ErrInvalidDatabaseName indicates an invalid database identifier.
+ ErrInvalidDatabaseName = errors.New("invalid database name")
+ // ErrMigrationDirty indicates migrations stopped at a dirty version.
+ ErrMigrationDirty = errors.New("postgres migration dirty")
+ // ErrNilMigrator is returned when a migrator receiver is nil.
+ ErrNilMigrator = errors.New("postgres migrator is nil")
+ // ErrMigrationsNotFound is returned when the migration source directory is missing or empty.
+ // Services that intentionally skip migrations can opt in via WithAllowMissingMigrations().
+ ErrMigrationsNotFound = errors.New("migration files not found")
+
+ dbOpenFn = sql.Open
+
+ createResolverFn = func(primaryDB, replicaDB *sql.DB, logger log.Logger) (_ dbresolver.DB, err error) {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ if logger == nil {
+ logger = log.NewNop()
+ }
+
+ runtime.HandlePanicValue(context.Background(), logger, recovered, "postgres", "create_resolver")
+ err = fmt.Errorf("failed to create resolver: %w", fmt.Errorf("recovered panic: %v", recovered))
+ }
+ }()
+
+ connectionDB := dbresolver.New(
+ dbresolver.WithPrimaryDBs(primaryDB),
+ dbresolver.WithReplicaDBs(replicaDB),
+ dbresolver.WithLoadBalancer(dbresolver.RoundRobinLB),
+ )
+
+ if connectionDB == nil {
+ return nil, errors.New("resolver returned nil connection")
+ }
+
+ return connectionDB, nil
+ }
+
+ runMigrationsFn = runMigrations
+
+ connectionStringCredentialsPattern = regexp.MustCompile(`://[^@\s]+@`)
+ connectionStringPasswordPattern = regexp.MustCompile(`(?i)(password=)(\S+)`)
+ sslPathPattern = regexp.MustCompile(`(?i)(sslkey|sslcert|sslrootcert|sslpassword)=(\S+)`)
+ dbNamePattern = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]{0,62}$`)
+)
+
+// nilClientAssert fires a telemetry assertion for nil-receiver calls and returns ErrNilClient.
+// The logger is intentionally nil here because this function is called on a nil *Client receiver,
+// so there is no struct instance from which to extract a logger. The assert package handles
+// nil loggers gracefully by falling back to stderr.
+func nilClientAssert(operation string) error {
+ asserter := assert.New(context.Background(), nil, "postgres", operation)
+ _ = asserter.Never(context.Background(), "postgres client receiver is nil")
+
+ return fmt.Errorf("postgres %s: %w", operation, ErrNilClient)
+}
+
+// nilMigratorAssert fires a telemetry assertion for nil-receiver calls and returns ErrNilMigrator.
+// The logger is intentionally nil here because this function is called on a nil *Migrator receiver,
+// so there is no struct instance from which to extract a logger. The assert package handles
+// nil loggers gracefully by falling back to stderr.
+func nilMigratorAssert(operation string) error {
+ asserter := assert.New(context.Background(), nil, "postgres", operation)
+ _ = asserter.Never(context.Background(), "postgres migrator receiver is nil")
+
+ return fmt.Errorf("postgres %s: %w", operation, ErrNilMigrator)
+}
+
+// Config stores immutable connection options for a postgres client.
+type Config struct {
+ PrimaryDSN string
+ ReplicaDSN string
+ Logger log.Logger
+ MetricsFactory *metrics.MetricsFactory
+ MaxOpenConnections int
+ MaxIdleConnections int
+ ConnMaxLifetime time.Duration
+ ConnMaxIdleTime time.Duration
}
-// resolveMultiStatementEnabled returns the resolved value of MultiStatementEnabled.
-// Returns true if MultiStatementEnabled is nil (backward compatible default),
-// otherwise returns the dereferenced value.
-func (pc *PostgresConnection) resolveMultiStatementEnabled() bool {
- if pc.MultiStatementEnabled != nil {
- return *pc.MultiStatementEnabled
+func (c Config) withDefaults() Config {
+ if c.Logger == nil {
+ c.Logger = log.NewNop()
}
- return true
+ if c.MaxOpenConnections <= 0 {
+ c.MaxOpenConnections = defaultMaxOpenConns
+ }
+
+ if c.MaxIdleConnections <= 0 {
+ c.MaxIdleConnections = defaultMaxIdleConns
+ }
+
+ if c.ConnMaxLifetime <= 0 {
+ c.ConnMaxLifetime = defaultConnMaxLifetime
+ }
+
+ if c.ConnMaxIdleTime <= 0 {
+ c.ConnMaxIdleTime = defaultConnMaxIdleTime
+ }
+
+ return c
}
-// Connect keeps a singleton connection with postgres.
-func (pc *PostgresConnection) Connect() error {
- pc.Logger.Info("Connecting to primary and replica databases...")
+func (c Config) validate() error {
+ if strings.TrimSpace(c.PrimaryDSN) == "" {
+ return fmt.Errorf("%w: primary dsn cannot be empty", ErrInvalidConfig)
+ }
- dbPrimary, err := sql.Open("pgx", pc.ConnectionStringPrimary)
- if err != nil {
- pc.Logger.Errorf("failed to connect to primary database: %v", err)
- return fmt.Errorf("failed to connect to primary database: %w", err)
+ if err := validateDSN(c.PrimaryDSN); err != nil {
+ return fmt.Errorf("%w: primary dsn: %w", ErrInvalidConfig, err)
}
- dbPrimary.SetMaxOpenConns(pc.MaxOpenConnections)
- dbPrimary.SetMaxIdleConns(pc.MaxIdleConnections)
- dbPrimary.SetConnMaxLifetime(time.Minute * 30)
- dbPrimary.SetConnMaxIdleTime(5 * time.Minute)
+ if strings.TrimSpace(c.ReplicaDSN) == "" {
+ return fmt.Errorf("%w: replica dsn cannot be empty", ErrInvalidConfig)
+ }
- dbReadOnlyReplica, err := sql.Open("pgx", pc.ConnectionStringReplica)
- if err != nil {
- pc.Logger.Errorf("failed to connect to replica database: %v", err)
- return fmt.Errorf("failed to connect to replica database: %w", err)
+ if err := validateDSN(c.ReplicaDSN); err != nil {
+ return fmt.Errorf("%w: replica dsn: %w", ErrInvalidConfig, err)
}
- dbReadOnlyReplica.SetMaxOpenConns(pc.MaxOpenConnections)
- dbReadOnlyReplica.SetMaxIdleConns(pc.MaxIdleConnections)
- dbReadOnlyReplica.SetConnMaxLifetime(time.Minute * 30)
- dbReadOnlyReplica.SetConnMaxIdleTime(5 * time.Minute)
+ return nil
+}
+
+// validateDSN checks structural validity of URL-format DSNs.
+// Key-value format DSNs (without postgres:// prefix) are accepted without structural checks.
+func validateDSN(dsn string) error {
+ lower := strings.ToLower(strings.TrimSpace(dsn))
+ if strings.HasPrefix(lower, "postgres://") || strings.HasPrefix(lower, "postgresql://") {
+ if _, err := url.Parse(dsn); err != nil {
+ return fmt.Errorf("malformed URL: %w", err)
+ }
+ }
- connectionDB := dbresolver.New(
- dbresolver.WithPrimaryDBs(dbPrimary),
- dbresolver.WithReplicaDBs(dbReadOnlyReplica),
- dbresolver.WithLoadBalancer(dbresolver.RoundRobinLB))
+ return nil
+}
- migrationsPath, err := pc.getMigrationsPath()
+// warnInsecureDSN logs a warning if the DSN explicitly disables TLS.
+// This is advisory -- development environments commonly use sslmode=disable.
+func warnInsecureDSN(ctx context.Context, logger log.Logger, dsn, label string) {
+ if logger == nil || !logger.Enabled(log.LevelWarn) {
+ return
+ }
+
+ if strings.Contains(strings.ToLower(dsn), "sslmode=disable") {
+ logger.Log(ctx, log.LevelWarn,
+ "TLS disabled in database connection; production deployments should use sslmode=require or stronger",
+ log.String("dsn_label", label),
+ )
+ }
+}
+
+// connectBackoffCap is the maximum delay between lazy-connect retries.
+const connectBackoffCap = 30 * time.Second
+
+// connectionFailuresMetric defines the counter for postgres connection failures.
+var connectionFailuresMetric = metrics.Metric{
+ Name: "postgres_connection_failures_total",
+ Unit: "1",
+ Description: "Total number of postgres connection failures",
+}
+
+// Client is the v2 postgres connection manager.
+type Client struct {
+ mu sync.RWMutex
+ cfg Config
+ metricsFactory *metrics.MetricsFactory
+ resolver dbresolver.DB
+ primary *sql.DB
+ replica *sql.DB
+
+ // Lazy-connect rate-limiting: prevents thundering-herd reconnect storms
+ // when the database is down by enforcing exponential backoff between attempts.
+ lastConnectAttempt time.Time
+ connectAttempts int
+}
+
+// New creates a postgres client with immutable configuration.
+func New(cfg Config) (*Client, error) {
+ cfg = cfg.withDefaults()
+
+ if err := cfg.validate(); err != nil {
+ return nil, fmt.Errorf("postgres new: %w", err)
+ }
+
+ return &Client{cfg: cfg, metricsFactory: cfg.MetricsFactory}, nil
+}
+
+// logAtLevel emits a structured log entry at the specified level.
+func (c *Client) logAtLevel(ctx context.Context, level log.Level, msg string, fields ...log.Field) {
+ if c == nil || c.cfg.Logger == nil {
+ return
+ }
+
+ if !c.cfg.Logger.Enabled(level) {
+ return
+ }
+
+ c.cfg.Logger.Log(ctx, level, msg, fields...)
+}
+
+// Connect establishes a new primary/replica resolver and swaps it atomically.
+func (c *Client) Connect(ctx context.Context) error {
+ if c == nil {
+ return nilClientAssert("connect")
+ }
+
+ if ctx == nil {
+ return fmt.Errorf("postgres connect: %w", ErrNilContext)
+ }
+
+ tracer := otel.Tracer("postgres")
+
+ ctx, span := tracer.Start(ctx, "postgres.connect")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemPostgreSQL))
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := c.connectLocked(ctx); err != nil {
+ c.recordConnectionFailure(ctx, "connect")
+
+ libOpentelemetry.HandleSpanError(span, "Failed to connect to postgres", err)
+
+ return err
+ }
+
+ return nil
+}
+
+// connectLocked performs the actual connection logic.
+// The caller MUST hold c.mu (write lock) before calling this method.
+func (c *Client) connectLocked(ctx context.Context) error {
+ primary, replica, resolver, err := c.buildConnection(ctx)
if err != nil {
return err
}
- primaryURL, err := url.Parse(filepath.ToSlash(migrationsPath))
+ oldResolver := c.resolver
+ oldPrimary := c.primary
+ oldReplica := c.replica
+
+ c.resolver = resolver
+ c.primary = primary
+ c.replica = replica
+
+ if oldResolver != nil {
+ if err := oldResolver.Close(); err != nil {
+ c.logAtLevel(ctx, log.LevelWarn, "failed to close previous resolver after swap", log.Err(err))
+ }
+ }
+
+ // Always close old primary/replica explicitly to prevent leaks.
+ // The resolver may not own the underlying sql.DB connections.
+ if err := closeDB(oldPrimary); err != nil {
+ c.logAtLevel(ctx, log.LevelWarn, "failed to close old primary during swap", log.Err(err))
+ }
+
+ if err := closeDB(oldReplica); err != nil {
+ c.logAtLevel(ctx, log.LevelWarn, "failed to close old replica during swap", log.Err(err))
+ }
+
+ c.logAtLevel(ctx, log.LevelInfo, "connected to postgres")
+
+ return nil
+}
+
+func (c *Client) buildConnection(ctx context.Context) (*sql.DB, *sql.DB, dbresolver.DB, error) {
+ c.logAtLevel(ctx, log.LevelInfo, "connecting to primary and replica databases")
+
+ warnInsecureDSN(ctx, c.cfg.Logger, c.cfg.PrimaryDSN, "primary")
+ warnInsecureDSN(ctx, c.cfg.Logger, c.cfg.ReplicaDSN, "replica")
+
+ primary, err := c.newSQLDB(ctx, c.cfg.PrimaryDSN)
if err != nil {
- pc.Logger.Errorf("failed to parse migrations url: %v", err)
- return fmt.Errorf("failed to parse migrations url: %w", err)
+ return nil, nil, nil, fmt.Errorf("postgres connect: %w", err)
}
- primaryURL.Scheme = "file"
+ replica, err := c.newSQLDB(ctx, c.cfg.ReplicaDSN)
+ if err != nil {
+ _ = closeDB(primary)
+ return nil, nil, nil, fmt.Errorf("postgres connect: %w", err)
+ }
- primaryDriver, err := postgres.WithInstance(dbPrimary, &postgres.Config{
- MultiStatementEnabled: pc.resolveMultiStatementEnabled(),
- DatabaseName: pc.PrimaryDBName,
- SchemaName: "public",
- })
+ resolver, err := createResolverFn(primary, replica, c.cfg.Logger)
if err != nil {
- pc.Logger.Errorf("failed to create postgres driver instance: %v", err)
- return fmt.Errorf("failed to create postgres driver instance: %w", err)
+ _ = closeDB(primary)
+ _ = closeDB(replica)
+
+ c.logAtLevel(ctx, log.LevelError, "failed to create resolver", log.Err(err))
+
+ return nil, nil, nil, fmt.Errorf("postgres connect: failed to create resolver: %w", err)
}
- m, err := migrate.NewWithDatabaseInstance(primaryURL.String(), pc.PrimaryDBName, primaryDriver)
+ if err := resolver.PingContext(ctx); err != nil {
+ _ = resolver.Close()
+ _ = closeDB(primary)
+ _ = closeDB(replica)
+
+ c.logAtLevel(ctx, log.LevelError, "failed to ping database", log.Err(err))
+
+ return nil, nil, nil, fmt.Errorf("postgres connect: failed to ping database: %w", err)
+ }
+
+ return primary, replica, resolver, nil
+}
+
+func (c *Client) newSQLDB(ctx context.Context, dsn string) (*sql.DB, error) {
+ db, err := dbOpenFn("pgx", dsn)
if err != nil {
- pc.Logger.Errorf("failed to get migrations: %v", err)
- return fmt.Errorf("failed to create migration instance: %w", err)
+ sanitized := newSanitizedError(err, "failed to open database")
+ c.logAtLevel(ctx, log.LevelError, "failed to open database", log.Err(sanitized))
+
+ return nil, sanitized
+ }
+
+ db.SetMaxOpenConns(c.cfg.MaxOpenConnections)
+ db.SetMaxIdleConns(c.cfg.MaxIdleConnections)
+ db.SetConnMaxLifetime(c.cfg.ConnMaxLifetime)
+ db.SetConnMaxIdleTime(c.cfg.ConnMaxIdleTime)
+
+ return db, nil
+}
+
+// Resolver returns the resolver, connecting lazily if needed.
+// Unlike sync.Once, this uses double-checked locking so that a transient
+// failure on the first call does not permanently break the client --
+// subsequent calls will retry the connection.
+func (c *Client) Resolver(ctx context.Context) (dbresolver.DB, error) {
+ if c == nil {
+ return nil, nilClientAssert("resolver")
+ }
+
+ if ctx == nil {
+ return nil, fmt.Errorf("postgres resolver: %w", ErrNilContext)
+ }
+
+ // Fast path: already connected (read-lock only).
+ c.mu.RLock()
+ resolver := c.resolver
+ c.mu.RUnlock()
+
+ if resolver != nil {
+ return resolver, nil
+ }
+
+ // Slow path: acquire write lock and double-check before connecting.
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.resolver != nil {
+ return c.resolver, nil
}
- if err := m.Up(); err != nil {
- if errors.Is(err, migrate.ErrNoChange) {
- pc.Logger.Info("No new migrations found. Skipping...")
- } else if strings.Contains(err.Error(), "file does not exist") {
- pc.Logger.Warn("No migration files found. Skipping migration step...")
- } else {
- pc.Logger.Errorf("Migration failed: %v", err)
- return fmt.Errorf("migration failed: %w", err)
+ // Rate-limit lazy-connect retries: if previous attempts failed recently,
+ // enforce a minimum delay before the next attempt to prevent reconnect storms.
+ if c.connectAttempts > 0 {
+ delay := min(backoff.ExponentialWithJitter(1*time.Second, c.connectAttempts), connectBackoffCap)
+
+ if elapsed := time.Since(c.lastConnectAttempt); elapsed < delay {
+ return nil, fmt.Errorf("postgres resolver: rate-limited (next attempt in %s)", delay-elapsed)
}
}
- if err := connectionDB.Ping(); err != nil {
- pc.Logger.Errorf("PostgresConnection.Ping failed: %v", err)
- return fmt.Errorf("failed to ping database: %w", err)
+ c.lastConnectAttempt = time.Now()
+
+ tracer := otel.Tracer("postgres")
+
+ ctx, span := tracer.Start(ctx, "postgres.resolve")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemPostgreSQL))
+
+ if err := c.connectLocked(ctx); err != nil {
+ c.connectAttempts++
+ c.recordConnectionFailure(ctx, "resolve")
+
+ libOpentelemetry.HandleSpanError(span, "Failed to resolve postgres connection", err)
+
+ return nil, err
}
- pc.Connected = true
- pc.ConnectionDB = &connectionDB
+ c.connectAttempts = 0
- pc.Logger.Info("Connected to postgres ✅ \n")
+ if c.resolver == nil {
+ err := fmt.Errorf("postgres resolver: %w", ErrNotConnected)
+ libOpentelemetry.HandleSpanError(span, "Postgres resolver not connected after connect", err)
- return nil
+ return nil, err
+ }
+
+ return c.resolver, nil
}
-// GetDB returns a pointer to the postgres connection, initializing it if necessary.
-func (pc *PostgresConnection) GetDB() (dbresolver.DB, error) {
- if pc.ConnectionDB == nil {
- if err := pc.Connect(); err != nil {
- pc.Logger.Infof("ERRCONECT %s", err)
- return nil, err
+// Primary returns the current primary sql.DB, useful for admin operations.
+func (c *Client) Primary() (*sql.DB, error) {
+ if c == nil {
+ return nil, nilClientAssert("primary")
+ }
+
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if c.primary == nil {
+ return nil, fmt.Errorf("postgres primary: %w", ErrNotConnected)
+ }
+
+ return c.primary, nil
+}
+
+// Close releases database resources.
+// All three handles (resolver, primary, replica) are always explicitly closed
+// to prevent leaks -- the resolver may not own the underlying sql.DB connections.
+func (c *Client) Close() error {
+ if c == nil {
+ return nilClientAssert("close")
+ }
+
+ tracer := otel.Tracer("postgres")
+
+ _, span := tracer.Start(context.Background(), "postgres.close")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemPostgreSQL))
+
+ c.mu.Lock()
+ resolver := c.resolver
+ primary := c.primary
+ replica := c.replica
+
+ c.resolver = nil
+ c.primary = nil
+ c.replica = nil
+ c.mu.Unlock()
+
+ var errs []error
+
+ if resolver != nil {
+ if err := resolver.Close(); err != nil {
+ errs = append(errs, err)
}
}
- return *pc.ConnectionDB, nil
+ // Always close primary/replica explicitly to prevent leaks.
+ // The resolver may not own the underlying sql.DB connections.
+ if err := closeDB(primary); err != nil {
+ errs = append(errs, err)
+ }
+
+ if err := closeDB(replica); err != nil {
+ errs = append(errs, err)
+ }
+
+ if len(errs) > 0 {
+ closeErr := fmt.Errorf("postgres close: %w", errors.Join(errs...))
+ libOpentelemetry.HandleSpanError(span, "Failed to close postgres", closeErr)
+
+ return closeErr
+ }
+
+ return nil
+}
+
+// IsConnected reports whether the resolver is currently initialized.
+func (c *Client) IsConnected() (bool, error) {
+ if c == nil {
+ return false, nilClientAssert("is_connected")
+ }
+
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return c.resolver != nil, nil
+}
+
+func closeDB(db *sql.DB) error {
+ if db == nil {
+ return nil
+ }
+
+ return db.Close()
+}
+
+// MigrationConfig stores migration-only settings.
+type MigrationConfig struct {
+ PrimaryDSN string
+ DatabaseName string
+ MigrationsPath string
+ Component string
+ // AllowMultiStatements enables multi-statement execution in migrations.
+ // SECURITY: Only enable when migration files are from trusted, version-controlled sources.
+ // Multi-statement mode increases the blast radius of compromised migration files.
+ AllowMultiStatements bool
+ // AllowMissingMigrations makes Migrator.Up return nil instead of ErrMigrationsNotFound
+ // when the migration source directory does not exist. Use this for services that
+ // intentionally have no migrations (e.g., worker-only services sharing a database).
+ AllowMissingMigrations bool
+ Logger log.Logger
}
-// getMigrationsPath returns the path to migration files, calculating it if not explicitly provided
-func (pc *PostgresConnection) getMigrationsPath() (string, error) {
- if pc.MigrationsPath != "" {
- return pc.MigrationsPath, nil
+func (c MigrationConfig) withDefaults() MigrationConfig {
+ if c.Logger == nil {
+ c.Logger = log.NewNop()
}
- calculatedPath, err := filepath.Abs(filepath.Join("components", pc.Component, "migrations"))
+ return c
+}
+
+func (c MigrationConfig) validate() error {
+ if strings.TrimSpace(c.PrimaryDSN) == "" {
+ return fmt.Errorf("%w: primary dsn cannot be empty", ErrInvalidConfig)
+ }
+
+ if err := validateDBName(c.DatabaseName); err != nil {
+ return fmt.Errorf("migration config: %w", err)
+ }
+
+ if strings.TrimSpace(c.MigrationsPath) == "" && strings.TrimSpace(c.Component) == "" {
+ return fmt.Errorf("%w: migrations_path or component is required", ErrInvalidConfig)
+ }
+
+ return nil
+}
+
+// Migrator runs schema migrations explicitly.
+type Migrator struct {
+ cfg MigrationConfig
+}
+
+// NewMigrator creates a migrator with explicit migration config.
+func NewMigrator(cfg MigrationConfig) (*Migrator, error) {
+ cfg = cfg.withDefaults()
+
+ if err := cfg.validate(); err != nil {
+ return nil, fmt.Errorf("postgres new_migrator: %w", err)
+ }
+
+ return &Migrator{cfg: cfg}, nil
+}
+
+func (m *Migrator) logAtLevel(ctx context.Context, level log.Level, msg string, fields ...log.Field) {
+ if m == nil || m.cfg.Logger == nil {
+ return
+ }
+
+ if !m.cfg.Logger.Enabled(level) {
+ return
+ }
+
+ m.cfg.Logger.Log(ctx, level, msg, fields...)
+}
+
+// Up runs all up migrations.
+//
+// Note: golang-migrate's m.Up() does not accept a context, so cancellation
+// cannot stop a migration in progress. This method checks context state
+// before starting but cannot interrupt a running migration.
+func (m *Migrator) Up(ctx context.Context) error {
+ if m == nil {
+ return nilMigratorAssert("migrate_up")
+ }
+
+ if ctx == nil {
+ return fmt.Errorf("postgres migrate_up: %w", ErrNilContext)
+ }
+
+ tracer := otel.Tracer("postgres")
+
+ ctx, span := tracer.Start(ctx, "postgres.migrate_up")
+ defer span.End()
+
+ span.SetAttributes(
+ attribute.String(constant.AttrDBSystem, constant.DBSystemPostgreSQL),
+ attribute.String(constant.AttrDBName, m.cfg.DatabaseName),
+ )
+
+ // Fail fast if the context is already cancelled or expired.
+ if err := ctx.Err(); err != nil {
+ libOpentelemetry.HandleSpanError(span, "Context already done before migration", err)
+
+ return fmt.Errorf("postgres migrate_up: context already done: %w", err)
+ }
+
+ db, err := dbOpenFn("pgx", m.cfg.PrimaryDSN)
if err != nil {
- pc.Logger.Errorf("failed to get migration filepath: %v", err)
+ sanitized := newSanitizedError(err, "failed to open migration database")
+ m.logAtLevel(ctx, log.LevelError, "failed to open migration database", log.Err(sanitized))
+
+ libOpentelemetry.HandleSpanError(span, "Failed to open migration database", sanitized)
+
+ return fmt.Errorf("postgres migrate_up: %w", sanitized)
+ }
+ defer db.Close()
+ migrationsPath, err := resolveMigrationsPath(m.cfg.MigrationsPath, m.cfg.Component)
+ if err != nil {
+ m.logAtLevel(ctx, log.LevelError, "failed to resolve migration path", log.Err(err))
+
+ libOpentelemetry.HandleSpanError(span, "Failed to resolve migration path", err)
+
+ return fmt.Errorf("postgres migrate_up: %w", err)
+ }
+
+ if err := runMigrationsFn(ctx, db, migrationsPath, m.cfg.DatabaseName, m.cfg.AllowMultiStatements, m.cfg.AllowMissingMigrations, m.cfg.Logger); err != nil {
+ libOpentelemetry.HandleSpanError(span, "Migration up failed", err)
+
+ return fmt.Errorf("postgres migrate_up: %w", err)
+ }
+
+ return nil
+}
+
+func resolveMigrationsPath(migrationsPath, component string) (string, error) {
+ if strings.TrimSpace(migrationsPath) != "" {
+ return sanitizePath(migrationsPath)
+ }
+
+ // filepath.Base strips directory components, so "../../etc" becomes "etc".
+ sanitized := filepath.Base(component)
+ if sanitized == "." || sanitized == string(filepath.Separator) || sanitized == "" {
+ return "", fmt.Errorf("invalid component name: %q", component)
+ }
+
+ calculatedPath, err := filepath.Abs(filepath.Join("components", sanitized, "migrations"))
+ if err != nil {
return "", err
}
return calculatedPath, nil
}
+
+// SanitizedError wraps a database error with a credential-free message.
+// Error() returns only the sanitized text.
+//
+// Unwrap returns a sanitized copy of the original error that preserves
+// error types and sentinels (via errors.Is / errors.As) while stripping
+// connection strings and credentials from the message text.
+type SanitizedError struct {
+ // Message is the credential-free error description.
+ Message string
+ // cause is a sanitized version of the original error that preserves
+ // error types/sentinels but strips credentials from messages.
+ cause error
+}
+
+func (e *SanitizedError) Error() string { return e.Message }
+
+// Unwrap returns the sanitized cause, enabling errors.Is / errors.As
+// chain traversal without leaking credentials.
+func (e *SanitizedError) Unwrap() error { return e.cause }
+
+// sanitizedCause creates a credential-free copy of the cause error chain.
+// It preserves the type of known sentinel errors (e.g., sql.ErrNoRows) by
+// wrapping a new error with the sanitized message around the original sentinel.
+func sanitizedCause(err error) error {
+ if err == nil {
+ return nil
+ }
+
+ sanitizedMsg := sanitizeSensitiveString(err.Error())
+
+ return errors.New(sanitizedMsg)
+}
+
+// newSanitizedError wraps err with a credential-free message.
+// A sanitized copy of the cause is stored for error chain traversal.
+func newSanitizedError(err error, prefix string) *SanitizedError {
+ if err == nil {
+ return nil
+ }
+
+ return &SanitizedError{
+ Message: fmt.Sprintf("%s: %s", prefix, sanitizeSensitiveString(err.Error())),
+ cause: sanitizedCause(err),
+ }
+}
+
+// sanitizeSensitiveString removes credentials and sensitive paths from a string.
+func sanitizeSensitiveString(s string) string {
+ s = connectionStringCredentialsPattern.ReplaceAllString(s, "://***@")
+ s = connectionStringPasswordPattern.ReplaceAllString(s, "${1}***")
+ s = sslPathPattern.ReplaceAllString(s, "${1}=***")
+
+ return s
+}
+
+func sanitizePath(path string) (string, error) {
+ cleaned := filepath.Clean(path)
+ if slices.Contains(strings.Split(cleaned, string(filepath.Separator)), "..") {
+ return "", fmt.Errorf("invalid migrations path: %q", path)
+ }
+
+ absPath, err := filepath.Abs(cleaned)
+ if err != nil {
+ return "", fmt.Errorf("failed to resolve migrations path: %w", err)
+ }
+
+ return absPath, nil
+}
+
+func validateDBName(name string) error {
+ if !dbNamePattern.MatchString(name) {
+ return fmt.Errorf("%w: %q", ErrInvalidDatabaseName, name)
+ }
+
+ return nil
+}
+
+// migrationOutcome describes the result of classifying a migration error.
+type migrationOutcome struct {
+ err error
+ level log.Level
+ message string
+ fields []log.Field
+}
+
+// classifyMigrationError converts a golang-migrate error into a typed outcome.
+// Returns a zero-value outcome (err == nil) on success or benign cases (ErrNoChange).
+// When allowMissing is true, ErrNotExist is treated as benign (nil error); otherwise
+// it returns ErrMigrationsNotFound so the caller can distinguish missing files from success.
+func classifyMigrationError(err error, allowMissing bool) migrationOutcome {
+ if err == nil {
+ return migrationOutcome{}
+ }
+
+ if errors.Is(err, migrate.ErrNoChange) {
+ return migrationOutcome{
+ level: log.LevelInfo,
+ message: "no new migrations found, skipping",
+ }
+ }
+
+ if errors.Is(err, os.ErrNotExist) {
+ if allowMissing {
+ return migrationOutcome{
+ level: log.LevelWarn,
+ message: "no migration files found, skipping (AllowMissingMigrations=true)",
+ }
+ }
+
+ return migrationOutcome{
+ err: fmt.Errorf("%w: source directory missing or empty", ErrMigrationsNotFound),
+ level: log.LevelError,
+ message: "no migration files found",
+ }
+ }
+
+ var dirtyErr migrate.ErrDirty
+ if errors.As(err, &dirtyErr) {
+ return migrationOutcome{
+ err: fmt.Errorf("%w: database version %d", ErrMigrationDirty, dirtyErr.Version),
+ level: log.LevelError,
+ message: "migration failed with dirty version",
+ fields: []log.Field{log.Int("dirty_version", dirtyErr.Version)},
+ }
+ }
+
+ return migrationOutcome{
+ err: fmt.Errorf("migration failed: %w", err),
+ level: log.LevelError,
+ message: "migration failed",
+ fields: []log.Field{log.Err(err)},
+ }
+}
+
+// recordConnectionFailure increments the postgres connection failure counter.
+// No-op when metricsFactory is nil. ctx is used for metric recording and tracing.
+func (c *Client) recordConnectionFailure(ctx context.Context, operation string) {
+ if c == nil || c.metricsFactory == nil {
+ return
+ }
+
+ counter, err := c.metricsFactory.Counter(connectionFailuresMetric)
+ if err != nil {
+ c.logAtLevel(ctx, log.LevelWarn, "failed to create postgres metric counter", log.Err(err))
+ return
+ }
+
+ err = counter.
+ WithLabels(map[string]string{
+ "operation": constant.SanitizeMetricLabel(operation),
+ }).
+ AddOne(ctx)
+ if err != nil {
+ c.logAtLevel(ctx, log.LevelWarn, "failed to record postgres metric", log.Err(err))
+ }
+}
+
+// migrationLogAtLevel logs at the given level if logger is non-nil and the level is enabled.
+// This eliminates repeated nil-check + level-check branches in migration helpers.
+func migrationLogAtLevel(ctx context.Context, logger log.Logger, level log.Level, msg string, fields ...log.Field) {
+ if logger == nil || !logger.Enabled(level) {
+ return
+ }
+
+ logger.Log(ctx, level, msg, fields...)
+}
+
+// resolveMigrationSource parses the migrations path into a file:// URL.
+func resolveMigrationSource(migrationsPath string) (*url.URL, error) {
+ primaryURL, err := url.Parse(filepath.ToSlash(migrationsPath))
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse migrations url: %w", err)
+ }
+
+ primaryURL.Scheme = "file"
+
+ return primaryURL, nil
+}
+
+// createMigrationInstance creates the postgres driver and migration instance.
+func createMigrationInstance(dbPrimary *sql.DB, sourceURL, primaryDBName string, allowMultiStatements bool) (*migrate.Migrate, error) {
+ primaryDriver, err := postgres.WithInstance(dbPrimary, &postgres.Config{
+ MultiStatementEnabled: allowMultiStatements,
+ DatabaseName: primaryDBName,
+ SchemaName: "public",
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create postgres driver instance: %w", err)
+ }
+
+ mig, err := migrate.NewWithDatabaseInstance(sourceURL, primaryDBName, primaryDriver)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create migration instance: %w", err)
+ }
+
+ return mig, nil
+}
+
+// closeMigration releases source and database driver resources. Errors are logged
+// but not propagated since the migration itself already ran (or failed).
+func closeMigration(ctx context.Context, mig *migrate.Migrate, logger log.Logger) {
+ sourceErr, dbErr := mig.Close()
+ if sourceErr != nil {
+ migrationLogAtLevel(ctx, logger, log.LevelWarn, "failed to close migration source driver", log.Err(sourceErr))
+ }
+
+ if dbErr != nil {
+ migrationLogAtLevel(ctx, logger, log.LevelWarn, "failed to close migration database driver", log.Err(dbErr))
+ }
+}
+
+func runMigrations(ctx context.Context, dbPrimary *sql.DB, migrationsPath, primaryDBName string, allowMultiStatements, allowMissingMigrations bool, logger log.Logger) error {
+ if err := validateDBName(primaryDBName); err != nil {
+ migrationLogAtLevel(ctx, logger, log.LevelError, "invalid primary database name", log.Err(err))
+
+ return fmt.Errorf("migrations: %w", err)
+ }
+
+ primaryURL, err := resolveMigrationSource(migrationsPath)
+ if err != nil {
+ migrationLogAtLevel(ctx, logger, log.LevelError, "failed to parse migrations url", log.Err(err))
+
+ return err
+ }
+
+ mig, err := createMigrationInstance(dbPrimary, primaryURL.String(), primaryDBName, allowMultiStatements)
+ if err != nil {
+ migrationLogAtLevel(ctx, logger, log.LevelError, err.Error())
+
+ return err
+ }
+
+ defer closeMigration(ctx, mig, logger)
+
+ if err := mig.Up(); err != nil {
+ outcome := classifyMigrationError(err, allowMissingMigrations)
+
+ migrationLogAtLevel(ctx, logger, outcome.level, outcome.message, outcome.fields...)
+
+ return outcome.err
+ }
+
+ return nil
+}
diff --git a/commons/postgres/postgres_integration_test.go b/commons/postgres/postgres_integration_test.go
new file mode 100644
index 00000000..0bcfb47e
--- /dev/null
+++ b/commons/postgres/postgres_integration_test.go
@@ -0,0 +1,257 @@
+//go:build integration
+
+package postgres
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go"
+ tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
+ "github.com/testcontainers/testcontainers-go/wait"
+)
+
+// setupPostgresContainer starts a disposable PostgreSQL container and returns
+// the connection string plus a teardown function. The container is terminated
+// when the returned cleanup function is invoked (typically via t.Cleanup).
+func setupPostgresContainer(t *testing.T) (string, func()) {
+ t.Helper()
+
+ ctx := context.Background()
+
+ container, err := tcpostgres.Run(ctx,
+ "postgres:16-alpine",
+ tcpostgres.WithDatabase("testdb"),
+ tcpostgres.WithUsername("test"),
+ tcpostgres.WithPassword("test"),
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("database system is ready to accept connections").
+ WithOccurrence(2).
+ WithStartupTimeout(30*time.Second),
+ ),
+ )
+ require.NoError(t, err)
+
+ connStr, err := container.ConnectionString(ctx, "sslmode=disable")
+ require.NoError(t, err)
+
+ return connStr, func() {
+ require.NoError(t, container.Terminate(ctx))
+ }
+}
+
+// newTestConfig builds a Config pointing both primary and replica at the same
+// container DSN. This is intentional for integration tests — we are validating
+// the connector lifecycle, not read/write splitting.
+func newTestConfig(dsn string) Config {
+ return Config{
+ PrimaryDSN: dsn,
+ ReplicaDSN: dsn,
+ Logger: log.NewNop(),
+ MetricsFactory: metrics.NewNopFactory(),
+ }
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Postgres_ConnectAndResolve
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Postgres_ConnectAndResolve(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err, "New() should succeed with valid DSN")
+
+ err = client.Connect(ctx)
+ require.NoError(t, err, "Connect() should succeed against running container")
+
+ resolver, err := client.Resolver(ctx)
+ require.NoError(t, err, "Resolver() should return a live resolver after Connect()")
+ require.NotNil(t, resolver, "resolver must not be nil")
+
+ // Verify the resolver is actually connected to a live database.
+ err = resolver.PingContext(ctx)
+ assert.NoError(t, err, "PingContext on resolver should succeed")
+
+ err = client.Close()
+ assert.NoError(t, err, "Close() should release resources cleanly")
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Postgres_PrimaryAccess
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Postgres_PrimaryAccess(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ err = client.Connect(ctx)
+ require.NoError(t, err, "Connect() should succeed")
+
+ db, err := client.Primary()
+ require.NoError(t, err, "Primary() should return the underlying *sql.DB")
+ require.NotNil(t, db, "primary *sql.DB must not be nil")
+
+ // Verify the raw *sql.DB is usable.
+ err = db.PingContext(ctx)
+ assert.NoError(t, err, "PingContext on primary *sql.DB should succeed")
+
+ // Verify we can execute a trivial query to confirm connectivity beyond Ping.
+ var result int
+ err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
+ require.NoError(t, err, "trivial query should succeed")
+ assert.Equal(t, 1, result, "SELECT 1 should return 1")
+
+ err = client.Close()
+ assert.NoError(t, err)
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Postgres_IsConnected
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Postgres_IsConnected(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ // Before Connect(), IsConnected must be false.
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected, "IsConnected() should be false before Connect()")
+
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ // After Connect(), IsConnected must be true.
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "IsConnected() should be true after Connect()")
+
+ err = client.Close()
+ require.NoError(t, err)
+
+ // After Close(), IsConnected must be false again.
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected, "IsConnected() should be false after Close()")
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Postgres_LazyConnect
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Postgres_LazyConnect(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ // Do NOT call Connect() — Resolver() must lazy-connect on first access.
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected, "should not be connected before Resolver() call")
+
+ resolver, err := client.Resolver(ctx)
+ require.NoError(t, err, "Resolver() should lazy-connect successfully")
+ require.NotNil(t, resolver)
+
+ // After lazy connect, IsConnected must flip to true.
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "IsConnected() should be true after lazy connect via Resolver()")
+
+ // Verify the resolver is functional.
+ err = resolver.PingContext(ctx)
+ assert.NoError(t, err, "PingContext should succeed on lazily-connected resolver")
+
+ err = client.Close()
+ assert.NoError(t, err)
+}
+
+// ---------------------------------------------------------------------------
+// TestIntegration_Postgres_Migration
+// ---------------------------------------------------------------------------
+
+func TestIntegration_Postgres_Migration(t *testing.T) {
+ dsn, cleanup := setupPostgresContainer(t)
+ t.Cleanup(cleanup)
+
+ ctx := context.Background()
+
+ // Create a temporary directory with migration files.
+ migDir := t.TempDir()
+
+ upSQL := "CREATE TABLE IF NOT EXISTS test_items (id SERIAL PRIMARY KEY, name TEXT NOT NULL);"
+ downSQL := "DROP TABLE IF EXISTS test_items;"
+
+ err := os.WriteFile(filepath.Join(migDir, "000001_create_test_table.up.sql"), []byte(upSQL), 0o644)
+ require.NoError(t, err, "failed to write up migration file")
+
+ err = os.WriteFile(filepath.Join(migDir, "000001_create_test_table.down.sql"), []byte(downSQL), 0o644)
+ require.NoError(t, err, "failed to write down migration file")
+
+ // Run the migrator.
+ migrator, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: dsn,
+ DatabaseName: "testdb",
+ MigrationsPath: migDir,
+ Component: "integration_test",
+ Logger: log.NewNop(),
+ })
+ require.NoError(t, err, "NewMigrator() should succeed")
+
+ err = migrator.Up(ctx)
+ require.NoError(t, err, "Migrator.Up() should apply the migration successfully")
+
+ // Verify the table exists by querying it through a fresh client.
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ db, err := client.Primary()
+ require.NoError(t, err)
+
+ // Insert a row to confirm the table schema is correct.
+ _, err = db.ExecContext(ctx, "INSERT INTO test_items (name) VALUES ($1)", "integration_test_item")
+ require.NoError(t, err, "INSERT into migrated table should succeed")
+
+ // Read it back.
+ var name string
+ err = db.QueryRowContext(ctx, "SELECT name FROM test_items WHERE name = $1", "integration_test_item").Scan(&name)
+ require.NoError(t, err, "SELECT from migrated table should succeed")
+ assert.Equal(t, "integration_test_item", name, "should read back the inserted value")
+
+ // Verify the table has exactly one row.
+ var count int
+ err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM test_items").Scan(&count)
+ require.NoError(t, err)
+ assert.Equal(t, 1, count, "migrated table should contain exactly one row")
+
+ err = client.Close()
+ assert.NoError(t, err)
+}
diff --git a/commons/postgres/postgres_test.go b/commons/postgres/postgres_test.go
index 0099ffa5..2b7b3836 100644
--- a/commons/postgres/postgres_test.go
+++ b/commons/postgres/postgres_test.go
@@ -1,145 +1,1554 @@
+//go:build unit
+
package postgres
import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "os"
+ "strings"
+ "sync/atomic"
"testing"
+ "time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
- "github.com/LerianStudio/lib-commons/v2/commons/pointers"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/bxcodec/dbresolver/v2"
+ "github.com/golang-migrate/migrate/v4"
"github.com/stretchr/testify/assert"
- "go.uber.org/mock/gomock"
+ "github.com/stretchr/testify/require"
)
-func TestPostgresConnection_MultiStatementEnabled_Nil_DefaultsToTrue(t *testing.T) {
- t.Parallel()
+type fakeResolver struct {
+ pingErr error
+ closeErr error
+ pingCtx context.Context
+ closeCall atomic.Int32
+}
+
+func (f *fakeResolver) Begin() (dbresolver.Tx, error) { return nil, nil }
+
+func (f *fakeResolver) BeginTx(context.Context, *sql.TxOptions) (dbresolver.Tx, error) {
+ return nil, nil
+}
+
+func (f *fakeResolver) Close() error {
+ f.closeCall.Add(1)
+
+ return f.closeErr
+}
+
+func (f *fakeResolver) Conn(context.Context) (dbresolver.Conn, error) { return nil, nil }
+
+func (f *fakeResolver) Driver() driver.Driver { return nil }
+
+func (f *fakeResolver) Exec(string, ...interface{}) (sql.Result, error) { return nil, nil }
+
+func (f *fakeResolver) ExecContext(context.Context, string, ...interface{}) (sql.Result, error) {
+ return nil, nil
+}
+
+func (f *fakeResolver) Ping() error { return nil }
+
+func (f *fakeResolver) PingContext(ctx context.Context) error {
+ f.pingCtx = ctx
+
+ return f.pingErr
+}
+
+func (f *fakeResolver) Prepare(string) (dbresolver.Stmt, error) { return nil, nil }
+
+func (f *fakeResolver) PrepareContext(context.Context, string) (dbresolver.Stmt, error) {
+ return nil, nil
+}
+
+func (f *fakeResolver) Query(string, ...interface{}) (*sql.Rows, error) { return nil, nil }
+
+func (f *fakeResolver) QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) {
+ return nil, nil
+}
+
+func (f *fakeResolver) QueryRow(string, ...interface{}) *sql.Row { return &sql.Row{} }
+
+func (f *fakeResolver) QueryRowContext(context.Context, string, ...interface{}) *sql.Row {
+ return &sql.Row{}
+}
- ctrl := gomock.NewController(t)
- t.Cleanup(ctrl.Finish)
+func (f *fakeResolver) SetConnMaxIdleTime(time.Duration) {}
- mockLogger := log.NewMockLogger(ctrl)
+func (f *fakeResolver) SetConnMaxLifetime(time.Duration) {}
- pc := &PostgresConnection{
- ConnectionStringPrimary: "postgres://user:pass@localhost:5432/testdb",
- ConnectionStringReplica: "postgres://user:pass@localhost:5432/testdb",
- PrimaryDBName: "testdb",
- ReplicaDBName: "testdb",
- Logger: mockLogger,
- MaxOpenConnections: 10,
- MaxIdleConnections: 5,
- MultiStatementEnabled: nil, // explicitly nil to test default
+func (f *fakeResolver) SetMaxIdleConns(int) {}
+
+func (f *fakeResolver) SetMaxOpenConns(int) {}
+
+func (f *fakeResolver) PrimaryDBs() []*sql.DB { return nil }
+
+func (f *fakeResolver) ReplicaDBs() []*sql.DB { return nil }
+
+func (f *fakeResolver) Stats() sql.DBStats { return sql.DBStats{} }
+
+// testDB opens a sql.DB for test dependency injection.
+// WARNING: Tests using testDB with withPatchedDependencies must NOT call t.Parallel()
+// as withPatchedDependencies mutates global state.
+func testDB(t *testing.T) *sql.DB {
+ t.Helper()
+
+ dsn := os.Getenv("POSTGRES_DSN")
+ if dsn == "" {
+ dsn = "postgres://postgres:secret@localhost:5432/postgres?sslmode=disable"
}
- // Verify the field is nil
- assert.Nil(t, pc.MultiStatementEnabled, "MultiStatementEnabled should be nil by default")
+ db, err := sql.Open("pgx", dsn)
+ if err != nil {
+ t.Skipf("skipping: cannot open postgres connection (set POSTGRES_DSN to configure): %v", err)
+ }
- // Verify default resolution logic (using helper method from Connect())
- assert.True(t, pc.resolveMultiStatementEnabled(), "nil MultiStatementEnabled should resolve to true")
+ t.Cleanup(func() { _ = db.Close() })
+
+ return db
}
-func TestPostgresConnection_MultiStatementEnabled_ExplicitTrue(t *testing.T) {
- t.Parallel()
+// withPatchedDependencies replaces package-level dependency functions for testing.
+// WARNING: Tests using this helper must NOT call t.Parallel() as it mutates global state.
+func withPatchedDependencies(
+ t *testing.T,
+ openFn func(string, string) (*sql.DB, error),
+ resolverFn func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error),
+ migrateFn func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error,
+) {
+ t.Helper()
+
+ originalOpen := dbOpenFn
+ originalResolver := createResolverFn
+ originalMigrations := runMigrationsFn
- ctrl := gomock.NewController(t)
- t.Cleanup(ctrl.Finish)
+ dbOpenFn = openFn
+ createResolverFn = resolverFn
+ runMigrationsFn = migrateFn
- mockLogger := log.NewMockLogger(ctrl)
+ t.Cleanup(func() {
+ dbOpenFn = originalOpen
+ createResolverFn = originalResolver
+ runMigrationsFn = originalMigrations
+ })
+}
- pc := &PostgresConnection{
- ConnectionStringPrimary: "postgres://user:pass@localhost:5432/testdb",
- ConnectionStringReplica: "postgres://user:pass@localhost:5432/testdb",
- PrimaryDBName: "testdb",
- ReplicaDBName: "testdb",
- Logger: mockLogger,
- MaxOpenConnections: 10,
- MaxIdleConnections: 5,
- MultiStatementEnabled: pointers.Bool(true), // explicitly true
+func validConfig() Config {
+ return Config{
+ PrimaryDSN: "postgres://postgres:secret@localhost:5432/postgres?sslmode=disable",
+ ReplicaDSN: "postgres://postgres:secret@localhost:5432/postgres?sslmode=disable",
}
+}
+
+func TestNewConfigValidationAndDefaults(t *testing.T) {
+ t.Run("rejects missing dsn", func(t *testing.T) {
+ _, err := New(Config{})
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+
+ t.Run("applies defaults", func(t *testing.T) {
+ client, err := New(validConfig())
+
+ require.NoError(t, err)
+ require.NotNil(t, client)
+ assert.NotNil(t, client.cfg.Logger)
+ assert.Equal(t, defaultMaxOpenConns, client.cfg.MaxOpenConnections)
+ assert.Equal(t, defaultMaxIdleConns, client.cfg.MaxIdleConnections)
+ })
+}
+
+func TestConnectRequiresContext(t *testing.T) {
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilContext)
+}
+
+func TestDBRequiresContext(t *testing.T) {
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ _, err = client.Resolver(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilContext)
+}
+
+func TestConnectSanitizesSensitiveError(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) {
+ return nil, errors.New("parse postgres://alice:supersecret@db.internal:5432/main failed password=supersecret")
+ },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return nil, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
- // Verify the field is set to true
- assert.NotNil(t, pc.MultiStatementEnabled, "MultiStatementEnabled should not be nil")
- assert.True(t, *pc.MultiStatementEnabled, "MultiStatementEnabled should be true")
+ err = client.Connect(context.Background())
+ require.Error(t, err)
+ assert.NotContains(t, err.Error(), "supersecret")
+ assert.Contains(t, err.Error(), "://***@")
+ assert.Contains(t, err.Error(), "password=***")
- // Verify resolution logic (using helper method from Connect())
- assert.True(t, pc.resolveMultiStatementEnabled(), "explicit true should resolve to true")
+ // Verify error chain preservation via SanitizedError
+ var sanitizedErr *SanitizedError
+ assert.True(t, errors.As(err, &sanitizedErr))
}
-func TestPostgresConnection_MultiStatementEnabled_ExplicitFalse(t *testing.T) {
+func TestConnectAtomicSwapKeepsOldOnFailure(t *testing.T) {
+ oldResolver := &fakeResolver{}
+ newResolver := &fakeResolver{pingErr: errors.New("boom")}
+
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return newResolver, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+ client.resolver = oldResolver
+
+ err = client.Connect(context.Background())
+ require.Error(t, err)
+ assert.Equal(t, oldResolver, client.resolver)
+ assert.Equal(t, int32(0), oldResolver.closeCall.Load())
+ assert.Equal(t, int32(1), newResolver.closeCall.Load())
+}
+
+func TestConnectAtomicSwapClosesPreviousOnSuccess(t *testing.T) {
+ oldResolver := &fakeResolver{}
+ newResolver := &fakeResolver{}
+
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return newResolver, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+ client.resolver = oldResolver
+
+ err = client.Connect(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, int32(1), oldResolver.closeCall.Load())
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected)
+
+ assert.NoError(t, client.Close())
+}
+
+func TestDBLazyConnect(t *testing.T) {
+ resolver := &fakeResolver{}
+
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return resolver, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ db, err := client.Resolver(context.Background())
+ require.NoError(t, err)
+ assert.NotNil(t, db)
+ assert.NotNil(t, resolver.pingCtx)
+
+ assert.NoError(t, client.Close())
+}
+
+func TestCloseIsIdempotent(t *testing.T) {
+ resolver := &fakeResolver{}
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+ client.resolver = resolver
+
+ require.NoError(t, client.Close())
+ require.NoError(t, client.Close())
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected)
+ assert.Equal(t, int32(1), resolver.closeCall.Load())
+}
+
+func TestNewMigratorValidation(t *testing.T) {
+ t.Run("requires db name", func(t *testing.T) {
+ _, err := NewMigrator(MigrationConfig{PrimaryDSN: "postgres://localhost:5432/postgres"})
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidDatabaseName)
+ })
+
+ t.Run("requires component or path", func(t *testing.T) {
+ _, err := NewMigrator(MigrationConfig{PrimaryDSN: "postgres://localhost:5432/postgres", DatabaseName: "ledger"})
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+}
+
+func TestMigratorUpRunsExplicitly(t *testing.T) {
+ var migrationCalls atomic.Int32
+
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error {
+ migrationCalls.Add(1)
+ return nil
+ },
+ )
+
+ migrator, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "postgres://postgres:secret@localhost:5432/postgres?sslmode=disable",
+ DatabaseName: "postgres",
+ MigrationsPath: "components/ledger/migrations",
+ })
+ require.NoError(t, err)
+
+ err = migrator.Up(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, int32(1), migrationCalls.Load())
+}
+
+// ---------------------------------------------------------------------------
+// Config.withDefaults
+// ---------------------------------------------------------------------------
+
+func TestConfigWithDefaults(t *testing.T) {
t.Parallel()
- ctrl := gomock.NewController(t)
- t.Cleanup(ctrl.Finish)
+ t.Run("nil logger gets default", func(t *testing.T) {
+ t.Parallel()
- mockLogger := log.NewMockLogger(ctrl)
+ cfg := Config{PrimaryDSN: "dsn", ReplicaDSN: "dsn"}.withDefaults()
+ assert.NotNil(t, cfg.Logger)
+ })
- pc := &PostgresConnection{
- ConnectionStringPrimary: "postgres://user:pass@localhost:5432/testdb",
- ConnectionStringReplica: "postgres://user:pass@localhost:5432/testdb",
- PrimaryDBName: "testdb",
- ReplicaDBName: "testdb",
- Logger: mockLogger,
- MaxOpenConnections: 10,
- MaxIdleConnections: 5,
- MultiStatementEnabled: pointers.Bool(false), // explicitly false
- }
+ t.Run("zero MaxOpenConnections gets default", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := Config{PrimaryDSN: "dsn", ReplicaDSN: "dsn"}.withDefaults()
+ assert.Equal(t, defaultMaxOpenConns, cfg.MaxOpenConnections)
+ })
+
+ t.Run("zero MaxIdleConnections gets default", func(t *testing.T) {
+ t.Parallel()
- // Verify the field is set to false
- assert.NotNil(t, pc.MultiStatementEnabled, "MultiStatementEnabled should not be nil")
- assert.False(t, *pc.MultiStatementEnabled, "MultiStatementEnabled should be false")
+ cfg := Config{PrimaryDSN: "dsn", ReplicaDSN: "dsn"}.withDefaults()
+ assert.Equal(t, defaultMaxIdleConns, cfg.MaxIdleConnections)
+ })
- // Verify resolution logic (using helper method from Connect())
- assert.False(t, pc.resolveMultiStatementEnabled(), "explicit false should resolve to false")
+ t.Run("custom values preserved", func(t *testing.T) {
+ t.Parallel()
+
+ logger := log.NewNop()
+ cfg := Config{
+ PrimaryDSN: "dsn",
+ ReplicaDSN: "dsn",
+ Logger: logger,
+ MaxOpenConnections: 50,
+ MaxIdleConnections: 20,
+ }.withDefaults()
+
+ assert.Equal(t, logger, cfg.Logger)
+ assert.Equal(t, 50, cfg.MaxOpenConnections)
+ assert.Equal(t, 20, cfg.MaxIdleConnections)
+ })
}
-func TestPostgresConnection_MultiStatementEnabled_AllCases(t *testing.T) {
+// ---------------------------------------------------------------------------
+// Config.validate
+// ---------------------------------------------------------------------------
+
+func TestConfigValidate(t *testing.T) {
t.Parallel()
- tests := []struct {
- name string
- multiStatementEnabled *bool
- expectedResolved bool
- description string
- }{
- {
- name: "nil_defaults_to_true",
- multiStatementEnabled: nil,
- expectedResolved: true,
- description: "backward compatibility - nil should default to true",
+ t.Run("empty primary DSN", func(t *testing.T) {
+ t.Parallel()
+
+ err := Config{PrimaryDSN: "", ReplicaDSN: "dsn"}.validate()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+
+ t.Run("whitespace-only primary DSN", func(t *testing.T) {
+ t.Parallel()
+
+ err := Config{PrimaryDSN: " ", ReplicaDSN: "dsn"}.validate()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+
+ t.Run("empty replica DSN", func(t *testing.T) {
+ t.Parallel()
+
+ err := Config{PrimaryDSN: "dsn", ReplicaDSN: ""}.validate()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+
+ t.Run("valid config", func(t *testing.T) {
+ t.Parallel()
+
+ err := Config{PrimaryDSN: "dsn", ReplicaDSN: "dsn"}.validate()
+ assert.NoError(t, err)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// New
+// ---------------------------------------------------------------------------
+
+func TestNew(t *testing.T) {
+ t.Run("valid config returns client", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+ require.NotNil(t, client)
+ })
+
+ t.Run("invalid config returns error", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := New(Config{})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Client nil receiver safety
+// ---------------------------------------------------------------------------
+
+func TestClientNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Connect nil client", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ err := c.Connect(context.Background())
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilClient)
+ })
+
+ t.Run("Resolver nil client", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ _, err := c.Resolver(context.Background())
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilClient)
+ })
+
+ t.Run("Close nil client", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ err := c.Close()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilClient)
+ })
+
+ t.Run("IsConnected nil client", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ connected, err := c.IsConnected()
+ assert.False(t, connected)
+ assert.ErrorIs(t, err, ErrNilClient)
+ })
+
+ t.Run("Primary nil client", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ _, err := c.Primary()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilClient)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Client nil context
+// ---------------------------------------------------------------------------
+
+func TestClientNilContext(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Connect nil ctx", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilContext)
+ })
+
+ t.Run("Resolver nil ctx", func(t *testing.T) {
+ t.Parallel()
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ _, err = client.Resolver(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilContext)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Connect with mock dbOpenFn errors
+// ---------------------------------------------------------------------------
+
+func TestConnectDbOpenError(t *testing.T) {
+ t.Run("primary open fails", func(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) {
+ return nil, errors.New("connection refused")
+ },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to open database")
+ })
+
+ t.Run("replica open fails", func(t *testing.T) {
+ callCount := 0
+
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) {
+ callCount++
+ if callCount == 1 {
+ return testDB(t), nil
+ }
+
+ return nil, errors.New("replica down")
+ },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to open database")
+ })
+
+ t.Run("resolver creation fails", func(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) {
+ return nil, errors.New("resolver error")
+ },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to create resolver")
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Resolver lazy connect - double-checked locking (second call returns cached)
+// ---------------------------------------------------------------------------
+
+func TestResolverCachesResolver(t *testing.T) {
+ resolver := &fakeResolver{}
+
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return resolver, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ // First call connects lazily.
+ r1, err := client.Resolver(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, resolver, r1)
+
+ // Second call returns cached (fast path).
+ r2, err := client.Resolver(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, r1, r2)
+
+ assert.NoError(t, client.Close())
+}
+
+// ---------------------------------------------------------------------------
+// Primary not connected
+// ---------------------------------------------------------------------------
+
+func TestPrimaryNotConnected(t *testing.T) {
+ t.Parallel()
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ _, err = client.Primary()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNotConnected)
+}
+
+// ---------------------------------------------------------------------------
+// Close with error from resolver
+// ---------------------------------------------------------------------------
+
+func TestCloseResolverError(t *testing.T) {
+ resolver := &fakeResolver{closeErr: errors.New("close boom")}
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+ client.resolver = resolver
+
+ err = client.Close()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "close boom")
+}
+
+// ---------------------------------------------------------------------------
+// MigrationConfig
+// ---------------------------------------------------------------------------
+
+func TestMigrationConfigWithDefaults(t *testing.T) {
+ t.Parallel()
+
+ cfg := MigrationConfig{}.withDefaults()
+ assert.NotNil(t, cfg.Logger)
+}
+
+func TestMigrationConfigValidate(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty DSN", func(t *testing.T) {
+ t.Parallel()
+
+ err := MigrationConfig{DatabaseName: "ledger", MigrationsPath: "/tmp"}.validate()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+
+ t.Run("invalid DB name", func(t *testing.T) {
+ t.Parallel()
+
+ err := MigrationConfig{PrimaryDSN: "dsn", DatabaseName: "no-dashes", MigrationsPath: "/tmp"}.validate()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidDatabaseName)
+ })
+
+ t.Run("empty path and component", func(t *testing.T) {
+ t.Parallel()
+
+ err := MigrationConfig{PrimaryDSN: "dsn", DatabaseName: "ledger"}.validate()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ })
+
+ t.Run("valid with path", func(t *testing.T) {
+ t.Parallel()
+
+ err := MigrationConfig{PrimaryDSN: "dsn", DatabaseName: "ledger", MigrationsPath: "/tmp"}.validate()
+ assert.NoError(t, err)
+ })
+
+ t.Run("valid with component", func(t *testing.T) {
+ t.Parallel()
+
+ err := MigrationConfig{PrimaryDSN: "dsn", DatabaseName: "ledger", Component: "ledger"}.validate()
+ assert.NoError(t, err)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// NewMigrator
+// ---------------------------------------------------------------------------
+
+func TestNewMigratorValid(t *testing.T) {
+ t.Parallel()
+
+ m, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "dsn",
+ DatabaseName: "ledger",
+ MigrationsPath: "/migrations",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, m)
+}
+
+func TestNewMigratorInvalid(t *testing.T) {
+ t.Parallel()
+
+ _, err := NewMigrator(MigrationConfig{})
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+}
+
+// ---------------------------------------------------------------------------
+// Migrator nil receiver and nil context
+// ---------------------------------------------------------------------------
+
+func TestMigratorNilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var m *Migrator
+ err := m.Up(context.Background())
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilMigrator)
+}
+
+func TestMigratorNilContext(t *testing.T) {
+ m, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "dsn",
+ DatabaseName: "ledger",
+ MigrationsPath: "/migrations",
+ })
+ require.NoError(t, err)
+
+ err = m.Up(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilContext)
+}
+
+func TestMigratorUpDbOpenError(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) {
+ return nil, errors.New("parse postgres://alice:supersecret@db:5432/main failed")
},
- {
- name: "explicit_true",
- multiStatementEnabled: pointers.Bool(true),
- expectedResolved: true,
- description: "explicit true should resolve to true",
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return nil, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ m, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "postgres://alice:supersecret@db:5432/main?sslmode=disable",
+ DatabaseName: "main",
+ MigrationsPath: "/migrations",
+ })
+ require.NoError(t, err)
+
+ err = m.Up(context.Background())
+ require.Error(t, err)
+ assert.NotContains(t, err.Error(), "supersecret")
+}
+
+func TestMigratorUpResolvesPathFromComponent(t *testing.T) {
+ var capturedPath string
+
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(_ context.Context, _ *sql.DB, path, _ string, _, _ bool, _ log.Logger) error {
+ capturedPath = path
+ return nil
},
- {
- name: "explicit_false",
- multiStatementEnabled: pointers.Bool(false),
- expectedResolved: false,
- description: "explicit false should resolve to false",
+ )
+
+ m, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "postgres://localhost/db",
+ DatabaseName: "ledger",
+ Component: "ledger",
+ })
+ require.NoError(t, err)
+
+ err = m.Up(context.Background())
+ require.NoError(t, err)
+ assert.Contains(t, capturedPath, "components")
+ assert.Contains(t, capturedPath, "ledger")
+ assert.Contains(t, capturedPath, "migrations")
+}
+
+func TestMigratorUpMigrationError(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(_ context.Context, _ *sql.DB, _, _ string, _, _ bool, _ log.Logger) error {
+ return errors.New("migration failed")
},
- }
+ )
+
+ m, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "postgres://localhost/db",
+ DatabaseName: "ledger",
+ MigrationsPath: "/migrations",
+ })
+ require.NoError(t, err)
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
+ err = m.Up(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "migration failed")
+}
+
+// ---------------------------------------------------------------------------
+// sanitizeSensitiveString
+// ---------------------------------------------------------------------------
- ctrl := gomock.NewController(t)
- t.Cleanup(ctrl.Finish)
+func TestSanitizeSensitiveString(t *testing.T) {
+ t.Parallel()
- mockLogger := log.NewMockLogger(ctrl)
+ t.Run("masks user:password in DSN", func(t *testing.T) {
+ t.Parallel()
- pc := &PostgresConnection{
- ConnectionStringPrimary: "postgres://user:pass@localhost:5432/testdb",
- ConnectionStringReplica: "postgres://user:pass@localhost:5432/testdb",
- PrimaryDBName: "testdb",
- ReplicaDBName: "testdb",
- Logger: mockLogger,
- MaxOpenConnections: 10,
- MaxIdleConnections: 5,
- MultiStatementEnabled: tt.multiStatementEnabled,
- }
+ result := sanitizeSensitiveString("failed to connect to postgres://alice:supersecret@db.internal:5432/main")
+ assert.NotContains(t, result, "alice")
+ assert.NotContains(t, result, "supersecret")
+ assert.Contains(t, result, "://***@")
+ })
+
+ t.Run("masks password= param", func(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeSensitiveString("connection error password=mysecret host=db")
+ assert.NotContains(t, result, "mysecret")
+ assert.Contains(t, result, "password=***")
+ })
+
+ t.Run("masks password containing ampersand", func(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeSensitiveString("connection error password=sec&ret host=db")
+ assert.NotContains(t, result, "sec&ret")
+ assert.Contains(t, result, "password=***")
+ })
+
+ t.Run("masks sslkey path", func(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeSensitiveString("host=db sslkey=/etc/ssl/private/key.pem port=5432")
+ assert.NotContains(t, result, "/etc/ssl/private/key.pem")
+ assert.Contains(t, result, "sslkey=***")
+ })
+
+ t.Run("masks sslcert and sslrootcert", func(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeSensitiveString("sslcert=/path/cert.pem sslrootcert=/path/ca.pem")
+ assert.NotContains(t, result, "/path/cert.pem")
+ assert.Contains(t, result, "sslcert=***")
+ assert.Contains(t, result, "sslrootcert=***")
+ })
+
+ t.Run("error without credentials passes through", func(t *testing.T) {
+ t.Parallel()
+
+ result := sanitizeSensitiveString("timeout connecting to database")
+ assert.Equal(t, "timeout connecting to database", result)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// sanitizePath
+// ---------------------------------------------------------------------------
+
+func TestSanitizePath(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid path", func(t *testing.T) {
+ t.Parallel()
+
+ result, err := sanitizePath("components/ledger/migrations")
+ require.NoError(t, err)
+ assert.NotEmpty(t, result)
+ })
+
+ t.Run("path with traversal rejected", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := sanitizePath("../../etc/passwd")
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid migrations path")
+ })
+
+ t.Run("absolute path accepted", func(t *testing.T) {
+ t.Parallel()
+
+ result, err := sanitizePath("/var/migrations")
+ require.NoError(t, err)
+ assert.Equal(t, "/var/migrations", result)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// validateDBName
+// ---------------------------------------------------------------------------
+
+func TestValidateDBName(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid names", func(t *testing.T) {
+ t.Parallel()
+
+ for _, name := range []string{"postgres", "ledger", "_private", "db_123", "A"} {
+ assert.NoError(t, validateDBName(name), "expected %q to be valid", name)
+ }
+ })
+
+ t.Run("invalid names", func(t *testing.T) {
+ t.Parallel()
+
+ for _, name := range []string{"", "no-dashes", "123start", "has space", "a;drop", "has.dot"} {
+ err := validateDBName(name)
+ require.Error(t, err, "expected %q to be invalid", name)
+ assert.ErrorIs(t, err, ErrInvalidDatabaseName)
+ }
+ })
+
+ t.Run("too long name", func(t *testing.T) {
+ t.Parallel()
+
+ longName := strings.Repeat("a", 64)
+ err := validateDBName(longName)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidDatabaseName)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// resolveMigrationsPath
+// ---------------------------------------------------------------------------
+
+func TestResolveMigrationsPath(t *testing.T) {
+ t.Parallel()
+
+ t.Run("explicit path used", func(t *testing.T) {
+ t.Parallel()
+
+ result, err := resolveMigrationsPath("components/ledger/migrations", "ignored")
+ require.NoError(t, err)
+ assert.NotEmpty(t, result)
+ })
+
+ t.Run("component-based path", func(t *testing.T) {
+ t.Parallel()
+
+ result, err := resolveMigrationsPath("", "ledger")
+ require.NoError(t, err)
+ assert.Contains(t, result, "components")
+ assert.Contains(t, result, "ledger")
+ assert.Contains(t, result, "migrations")
+ })
+
+ t.Run("invalid component (traversal stripped)", func(t *testing.T) {
+ t.Parallel()
+
+ // filepath.Base("../../etc") → "etc", which is valid, so no error.
+ result, err := resolveMigrationsPath("", "../../etc")
+ require.NoError(t, err)
+ assert.Contains(t, result, "etc")
+ })
+
+ t.Run("empty component and empty path", func(t *testing.T) {
+ t.Parallel()
+
+ // filepath.Base("") → ".", which triggers the guard.
+ _, err := resolveMigrationsPath("", "")
+ require.Error(t, err)
+ })
+
+ t.Run("dot-only component", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := resolveMigrationsPath("", ".")
+ require.Error(t, err)
+ })
+
+ t.Run("path with traversal rejected", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := resolveMigrationsPath("../../etc/passwd", "")
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid migrations path")
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Close without resolver falls back to closing primary/replica directly
+// ---------------------------------------------------------------------------
+
+func TestCloseNoResolverClosesPrimaryAndReplica(t *testing.T) {
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ primary := testDB(t)
+ replica := testDB(t)
+
+ client.primary = primary
+ client.replica = replica
+
+ err = client.Close()
+ assert.NoError(t, err)
+
+ // After Close(), primary and replica should be nil.
+ assert.Nil(t, client.primary)
+ assert.Nil(t, client.replica)
+}
+
+func TestCloseNoResolverOnlyPrimary(t *testing.T) {
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ primary := testDB(t)
+ client.primary = primary
+
+ err = client.Close()
+ assert.NoError(t, err)
+ assert.Nil(t, client.primary)
+}
+
+func TestCloseNoResolverOnlyReplica(t *testing.T) {
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ replica := testDB(t)
+ client.replica = replica
+
+ err = client.Close()
+ assert.NoError(t, err)
+ assert.Nil(t, client.replica)
+}
+
+// ---------------------------------------------------------------------------
+// connectLocked old resolver close error path
+// ---------------------------------------------------------------------------
+
+func TestConnectLockedOldResolverCloseError(t *testing.T) {
+ oldResolver := &fakeResolver{closeErr: errors.New("old close failed")}
+ newResolver := &fakeResolver{}
+
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return newResolver, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+ client.resolver = oldResolver
+
+ // Should succeed — old resolver close error is logged but not returned.
+ err = client.Connect(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, int32(1), oldResolver.closeCall.Load())
+
+ assert.NoError(t, client.Close())
+}
+
+// ---------------------------------------------------------------------------
+// Resolver lazy connect error path
+// ---------------------------------------------------------------------------
+
+func TestResolverLazyConnectError(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) {
+ return nil, errors.New("cannot connect")
+ },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ _, err = client.Resolver(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to open database")
+}
+
+// ---------------------------------------------------------------------------
+// Resolver double-checked locking — resolver set between RLock and Lock
+// ---------------------------------------------------------------------------
+
+func TestResolverDoubleCheckReturnsExisting(t *testing.T) {
+ resolver := &fakeResolver{}
+
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return resolver, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ // First call connects lazily
+ r1, err := client.Resolver(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, resolver, r1)
+
+ // Set resolver directly to simulate race (already set when write lock acquired)
+ newResolver := &fakeResolver{}
+ client.mu.Lock()
+ client.resolver = newResolver
+ client.mu.Unlock()
+
+ r2, err := client.Resolver(context.Background())
+ require.NoError(t, err)
+ assert.Equal(t, newResolver, r2)
+}
+
+// ---------------------------------------------------------------------------
+// Primary returns db when connected
+// ---------------------------------------------------------------------------
+
+func TestPrimaryReturnsDBWhenConnected(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(string, string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(context.Background())
+ require.NoError(t, err)
+
+ db, err := client.Primary()
+ require.NoError(t, err)
+ assert.NotNil(t, db)
+
+ assert.NoError(t, client.Close())
+}
+
+// ---------------------------------------------------------------------------
+// Migrator Up resolveMigrationsPath error
+// ---------------------------------------------------------------------------
+
+func TestMigratorUpResolveMigrationsPathError(t *testing.T) {
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return &fakeResolver{}, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ m, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "postgres://localhost/db",
+ DatabaseName: "ledger",
+ MigrationsPath: "../../etc/passwd",
+ })
+ require.NoError(t, err)
+
+ err = m.Up(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid migrations path")
+}
+
+// ---------------------------------------------------------------------------
+// closeDB
+// ---------------------------------------------------------------------------
+
+func TestCloseDBNil(t *testing.T) {
+ t.Parallel()
+
+ err := closeDB(nil)
+ assert.NoError(t, err)
+}
+
+// ---------------------------------------------------------------------------
+// Client logAtLevel nil safety
+// ---------------------------------------------------------------------------
- // Use the same resolution logic as Connect()
- assert.Equal(t, tt.expectedResolved, pc.resolveMultiStatementEnabled(), tt.description)
+func TestClientLogAtLevelNilSafety(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil client does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ var c *Client
+ assert.NotPanics(t, func() {
+ c.logAtLevel(context.Background(), log.LevelInfo, "test")
+ })
+ })
+
+ t.Run("nil logger does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ c := &Client{}
+ assert.NotPanics(t, func() {
+ c.logAtLevel(context.Background(), log.LevelInfo, "test")
+ })
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Migrator logAtLevel nil safety
+// ---------------------------------------------------------------------------
+
+func TestMigratorLogAtLevelNilSafety(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil migrator does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ var m *Migrator
+ assert.NotPanics(t, func() {
+ m.logAtLevel(context.Background(), log.LevelInfo, "test")
})
+ })
+
+ t.Run("nil logger does not panic", func(t *testing.T) {
+ t.Parallel()
+
+ m := &Migrator{}
+ assert.NotPanics(t, func() {
+ m.logAtLevel(context.Background(), log.LevelError, "test")
+ })
+ })
+}
+
+// ---------------------------------------------------------------------------
+// SanitizedError
+// ---------------------------------------------------------------------------
+
+func TestSanitizedError(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Error returns sanitized message", func(t *testing.T) {
+ t.Parallel()
+
+ cause := errors.New("connect to postgres://alice:supersecret@db:5432 failed")
+ se := newSanitizedError(cause, "failed to open database")
+ assert.NotContains(t, se.Error(), "supersecret")
+ assert.NotContains(t, se.Error(), "alice")
+ assert.Contains(t, se.Error(), "://***@")
+ })
+
+ t.Run("Unwrap returns sanitized cause without credentials", func(t *testing.T) {
+ t.Parallel()
+
+ cause := errors.New("connect to postgres://alice:supersecret@db:5432 failed")
+ se := newSanitizedError(cause, "open failed")
+ unwrapped := se.Unwrap()
+ require.NotNil(t, unwrapped, "Unwrap must return a sanitized cause for error chain traversal")
+ assert.NotContains(t, unwrapped.Error(), "supersecret", "Unwrap must not leak credentials")
+ assert.NotContains(t, unwrapped.Error(), "alice", "Unwrap must not leak credentials")
+ assert.Contains(t, unwrapped.Error(), "://***@", "Unwrap must contain sanitized URI")
+ })
+
+ t.Run("nil error returns nil", func(t *testing.T) {
+ t.Parallel()
+
+ assert.Nil(t, newSanitizedError(nil, "prefix"))
+ })
+
+ t.Run("errors.Is does not match original cause directly", func(t *testing.T) {
+ t.Parallel()
+
+ inner := errors.New("inner")
+ wrapped := fmt.Errorf("wrapped: %w", inner)
+ se := newSanitizedError(wrapped, "outer")
+ // The sanitized cause is a new error with the sanitized message text,
+ // so errors.Is will not match the original inner error.
+ assert.NotErrorIs(t, se, inner, "sanitized cause is a new error, not the original")
+ assert.Contains(t, se.Error(), "outer", "sanitized message should contain prefix")
+ // But Unwrap works for typed assertions.
+ assert.NotNil(t, se.Unwrap())
+ })
+}
+
+// ---------------------------------------------------------------------------
+// classifyMigrationError
+// ---------------------------------------------------------------------------
+
+func TestClassifyMigrationError(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil error returns zero outcome", func(t *testing.T) {
+ t.Parallel()
+
+ outcome := classifyMigrationError(nil, false)
+ assert.Nil(t, outcome.err)
+ })
+
+ t.Run("ErrNoChange returns nil error with info level", func(t *testing.T) {
+ t.Parallel()
+
+ outcome := classifyMigrationError(migrate.ErrNoChange, false)
+ assert.Nil(t, outcome.err)
+ assert.Equal(t, log.LevelInfo, outcome.level)
+ assert.NotEmpty(t, outcome.message)
+ })
+
+ t.Run("ErrNotExist returns ErrMigrationsNotFound by default", func(t *testing.T) {
+ t.Parallel()
+
+ outcome := classifyMigrationError(os.ErrNotExist, false)
+ require.Error(t, outcome.err)
+ assert.ErrorIs(t, outcome.err, ErrMigrationsNotFound)
+ assert.Equal(t, log.LevelError, outcome.level)
+ })
+
+ t.Run("ErrNotExist returns nil error when allowMissing is true", func(t *testing.T) {
+ t.Parallel()
+
+ outcome := classifyMigrationError(os.ErrNotExist, true)
+ assert.Nil(t, outcome.err)
+ assert.Equal(t, log.LevelWarn, outcome.level)
+ assert.NotEmpty(t, outcome.message)
+ })
+
+ t.Run("ErrDirty returns wrapped sentinel with version", func(t *testing.T) {
+ t.Parallel()
+
+ outcome := classifyMigrationError(migrate.ErrDirty{Version: 42}, false)
+ require.Error(t, outcome.err)
+ assert.ErrorIs(t, outcome.err, ErrMigrationDirty)
+ assert.Contains(t, outcome.err.Error(), "42")
+ assert.Equal(t, log.LevelError, outcome.level)
+ assert.NotEmpty(t, outcome.fields)
+ })
+
+ t.Run("generic error returns wrapped error", func(t *testing.T) {
+ t.Parallel()
+
+ cause := errors.New("disk full")
+ outcome := classifyMigrationError(cause, false)
+ require.Error(t, outcome.err)
+ assert.ErrorIs(t, outcome.err, cause)
+ assert.Equal(t, log.LevelError, outcome.level)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// createResolverFn panic recovery
+// ---------------------------------------------------------------------------
+
+func TestCreateResolverFnPanicRecovery(t *testing.T) {
+ // dbresolver.New doesn't panic with nil DBs (it wraps them), so we test
+ // the recovery pattern by installing a resolver factory that panics and
+ // verifying buildConnection converts it to an error, not a crash.
+ original := createResolverFn
+ origOpen := dbOpenFn
+ t.Cleanup(func() {
+ createResolverFn = original
+ dbOpenFn = origOpen
+ })
+
+ dbOpenFn = func(_, _ string) (*sql.DB, error) { return testDB(t), nil }
+ createResolverFn = func(_ *sql.DB, _ *sql.DB, logger log.Logger) (_ dbresolver.DB, err error) {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ err = fmt.Errorf("failed to create resolver: %v", recovered)
+ }
+ }()
+
+ panic("dbresolver exploded")
}
-}
\ No newline at end of file
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to create resolver")
+ assert.Contains(t, err.Error(), "dbresolver exploded")
+}
+
+// ---------------------------------------------------------------------------
+// Config expansion: ConnMaxLifetime, ConnMaxIdleTime
+// ---------------------------------------------------------------------------
+
+func TestConfigWithDefaultsNewFields(t *testing.T) {
+ t.Parallel()
+
+ t.Run("zero ConnMaxLifetime gets default", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := Config{PrimaryDSN: "dsn", ReplicaDSN: "dsn"}.withDefaults()
+ assert.Equal(t, defaultConnMaxLifetime, cfg.ConnMaxLifetime)
+ })
+
+ t.Run("zero ConnMaxIdleTime gets default", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := Config{PrimaryDSN: "dsn", ReplicaDSN: "dsn"}.withDefaults()
+ assert.Equal(t, defaultConnMaxIdleTime, cfg.ConnMaxIdleTime)
+ })
+
+ t.Run("custom values preserved", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := Config{
+ PrimaryDSN: "dsn",
+ ReplicaDSN: "dsn",
+ ConnMaxLifetime: 1 * time.Hour,
+ ConnMaxIdleTime: 10 * time.Minute,
+ }.withDefaults()
+ assert.Equal(t, 1*time.Hour, cfg.ConnMaxLifetime)
+ assert.Equal(t, 10*time.Minute, cfg.ConnMaxIdleTime)
+ })
+}
+
+// ---------------------------------------------------------------------------
+// validateDSN
+// ---------------------------------------------------------------------------
+
+func TestValidateDSN(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid postgres:// URL", func(t *testing.T) {
+ t.Parallel()
+
+ assert.NoError(t, validateDSN("postgres://localhost:5432/db"))
+ })
+
+ t.Run("valid postgresql:// URL", func(t *testing.T) {
+ t.Parallel()
+
+ assert.NoError(t, validateDSN("postgresql://localhost:5432/db"))
+ })
+
+ t.Run("key-value format accepted", func(t *testing.T) {
+ t.Parallel()
+
+ assert.NoError(t, validateDSN("host=localhost port=5432 dbname=mydb"))
+ })
+
+ t.Run("empty string accepted (checked elsewhere)", func(t *testing.T) {
+ t.Parallel()
+
+ assert.NoError(t, validateDSN(""))
+ })
+}
+
+// ---------------------------------------------------------------------------
+// warnInsecureDSN
+// ---------------------------------------------------------------------------
+
+func TestWarnInsecureDSN(t *testing.T) {
+ t.Parallel()
+
+ t.Run("no panic with nil logger", func(t *testing.T) {
+ t.Parallel()
+
+ assert.NotPanics(t, func() {
+ warnInsecureDSN(context.Background(), nil, "postgres://host/db?sslmode=disable", "primary")
+ })
+ })
+
+ t.Run("no panic with secure DSN", func(t *testing.T) {
+ t.Parallel()
+
+ warnInsecureDSN(context.Background(), log.NewNop(), "postgres://host/db?sslmode=require", "primary")
+ })
+
+ t.Run("no panic with insecure DSN", func(t *testing.T) {
+ t.Parallel()
+
+ warnInsecureDSN(context.Background(), log.NewNop(), "postgres://host/db?sslmode=disable", "primary")
+ })
+}
+
+// ---------------------------------------------------------------------------
+// Migrator.Up context deadline check
+// ---------------------------------------------------------------------------
+
+func TestMigratorUpContextAlreadyCancelled(t *testing.T) {
+ t.Parallel()
+
+ m, err := NewMigrator(MigrationConfig{
+ PrimaryDSN: "postgres://localhost/db",
+ DatabaseName: "ledger",
+ MigrationsPath: "/migrations",
+ })
+ require.NoError(t, err)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err = m.Up(ctx)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, context.Canceled)
+}
+
+// ---------------------------------------------------------------------------
+// Close defensive cleanup
+// ---------------------------------------------------------------------------
+
+func TestCloseDefensiveCleanup(t *testing.T) {
+ t.Run("closes primary and replica even when resolver succeeds", func(t *testing.T) {
+ resolver := &fakeResolver{}
+
+ withPatchedDependencies(
+ t,
+ func(_, _ string) (*sql.DB, error) { return testDB(t), nil },
+ func(*sql.DB, *sql.DB, log.Logger) (dbresolver.DB, error) { return resolver, nil },
+ func(context.Context, *sql.DB, string, string, bool, bool, log.Logger) error { return nil },
+ )
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+
+ err = client.Connect(context.Background())
+ require.NoError(t, err)
+
+ err = client.Close()
+ assert.NoError(t, err)
+ assert.Equal(t, int32(1), resolver.closeCall.Load())
+
+ // Verify that primary and replica handles are cleared after Close.
+ client.mu.Lock()
+ assert.Nil(t, client.primary, "primary should be nil after Close")
+ assert.Nil(t, client.replica, "replica should be nil after Close")
+ assert.Nil(t, client.resolver, "resolver should be nil after Close")
+ client.mu.Unlock()
+ })
+
+ t.Run("collects multiple close errors", func(t *testing.T) {
+ resolver := &fakeResolver{closeErr: errors.New("resolver close failed")}
+
+ client, err := New(validConfig())
+ require.NoError(t, err)
+ client.resolver = resolver
+
+ err = client.Close()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "resolver close failed")
+ })
+}
diff --git a/commons/postgres/resilience_integration_test.go b/commons/postgres/resilience_integration_test.go
new file mode 100644
index 00000000..6bb10b7d
--- /dev/null
+++ b/commons/postgres/resilience_integration_test.go
@@ -0,0 +1,447 @@
+//go:build integration
+
+package postgres
+
+import (
+ "context"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go"
+ tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
+ "github.com/testcontainers/testcontainers-go/wait"
+)
+
+// setupPostgresContainerRaw starts a PostgreSQL 16 container and returns the
+// container handle (for Stop/Start control), its connection string, and a
+// cleanup function. Unlike setupPostgresContainer, this returns the raw
+// container so tests can simulate server outages by stopping and restarting it.
+func setupPostgresContainerRaw(t *testing.T) (*tcpostgres.PostgresContainer, string, func()) {
+ t.Helper()
+
+ ctx := context.Background()
+
+ container, err := tcpostgres.Run(ctx,
+ "postgres:16-alpine",
+ tcpostgres.WithDatabase("testdb"),
+ tcpostgres.WithUsername("test"),
+ tcpostgres.WithPassword("test"),
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("database system is ready to accept connections").
+ WithOccurrence(2).
+ WithStartupTimeout(30*time.Second),
+ ),
+ )
+ require.NoError(t, err)
+
+ connStr, err := container.ConnectionString(ctx, "sslmode=disable")
+ require.NoError(t, err)
+
+ return container, connStr, func() {
+ _ = container.Terminate(ctx)
+ }
+}
+
+// waitForPostgresReady polls the restarted container until PostgreSQL is
+// accepting connections. After a container restart the mapped port may change,
+// so the caller must provide the current DSN. We try New()+Connect() every
+// pollInterval for up to timeout.
+func waitForPostgresReady(t *testing.T, dsn string, timeout, pollInterval time.Duration) {
+ t.Helper()
+
+ ctx := context.Background()
+ deadline := time.Now().Add(timeout)
+
+ for time.Now().Before(deadline) {
+ probe, err := New(newTestConfig(dsn))
+ if err != nil {
+ time.Sleep(pollInterval)
+ continue
+ }
+
+ if connectErr := probe.Connect(ctx); connectErr == nil {
+ _ = probe.Close()
+ return
+ }
+
+ _ = probe.Close()
+ time.Sleep(pollInterval)
+ }
+
+ t.Fatalf("PostgreSQL at DSN did not become ready within %s", timeout)
+}
+
+// TestIntegration_Postgres_Resilience_ReconnectAfterRestart validates the full
+// outage-recovery cycle:
+// 1. Connect and verify operations work (SELECT 1).
+// 2. Stop the container (simulates server crash / network partition).
+// 3. Verify that operations fail while the server is down.
+// 4. Restart the container and re-read the DSN (port may change).
+// 5. Create a fresh client with the new DSN and verify operations succeed.
+//
+// This is the most realistic resilience scenario: the backing PostgreSQL goes
+// down and comes back, possibly on a different port.
+func TestIntegration_Postgres_Resilience_ReconnectAfterRestart(t *testing.T) {
+ container, dsn, cleanup := setupPostgresContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Phase 1: establish a healthy connection and verify operations.
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err, "New() should succeed with valid DSN")
+
+ err = client.Connect(ctx)
+ require.NoError(t, err, "Connect() should succeed against running container")
+
+ db, err := client.Primary()
+ require.NoError(t, err, "Primary() should return a live *sql.DB")
+
+ var result int
+ err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
+ require.NoError(t, err, "SELECT 1 must succeed while server is healthy")
+ assert.Equal(t, 1, result)
+
+ // Phase 2: stop the container to simulate server going down.
+ t.Log("Stopping PostgreSQL container to simulate outage...")
+ require.NoError(t, container.Stop(ctx, nil))
+
+ // The existing *sql.DB handle is now pointing at a dead socket.
+ // Operations should fail (the exact error varies by OS/timing).
+ err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
+ assert.Error(t, err, "SELECT 1 must fail while server is down")
+
+ // Phase 3: restart the container. The mapped port may change after
+ // restart, so we must re-read the connection string from the container.
+ t.Log("Restarting PostgreSQL container...")
+ require.NoError(t, container.Start(ctx))
+
+ newDSN, err := container.ConnectionString(ctx, "sslmode=disable")
+ require.NoError(t, err, "must be able to read connection string after restart")
+ t.Logf("PostgreSQL DSN after restart: %s (was: %s)", newDSN, dsn)
+
+ // Poll until the server is accepting connections at the (potentially new) DSN.
+ waitForPostgresReady(t, newDSN, 15*time.Second, 200*time.Millisecond)
+ t.Log("PostgreSQL container is ready after restart")
+
+ // Phase 4: close the old client and create a fresh one with the new DSN.
+ _ = client.Close()
+
+ client2, err := New(newTestConfig(newDSN))
+ require.NoError(t, err, "New() must succeed after server restart")
+
+ defer func() { _ = client2.Close() }()
+
+ err = client2.Connect(ctx)
+ require.NoError(t, err, "Connect() must succeed against restarted container")
+
+ // Phase 5: verify the reconnected client can operate.
+ db2, err := client2.Primary()
+ require.NoError(t, err, "Primary() must return a live *sql.DB after reconnect")
+
+ var result2 int
+ err = db2.QueryRowContext(ctx, "SELECT 1").Scan(&result2)
+ require.NoError(t, err, "SELECT 1 must succeed after reconnect")
+ assert.Equal(t, 1, result2, "query result must be correct after reconnect")
+
+ connected, err := client2.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "client must report connected after successful reconnect")
+}
+
+// TestIntegration_Postgres_Resilience_BackoffRateLimiting validates that the
+// reconnect rate-limiter prevents thundering-herd storms. When the resolver is
+// nil and Resolver() is called rapidly, the first call attempts a real
+// reconnect; subsequent calls within the backoff window return a "rate-limited"
+// error without hitting the network.
+//
+// Mechanism (from postgres.go Resolver):
+// - connectAttempts tracks consecutive failures.
+// - Each failure increments connectAttempts and records lastConnectAttempt.
+// - The next Resolver() computes delay = ExponentialWithJitter(1s, attempts).
+// - If elapsed < delay, it returns "rate-limited" immediately.
+//
+// To trigger this path, we connect to a real PostgreSQL, stop the container so
+// reconnect attempts genuinely fail, then close the client (resolver=nil) and
+// fire rapid Resolver() calls.
+func TestIntegration_Postgres_Resilience_BackoffRateLimiting(t *testing.T) {
+ container, dsn, cleanup := setupPostgresContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ // Verify the connection is healthy before we break things.
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ resolver, err := client.Resolver(ctx)
+ require.NoError(t, err)
+ require.NoError(t, resolver.PingContext(ctx))
+
+ // Stop the container so reconnect attempts genuinely fail.
+ t.Log("Stopping container to make reconnect attempts fail...")
+ require.NoError(t, container.Stop(ctx, nil))
+
+ // Close the wrapper client to nil out the resolver. This puts the client
+ // into the "needs reconnect" state where Resolver() will attempt lazy-connect.
+ require.NoError(t, client.Close())
+
+ // First Resolver() call: should attempt a real reconnect to the stopped
+ // server, fail, and increment connectAttempts to 1.
+ _, err = client.Resolver(ctx)
+ require.Error(t, err, "first Resolver() must fail because server is stopped")
+ t.Logf("First Resolver() error (expected): %v", err)
+
+ // Rapid subsequent calls: should be rate-limited because we're within
+ // the backoff window. The delay after 1 failure is in
+ // [0, 1s * 2^1) = [0, 2s). Even with jitter at its minimum (0ms),
+ // consecutive calls within microseconds should be rate-limited after the
+ // first real attempt set lastConnectAttempt.
+ rateLimitedCount := 0
+ realAttemptCount := 0
+
+ const rapidCalls = 20
+
+ for range rapidCalls {
+ _, callErr := client.Resolver(ctx)
+ require.Error(t, callErr)
+
+ if strings.Contains(callErr.Error(), "rate-limited") {
+ rateLimitedCount++
+ } else {
+ realAttemptCount++
+ }
+ }
+
+ t.Logf("Of %d rapid calls: %d rate-limited, %d real attempts",
+ rapidCalls, rateLimitedCount, realAttemptCount)
+
+ // Due to the jitter in ExponentialWithJitter, the exact split between
+ // rate-limited and real attempts is non-deterministic. However, we
+ // expect the majority to be rate-limited since the calls happen in
+ // microseconds and the backoff window is at least hundreds of milliseconds.
+ assert.Greater(t, rateLimitedCount, 0,
+ "at least some calls must be rate-limited to prevent reconnect storms")
+
+ // Verify that real reconnect attempts are significantly fewer than
+ // rate-limited ones. This proves the backoff is working.
+ if rateLimitedCount > 0 && realAttemptCount > 0 {
+ assert.Greater(t, rateLimitedCount, realAttemptCount,
+ "rate-limited calls should outnumber real reconnect attempts")
+ }
+}
+
+// TestIntegration_Postgres_Resilience_GracefulDegradation validates that the
+// client degrades gracefully under failure conditions without panics or
+// undefined behavior:
+// 1. After server goes down, IsConnected() still returns true because the
+// resolver struct field was set during Connect() and has not been cleared.
+// 2. Primary() returns the stale *sql.DB (struct access, not a wire check).
+// 3. PingContext on the stale *sql.DB fails (server is down).
+// 4. Close() succeeds cleanly.
+// 5. Resolver() after Close() fails with an error (not a panic).
+// 6. No panics throughout the entire degradation sequence.
+func TestIntegration_Postgres_Resilience_GracefulDegradation(t *testing.T) {
+ container, dsn, cleanup := setupPostgresContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ defer func() {
+ // Best-effort close; may already be closed.
+ _ = client.Close()
+ }()
+
+ // Establish a healthy connection.
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ db, err := client.Primary()
+ require.NoError(t, err)
+ require.NoError(t, db.PingContext(ctx), "PingContext must succeed while server is healthy")
+
+ // Stop the server while the client still holds connection handles.
+ t.Log("Stopping PostgreSQL container...")
+ require.NoError(t, container.Stop(ctx, nil))
+
+ // IsConnected() checks c.resolver != nil. The struct field was set during
+ // Connect() and hasn't been cleared, so it still returns true even though
+ // the server is unreachable.
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected,
+ "IsConnected must still be true immediately after server stop "+
+ "(the struct field hasn't been cleared)")
+
+ // Primary() returns the stale *sql.DB — this is a struct read, not a
+ // wire check. The handle itself is still non-nil.
+ staleDB, err := client.Primary()
+ require.NoError(t, err, "Primary() must return the stale *sql.DB without error")
+ require.NotNil(t, staleDB, "stale *sql.DB must be non-nil")
+
+ // But PingContext on the stale handle must fail because the server is down.
+ pingErr := staleDB.PingContext(ctx)
+ assert.Error(t, pingErr, "PingContext on stale handle must fail when server is down")
+
+ // Close() should succeed cleanly, releasing all handles.
+ closeErr := client.Close()
+ assert.NoError(t, closeErr, "Close() must succeed even when server is unreachable")
+
+ // After Close(), IsConnected must be false (resolver was set to nil).
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected, "IsConnected must be false after Close()")
+
+ // Resolver() should attempt reconnect, fail (server is still down),
+ // and return an error — not panic.
+ _, resolverErr := client.Resolver(ctx)
+ assert.Error(t, resolverErr, "Resolver() must fail gracefully when server is down")
+
+ // Primary() after Close() should return ErrNotConnected — not panic.
+ _, primaryErr := client.Primary()
+ assert.Error(t, primaryErr, "Primary() must fail gracefully after Close()")
+
+ // Calling Close() again on an already-closed client must not panic.
+ assert.NotPanics(t, func() {
+ _ = client.Close()
+ }, "double Close() must not panic")
+}
+
+// TestIntegration_Postgres_Resilience_ConcurrentResolve validates that when
+// multiple goroutines call Resolver() simultaneously on a disconnected client,
+// the double-checked locking in Resolver() serializes reconnect attempts
+// correctly:
+// - No panics or data races (validated by -race detector).
+// - Only one goroutine performs the actual connect; others either get the
+// reconnected resolver from the second c.resolver!=nil check, or get a
+// rate-limited / connection error.
+// - All goroutines return without hanging (deadlock-free).
+func TestIntegration_Postgres_Resilience_ConcurrentResolve(t *testing.T) {
+ _, dsn, cleanup := setupPostgresContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(newTestConfig(dsn))
+ require.NoError(t, err)
+
+ // Verify healthy state before we break things.
+ err = client.Connect(ctx)
+ require.NoError(t, err)
+
+ resolver, err := client.Resolver(ctx)
+ require.NoError(t, err)
+ require.NoError(t, resolver.PingContext(ctx))
+
+ // Close the wrapper to put the client into "needs reconnect" state.
+ // The container is still running, so reconnect should succeed.
+ require.NoError(t, client.Close())
+
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ require.False(t, connected, "precondition: client must be disconnected")
+
+ const goroutines = 10
+
+ var (
+ wg sync.WaitGroup
+ successCount atomic.Int64
+ errorCount atomic.Int64
+ panicRecovered atomic.Int64
+ )
+
+ wg.Add(goroutines)
+
+ // All goroutines start simultaneously via a shared gate.
+ gate := make(chan struct{})
+
+ for i := range goroutines {
+ go func(id int) {
+ defer wg.Done()
+
+ // Catch any panics so the test can report them rather than crashing.
+ defer func() {
+ if r := recover(); r != nil {
+ panicRecovered.Add(1)
+ t.Errorf("goroutine %d panicked: %v", id, r)
+ }
+ }()
+
+ // Wait for the gate to open so all goroutines race together.
+ <-gate
+
+ res, resolveErr := client.Resolver(ctx)
+ if resolveErr != nil {
+ errorCount.Add(1)
+ return
+ }
+
+ // Verify the returned resolver is functional.
+ if pingErr := res.PingContext(ctx); pingErr != nil {
+ errorCount.Add(1)
+ return
+ }
+
+ successCount.Add(1)
+ }(i)
+ }
+
+ // Use a timeout to detect deadlocks: if goroutines don't finish within
+ // a generous window, something is stuck.
+ done := make(chan struct{})
+ go func() {
+ // Open the gate: all goroutines race into Resolver().
+ close(gate)
+ wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ // All goroutines completed.
+ case <-time.After(30 * time.Second):
+ t.Fatal("DEADLOCK: not all goroutines completed within 30 seconds")
+ }
+
+ successes := successCount.Load()
+ errors := errorCount.Load()
+ panics := panicRecovered.Load()
+
+ t.Logf("Concurrent resolve results: %d successes, %d errors, %d panics",
+ successes, errors, panics)
+
+ // Hard requirement: no panics.
+ assert.Equal(t, int64(0), panics,
+ "no goroutines should panic during concurrent resolve")
+
+ // At least one goroutine must succeed (the one that wins the write lock
+ // and reconnects). Others may succeed too (if they see c.resolver != nil
+ // in the fast path after the winner completes), or fail with rate-limited
+ // errors.
+ assert.Greater(t, successes, int64(0),
+ "at least one goroutine must successfully reconnect")
+
+ // All goroutines must have completed (no hangs).
+ assert.Equal(t, int64(goroutines), successes+errors+panics,
+ "all goroutines must complete")
+
+ // Verify the client is in a good state after the storm.
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected,
+ "client must be connected after successful concurrent resolve")
+
+ // Final cleanup.
+ require.NoError(t, client.Close())
+}
diff --git a/commons/rabbitmq/dlq.go b/commons/rabbitmq/dlq.go
new file mode 100644
index 00000000..99544e1d
--- /dev/null
+++ b/commons/rabbitmq/dlq.go
@@ -0,0 +1,201 @@
+package rabbitmq
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ amqp "github.com/rabbitmq/amqp091-go"
+)
+
+const (
+ defaultDLXExchangeName = "events.dlx"
+ defaultDLQName = "events.dlq"
+ defaultExchangeType = "topic"
+ defaultBindingKey = "#"
+
+ // DefaultDLQMessageTTL is the default TTL for dead-letter queue messages (7 days).
+ // Messages older than this are automatically discarded by the broker.
+ DefaultDLQMessageTTL = 7 * 24 * time.Hour
+
+ // DefaultDLQMaxLength is the default maximum number of messages retained in
+ // the dead-letter queue. When exceeded, the oldest messages are dropped.
+ DefaultDLQMaxLength int64 = 10000
+)
+
+// AMQPChannel defines the AMQP channel operations required for DLQ setup.
+type AMQPChannel interface {
+ ExchangeDeclare(
+ name, kind string,
+ durable, autoDelete, internal, noWait bool,
+ args amqp.Table,
+ ) error
+ QueueDeclare(
+ name string,
+ durable, autoDelete, exclusive, noWait bool,
+ args amqp.Table,
+ ) (amqp.Queue, error)
+ QueueBind(name, key, exchange string, noWait bool, args amqp.Table) error
+}
+
+// DLQTopologyConfig defines exchange/queue names for DLQ topology.
+type DLQTopologyConfig struct {
+ DLXExchangeName string
+ DLQName string
+ ExchangeType string
+ BindingKey string
+ QueueMessageTTL time.Duration
+ QueueMaxLength int64
+}
+
+// DLQOption configures DLQ topology declaration.
+type DLQOption func(*DLQTopologyConfig)
+
+// WithDLXExchangeName overrides the dead-letter exchange name.
+func WithDLXExchangeName(name string) DLQOption {
+ return func(cfg *DLQTopologyConfig) {
+ if name != "" {
+ cfg.DLXExchangeName = name
+ }
+ }
+}
+
+// WithDLQName overrides the dead-letter queue name.
+func WithDLQName(name string) DLQOption {
+ return func(cfg *DLQTopologyConfig) {
+ if name != "" {
+ cfg.DLQName = name
+ }
+ }
+}
+
+// WithDLQExchangeType overrides the dead-letter exchange type.
+func WithDLQExchangeType(exchangeType string) DLQOption {
+ return func(cfg *DLQTopologyConfig) {
+ if exchangeType != "" {
+ cfg.ExchangeType = exchangeType
+ }
+ }
+}
+
+// WithDLQBindingKey overrides the queue binding key to the DLX.
+func WithDLQBindingKey(bindingKey string) DLQOption {
+ return func(cfg *DLQTopologyConfig) {
+ if bindingKey != "" {
+ cfg.BindingKey = bindingKey
+ }
+ }
+}
+
+// WithDLQMessageTTL sets x-message-ttl for the DLQ queue.
+func WithDLQMessageTTL(ttl time.Duration) DLQOption {
+ return func(cfg *DLQTopologyConfig) {
+ if ttl > 0 {
+ cfg.QueueMessageTTL = ttl
+ }
+ }
+}
+
+// WithDLQMaxLength sets x-max-length for the DLQ queue.
+func WithDLQMaxLength(maxLength int64) DLQOption {
+ return func(cfg *DLQTopologyConfig) {
+ if maxLength > 0 {
+ cfg.QueueMaxLength = maxLength
+ }
+ }
+}
+
+func defaultDLQConfig() DLQTopologyConfig {
+ return DLQTopologyConfig{
+ DLXExchangeName: defaultDLXExchangeName,
+ DLQName: defaultDLQName,
+ ExchangeType: defaultExchangeType,
+ BindingKey: defaultBindingKey,
+ QueueMessageTTL: DefaultDLQMessageTTL,
+ QueueMaxLength: DefaultDLQMaxLength,
+ }
+}
+
+func (cfg DLQTopologyConfig) queueDeclareArgs() amqp.Table {
+ args := make(amqp.Table)
+
+ if cfg.QueueMessageTTL > 0 {
+ ttlMillis := cfg.QueueMessageTTL.Milliseconds()
+ if ttlMillis <= 0 {
+ ttlMillis = 1
+ }
+
+ args["x-message-ttl"] = ttlMillis
+ }
+
+ if cfg.QueueMaxLength > 0 {
+ args["x-max-length"] = cfg.QueueMaxLength
+ }
+
+ if len(args) == 0 {
+ return nil
+ }
+
+ return args
+}
+
+// DeclareDLQTopology declares dead-letter exchange and queue.
+func DeclareDLQTopology(ch AMQPChannel, opts ...DLQOption) error {
+ if nilcheck.Interface(ch) {
+ return fmt.Errorf("declare dlq topology: %w", ErrChannelRequired)
+ }
+
+ cfg := defaultDLQConfig()
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(&cfg)
+ }
+ }
+
+ if err := ch.ExchangeDeclare(
+ cfg.DLXExchangeName,
+ cfg.ExchangeType,
+ true,
+ false,
+ false,
+ false,
+ nil,
+ ); err != nil {
+ return fmt.Errorf("declare dlx exchange: %w", err)
+ }
+
+ if _, err := ch.QueueDeclare(
+ cfg.DLQName,
+ true,
+ false,
+ false,
+ false,
+ cfg.queueDeclareArgs(),
+ ); err != nil {
+ return fmt.Errorf("declare dlq queue: %w", err)
+ }
+
+ if err := ch.QueueBind(
+ cfg.DLQName,
+ cfg.BindingKey,
+ cfg.DLXExchangeName,
+ false,
+ nil,
+ ); err != nil {
+ return fmt.Errorf("bind dlq to dlx: %w", err)
+ }
+
+ return nil
+}
+
+// GetDLXArgs returns queue declaration args for dead-lettering.
+func GetDLXArgs(dlxExchangeName string) amqp.Table {
+ if dlxExchangeName == "" {
+ dlxExchangeName = defaultDLXExchangeName
+ }
+
+ return amqp.Table{
+ "x-dead-letter-exchange": dlxExchangeName,
+ }
+}
diff --git a/commons/rabbitmq/dlq_test.go b/commons/rabbitmq/dlq_test.go
new file mode 100644
index 00000000..26bd47b8
--- /dev/null
+++ b/commons/rabbitmq/dlq_test.go
@@ -0,0 +1,239 @@
+//go:build unit
+
+package rabbitmq
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ amqp "github.com/rabbitmq/amqp091-go"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+type fakeChannel struct {
+ exchangeDeclareCount int
+ queueDeclareCount int
+ queueBindCount int
+
+ lastExchangeName string
+ lastExchangeType string
+ lastQueueName string
+ lastQueueArgs amqp.Table
+ lastBindQueue string
+ lastBindKey string
+ lastBindExchange string
+}
+
+func (f *fakeChannel) ExchangeDeclare(name, kind string, _, _, _, _ bool, _ amqp.Table) error {
+ f.exchangeDeclareCount++
+ f.lastExchangeName = name
+ f.lastExchangeType = kind
+
+ return nil
+}
+
+func (f *fakeChannel) QueueDeclare(name string, _, _, _, _ bool, args amqp.Table) (amqp.Queue, error) {
+ f.queueDeclareCount++
+ f.lastQueueName = name
+ f.lastQueueArgs = args
+
+ return amqp.Queue{Name: name}, nil
+}
+
+func (f *fakeChannel) QueueBind(name, key, exchange string, _ bool, _ amqp.Table) error {
+ f.queueBindCount++
+ f.lastBindQueue = name
+ f.lastBindKey = key
+ f.lastBindExchange = exchange
+
+ return nil
+}
+
+func TestDeclareDLQTopology_Success(t *testing.T) {
+ t.Parallel()
+
+ ch := &fakeChannel{}
+ err := DeclareDLQTopology(ch, WithDLXExchangeName("matcher.events.dlx"), WithDLQName("matcher.events.dlq"))
+
+ require.NoError(t, err)
+ assert.Equal(t, 1, ch.exchangeDeclareCount)
+ assert.Equal(t, 1, ch.queueDeclareCount)
+ assert.Equal(t, 1, ch.queueBindCount)
+
+ assert.Equal(t, "matcher.events.dlx", ch.lastExchangeName)
+ assert.Equal(t, defaultExchangeType, ch.lastExchangeType)
+ assert.Equal(t, "matcher.events.dlq", ch.lastQueueName)
+ assert.Equal(t, "matcher.events.dlq", ch.lastBindQueue)
+ assert.Equal(t, "#", ch.lastBindKey)
+ assert.Equal(t, "matcher.events.dlx", ch.lastBindExchange)
+
+ // Verify default TTL and max-length are applied
+ require.NotNil(t, ch.lastQueueArgs)
+ assert.Equal(t, DefaultDLQMessageTTL.Milliseconds(), ch.lastQueueArgs["x-message-ttl"])
+ assert.Equal(t, DefaultDLQMaxLength, ch.lastQueueArgs["x-max-length"])
+}
+
+func TestDeclareDLQTopology_NilChannel(t *testing.T) {
+ t.Parallel()
+
+ err := DeclareDLQTopology(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrChannelRequired)
+}
+
+func TestDeclareDLQTopology_TypedNilChannel(t *testing.T) {
+ t.Parallel()
+
+ var nilChannel *fakeChannel
+ var ch AMQPChannel = nilChannel
+
+ err := DeclareDLQTopology(ch)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrChannelRequired)
+}
+
+func TestDeclareDLQTopology_ExchangeError(t *testing.T) {
+ t.Parallel()
+
+ ch := &fakeChannelExchangeError{}
+ err := DeclareDLQTopology(ch)
+ require.Error(t, err)
+ require.ErrorIs(t, err, errExchangeFailed)
+}
+
+func TestDeclareDLQTopology_QueueDeclareError(t *testing.T) {
+ t.Parallel()
+
+ errQueueDeclareFailed := errors.New("queue declare failed")
+ ch := &fakeChannelQueueDeclareError{err: errQueueDeclareFailed}
+
+ err := DeclareDLQTopology(ch)
+ require.Error(t, err)
+ require.ErrorIs(t, err, errQueueDeclareFailed)
+}
+
+func TestDeclareDLQTopology_QueueBindError(t *testing.T) {
+ t.Parallel()
+
+ errQueueBindFailed := errors.New("queue bind failed")
+ ch := &fakeChannelQueueBindError{err: errQueueBindFailed}
+
+ err := DeclareDLQTopology(ch)
+ require.Error(t, err)
+ require.ErrorIs(t, err, errQueueBindFailed)
+}
+
+var errExchangeFailed = errors.New("exchange declare failed")
+
+type fakeChannelExchangeError struct{ fakeChannel }
+
+func (f *fakeChannelExchangeError) ExchangeDeclare(
+ _, _ string,
+ _, _, _, _ bool,
+ _ amqp.Table,
+) error {
+ return errExchangeFailed
+}
+
+func TestGetDLXArgs(t *testing.T) {
+ t.Parallel()
+
+ args := GetDLXArgs("my.dlx")
+ require.NotNil(t, args)
+ assert.Equal(t, "my.dlx", args["x-dead-letter-exchange"])
+}
+
+func TestGetDLXArgs_DefaultExchange(t *testing.T) {
+ t.Parallel()
+
+ args := GetDLXArgs("")
+ require.NotNil(t, args)
+ assert.Equal(t, defaultDLXExchangeName, args["x-dead-letter-exchange"])
+}
+
+func TestDeclareDLQTopology_CustomExchangeTypeAndBindingKey(t *testing.T) {
+ t.Parallel()
+
+ ch := &fakeChannel{}
+ err := DeclareDLQTopology(
+ ch,
+ WithDLQExchangeType("direct"),
+ WithDLQBindingKey("payments.failed"),
+ )
+
+ require.NoError(t, err)
+ assert.Equal(t, "direct", ch.lastExchangeType)
+ assert.Equal(t, "payments.failed", ch.lastBindKey)
+}
+
+func TestDeclareDLQTopology_EmptyExchangeTypeAndBindingKeyKeepDefaults(t *testing.T) {
+ t.Parallel()
+
+ ch := &fakeChannel{}
+ err := DeclareDLQTopology(
+ ch,
+ WithDLQExchangeType(""),
+ WithDLQBindingKey(""),
+ )
+
+ require.NoError(t, err)
+ assert.Equal(t, defaultExchangeType, ch.lastExchangeType)
+ assert.Equal(t, defaultBindingKey, ch.lastBindKey)
+}
+
+func TestDeclareDLQTopology_QueueArgsOptions(t *testing.T) {
+ t.Parallel()
+
+ ch := &fakeChannel{}
+ err := DeclareDLQTopology(
+ ch,
+ WithDLQMessageTTL(45*time.Second),
+ WithDLQMaxLength(500),
+ )
+
+ require.NoError(t, err)
+ require.NotNil(t, ch.lastQueueArgs)
+ assert.Equal(t, int64(45000), ch.lastQueueArgs["x-message-ttl"])
+ assert.Equal(t, int64(500), ch.lastQueueArgs["x-max-length"])
+}
+
+func TestDeclareDLQTopology_ZeroTTLAndMaxLengthKeepDefaults(t *testing.T) {
+ t.Parallel()
+
+ ch := &fakeChannel{}
+ err := DeclareDLQTopology(
+ ch,
+ WithDLQMessageTTL(0),
+ WithDLQMaxLength(0),
+ )
+
+ require.NoError(t, err)
+ // Zero values in options are ignored, so defaults apply (7 days TTL, 10000 max-length).
+ require.NotNil(t, ch.lastQueueArgs)
+ assert.Equal(t, DefaultDLQMessageTTL.Milliseconds(), ch.lastQueueArgs["x-message-ttl"])
+ assert.Equal(t, DefaultDLQMaxLength, ch.lastQueueArgs["x-max-length"])
+}
+
+type fakeChannelQueueDeclareError struct {
+ fakeChannel
+ err error
+}
+
+func (f *fakeChannelQueueDeclareError) QueueDeclare(
+ _ string,
+ _, _, _, _ bool,
+ _ amqp.Table,
+) (amqp.Queue, error) {
+ return amqp.Queue{}, f.err
+}
+
+type fakeChannelQueueBindError struct {
+ fakeChannel
+ err error
+}
+
+func (f *fakeChannelQueueBindError) QueueBind(_ string, _ string, _ string, _ bool, _ amqp.Table) error {
+ return f.err
+}
diff --git a/commons/rabbitmq/doc.go b/commons/rabbitmq/doc.go
new file mode 100644
index 00000000..1e2f992c
--- /dev/null
+++ b/commons/rabbitmq/doc.go
@@ -0,0 +1,17 @@
+// Package rabbitmq provides AMQP connection, consumer, and producer helpers.
+//
+// It includes safer connection-string error sanitization and health-check helpers,
+// a confirmable publisher abstraction with broker-ack waiting and auto-recovery
+// (serialized publish+confirm per publisher instance for deterministic confirms),
+// and DLQ topology declaration helpers.
+//
+// Health-check security defaults:
+// - Basic auth over plain HTTP is rejected unless AllowInsecureHealthCheck=true.
+// - Basic-auth health checks require HealthCheckAllowedHosts unless
+// AllowInsecureHealthCheck=true. Hosts can be derived automatically from
+// AMQP connection settings when explicit allowlist entries are not set.
+// - Health-check host restrictions can be enforced with HealthCheckAllowedHosts
+// (entries may be host, host:port, or CIDR) and RequireHealthCheckAllowedHosts.
+// - When basic auth is not used and no explicit allowlist is configured,
+// compatibility mode keeps host validation permissive by default.
+package rabbitmq
diff --git a/commons/rabbitmq/publisher.go b/commons/rabbitmq/publisher.go
new file mode 100644
index 00000000..b78b7161
--- /dev/null
+++ b/commons/rabbitmq/publisher.go
@@ -0,0 +1,943 @@
+package rabbitmq
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/backoff"
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+ amqp "github.com/rabbitmq/amqp091-go"
+)
+
+// recoveryAttemptResult indicates the outcome of a single recovery attempt.
+type recoveryAttemptResult int
+
+const (
+ recoveryAttemptRetry recoveryAttemptResult = iota // retry next attempt
+ recoveryAttemptSuccess // recovery succeeded
+ recoveryAttemptAborted // recovery aborted externally
+)
+
+// Publisher confirm errors.
+var (
+ // ErrConnectionRequired aliases ErrNilConnection for naming consistency in publisher constructors.
+ ErrConnectionRequired = ErrNilConnection
+ ErrPublisherRequired = errors.New("confirmable publisher is required")
+ ErrChannelRequired = errors.New("rabbitmq channel is required")
+ ErrPublisherNotReady = errors.New("confirmable publisher not initialized")
+ ErrConfirmModeUnavailable = errors.New("channel does not support confirm mode")
+ ErrPublishNacked = errors.New("message was nacked by broker")
+ ErrConfirmTimeout = errors.New("confirmation timed out")
+ ErrPublisherClosed = errors.New("publisher is closed")
+ ErrReconnectAfterClose = errors.New("cannot reconnect: publisher was explicitly closed")
+ ErrReconnectWhileOpen = errors.New("cannot reconnect: publisher is still open, call Close first")
+ ErrRecoveryExhausted = errors.New("automatic recovery exhausted all attempts")
+)
+
+const (
+ // DefaultConfirmTimeout is the default timeout for waiting on broker confirmation.
+ DefaultConfirmTimeout = 5 * time.Second
+
+ // confirmChannelBuffer is the buffer size for the confirmation channel.
+ // Should be >= max unconfirmed messages to avoid blocking.
+ confirmChannelBuffer = 256
+
+ // DefaultMaxRecoveryAttempts is the default number of recovery attempts before giving up.
+ DefaultMaxRecoveryAttempts = 10
+
+ // DefaultRecoveryBackoffInitial is the starting backoff duration for recovery retries.
+ DefaultRecoveryBackoffInitial = 1 * time.Second
+
+ // DefaultRecoveryBackoffMax is the maximum backoff duration between recovery retries.
+ DefaultRecoveryBackoffMax = 30 * time.Second
+)
+
+// HealthState represents the current connection health of a ConfirmablePublisher.
+type HealthState int
+
+const (
+ // HealthStateConnected indicates the publisher has a healthy AMQP channel
+ // and is ready to publish messages.
+ HealthStateConnected HealthState = iota
+
+ // HealthStateReconnecting indicates the publisher detected a channel closure
+ // and is actively attempting to recover by obtaining a new channel.
+ HealthStateReconnecting
+
+ // HealthStateDegraded indicates the publisher's confirmation stream was
+ // corrupted (e.g., confirm timeout or context cancellation). The underlying
+ // channel has been invalidated but auto-recovery may restore it. If no
+ // auto-recovery is configured, callers should call Reconnect() to recover.
+ HealthStateDegraded
+
+ // HealthStateDisconnected indicates the publisher has exhausted all recovery
+ // attempts and is no longer able to publish. Manual intervention is required.
+ HealthStateDisconnected
+)
+
+// String returns a human-readable representation of the health state.
+func (h HealthState) String() string {
+ switch h {
+ case HealthStateConnected:
+ return "connected"
+ case HealthStateReconnecting:
+ return "reconnecting"
+ case HealthStateDegraded:
+ return "degraded"
+ case HealthStateDisconnected:
+ return "disconnected"
+ default:
+ return "unknown"
+ }
+}
+
+// ChannelProvider is a function that returns a new AMQP channel for recovery.
+// It is called by the auto-recovery goroutine when the current channel closes.
+// The returned channel must be a fresh, dedicated channel (not shared with
+// other publishers). The provider should handle its own connection management
+// internally.
+type ChannelProvider func() (ConfirmableChannel, error)
+
+// HealthCallback is called when the publisher's connection health changes.
+type HealthCallback func(HealthState)
+
+// recoveryConfig holds the auto-recovery configuration.
+// A nil recoveryConfig means auto-recovery is disabled.
+type recoveryConfig struct {
+ provider ChannelProvider
+ healthCallback HealthCallback
+ maxAttempts int
+ backoffInitial time.Duration
+ backoffMax time.Duration
+}
+
+// ConfirmableChannel defines the interface for AMQP channel operations with confirms.
+type ConfirmableChannel interface {
+ Confirm(noWait bool) error
+ NotifyPublish(confirm chan amqp.Confirmation) chan amqp.Confirmation
+ NotifyClose(c chan *amqp.Error) chan *amqp.Error
+ PublishWithContext(
+ ctx context.Context,
+ exchange, key string,
+ mandatory, immediate bool,
+ msg amqp.Publishing,
+ ) error
+ Close() error
+}
+
+// ConfirmablePublisher wraps an AMQP channel with publisher confirms enabled.
+type ConfirmablePublisher struct {
+ ch ConfirmableChannel
+ confirms chan amqp.Confirmation
+ closedCh chan struct{}
+ closeOnce *sync.Once
+ done chan struct{}
+ logger libLog.Logger
+ confirmTimeout time.Duration
+ invalidConfirmTimeout struct {
+ set bool
+ value time.Duration
+ }
+ recovery *recoveryConfig
+ mu sync.RWMutex
+ publishMu sync.Mutex
+ health HealthState
+ closed bool
+ shutdown bool
+ recoveryExhausted bool
+}
+
+// ConfirmablePublisherOption configures a ConfirmablePublisher.
+type ConfirmablePublisherOption func(*ConfirmablePublisher)
+
+// WithLogger sets a structured logger for the publisher.
+func WithLogger(logger libLog.Logger) ConfirmablePublisherOption {
+ return func(pub *ConfirmablePublisher) {
+ if nilcheck.Interface(logger) {
+ return
+ }
+
+ pub.logger = logger
+ }
+}
+
+// WithConfirmTimeout sets the timeout for waiting on broker confirmation.
+func WithConfirmTimeout(timeout time.Duration) ConfirmablePublisherOption {
+ return func(pub *ConfirmablePublisher) {
+ if timeout > 0 {
+ pub.confirmTimeout = timeout
+ pub.invalidConfirmTimeout.set = false
+ pub.invalidConfirmTimeout.value = 0
+
+ return
+ }
+
+ pub.invalidConfirmTimeout.set = true
+ pub.invalidConfirmTimeout.value = timeout
+ }
+}
+
+// WithAutoRecovery enables automatic channel recovery.
+func WithAutoRecovery(provider ChannelProvider) ConfirmablePublisherOption {
+ return func(pub *ConfirmablePublisher) {
+ if provider == nil {
+ return
+ }
+
+ ensureRecoveryConfig(pub)
+
+ pub.recovery.provider = provider
+ }
+}
+
+// WithMaxRecoveryAttempts sets maximum consecutive recovery attempts.
+func WithMaxRecoveryAttempts(maxAttempts int) ConfirmablePublisherOption {
+ return func(pub *ConfirmablePublisher) {
+ if maxAttempts <= 0 {
+ return
+ }
+
+ ensureRecoveryConfig(pub)
+
+ pub.recovery.maxAttempts = maxAttempts
+ }
+}
+
+// WithRecoveryBackoff sets the initial and max backoff durations for recovery.
+func WithRecoveryBackoff(initial, maxBackoff time.Duration) ConfirmablePublisherOption {
+ return func(pub *ConfirmablePublisher) {
+ if initial <= 0 || maxBackoff <= 0 {
+ return
+ }
+
+ if initial > maxBackoff {
+ logIfConfigured(
+ pub.logger,
+ libLog.LevelWarn,
+ fmt.Sprintf("rabbitmq: ignoring invalid recovery backoff initial=%v max=%v", initial, maxBackoff),
+ )
+
+ return
+ }
+
+ ensureRecoveryConfig(pub)
+
+ pub.recovery.backoffInitial = initial
+ pub.recovery.backoffMax = maxBackoff
+ }
+}
+
+// WithHealthCallback registers a callback for health state changes.
+func WithHealthCallback(fn HealthCallback) ConfirmablePublisherOption {
+ return func(pub *ConfirmablePublisher) {
+ if fn == nil {
+ return
+ }
+
+ ensureRecoveryConfig(pub)
+
+ pub.recovery.healthCallback = fn
+ }
+}
+
+// NewConfirmablePublisher creates a publisher with confirms enabled.
+func NewConfirmablePublisher(
+ conn *RabbitMQConnection,
+ opts ...ConfirmablePublisherOption,
+) (*ConfirmablePublisher, error) {
+ if conn == nil {
+ return nil, ErrConnectionRequired
+ }
+
+ channel := conn.ChannelSnapshot()
+
+ if channel == nil {
+ return nil, ErrChannelRequired
+ }
+
+ return NewConfirmablePublisherFromChannel(channel, opts...)
+}
+
+// NewConfirmablePublisherFromChannel creates a publisher from an existing channel.
+func NewConfirmablePublisherFromChannel(
+ ch ConfirmableChannel,
+ opts ...ConfirmablePublisherOption,
+) (*ConfirmablePublisher, error) {
+ if nilcheck.Interface(ch) {
+ return nil, ErrChannelRequired
+ }
+
+ if err := ch.Confirm(false); err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrConfirmModeUnavailable, err)
+ }
+
+ confirms := make(chan amqp.Confirmation, confirmChannelBuffer)
+ ch.NotifyPublish(confirms)
+
+ closeNotify := ch.NotifyClose(make(chan *amqp.Error, 1))
+
+ publisher := &ConfirmablePublisher{
+ ch: ch,
+ confirms: confirms,
+ closedCh: make(chan struct{}),
+ closeOnce: &sync.Once{},
+ done: make(chan struct{}),
+ logger: libLog.NewNop(),
+ confirmTimeout: DefaultConfirmTimeout,
+ health: HealthStateConnected,
+ }
+
+ for _, opt := range opts {
+ if opt != nil {
+ opt(publisher)
+ }
+ }
+
+ publisher.logDeferredOptionWarnings()
+
+ publisher.startCloseMonitor(closeNotify)
+
+ return publisher, nil
+}
+
+// startCloseMonitor launches a goroutine that watches channel close events.
+func (pub *ConfirmablePublisher) startCloseMonitor(closeNotify chan *amqp.Error) {
+ monitorDone := pub.done
+ monitorLogger := pub.logger
+
+ runtime.SafeGo(monitorLogger, "confirmable-publisher-close-monitor", runtime.KeepRunning, func() {
+ select {
+ case amqpErr := <-closeNotify:
+ pub.handleMonitoredClose(amqpErr)
+ case <-monitorDone:
+ return
+ }
+ })
+}
+
+func (pub *ConfirmablePublisher) handleMonitoredClose(amqpErr *amqp.Error) {
+ pub.mu.Lock()
+ pub.ensureCloseSignalsLocked()
+ monitorCloseOnce := pub.closeOnce
+ monitorClosedCh := pub.closedCh
+ hasRecovery := pub.recovery != nil && pub.recovery.provider != nil
+ pub.closed = true
+ pub.mu.Unlock()
+
+ monitorCloseOnce.Do(func() { close(monitorClosedCh) })
+
+ if hasRecovery {
+ pub.attemptAutoRecovery(amqpErr)
+
+ return
+ }
+
+ pub.emitHealthState(HealthStateDisconnected)
+}
+
+func (pub *ConfirmablePublisher) attemptAutoRecovery(amqpErr *amqp.Error) {
+ pub.mu.RLock()
+ recovery := pub.recovery
+ logger := pub.logger
+ pub.mu.RUnlock()
+
+ if recovery == nil || recovery.provider == nil {
+ return
+ }
+
+ pub.emitHealthState(HealthStateReconnecting)
+ pub.logChannelClosed(logger, amqpErr, recovery.maxAttempts)
+
+ if !pub.prepareForRecovery() {
+ logIfConfigured(logger, libLog.LevelInfo, "rabbitmq: recovery aborted, publisher is shutting down")
+ pub.emitHealthState(HealthStateDisconnected)
+
+ return
+ }
+
+ pub.mu.RLock()
+ recoveryStop := pub.done
+ pub.mu.RUnlock()
+
+ for attempt := range recovery.maxAttempts {
+ result := pub.executeRecoveryAttempt(recovery, logger, recoveryStop, attempt)
+ if result == recoveryAttemptSuccess || result == recoveryAttemptAborted {
+ return
+ }
+ }
+
+ logIfConfigured(
+ logger,
+ libLog.LevelError,
+ fmt.Sprintf("rabbitmq: auto-recovery failed after %d attempts, publisher is disconnected", recovery.maxAttempts),
+ )
+
+ pub.mu.Lock()
+ pub.recoveryExhausted = true
+ pub.mu.Unlock()
+
+ pub.emitHealthState(HealthStateDisconnected)
+}
+
+func (pub *ConfirmablePublisher) logChannelClosed(logger libLog.Logger, amqpErr *amqp.Error, maxAttempts int) {
+ if nilcheck.Interface(logger) {
+ return
+ }
+
+ errMsg := "unknown"
+ if amqpErr != nil {
+ errMsg = sanitizeAMQPErr(amqpErr, "")
+ }
+
+ logger.Log(context.Background(), libLog.LevelWarn,
+ fmt.Sprintf("rabbitmq: channel closed (%s), starting auto-recovery (max %d attempts)", errMsg, maxAttempts))
+}
+
+func (pub *ConfirmablePublisher) executeRecoveryAttempt(
+ recovery *recoveryConfig,
+ logger libLog.Logger,
+ recoveryStop <-chan struct{},
+ attempt int,
+) recoveryAttemptResult {
+ select {
+ case <-recoveryStop:
+ logIfConfigured(logger, libLog.LevelInfo, "rabbitmq: recovery aborted (publisher closed externally)")
+ pub.emitHealthState(HealthStateDisconnected)
+
+ return recoveryAttemptAborted
+ default:
+ }
+
+ if aborted := pub.waitRecoveryBackoff(recovery, logger, recoveryStop, attempt); aborted {
+ return recoveryAttemptAborted
+ }
+
+ return pub.tryReconnectChannel(recovery, logger, attempt)
+}
+
+func (pub *ConfirmablePublisher) waitRecoveryBackoff(
+ recovery *recoveryConfig,
+ logger libLog.Logger,
+ recoveryStop <-chan struct{},
+ attempt int,
+) bool {
+ delay := backoff.ExponentialWithJitter(recovery.backoffInitial, attempt)
+ if delay > recovery.backoffMax {
+ delay = backoff.FullJitter(recovery.backoffMax)
+ }
+
+ logIfConfigured(
+ logger,
+ libLog.LevelInfo,
+ fmt.Sprintf("rabbitmq: recovery attempt %d/%d, backoff %v", attempt+1, recovery.maxAttempts, delay),
+ )
+
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+
+ select {
+ case <-timer.C:
+ return false
+ case <-recoveryStop:
+ logIfConfigured(logger, libLog.LevelInfo, "rabbitmq: recovery aborted during backoff (publisher closed)")
+ pub.emitHealthState(HealthStateDisconnected)
+
+ return true
+ }
+}
+
+func (pub *ConfirmablePublisher) tryReconnectChannel(
+ recovery *recoveryConfig,
+ logger libLog.Logger,
+ attempt int,
+) recoveryAttemptResult {
+ newCh, err := recovery.provider()
+ if err != nil {
+ sanitizedErr := sanitizeAMQPErr(err, "")
+ logIfConfigured(
+ logger,
+ libLog.LevelWarn,
+ fmt.Sprintf("rabbitmq: recovery attempt %d/%d failed: %s", attempt+1, recovery.maxAttempts, sanitizedErr),
+ )
+
+ return recoveryAttemptRetry
+ }
+
+ if err := pub.Reconnect(newCh); err != nil {
+ sanitizedErr := sanitizeAMQPErr(err, "")
+ logIfConfigured(
+ logger,
+ libLog.LevelWarn,
+ fmt.Sprintf("rabbitmq: recovery attempt %d/%d reconnect failed: %s", attempt+1, recovery.maxAttempts, sanitizedErr),
+ )
+
+ if !nilcheck.Interface(newCh) {
+ _ = newCh.Close()
+ }
+
+ return recoveryAttemptRetry
+ }
+
+ logIfConfigured(
+ logger,
+ libLog.LevelInfo,
+ fmt.Sprintf("rabbitmq: auto-recovery succeeded on attempt %d/%d", attempt+1, recovery.maxAttempts),
+ )
+
+ pub.emitHealthState(HealthStateConnected)
+
+ return recoveryAttemptSuccess
+}
+
+func (pub *ConfirmablePublisher) prepareForRecovery() bool {
+ pub.publishMu.Lock()
+ defer pub.publishMu.Unlock()
+
+ pub.mu.Lock()
+ if pub.shutdown {
+ pub.mu.Unlock()
+
+ return false
+ }
+
+ currentCh := pub.ch
+ confirms := pub.confirms
+ confirmTimeout := pub.confirmTimeout
+ pub.ensureCloseSignalsLocked()
+
+ pub.closed = true
+ pub.recoveryExhausted = false
+ pub.ch = nil
+ safeCloseSignal(pub.done)
+ pub.closeOnce.Do(func() { close(pub.closedCh) })
+ pub.mu.Unlock()
+
+ if !nilcheck.Interface(currentCh) {
+ _ = currentCh.Close()
+ }
+
+ drainConfirms(confirms, confirmTimeout)
+
+ pub.mu.Lock()
+ pub.done = make(chan struct{})
+ pub.mu.Unlock()
+
+ return true
+}
+
+func (pub *ConfirmablePublisher) emitHealthState(state HealthState) {
+ pub.mu.Lock()
+ pub.health = state
+ recovery := pub.recovery
+ pub.mu.Unlock()
+
+ if recovery == nil || recovery.healthCallback == nil {
+ return
+ }
+
+ recovery.healthCallback(state)
+}
+
+// Publish sends a message and waits for broker confirmation.
+//
+// This method is intentionally serialized per publisher instance: only one
+// publish+confirm flow is in-flight at a time. For explicit naming, prefer
+// PublishAndWaitConfirm. For higher throughput, shard publishing across
+// multiple publisher instances.
+func (pub *ConfirmablePublisher) Publish(
+ ctx context.Context,
+ exchange, routingKey string,
+ mandatory, immediate bool,
+ msg amqp.Publishing,
+) error {
+ if pub == nil {
+ return ErrPublisherRequired
+ }
+
+ return pub.PublishAndWaitConfirm(ctx, exchange, routingKey, mandatory, immediate, msg)
+}
+
+// PublishAndWaitConfirm sends a message and synchronously waits for broker confirmation.
+//
+// Calls are serialized per publisher instance to preserve confirm ordering
+// without delivery-tag correlation state.
+func (pub *ConfirmablePublisher) PublishAndWaitConfirm(
+ ctx context.Context,
+ exchange, routingKey string,
+ mandatory, immediate bool,
+ msg amqp.Publishing,
+) error {
+ if pub == nil {
+ return ErrPublisherRequired
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ pub.publishMu.Lock()
+ defer pub.publishMu.Unlock()
+
+ pub.mu.RLock()
+
+ if pub.closed {
+ recoveryExhausted := pub.recoveryExhausted
+ pub.mu.RUnlock()
+
+ if recoveryExhausted {
+ return fmt.Errorf("%w: %w", ErrPublisherClosed, ErrRecoveryExhausted)
+ }
+
+ return ErrPublisherClosed
+ }
+
+ if pub.ch == nil {
+ pub.mu.RUnlock()
+ return ErrPublisherNotReady
+ }
+
+ publishChannel := pub.ch
+ confirms := pub.confirms
+ closedCh := pub.closedCh
+ confirmTimeout := pub.confirmTimeout
+ pub.mu.RUnlock()
+
+ if err := publishChannel.PublishWithContext(ctx, exchange, routingKey, mandatory, immediate, msg); err != nil {
+ return fmt.Errorf("publish: %w", err)
+ }
+
+ err := waitForConfirm(ctx, confirms, closedCh, confirmTimeout)
+ if err != nil && isConfirmStreamCorrupted(err) {
+ // The pending confirmation will corrupt the next waitForConfirm call.
+ // Invalidate the channel so the close monitor triggers auto-recovery
+ // after publishMu is released by the deferred unlock above.
+ pub.invalidateChannel(publishChannel)
+ }
+
+ return err
+}
+
+// isConfirmStreamCorrupted reports whether the error indicates the
+// confirmation channel has a stale entry that would desynchronize the
+// next waitForConfirm call.
+func isConfirmStreamCorrupted(err error) bool {
+ return errors.Is(err, ErrConfirmTimeout) ||
+ errors.Is(err, context.Canceled) ||
+ errors.Is(err, context.DeadlineExceeded)
+}
+
+// invalidateChannel marks the publisher as closed and closes the
+// underlying AMQP channel. The close event propagates to the close
+// monitor goroutine which initiates auto-recovery (if configured)
+// after the caller releases publishMu.
+//
+// The publisher transitions to HealthStateDegraded to signal that the
+// confirmation stream is corrupted but recovery may restore it. If
+// auto-recovery is not configured, callers should call Reconnect()
+// with a fresh channel to restore the publisher.
+//
+// Must be called while holding publishMu.
+func (pub *ConfirmablePublisher) invalidateChannel(ch ConfirmableChannel) {
+ pub.mu.Lock()
+ pub.ensureCloseSignalsLocked()
+ pub.closed = true
+ pub.ch = nil
+ pub.mu.Unlock()
+
+ pub.emitHealthState(HealthStateDegraded)
+
+ pub.closeOnce.Do(func() { close(pub.closedCh) })
+
+ if !nilcheck.Interface(ch) {
+ _ = ch.Close()
+ }
+}
+
+func waitForConfirm(
+ ctx context.Context,
+ confirms <-chan amqp.Confirmation,
+ closedCh <-chan struct{},
+ confirmTimeout time.Duration,
+) error {
+ timeout := time.NewTimer(confirmTimeout)
+ defer timeout.Stop()
+
+ select {
+ case confirmed, ok := <-confirms:
+ if !ok {
+ return ErrPublisherClosed
+ }
+
+ if !confirmed.Ack {
+ return fmt.Errorf("%w: delivery_tag=%d", ErrPublishNacked, confirmed.DeliveryTag)
+ }
+
+ return nil
+
+ case <-closedCh:
+ return ErrPublisherClosed
+
+ case <-timeout.C:
+ return ErrConfirmTimeout
+
+ case <-ctx.Done():
+ return fmt.Errorf("context cancelled: %w", ctx.Err())
+ }
+}
+
+// Close drains pending confirmations and permanently closes the publisher.
+// After Close, Reconnect is rejected and callers should create a new publisher.
+func (pub *ConfirmablePublisher) Close() error {
+ if pub == nil {
+ return ErrPublisherRequired
+ }
+
+ pub.publishMu.Lock()
+ defer pub.publishMu.Unlock()
+
+ pub.mu.Lock()
+ pub.ensureCloseSignalsLocked()
+
+ if pub.shutdown {
+ pub.mu.Unlock()
+
+ return nil
+ }
+
+ pub.shutdown = true
+ pub.closed = true
+ pub.recoveryExhausted = false
+ currentCh := pub.ch
+ safeCloseSignal(pub.done)
+ pub.closeOnce.Do(func() { close(pub.closedCh) })
+ pub.mu.Unlock()
+
+ if !nilcheck.Interface(currentCh) {
+ if err := currentCh.Close(); err != nil {
+ return fmt.Errorf("closing publisher channel: %w", err)
+ }
+ }
+
+ drainConfirms(pub.confirms, pub.confirmTimeout)
+ pub.emitHealthState(HealthStateDisconnected)
+
+ return nil
+}
+
+// Reconnect replaces the underlying AMQP channel with a fresh one.
+//
+// Caller contract:
+// - Reconnect is only valid after an operational close (for example, auto-recovery
+// transition) when publisher.closed is true and publisher.shutdown is false.
+// - After explicit Close, the publisher enters terminal shutdown and Reconnect
+// returns ErrReconnectAfterClose.
+//
+// Reconnect replaces the underlying AMQP channel with a fresh one.
+//
+// Caller contract:
+// - Reconnect is only valid after an operational close (for example, auto-recovery
+// transition) when publisher.closed is true and publisher.shutdown is false.
+// - After explicit Close, the publisher enters terminal shutdown and Reconnect
+// returns ErrReconnectAfterClose.
+// - On success, the publisher transitions to HealthStateConnected and the
+// health callback is invoked.
+func (pub *ConfirmablePublisher) Reconnect(ch ConfirmableChannel) error {
+ if pub == nil {
+ return ErrPublisherRequired
+ }
+
+ if nilcheck.Interface(ch) {
+ return ErrChannelRequired
+ }
+
+ pub.publishMu.Lock()
+ defer pub.publishMu.Unlock()
+
+ var healthCallback HealthCallback
+
+ pub.mu.Lock()
+
+ if !pub.closed {
+ pub.mu.Unlock()
+
+ return ErrReconnectWhileOpen
+ }
+
+ if pub.shutdown {
+ pub.mu.Unlock()
+
+ return ErrReconnectAfterClose
+ }
+
+ if err := ch.Confirm(false); err != nil {
+ pub.mu.Unlock()
+
+ return fmt.Errorf("%w: %w", ErrConfirmModeUnavailable, err)
+ }
+
+ confirms := make(chan amqp.Confirmation, confirmChannelBuffer)
+ ch.NotifyPublish(confirms)
+
+ closeNotify := ch.NotifyClose(make(chan *amqp.Error, 1))
+
+ pub.ch = ch
+ pub.confirms = confirms
+ pub.closedCh = make(chan struct{})
+
+ pub.closeOnce = &sync.Once{}
+ if pub.done == nil {
+ pub.done = make(chan struct{})
+ }
+
+ pub.closed = false
+ pub.recoveryExhausted = false
+ pub.health = HealthStateConnected
+
+ if pub.recovery != nil {
+ healthCallback = pub.recovery.healthCallback
+ }
+
+ pub.startCloseMonitor(closeNotify)
+
+ pub.mu.Unlock()
+
+ // Emit health callback outside the lock to avoid deadlock with caller callbacks.
+ if healthCallback != nil {
+ healthCallback(HealthStateConnected)
+ }
+
+ return nil
+}
+
+// Channel returns the underlying channel for low-level operations.
+//
+// The return value can be nil when the publisher is closed, reconnecting,
+// or not yet initialized. Call ChannelOrError when callers need explicit
+// readiness errors.
+func (pub *ConfirmablePublisher) Channel() ConfirmableChannel {
+ if pub == nil {
+ return nil
+ }
+
+ pub.mu.RLock()
+ defer pub.mu.RUnlock()
+
+ if pub.closed {
+ return nil
+ }
+
+ return pub.ch
+}
+
+// ChannelOrError returns the underlying channel only when the publisher is ready.
+func (pub *ConfirmablePublisher) ChannelOrError() (ConfirmableChannel, error) {
+ if pub == nil {
+ return nil, ErrPublisherRequired
+ }
+
+ pub.mu.RLock()
+ defer pub.mu.RUnlock()
+
+ if pub.closed {
+ return nil, ErrPublisherClosed
+ }
+
+ if pub.ch == nil {
+ return nil, ErrPublisherNotReady
+ }
+
+ return pub.ch, nil
+}
+
+// HealthState returns the latest synchronous health state snapshot.
+func (pub *ConfirmablePublisher) HealthState() HealthState {
+ if pub == nil {
+ return HealthStateDisconnected
+ }
+
+ pub.mu.RLock()
+ defer pub.mu.RUnlock()
+
+ return pub.health
+}
+
+func ensureRecoveryConfig(pub *ConfirmablePublisher) {
+ if pub.recovery != nil {
+ return
+ }
+
+ pub.recovery = &recoveryConfig{
+ maxAttempts: DefaultMaxRecoveryAttempts,
+ backoffInitial: DefaultRecoveryBackoffInitial,
+ backoffMax: DefaultRecoveryBackoffMax,
+ }
+}
+
+func (pub *ConfirmablePublisher) logDeferredOptionWarnings() {
+ if !pub.invalidConfirmTimeout.set {
+ return
+ }
+
+ logIfConfigured(pub.logger, libLog.LevelWarn,
+ fmt.Sprintf("rabbitmq: ignoring invalid confirm timeout %v, using default", pub.invalidConfirmTimeout.value))
+}
+
+func (pub *ConfirmablePublisher) ensureCloseSignalsLocked() {
+ if pub.closeOnce == nil {
+ pub.closeOnce = &sync.Once{}
+ }
+
+ if pub.closedCh == nil {
+ pub.closedCh = make(chan struct{})
+ }
+}
+
+func safeCloseSignal(ch chan struct{}) {
+ if ch == nil {
+ return
+ }
+
+ select {
+ case <-ch:
+ return
+ default:
+ close(ch)
+ }
+}
+
+func drainConfirms(confirms <-chan amqp.Confirmation, timeout time.Duration) {
+ if confirms == nil {
+ return
+ }
+
+ if timeout <= 0 {
+ timeout = DefaultConfirmTimeout
+ }
+
+ grace := time.NewTimer(timeout)
+ defer grace.Stop()
+
+ for {
+ select {
+ case _, ok := <-confirms:
+ if !ok {
+ return
+ }
+ case <-grace.C:
+ return
+ }
+ }
+}
+
+func logIfConfigured(logger libLog.Logger, level libLog.Level, message string) {
+ if nilcheck.Interface(logger) {
+ return
+ }
+
+ logger.Log(context.Background(), level, message)
+}
diff --git a/commons/rabbitmq/publisher_test.go b/commons/rabbitmq/publisher_test.go
new file mode 100644
index 00000000..26deb7a3
--- /dev/null
+++ b/commons/rabbitmq/publisher_test.go
@@ -0,0 +1,834 @@
+//go:build unit
+
+package rabbitmq
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ amqp "github.com/rabbitmq/amqp091-go"
+)
+
+type mockConfirmableChannel struct {
+ mu sync.Mutex
+ confirmErr error
+ publishErr error
+ confirms chan amqp.Confirmation
+ closeNotify chan *amqp.Error
+ confirmCalled bool
+ publishCalled bool
+ closeCalled bool
+ deliveryCounter uint64
+}
+
+type panicPublisherLogger struct {
+ used bool
+}
+
+func (logger *panicPublisherLogger) Log(context.Context, libLog.Level, string, ...libLog.Field) {
+ logger.used = true
+}
+
+func (logger *panicPublisherLogger) With(...libLog.Field) libLog.Logger {
+ return logger
+}
+
+func (logger *panicPublisherLogger) WithGroup(string) libLog.Logger {
+ return logger
+}
+
+func (logger *panicPublisherLogger) Enabled(libLog.Level) bool {
+ return true
+}
+
+func (logger *panicPublisherLogger) Sync(context.Context) error {
+ return nil
+}
+
+func newMockChannel() *mockConfirmableChannel {
+ return &mockConfirmableChannel{
+ closeNotify: make(chan *amqp.Error, 1),
+ }
+}
+
+func (m *mockConfirmableChannel) Confirm(_ bool) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.confirmCalled = true
+
+ return m.confirmErr
+}
+
+func (m *mockConfirmableChannel) NotifyPublish(confirm chan amqp.Confirmation) chan amqp.Confirmation {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.confirms = confirm
+
+ return confirm
+}
+
+func (m *mockConfirmableChannel) NotifyClose(_ chan *amqp.Error) chan *amqp.Error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ return m.closeNotify
+}
+
+func (m *mockConfirmableChannel) PublishWithContext(
+ _ context.Context,
+ _, _ string,
+ _, _ bool,
+ _ amqp.Publishing,
+) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.publishCalled = true
+ m.deliveryCounter++
+
+ return m.publishErr
+}
+
+func (m *mockConfirmableChannel) Close() error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.closeCalled {
+ return nil
+ }
+
+ m.closeCalled = true
+ if m.confirms != nil {
+ close(m.confirms)
+ }
+
+ return nil
+}
+
+func (m *mockConfirmableChannel) sendConfirm(ack bool) {
+ m.mu.Lock()
+ tag := m.deliveryCounter
+ confirms := m.confirms
+ m.mu.Unlock()
+
+ confirms <- amqp.Confirmation{DeliveryTag: tag, Ack: ack}
+}
+
+func (m *mockConfirmableChannel) waitForPublish(t *testing.T) {
+ t.Helper()
+
+ require.Eventually(t, func() bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ return m.deliveryCounter > 0
+ }, time.Second, time.Millisecond)
+}
+
+func TestNewConfirmablePublisher_NilConnection(t *testing.T) {
+ t.Parallel()
+
+ publisher, err := NewConfirmablePublisher(nil)
+ assert.Nil(t, publisher)
+ assert.ErrorIs(t, err, ErrConnectionRequired)
+}
+
+func TestNewConfirmablePublisher_NilChannel(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{Channel: nil}
+ publisher, err := NewConfirmablePublisher(conn)
+ assert.Nil(t, publisher)
+ assert.ErrorIs(t, err, ErrChannelRequired)
+}
+
+func TestConfirmablePublisher_Publish_Success(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ go func() {
+ ch.waitForPublish(t)
+ ch.sendConfirm(true)
+ }()
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("ok")})
+ require.NoError(t, err)
+ assert.True(t, ch.publishCalled)
+}
+
+func TestConfirmablePublisher_PublishAndWaitConfirm_Success(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ go func() {
+ ch.waitForPublish(t)
+ ch.sendConfirm(true)
+ }()
+
+ err = publisher.PublishAndWaitConfirm(
+ context.Background(),
+ "exchange",
+ "route",
+ false,
+ false,
+ amqp.Publishing{Body: []byte("ok")},
+ )
+ require.NoError(t, err)
+}
+
+func TestConfirmablePublisher_Publish_Nack(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ go func() {
+ ch.waitForPublish(t)
+ ch.sendConfirm(false)
+ }()
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrPublishNacked)
+}
+
+func TestConfirmablePublisher_Publish_Timeout(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithConfirmTimeout(30*time.Millisecond))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrConfirmTimeout)
+}
+
+func TestNewConfirmablePublisherFromChannel_ConfirmError(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ ch.confirmErr = errors.New("confirm mode unavailable")
+
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.Nil(t, publisher)
+ require.ErrorIs(t, err, ErrConfirmModeUnavailable)
+}
+
+func TestConfirmablePublisher_ReconnectAfterCloseFails(t *testing.T) {
+ t.Parallel()
+
+ ch1 := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch1)
+ require.NoError(t, err)
+
+ require.NoError(t, publisher.Close())
+ err = publisher.Reconnect(newMockChannel())
+ require.ErrorIs(t, err, ErrReconnectAfterClose)
+}
+
+func TestConfirmablePublisher_ReconnectNilChannel(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ err = publisher.Reconnect(nil)
+ require.ErrorIs(t, err, ErrChannelRequired)
+}
+
+func TestConfirmablePublisher_WithConfirmTimeoutZeroKeepsDefault(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithConfirmTimeout(0))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ require.Equal(t, DefaultConfirmTimeout, publisher.confirmTimeout)
+}
+
+func TestConfirmablePublisher_WithConfirmTimeoutNegativeKeepsDefault(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithConfirmTimeout(-time.Second))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ require.Equal(t, DefaultConfirmTimeout, publisher.confirmTimeout)
+}
+
+func TestConfirmablePublisher_WithRecoveryBackoffRejectsInitialGreaterThanMax(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithRecoveryBackoff(5*time.Second, time.Second))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ require.Nil(t, publisher.recovery)
+}
+
+func TestConfirmablePublisher_ReconnectAfterRecoveryPreparation(t *testing.T) {
+ t.Parallel()
+
+ ch1 := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch1)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ require.True(t, publisher.prepareForRecovery())
+ recoveryDone := publisher.done
+
+ ch2 := newMockChannel()
+ require.NoError(t, publisher.Reconnect(ch2))
+ require.Equal(t, recoveryDone, publisher.done)
+
+ go func() {
+ ch2.waitForPublish(t)
+ ch2.sendConfirm(true)
+ }()
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("ok")})
+ require.NoError(t, err)
+}
+
+func TestConfirmablePublisher_ConcurrentReconnectSerialized(t *testing.T) {
+ t.Parallel()
+
+ publisher, err := NewConfirmablePublisherFromChannel(newMockChannel())
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ require.True(t, publisher.prepareForRecovery())
+
+ start := make(chan struct{})
+ errs := make(chan error, 2)
+
+ go func() {
+ <-start
+ errs <- publisher.Reconnect(newMockChannel())
+ }()
+
+ go func() {
+ <-start
+ errs <- publisher.Reconnect(newMockChannel())
+ }()
+
+ close(start)
+
+ errA := <-errs
+ errB := <-errs
+
+ if errA == nil {
+ require.ErrorIs(t, errB, ErrReconnectWhileOpen)
+
+ return
+ }
+
+ require.Nil(t, errB)
+ require.ErrorIs(t, errA, ErrReconnectWhileOpen)
+}
+
+func TestConfirmablePublisher_PublishDuringRecoveryState(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ require.True(t, publisher.prepareForRecovery())
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrPublisherClosed)
+}
+
+func TestConfirmablePublisher_ChannelAccessorAndChannelOrError(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ underlying := publisher.Channel()
+ require.NotNil(t, underlying)
+
+ readyChannel, err := publisher.ChannelOrError()
+ require.NoError(t, err)
+ require.Equal(t, underlying, readyChannel)
+
+ require.NoError(t, publisher.Close())
+ require.Nil(t, publisher.Channel())
+
+ notReadyChannel, err := publisher.ChannelOrError()
+ require.Nil(t, notReadyChannel)
+ require.ErrorIs(t, err, ErrPublisherClosed)
+}
+
+func TestConfirmablePublisher_AutoRecovery(t *testing.T) {
+ t.Parallel()
+
+ ch1 := newMockChannel()
+ ch2 := newMockChannel()
+
+ recovered := make(chan struct{})
+ publisher, err := NewConfirmablePublisherFromChannel(
+ ch1,
+ WithLogger(&libLog.NopLogger{}),
+ WithAutoRecovery(func() (ConfirmableChannel, error) { return ch2, nil }),
+ WithRecoveryBackoff(1*time.Millisecond, 5*time.Millisecond),
+ WithMaxRecoveryAttempts(3),
+ WithHealthCallback(func(state HealthState) {
+ if state == HealthStateConnected {
+ select {
+ case <-recovered:
+ default:
+ close(recovered)
+ }
+ }
+ }),
+ )
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ ch1.closeNotify <- amqp.ErrClosed
+
+ select {
+ case <-recovered:
+ case <-time.After(2 * time.Second):
+ t.Fatal("auto recovery did not complete")
+ }
+
+ go func() {
+ ch2.waitForPublish(t)
+ ch2.sendConfirm(true)
+ }()
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("ok")})
+ require.NoError(t, err)
+}
+
+func TestConfirmablePublisher_PrepareForRecoveryWaitsForInFlightPublish(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithConfirmTimeout(time.Second))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ publishDone := make(chan error, 1)
+ go func() {
+ publishDone <- publisher.Publish(
+ context.Background(),
+ "exchange",
+ "route",
+ false,
+ false,
+ amqp.Publishing{Body: []byte("ok")},
+ )
+ }()
+
+ ch.waitForPublish(t)
+
+ recoveryDone := make(chan bool, 1)
+ go func() {
+ recoveryDone <- publisher.prepareForRecovery()
+ }()
+
+ select {
+ case <-recoveryDone:
+ t.Fatal("prepareForRecovery must wait for in-flight publish")
+ default:
+ }
+
+ ch.sendConfirm(true)
+
+ select {
+ case err = <-publishDone:
+ require.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("publish did not complete")
+ }
+
+ select {
+ case prepared := <-recoveryDone:
+ require.True(t, prepared)
+ case <-time.After(time.Second):
+ t.Fatal("prepareForRecovery did not complete")
+ }
+}
+
+func TestConfirmablePublisher_CloseWaitsForInFlightPublish(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithConfirmTimeout(time.Second))
+ require.NoError(t, err)
+
+ publishDone := make(chan error, 1)
+ go func() {
+ publishDone <- publisher.Publish(
+ context.Background(),
+ "exchange",
+ "route",
+ false,
+ false,
+ amqp.Publishing{Body: []byte("ok")},
+ )
+ }()
+
+ ch.waitForPublish(t)
+
+ closeDone := make(chan error, 1)
+ go func() {
+ closeDone <- publisher.Close()
+ }()
+
+ select {
+ case err = <-closeDone:
+ t.Fatalf("close returned early while publish in-flight: %v", err)
+ default:
+ }
+
+ ch.sendConfirm(true)
+
+ select {
+ case err = <-publishDone:
+ require.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("publish did not complete")
+ }
+
+ select {
+ case err = <-closeDone:
+ require.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("close did not complete")
+ }
+
+ ch.mu.Lock()
+ closed := ch.closeCalled
+ ch.mu.Unlock()
+ require.True(t, closed)
+}
+
+func TestHealthState_String(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "connected", HealthStateConnected.String())
+ assert.Equal(t, "reconnecting", HealthStateReconnecting.String())
+ assert.Equal(t, "degraded", HealthStateDegraded.String())
+ assert.Equal(t, "disconnected", HealthStateDisconnected.String())
+ assert.Equal(t, "unknown", HealthState(99).String())
+}
+
+func TestConfirmablePublisher_HealthStateSnapshot(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ require.Equal(t, HealthStateConnected, publisher.HealthState())
+
+ publisher.emitHealthState(HealthStateReconnecting)
+ require.Equal(t, HealthStateReconnecting, publisher.HealthState())
+}
+
+func TestWithAutoRecoveryNilProvider(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithAutoRecovery(nil))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ assert.Nil(t, publisher.recovery)
+}
+
+func TestConfirmablePublisher_PublishError(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publishErr := errors.New("publish failed")
+ ch.publishErr = publishErr
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, publishErr)
+}
+
+func TestConfirmablePublisher_PublishOnClosedPublisher(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+
+ require.NoError(t, publisher.Close())
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrPublisherClosed)
+}
+
+func TestConfirmablePublisher_ReconnectWhileOpen(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ err = publisher.Reconnect(newMockChannel())
+ require.ErrorIs(t, err, ErrReconnectWhileOpen)
+}
+
+func TestConfirmablePublisher_PublishContextCancelled(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithConfirmTimeout(time.Second))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err = publisher.Publish(ctx, "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "context cancelled")
+}
+
+func TestConfirmablePublisher_CloseDuringRecoveryClosesRecoveryDone(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithAutoRecovery(func() (ConfirmableChannel, error) {
+ return newMockChannel(), nil
+ }))
+ require.NoError(t, err)
+
+ require.True(t, publisher.prepareForRecovery())
+ recoveryDone := publisher.done
+
+ require.NoError(t, publisher.Close())
+
+ select {
+ case <-recoveryDone:
+ case <-time.After(time.Second):
+ t.Fatal("recovery done channel was not closed by Close")
+ }
+
+ require.True(t, publisher.shutdown)
+}
+
+func TestConfirmablePublisher_AutoRecoveryExhausted(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ disconnected := make(chan struct{})
+
+ publisher, err := NewConfirmablePublisherFromChannel(
+ ch,
+ WithAutoRecovery(func() (ConfirmableChannel, error) {
+ return nil, errors.New("provider failed")
+ }),
+ WithRecoveryBackoff(time.Millisecond, 2*time.Millisecond),
+ WithMaxRecoveryAttempts(2),
+ WithHealthCallback(func(state HealthState) {
+ if state == HealthStateDisconnected {
+ select {
+ case <-disconnected:
+ default:
+ close(disconnected)
+ }
+ }
+ }),
+ )
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ ch.closeNotify <- amqp.ErrClosed
+
+ select {
+ case <-disconnected:
+ case <-time.After(time.Second):
+ t.Fatal("auto recovery did not report disconnection after exhaustion")
+ }
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrPublisherClosed)
+ require.ErrorIs(t, err, ErrRecoveryExhausted)
+}
+
+func TestConfirmablePublisher_ChannelCloseWithoutRecoveryTransitionsToDisconnected(t *testing.T) {
+ t.Parallel()
+
+ ch := newMockChannel()
+ publisher, err := NewConfirmablePublisherFromChannel(ch)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if err := publisher.Close(); err != nil {
+ t.Errorf("cleanup: publisher close: %v", err)
+ }
+ })
+
+ ch.closeNotify <- amqp.ErrClosed
+
+ require.Eventually(t, func() bool {
+ return publisher.HealthState() == HealthStateDisconnected
+ }, time.Second, time.Millisecond)
+
+ err = publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrPublisherClosed)
+}
+
+func TestConfirmablePublisher_WithTypedNilLoggerDoesNotPanic(t *testing.T) {
+ t.Parallel()
+
+ var logger *panicPublisherLogger
+
+ ch := newMockChannel()
+ require.NotPanics(t, func() {
+ publisher, err := NewConfirmablePublisherFromChannel(ch, WithLogger(logger))
+ require.NoError(t, err)
+ require.NoError(t, publisher.Close())
+ })
+}
+
+func TestConfirmablePublisher_CloseZeroValueIsSafe(t *testing.T) {
+ t.Parallel()
+
+ pub := &ConfirmablePublisher{}
+ require.NotPanics(t, func() {
+ require.NoError(t, pub.Close())
+ })
+
+ require.NoError(t, pub.Close())
+}
+
+func TestConfirmablePublisher_NilReceiverGuards(t *testing.T) {
+ t.Parallel()
+
+ var publisher *ConfirmablePublisher
+
+ err := publisher.Publish(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrPublisherRequired)
+
+ err = publisher.PublishAndWaitConfirm(context.Background(), "exchange", "route", false, false, amqp.Publishing{Body: []byte("x")})
+ require.ErrorIs(t, err, ErrPublisherRequired)
+
+ err = publisher.Close()
+ require.ErrorIs(t, err, ErrPublisherRequired)
+
+ err = publisher.Reconnect(newMockChannel())
+ require.ErrorIs(t, err, ErrPublisherRequired)
+
+ ch, err := publisher.ChannelOrError()
+ require.Nil(t, ch)
+ require.ErrorIs(t, err, ErrPublisherRequired)
+
+ require.Nil(t, publisher.Channel())
+ require.Equal(t, HealthStateDisconnected, publisher.HealthState())
+}
diff --git a/commons/rabbitmq/rabbitmq.go b/commons/rabbitmq/rabbitmq.go
index 8654de76..6c8a5eaf 100644
--- a/commons/rabbitmq/rabbitmq.go
+++ b/commons/rabbitmq/rabbitmq.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package rabbitmq
import (
@@ -12,323 +8,1398 @@ import (
"io"
"net"
"net/http"
+ "net/netip"
"net/url"
+ "regexp"
"strings"
"sync"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/LerianStudio/lib-commons/v4/commons/backoff"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/internal/nilcheck"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
amqp "github.com/rabbitmq/amqp091-go"
- "go.uber.org/zap"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/trace"
)
-// DefaultConnectionTimeout is the default timeout for establishing RabbitMQ connections
-// when ConnectionTimeout field is not set.
-const DefaultConnectionTimeout = 15 * time.Second
+// connectionFailuresMetric defines the counter for rabbitmq connection failures.
+var connectionFailuresMetric = metrics.Metric{
+ Name: "rabbitmq_connection_failures_total",
+ Unit: "1",
+ Description: "Total number of rabbitmq connection failures",
+}
// RabbitMQConnection is a hub which deal with rabbitmq connections.
type RabbitMQConnection struct {
- mu sync.Mutex // protects connection and channel operations
- ConnectionStringSource string
+ mu sync.RWMutex // protects connection and channel operations
+ ConnectionStringSource string `json:"-"`
Connection *amqp.Connection
Queue string
HealthCheckURL string
Host string
Port string
- User string
- Pass string //#nosec G117 -- Credential field required for RabbitMQ connection config
+ User string `json:"-"`
+ Pass string `json:"-"`
VHost string
Channel *amqp.Channel
Logger log.Logger
+ MetricsFactory *metrics.MetricsFactory
Connected bool
- ConnectionTimeout time.Duration // timeout for establishing connection. Zero value uses default of 15s.
+
+ dialer func(string) (*amqp.Connection, error)
+ dialerContext func(context.Context, string) (*amqp.Connection, error)
+ channelFactory func(*amqp.Connection) (*amqp.Channel, error)
+ channelFactoryContext func(context.Context, *amqp.Connection) (*amqp.Channel, error)
+ connectionCloser func(*amqp.Connection) error
+ connectionCloserContext func(context.Context, *amqp.Connection) error
+ connectionClosedFn func(*amqp.Connection) bool
+ channelClosedFn func(*amqp.Channel) bool
+ channelCloser func(*amqp.Channel) error
+ channelCloserContext func(context.Context, *amqp.Channel) error
+ healthHTTPClient *http.Client
+
+ // AllowInsecureTLS must be set to true to explicitly acknowledge that
+ // the health check HTTP client has TLS certificate verification disabled.
+ // Without this flag, applyDefaults returns ErrInsecureTLS.
+ AllowInsecureTLS bool
+
+ // AllowInsecureHealthCheck must be set to true to explicitly acknowledge
+ // that basic auth credentials are sent over plain HTTP (not HTTPS).
+ // Without this flag, health check validation returns ErrInsecureHealthCheck.
+ AllowInsecureHealthCheck bool
+
+ // HealthCheckAllowedHosts restricts which hosts the health check URL may
+ // target. When non-empty, the health check URL's host (optionally host:port)
+ // must match one of the entries. This protects against SSRF via
+ // configuration injection.
+ // When empty, compatibility mode allows any host unless
+ // RequireHealthCheckAllowedHosts is true. For basic-auth health checks,
+ // derived hosts from AMQP settings are used as fallback enforcement.
+ HealthCheckAllowedHosts []string
+
+ // RequireHealthCheckAllowedHosts enforces a non-empty HealthCheckAllowedHosts
+ // list for every health check. Keep this false during compatibility rollout,
+ // then enable it to hard-fail unsafe configurations.
+ RequireHealthCheckAllowedHosts bool
+
+ warnMissingAllowlistOnce sync.Once
+
+ // Reconnect rate-limiting: prevents thundering-herd reconnect storms
+ // when the broker is down by enforcing exponential backoff between attempts.
+ lastReconnectAttempt time.Time
+ reconnectAttempts int
+}
+
+const defaultRabbitMQHealthCheckTimeout = 5 * time.Second
+
+// reconnectBackoffCap is the maximum delay between reconnect attempts.
+const reconnectBackoffCap = 30 * time.Second
+
+// ErrInsecureTLS is returned when the health check HTTP client has TLS verification disabled
+// without explicitly acknowledging the risk via AllowInsecureTLS.
+var ErrInsecureTLS = errors.New("rabbitmq health check HTTP client has TLS verification disabled — set AllowInsecureTLS to acknowledge this risk")
+
+// ErrInsecureHealthCheck is returned when the health check URL uses HTTP with basic auth
+// credentials without explicitly opting in via AllowInsecureHealthCheck.
+var ErrInsecureHealthCheck = errors.New("rabbitmq health check uses HTTP with basic auth credentials — set AllowInsecureHealthCheck to acknowledge this risk")
+
+// ErrHealthCheckHostNotAllowed is returned when the health check URL targets a host
+// not present in the HealthCheckAllowedHosts allowlist.
+var ErrHealthCheckHostNotAllowed = errors.New("rabbitmq health check host not in allowed list")
+
+// ErrHealthCheckAllowedHostsRequired is returned when strict allowlist mode is enabled
+// but no allowed hosts were configured.
+var ErrHealthCheckAllowedHostsRequired = errors.New("rabbitmq health check allowed hosts list is required")
+
+// ErrNilConnection is returned when a method is called on a nil RabbitMQConnection.
+var ErrNilConnection = errors.New("rabbitmq connection is nil")
+
+const redactedURLPassword = "xxxxx"
+
+// Best-effort URL matcher used for redaction on arbitrary error messages.
+// This intentionally differs from outbox's storage sanitizer because this path
+// optimizes for preserving operational error context while redacting credentials.
+var urlPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9+.-]*://[^\s]+`)
+
+// ChannelSnapshot returns the current channel reference under connection lock.
+func (rc *RabbitMQConnection) ChannelSnapshot() *amqp.Channel {
+ if rc == nil {
+ return nil
+ }
+
+ rc.mu.RLock()
+ defer rc.mu.RUnlock()
+
+ return rc.Channel
+}
+
+// nilConnectionAssert fires a telemetry assertion for nil-receiver calls and returns ErrNilConnection.
+// The logger is intentionally nil here because this function is called on a nil *RabbitMQConnection
+// receiver, so there is no struct instance from which to extract a logger. The assert package
+// handles nil loggers gracefully by falling back to stderr.
+func nilConnectionAssert(operation string) error {
+ asserter := assert.New(context.Background(), nil, "rabbitmq", operation)
+ _ = asserter.Never(context.Background(), "rabbitmq connection receiver is nil")
+
+ return ErrNilConnection
}
// Connect keeps a singleton connection with rabbitmq.
func (rc *RabbitMQConnection) Connect() error {
+ return rc.ConnectContext(context.Background())
+}
+
+// isFullyConnected reports whether the connection and channel are both open.
+// The caller MUST hold rc.mu.
+func (rc *RabbitMQConnection) isFullyConnected() bool {
+ return rc.Connected &&
+ rc.Connection != nil && !rc.connectionClosedFn(rc.Connection) &&
+ rc.Channel != nil && !rc.channelClosedFn(rc.Channel)
+}
+
+// connectSnapshot captures the configuration state needed for dialing and health
+// checking under the lock. The caller MUST hold rc.mu.
+type connectSnapshot struct {
+ connStr string
+ healthCheckURL string
+ healthUser string
+ healthPass string
+ healthPolicy healthCheckURLConfig
+ healthClient *http.Client
+ dialer func(context.Context, string) (*amqp.Connection, error)
+ channelFactory func(context.Context, *amqp.Connection) (*amqp.Channel, error)
+ connectionClosedFn func(*amqp.Connection) bool
+ connCloser func(*amqp.Connection) error
+ logger log.Logger
+}
+
+// snapshotConnectState captures connect-time state under the lock.
+// The caller MUST hold rc.mu.
+func (rc *RabbitMQConnection) snapshotConnectState() connectSnapshot {
+ connStr := rc.ConnectionStringSource
+ healthCheckURL := rc.HealthCheckURL
+ configuredHosts := append([]string(nil), rc.HealthCheckAllowedHosts...)
+ derivedHosts := mergeAllowedHosts(
+ deriveAllowedHostsFromConnectionString(connStr),
+ append(
+ deriveAllowedHostsFromHostPort(rc.Host, rc.Port),
+ deriveAllowedHostsFromHealthCheckURL(healthCheckURL)...,
+ )...,
+ )
+
+ return connectSnapshot{
+ connStr: connStr,
+ healthCheckURL: healthCheckURL,
+ healthUser: rc.User,
+ healthPass: rc.Pass,
+ healthPolicy: healthCheckURLConfig{
+ allowInsecure: rc.AllowInsecureHealthCheck,
+ hasBasicAuth: rc.User != "" || rc.Pass != "",
+ allowedHosts: configuredHosts,
+ derivedAllowedHosts: derivedHosts,
+ allowlistConfigured: len(configuredHosts) > 0,
+ requireAllowedHosts: rc.RequireHealthCheckAllowedHosts,
+ },
+ healthClient: rc.healthHTTPClient,
+ dialer: rc.dialerContext,
+ channelFactory: rc.channelFactoryContext,
+ connectionClosedFn: rc.connectionClosedFn,
+ connCloser: rc.connectionCloser,
+ logger: rc.logger(),
+ }
+}
+
+// ConnectContext keeps a singleton connection with rabbitmq.
+func (rc *RabbitMQConnection) ConnectContext(ctx context.Context) error {
+ if rc == nil {
+ return nilConnectionAssert("connect_context")
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if err := ctx.Err(); err != nil {
+ return fmt.Errorf("rabbitmq connect: %w", err)
+ }
+
+ tracer := otel.Tracer("rabbitmq")
+
+ ctx, span := tracer.Start(ctx, "rabbitmq.connect")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRabbitMQ))
+
rc.mu.Lock()
- defer rc.mu.Unlock()
- rc.Logger.Info("Connecting on rabbitmq...")
+ if err := rc.applyDefaults(); err != nil {
+ rc.mu.Unlock()
- conn, err := amqp.Dial(rc.ConnectionStringSource)
- if err != nil {
- rc.Logger.Error("failed to connect on rabbitmq", zap.Error(err))
- return fmt.Errorf("failed to connect to rabbitmq: %w", err)
+ libOpentelemetry.HandleSpanError(span, "Failed to apply defaults", err)
+
+ return fmt.Errorf("rabbitmq connect: %w", err)
}
- ch, err := conn.Channel()
- if err != nil {
- if closeErr := conn.Close(); closeErr != nil {
- rc.Logger.Warn("failed to close connection during cleanup", zap.Error(closeErr))
- }
+ // Fast-path: if already connected with an open connection and channel,
+ // return immediately without creating a new connection.
+ if rc.isFullyConnected() {
+ rc.mu.Unlock()
+
+ return nil
+ }
+
+ snap := rc.snapshotConnectState()
+ rc.mu.Unlock()
- rc.Logger.Error("failed to open channel on rabbitmq", zap.Error(err))
+ snap.logger.Log(ctx, log.LevelInfo, "connecting to rabbitmq")
- return fmt.Errorf("failed to open channel on rabbitmq: %w", err)
+ conn, ch, err := rc.dialAndOpenChannel(ctx, span, snap)
+ if err != nil {
+ return err
}
- if ch == nil || !rc.HealthCheck() {
- if closeErr := conn.Close(); closeErr != nil {
- rc.Logger.Warn("failed to close connection during cleanup", zap.Error(closeErr))
- }
+ if healthErr := rc.healthCheck(ctx, snap.healthCheckURL, snap.healthUser, snap.healthPass, snap.healthClient, snap.healthPolicy, snap.logger); healthErr != nil {
+ rc.closeConnectionWith(conn, snap.connCloser)
+ rc.clearConnectionState()
- rc.Connected = false
- err = errors.New("can't connect rabbitmq")
- rc.Logger.Error("RabbitMQ.HealthCheck failed", zap.Error(err))
+ snap.logger.Log(ctx, log.LevelError, "rabbitmq health check failed")
- return fmt.Errorf("rabbitmq health check failed: %w", err)
+ return fmt.Errorf("rabbitmq health check failed: %w", healthErr)
}
- rc.Logger.Info("Connected on rabbitmq ✅ \n")
+ snap.logger.Log(ctx, log.LevelInfo, "connected to rabbitmq")
+
+ rc.mu.Lock()
+ if rc.Connection != nil && rc.Connection != conn && !snap.connectionClosedFn(rc.Connection) {
+ rc.mu.Unlock()
+
+ rc.closeConnectionWith(conn, snap.connCloser)
+
+ return nil
+ }
rc.Connected = true
rc.Connection = conn
-
rc.Channel = ch
+ rc.mu.Unlock()
return nil
}
+// dialAndOpenChannel dials the broker and opens a channel. On any failure it
+// clears connection state and records the error on the span before returning.
+func (rc *RabbitMQConnection) dialAndOpenChannel(ctx context.Context, span trace.Span, snap connectSnapshot) (*amqp.Connection, *amqp.Channel, error) {
+ conn, err := snap.dialer(ctx, snap.connStr)
+ if err != nil {
+ snap.logger.Log(ctx, log.LevelError, "failed to connect to rabbitmq", log.String("error_detail", sanitizeAMQPErr(err, snap.connStr)))
+ rc.recordConnectionFailure("connect")
+ rc.clearConnectionState()
+
+ sanitizedErr := newSanitizedError(err, snap.connStr, "failed to connect to rabbitmq")
+ libOpentelemetry.HandleSpanError(span, "Failed to connect to rabbitmq", sanitizedErr)
+
+ return nil, nil, sanitizedErr
+ }
+
+ ch, err := snap.channelFactory(ctx, conn)
+ if err != nil {
+ rc.closeConnectionWith(conn, snap.connCloser)
+ rc.clearConnectionState()
+
+ snap.logger.Log(ctx, log.LevelError, "failed to open channel on rabbitmq", log.Err(err))
+
+ libOpentelemetry.HandleSpanError(span, "Failed to open channel on rabbitmq", err)
+
+ return nil, nil, fmt.Errorf("failed to open channel on rabbitmq: %w", err)
+ }
+
+ if ch == nil {
+ rc.closeConnectionWith(conn, snap.connCloser)
+ rc.clearConnectionState()
+
+ err = errors.New("can't connect rabbitmq")
+
+ snap.logger.Log(ctx, log.LevelError, "rabbitmq health check failed")
+
+ libOpentelemetry.HandleSpanError(span, "RabbitMQ health check failed", err)
+
+ return nil, nil, fmt.Errorf("rabbitmq health check failed: %w", err)
+ }
+
+ return conn, ch, nil
+}
+
// EnsureChannel ensures that the channel is open and connected.
-// For context-aware connection handling with timeout support, see EnsureChannelWithContext.
func (rc *RabbitMQConnection) EnsureChannel() error {
+ return rc.EnsureChannelContext(context.Background())
+}
+
+// ensureChannelSnapshot captures state needed by EnsureChannelContext under the lock.
+type ensureChannelSnapshot struct {
+ connStr string
+ logger log.Logger
+ dialer func(context.Context, string) (*amqp.Connection, error)
+ channelFactory func(context.Context, *amqp.Connection) (*amqp.Channel, error)
+ connCloser func(*amqp.Connection) error
+ connectionClosedFn func(*amqp.Connection) bool
+ needConnection bool
+ needChannel bool
+ existingConn *amqp.Connection
+}
+
+// snapshotEnsureChannelState captures and returns a snapshot of state needed for channel
+// ensuring, applying defaults and rate-limiting under the lock. Returns an error if
+// defaults fail or the request is rate-limited.
+func (rc *RabbitMQConnection) snapshotEnsureChannelState() (ensureChannelSnapshot, error) {
rc.mu.Lock()
defer rc.mu.Unlock()
+ if err := rc.applyDefaults(); err != nil {
+ return ensureChannelSnapshot{}, fmt.Errorf("rabbitmq ensure channel: %w", err)
+ }
+
+ connectionClosedFn := rc.connectionClosedFn
+ channelClosedFn := rc.channelClosedFn
+ needConnection := rc.Connection == nil || connectionClosedFn(rc.Connection)
+ needChannel := needConnection || rc.Channel == nil || channelClosedFn(rc.Channel)
+
+ // Rate-limit reconnect attempts: if we've failed recently, enforce a
+ // minimum delay before the next attempt to prevent reconnect storms.
+ if needConnection && rc.reconnectAttempts > 0 {
+ delay := min(backoff.ExponentialWithJitter(500*time.Millisecond, rc.reconnectAttempts), reconnectBackoffCap)
+
+ if elapsed := time.Since(rc.lastReconnectAttempt); elapsed < delay {
+ return ensureChannelSnapshot{}, fmt.Errorf("rabbitmq ensure channel: rate-limited (next attempt in %s)", delay-elapsed)
+ }
+ }
+
+ return ensureChannelSnapshot{
+ connStr: rc.ConnectionStringSource,
+ logger: rc.logger(),
+ dialer: rc.dialerContext,
+ channelFactory: rc.channelFactoryContext,
+ connCloser: rc.connectionCloser,
+ connectionClosedFn: connectionClosedFn,
+ needConnection: needConnection,
+ needChannel: needChannel,
+ existingConn: rc.Connection,
+ }, nil
+}
+
+// EnsureChannelContext ensures that the channel is open and connected.
+func (rc *RabbitMQConnection) EnsureChannelContext(ctx context.Context) error {
+ if rc == nil {
+ return nilConnectionAssert("ensure_channel_context")
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if err := ctx.Err(); err != nil {
+ return fmt.Errorf("rabbitmq ensure channel: %w", err)
+ }
+
+ tracer := otel.Tracer("rabbitmq")
+
+ ctx, span := tracer.Start(ctx, "rabbitmq.ensure_channel")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRabbitMQ))
+
+ snap, err := rc.snapshotEnsureChannelState()
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to prepare ensure channel state", err)
+ return err
+ }
+
+ if !snap.needChannel {
+ return nil
+ }
+
+ var conn *amqp.Connection
+
newConnection := false
- if rc.Connection == nil || rc.Connection.IsClosed() {
- conn, err := amqp.Dial(rc.ConnectionStringSource)
+ if snap.needConnection {
+ rc.mu.Lock()
+ rc.lastReconnectAttempt = time.Now()
+ rc.mu.Unlock()
+
+ conn, err = snap.dialer(ctx, snap.connStr)
if err != nil {
- rc.Logger.Error("failed to connect to rabbitmq", zap.Error(err))
+ snap.logger.Log(ctx, log.LevelError, "failed to connect to rabbitmq", log.String("error_detail", sanitizeAMQPErr(err, snap.connStr)))
+ rc.recordConnectionFailure("ensure_channel_connect")
+
+ rc.mu.Lock()
+ rc.Connected = false
+ rc.reconnectAttempts++
+ rc.mu.Unlock()
- return fmt.Errorf("failed to connect to rabbitmq: %w", err)
+ sanitizedErr := newSanitizedError(err, snap.connStr, "can't connect to rabbitmq")
+ libOpentelemetry.HandleSpanError(span, "Failed to connect to rabbitmq", sanitizedErr)
+
+ return sanitizedErr
}
- rc.Connection = conn
newConnection = true
+ } else {
+ conn = snap.existingConn
}
- if rc.Channel == nil || rc.Channel.IsClosed() {
- ch, err := rc.Connection.Channel()
- if err != nil {
- // cleanup connection if we just created it and channel creation fails
- if newConnection {
- if closeErr := rc.Connection.Close(); closeErr != nil {
- rc.Logger.Warn("failed to close connection during cleanup", zap.Error(closeErr))
- }
+ ch, err := snap.channelFactory(ctx, conn)
+ if err == nil && ch == nil {
+ err = errors.New("channel factory returned nil channel")
+ }
- rc.Connection = nil
- }
+ if err != nil {
+ rc.handleChannelFailure(conn, snap.existingConn, newConnection, snap.connCloser)
+ rc.recordConnectionFailure("ensure_channel")
- // Reset stale state so GetNewConnect triggers reconnection
- rc.Connected = false
- rc.Channel = nil
+ snap.logger.Log(ctx, log.LevelError, "failed to open channel on rabbitmq", log.Err(err))
- rc.Logger.Error("failed to open channel on rabbitmq", zap.Error(err))
+ libOpentelemetry.HandleSpanError(span, "Failed to open channel on rabbitmq", err)
- return fmt.Errorf("failed to open channel on rabbitmq: %w", err)
- }
+ return fmt.Errorf("rabbitmq ensure channel: %w", err)
+ }
- rc.Channel = ch
+ rc.mu.Lock()
+ if newConnection {
+ rc.Connection = conn
+ rc.reconnectAttempts = 0
}
+ rc.Channel = ch
rc.Connected = true
+ rc.mu.Unlock()
return nil
}
-// EnsureChannelWithContext ensures that the channel is open and connected,
-// respecting context cancellation and deadline. Unlike EnsureChannel, this method
-// will return immediately if context is cancelled or deadline exceeded.
-//
-// The effective connection timeout is the minimum of:
-// - The remaining time until context deadline (if context has a deadline)
-// - ConnectionTimeout field value (defaults to 15s if zero)
-//
-// Usage:
-//
-// ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
-// defer cancel()
-// if err := conn.EnsureChannelWithContext(ctx); err != nil {
-// // Handle error - could be context timeout or connection failure
-// }
-func (rc *RabbitMQConnection) EnsureChannelWithContext(ctx context.Context) error {
- // Check context before acquiring lock to fail fast
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
+// GetNewConnect returns a pointer to the rabbitmq connection, initializing it if necessary.
+func (rc *RabbitMQConnection) GetNewConnect() (*amqp.Channel, error) {
+ return rc.GetNewConnectContext(context.Background())
+}
+
+// GetNewConnectContext returns a pointer to the rabbitmq connection, initializing it if necessary.
+func (rc *RabbitMQConnection) GetNewConnectContext(ctx context.Context) (*amqp.Channel, error) {
+ if rc == nil {
+ return nil, nilConnectionAssert("get_new_connect_context")
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+
+ rc.mu.Lock()
+
+ if err := rc.applyDefaults(); err != nil {
+ rc.mu.Unlock()
+
+ return nil, err
+ }
+
+ if rc.Connected && rc.Channel != nil && !rc.channelClosedFn(rc.Channel) {
+ ch := rc.Channel
+ rc.mu.Unlock()
+
+ return ch, nil
+ }
+ rc.mu.Unlock()
+
+ if err := rc.EnsureChannelContext(ctx); err != nil {
+ rc.logger().Log(ctx, log.LevelError, "failed to ensure channel", log.Err(err))
+
+ return nil, err
}
rc.mu.Lock()
defer rc.mu.Unlock()
- // Check context again after acquiring lock
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
+ if rc.Channel == nil {
+ rc.Connected = false
+
+ return nil, errors.New("rabbitmq channel not available")
}
- newConnection := false
+ return rc.Channel, nil
+}
- if rc.Connection == nil || rc.Connection.IsClosed() {
- conn, err := rc.dialWithContext(ctx)
- if err != nil {
- rc.Logger.Error("failed to connect to rabbitmq", zap.Error(err))
- return fmt.Errorf("failed to connect to rabbitmq: %w", err)
- }
+// HealthCheck rabbitmq when the server is started.
+func (rc *RabbitMQConnection) HealthCheck() (bool, error) {
+ return rc.HealthCheckContext(context.Background())
+}
- rc.Connection = conn
- newConnection = true
+// HealthCheckContext rabbitmq when the server is started.
+// It captures config fields under lock to avoid reading them during concurrent mutation.
+func (rc *RabbitMQConnection) HealthCheckContext(ctx context.Context) (bool, error) {
+ if rc == nil {
+ return false, nilConnectionAssert("health_check_context")
}
- if rc.Channel == nil || rc.Channel.IsClosed() {
- // Check context before Channel() which doesn't accept context parameter
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- }
+ if ctx == nil {
+ ctx = context.Background()
+ }
- ch, err := rc.Connection.Channel()
- if err != nil {
- // cleanup connection if we just created it and channel creation fails
- if newConnection {
- if closeErr := rc.Connection.Close(); closeErr != nil {
- rc.Logger.Warn("failed to close connection during cleanup", zap.Error(closeErr))
- }
+ tracer := otel.Tracer("rabbitmq")
- rc.Connection = nil
- }
+ ctx, span := tracer.Start(ctx, "rabbitmq.health_check")
+ defer span.End()
- // Reset stale state so GetNewConnect triggers reconnection
- rc.Connected = false
- rc.Channel = nil
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRabbitMQ))
- rc.Logger.Error("failed to open channel on rabbitmq", zap.Error(err))
+ rc.mu.Lock()
+ if err := rc.applyDefaults(); err != nil {
+ rc.mu.Unlock()
- return fmt.Errorf("failed to open channel on rabbitmq: %w", err)
- }
+ return false, err
+ }
- rc.Channel = ch
+ healthURL := rc.HealthCheckURL
+ user := rc.User
+ pass := rc.Pass
+ configuredHosts := append([]string(nil), rc.HealthCheckAllowedHosts...)
+ derivedHosts := mergeAllowedHosts(
+ deriveAllowedHostsFromConnectionString(rc.ConnectionStringSource),
+ append(
+ deriveAllowedHostsFromHostPort(rc.Host, rc.Port),
+ deriveAllowedHostsFromHealthCheckURL(healthURL)...,
+ )...,
+ )
+ healthPolicy := healthCheckURLConfig{
+ allowInsecure: rc.AllowInsecureHealthCheck,
+ hasBasicAuth: rc.User != "" || rc.Pass != "",
+ allowedHosts: configuredHosts,
+ derivedAllowedHosts: derivedHosts,
+ allowlistConfigured: len(configuredHosts) > 0,
+ requireAllowedHosts: rc.RequireHealthCheckAllowedHosts,
}
+ client := rc.healthHTTPClient
+ logger := rc.logger()
+ rc.mu.Unlock()
- rc.Connected = true
+ if err := rc.healthCheck(ctx, healthURL, user, pass, client, healthPolicy, logger); err != nil {
+ return false, err
+ }
- return nil
+ return true, nil
+}
+
+// healthCheck is the internal implementation that operates on pre-captured config values,
+// safe to call without holding the mutex.
+func (rc *RabbitMQConnection) healthCheck(
+ ctx context.Context,
+ rawHealthURL, user, pass string,
+ client *http.Client,
+ policy healthCheckURLConfig,
+ logger log.Logger,
+) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if err := ctx.Err(); err != nil {
+ logger.Log(ctx, log.LevelError, "context canceled during rabbitmq health check", log.Err(err))
+
+ return fmt.Errorf("rabbitmq health check context: %w", err)
+ }
+
+ if policy.hasBasicAuth != (user != "" || pass != "") {
+ policy.hasBasicAuth = user != "" || pass != ""
+ }
+
+ if !policy.allowlistConfigured && !policy.requireAllowedHosts {
+ rc.warnMissingAllowlistOnce.Do(func() {
+ logger.Log(
+ ctx,
+ log.LevelWarn,
+ "rabbitmq health check explicit host allowlist is empty; compatibility mode may skip host validation. Configure HealthCheckAllowedHosts and set RequireHealthCheckAllowedHosts=true to enforce strict SSRF hardening",
+ )
+ })
+ }
+
+ healthURL, err := validateHealthCheckURLWithConfig(rawHealthURL, policy)
+ if err != nil {
+ logger.Log(ctx, log.LevelError, "invalid rabbitmq health check URL", log.Err(err))
+
+ return err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, healthURL, nil)
+ if err != nil {
+ logger.Log(ctx, log.LevelError, "failed to create rabbitmq health check request", log.Err(err))
+
+ return fmt.Errorf("building rabbitmq health check request: %w", err)
+ }
+
+ req.SetBasicAuth(user, pass)
+
+ if client == nil {
+ client = &http.Client{Timeout: defaultRabbitMQHealthCheckTimeout}
+ }
+
+ // #nosec G704 -- URL is validated via validateHealthCheckURLWithConfig before request; host allowlist and IP safety checks prevent SSRF
+ resp, err := client.Do(req)
+ if err != nil {
+ logger.Log(ctx, log.LevelError, "failed to execute rabbitmq health check request", log.Err(err))
+
+ return fmt.Errorf("executing rabbitmq health check request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ return parseHealthCheckResponse(ctx, resp, logger)
}
-// dialWithContext creates an AMQP connection with context awareness.
-// It extracts the deadline from context and uses it as connection timeout.
-// If context has no deadline, uses ConnectionTimeout field (default 15s).
-func (rc *RabbitMQConnection) dialWithContext(ctx context.Context) (*amqp.Connection, error) {
- // Determine timeout from context deadline or default
- timeout := rc.ConnectionTimeout
- if timeout <= 0 {
- timeout = DefaultConnectionTimeout
+func parseHealthCheckResponse(ctx context.Context, resp *http.Response, logger log.Logger) error {
+ if resp == nil {
+ return errors.New("rabbitmq health check response is empty")
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ logger.Log(ctx, log.LevelError, "rabbitmq health check failed", log.String("status", resp.Status))
+
+ return fmt.Errorf("rabbitmq health check status %q", resp.Status)
}
- if deadline, ok := ctx.Deadline(); ok {
- remaining := time.Until(deadline)
- if remaining <= 0 {
- return nil, context.DeadlineExceeded
+ body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
+ if err != nil {
+ logger.Log(ctx, log.LevelError, "failed to read rabbitmq health check response", log.Err(err))
+
+ return fmt.Errorf("reading rabbitmq health check response: %w", err)
+ }
+
+ var result map[string]any
+
+ err = json.Unmarshal(body, &result)
+ if err != nil {
+ logger.Log(ctx, log.LevelError, "failed to parse rabbitmq health check response", log.Err(err))
+
+ return fmt.Errorf("parsing rabbitmq health check response: %w", err)
+ }
+
+ if result == nil {
+ logger.Log(ctx, log.LevelError, "rabbitmq health check response is empty or null")
+
+ return errors.New("rabbitmq health check response is empty")
+ }
+
+ if status, ok := result["status"].(string); ok && status == "ok" {
+ return nil
+ }
+
+ logger.Log(ctx, log.LevelError, "rabbitmq is not healthy")
+
+ return errors.New("rabbitmq is not healthy")
+}
+
+func (rc *RabbitMQConnection) applyDefaults() error {
+ rc.applyConnectionDefaults()
+ rc.applyChannelDefaults()
+
+ return rc.applyHealthDefaults()
+}
+
+func (rc *RabbitMQConnection) applyConnectionDefaults() {
+ if rc.dialer == nil {
+ rc.dialer = amqp.Dial
+ }
+
+ if rc.dialerContext == nil {
+ rc.dialerContext = func(_ context.Context, connectionString string) (*amqp.Connection, error) {
+ return rc.dialer(connectionString)
+ }
+ }
+
+ if rc.connectionCloser == nil {
+ rc.connectionCloser = func(connection *amqp.Connection) error {
+ if connection == nil {
+ return nil
+ }
+
+ return connection.Close()
}
+ }
- if remaining < timeout {
- timeout = remaining
+ if rc.connectionCloserContext == nil {
+ rc.connectionCloserContext = func(_ context.Context, connection *amqp.Connection) error {
+ return rc.connectionCloser(connection)
}
}
- // Create config with custom dialer that respects timeout
- config := amqp.Config{
- Dial: func(network, addr string) (net.Conn, error) {
- // Check context before dialing
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
+ if rc.connectionClosedFn == nil {
+ rc.connectionClosedFn = func(connection *amqp.Connection) bool {
+ if connection == nil {
+ return true
}
- dialer := &net.Dialer{
- Timeout: timeout,
+ return connection.IsClosed()
+ }
+ }
+}
+
+func (rc *RabbitMQConnection) applyChannelDefaults() {
+ if rc.channelFactory == nil {
+ rc.channelFactory = func(connection *amqp.Connection) (*amqp.Channel, error) {
+ if connection == nil {
+ return nil, errors.New("cannot create channel: connection is nil")
+ }
+
+ return connection.Channel()
+ }
+ }
+
+ if rc.channelFactoryContext == nil {
+ rc.channelFactoryContext = func(_ context.Context, connection *amqp.Connection) (*amqp.Channel, error) {
+ return rc.channelFactory(connection)
+ }
+ }
+
+ if rc.channelClosedFn == nil {
+ rc.channelClosedFn = func(ch *amqp.Channel) bool {
+ if ch == nil {
+ return true
}
- conn, err := dialer.DialContext(ctx, network, addr)
- if err != nil {
- return nil, err
+ return ch.IsClosed()
+ }
+ }
+
+ if rc.channelCloser == nil {
+ rc.channelCloser = func(ch *amqp.Channel) error {
+ if ch == nil {
+ return nil
}
- return conn, nil
- },
+ return ch.Close()
+ }
}
- return amqp.DialConfig(rc.ConnectionStringSource, config)
+ if rc.channelCloserContext == nil {
+ rc.channelCloserContext = func(_ context.Context, ch *amqp.Channel) error {
+ return rc.channelCloser(ch)
+ }
+ }
}
-// GetNewConnect returns a pointer to the rabbitmq connection, initializing it if necessary.
-func (rc *RabbitMQConnection) GetNewConnect() (*amqp.Channel, error) {
- if !rc.Connected {
- err := rc.Connect()
- if err != nil {
- rc.Logger.Infof("ERRCONECT %s", err)
+func (rc *RabbitMQConnection) applyHealthDefaults() error {
+ if rc.healthHTTPClient == nil {
+ rc.healthHTTPClient = &http.Client{Timeout: defaultRabbitMQHealthCheckTimeout}
- return nil, err
+ return nil
+ }
+
+ transport, ok := rc.healthHTTPClient.Transport.(*http.Transport)
+ if !ok || transport.TLSClientConfig == nil {
+ return nil
+ }
+
+ if transport.TLSClientConfig.InsecureSkipVerify && !rc.AllowInsecureTLS {
+ return ErrInsecureTLS
+ }
+
+ return nil
+}
+
+func (rc *RabbitMQConnection) closeConnectionWith(connection *amqp.Connection, closer func(*amqp.Connection) error) {
+ if closer == nil {
+ return
+ }
+
+ if err := closer(connection); err != nil {
+ rc.logger().Log(context.Background(), log.LevelWarn, "failed to close rabbitmq connection during cleanup", log.Err(err))
+ }
+}
+
+// clearConnectionState resets the connection state under lock after a failed
+// connect/reconnect attempt, ensuring no stale Connected/Connection/Channel
+// references remain.
+func (rc *RabbitMQConnection) clearConnectionState() {
+ rc.mu.Lock()
+ rc.Connected = false
+ rc.Connection = nil
+ rc.Channel = nil
+ rc.mu.Unlock()
+}
+
+// handleChannelFailure cleans up after a failed channel creation in EnsureChannelContext.
+// It conditionally closes the connection and resets the channel/connected state.
+func (rc *RabbitMQConnection) handleChannelFailure(conn, existingConn *amqp.Connection, newConnection bool, connCloser func(*amqp.Connection) error) {
+ if newConnection {
+ rc.closeConnectionWith(conn, connCloser)
+ }
+
+ rc.mu.Lock()
+ if newConnection && rc.Connection == existingConn {
+ rc.Connection = nil
+ }
+
+ rc.Channel = nil
+ rc.Connected = false
+ rc.mu.Unlock()
+}
+
+// Close closes the rabbitmq channel and connection.
+func (rc *RabbitMQConnection) Close() error {
+ return rc.CloseContext(context.Background())
+}
+
+// CloseContext closes the rabbitmq channel and connection.
+func (rc *RabbitMQConnection) CloseContext(ctx context.Context) error {
+ if rc == nil {
+ return nilConnectionAssert("close_context")
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if err := ctx.Err(); err != nil {
+ return fmt.Errorf("rabbitmq close: %w", err)
+ }
+
+ tracer := otel.Tracer("rabbitmq")
+
+ ctx, span := tracer.Start(ctx, "rabbitmq.close")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRabbitMQ))
+
+ rc.mu.Lock()
+ _ = rc.applyDefaults() // Close must not fail due to TLS config — resources still need cleanup.
+ channel := rc.Channel
+ connection := rc.Connection
+ chCloser := rc.channelCloserContext
+ connCloser := rc.connectionCloserContext
+ rc.Connection = nil
+ rc.Channel = nil
+ rc.Connected = false
+ logger := rc.logger()
+ rc.mu.Unlock()
+
+ var closeErr error
+
+ if channel != nil {
+ if err := chCloser(ctx, channel); err != nil {
+ closeErr = fmt.Errorf("failed to close rabbitmq channel: %w", err)
+ logger.Log(ctx, log.LevelWarn, "failed to close rabbitmq channel", log.Err(err))
}
}
- return rc.Channel, nil
+ if connection != nil {
+ if err := connCloser(ctx, connection); err != nil {
+ if closeErr == nil {
+ closeErr = fmt.Errorf("failed to close rabbitmq connection: %w", err)
+ } else {
+ closeErr = errors.Join(closeErr, fmt.Errorf("failed to close rabbitmq connection: %w", err))
+ }
+
+ logger.Log(ctx, log.LevelWarn, "failed to close rabbitmq connection", log.Err(err))
+ }
+ }
+
+ if closeErr != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to close rabbitmq", closeErr)
+ }
+
+ return closeErr
+}
+
+func (rc *RabbitMQConnection) logger() log.Logger {
+ if rc == nil {
+ return &log.NopLogger{}
+ }
+
+ // Use reflect-based typed-nil detection: an interface can be non-nil at the
+ // interface level while holding a nil concrete pointer (typed-nil). Calling
+ // methods on a typed-nil logger will panic. The nilcheck package handles this.
+ if nilcheck.Interface(rc.Logger) {
+ return &log.NopLogger{}
+ }
+
+ return rc.Logger
}
-// HealthCheck rabbitmq when the server is started
-func (rc *RabbitMQConnection) HealthCheck() bool {
- healthURL := rc.HealthCheckURL + "/api/health/checks/alarms"
+// healthCheckURLConfig holds validation parameters for health check URL checking.
+type healthCheckURLConfig struct {
+ allowInsecure bool
+ hasBasicAuth bool
+ allowedHosts []string
+ derivedAllowedHosts []string
+ allowlistConfigured bool
+ requireAllowedHosts bool
+}
+
+// validateHealthCheckURLWithConfig validates the health check URL and appends the RabbitMQ health endpoint path
+// if not already present. The HealthCheckURL should be the RabbitMQ management API base URL
+// (e.g., "http://host:15672" or "https://host:15672"), NOT the full health endpoint.
+// If the URL already ends with "/api/health/checks/alarms", it is returned as-is.
+func validateHealthCheckURLWithConfig(rawURL string, cfg healthCheckURLConfig) (string, error) {
+ cfg = normalizeHealthCheckURLConfig(cfg)
- req, err := http.NewRequest(http.MethodGet, healthURL, nil)
+ parsedURL, err := parseAndValidateHealthCheckBaseURL(rawURL, cfg)
if err != nil {
- rc.Logger.Errorf("failed to make GET request before client do: %v", err.Error())
+ return "", err
+ }
- return false
+ enforceHosts, hostsToEnforce, err := resolveHealthCheckAllowedHosts(cfg)
+ if err != nil {
+ return "", err
+ }
+
+ if err := validateHealthCheckHostAllowlist(parsedURL.Host, enforceHosts, hostsToEnforce); err != nil {
+ return "", err
+ }
+
+ return normalizeHealthCheckEndpointPath(parsedURL), nil
+}
+
+func normalizeHealthCheckURLConfig(cfg healthCheckURLConfig) healthCheckURLConfig {
+ if !cfg.allowlistConfigured && len(cfg.allowedHosts) > 0 {
+ cfg.allowlistConfigured = true
}
- req.SetBasicAuth(rc.User, rc.Pass)
+ return cfg
+}
- client := &http.Client{}
+func parseAndValidateHealthCheckBaseURL(rawURL string, cfg healthCheckURLConfig) (*url.URL, error) {
+ healthURL := strings.TrimSpace(rawURL)
+ if healthURL == "" {
+ return nil, errors.New("rabbitmq health check URL is empty")
+ }
- resp, err := client.Do(req) //#nosec G704 -- HealthCheckURL is operator-configured, not user input
+ parsedURL, err := url.Parse(healthURL)
if err != nil {
- rc.Logger.Errorf("failed to make GET request after client do: %v", err.Error())
+ return nil, err
+ }
- return false
+ if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
+ return nil, errors.New("rabbitmq health check URL must use http or https")
}
- defer resp.Body.Close()
+ if parsedURL.Host == "" {
+ return nil, errors.New("rabbitmq health check URL must include a host")
+ }
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- rc.Logger.Errorf("failed to read response body: %v", err.Error())
+ if parsedURL.User != nil {
+ return nil, errors.New("rabbitmq health check URL must not include user credentials")
+ }
- return false
+ if parsedURL.Scheme == "http" && cfg.hasBasicAuth && !cfg.allowInsecure {
+ return nil, ErrInsecureHealthCheck
}
- var result map[string]any
+ return parsedURL, nil
+}
- err = json.Unmarshal(body, &result)
+func resolveHealthCheckAllowedHosts(cfg healthCheckURLConfig) (bool, []string, error) {
+ if cfg.requireAllowedHosts && (!cfg.allowlistConfigured || len(cfg.allowedHosts) == 0) {
+ return false, nil, ErrHealthCheckAllowedHostsRequired
+ }
+
+ enforceHosts := cfg.allowlistConfigured
+ hostsToEnforce := cfg.allowedHosts
+
+ if cfg.hasBasicAuth && !cfg.allowInsecure {
+ switch {
+ case len(cfg.allowedHosts) > 0:
+ return true, cfg.allowedHosts, nil
+ case len(cfg.derivedAllowedHosts) > 0:
+ return true, cfg.derivedAllowedHosts, nil
+ default:
+ return false, nil, ErrHealthCheckAllowedHostsRequired
+ }
+ }
+
+ return enforceHosts, hostsToEnforce, nil
+}
+
+func validateHealthCheckHostAllowlist(host string, enforceHosts bool, allowedHosts []string) error {
+ if !enforceHosts {
+ return nil
+ }
+
+ if !isHostAllowed(host, allowedHosts) {
+ return fmt.Errorf("%w: %s", ErrHealthCheckHostNotAllowed, host)
+ }
+
+ return nil
+}
+
+func normalizeHealthCheckEndpointPath(parsedURL *url.URL) string {
+ const healthPath = "/api/health/checks/alarms"
+
+ normalized := strings.TrimSuffix(parsedURL.String(), "/")
+ if strings.HasSuffix(normalized, healthPath) {
+ return normalized
+ }
+
+ return normalized + healthPath
+}
+
+func isHostAllowed(host string, allowedHosts []string) bool {
+ hostName, hostPort := splitHostPortOrHost(host)
+ targetAddr, targetIsIP := parseNormalizedAddr(hostName)
+
+ for _, allowed := range allowedHosts {
+ allowed = strings.TrimSpace(allowed)
+ if allowed == "" {
+ continue
+ }
+
+ allowedName, allowedPort := splitHostPortOrHost(allowed)
+ if !isAllowedHostMatch(hostName, targetAddr, targetIsIP, allowedName) {
+ continue
+ }
+
+ if allowedPort == "" || strings.EqualFold(hostPort, allowedPort) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func isAllowedHostMatch(hostName string, hostAddr netip.Addr, hostIsIP bool, allowedName string) bool {
+ if prefix, ok := parseNormalizedPrefix(allowedName); ok {
+ return hostIsIP && prefix.Contains(hostAddr)
+ }
+
+ allowedAddr, allowedIsIP := parseNormalizedAddr(allowedName)
+
+ if hostIsIP && allowedIsIP {
+ return hostAddr == allowedAddr
+ }
+
+ if !hostIsIP && !allowedIsIP {
+ return strings.EqualFold(hostName, allowedName)
+ }
+
+ return false
+}
+
+func parseNormalizedAddr(value string) (netip.Addr, bool) {
+ trimmed := strings.Trim(strings.TrimSpace(value), "[]")
+ if trimmed == "" {
+ return netip.Addr{}, false
+ }
+
+ addr, err := netip.ParseAddr(trimmed)
+ if err != nil {
+ return netip.Addr{}, false
+ }
+
+ return addr.Unmap(), true
+}
+
+func parseNormalizedPrefix(value string) (netip.Prefix, bool) {
+ trimmed := strings.Trim(strings.TrimSpace(value), "[]")
+ if trimmed == "" {
+ return netip.Prefix{}, false
+ }
+
+ prefix, err := netip.ParsePrefix(trimmed)
if err != nil {
- rc.Logger.Errorf("failed to unmarshal response: %v", err.Error())
+ return netip.Prefix{}, false
+ }
+
+ return prefix.Masked(), true
+}
+
+func splitHostPortOrHost(value string) (string, string) {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ return "", ""
+ }
+
+ host, port, err := net.SplitHostPort(trimmed)
+ if err == nil {
+ return strings.Trim(host, "[]"), port
+ }
+
+ return strings.Trim(trimmed, "[]"), ""
+}
+
+// deriveAllowedHostsFromHealthCheckURL extracts the host (and host:port) from
+// a health check URL to add to the derived allowlist. This ensures that when
+// basic-auth credentials are configured, the health check is at minimum
+// restricted to its own configured URL host, preventing SSRF even without an
+// explicit allowlist.
+func deriveAllowedHostsFromHealthCheckURL(healthCheckURL string) []string {
+ trimmed := strings.TrimSpace(healthCheckURL)
+ if trimmed == "" {
+ return nil
+ }
+
+ parsedURL, err := url.Parse(trimmed)
+ if err != nil || parsedURL == nil || parsedURL.Host == "" {
+ return nil
+ }
+
+ hostName, _ := splitHostPortOrHost(parsedURL.Host)
+
+ return mergeAllowedHosts(nil, parsedURL.Host, hostName)
+}
+
+func deriveAllowedHostsFromConnectionString(connectionString string) []string {
+ trimmed := strings.TrimSpace(connectionString)
+ if trimmed == "" {
+ return nil
+ }
+
+ parsedURL, err := url.Parse(trimmed)
+ if err != nil || parsedURL == nil || parsedURL.Host == "" {
+ return nil
+ }
+
+ hostName, _ := splitHostPortOrHost(parsedURL.Host)
+
+ return mergeAllowedHosts(nil, parsedURL.Host, hostName)
+}
+
+func deriveAllowedHostsFromHostPort(host, port string) []string {
+ host = strings.TrimSpace(host)
+ if host == "" {
+ return nil
+ }
+
+ if strings.TrimSpace(port) == "" {
+ return mergeAllowedHosts(nil, host)
+ }
+
+ return mergeAllowedHosts(nil, net.JoinHostPort(host, port), host)
+}
+
+func mergeAllowedHosts(base []string, additional ...string) []string {
+ if len(base) == 0 && len(additional) == 0 {
+ return nil
+ }
+
+ merged := make([]string, 0, len(base)+len(additional))
+ seen := make(map[string]struct{}, len(base)+len(additional))
+
+ for _, host := range append(append([]string(nil), base...), additional...) {
+ trimmed := strings.TrimSpace(host)
+ if trimmed == "" {
+ continue
+ }
+
+ key := strings.ToLower(trimmed)
+ if _, exists := seen[key]; exists {
+ continue
+ }
+
+ seen[key] = struct{}{}
+
+ merged = append(merged, trimmed)
+ }
+
+ if len(merged) == 0 {
+ return nil
+ }
+
+ return merged
+}
+
+// sanitizedError wraps an original error with a redacted message.
+// Error() returns the sanitized message; Unwrap() returns the original
+// so that errors.Is / errors.As still work for programmatic inspection.
+type sanitizedError struct {
+ original error
+ message string
+}
+
+// Error returns the sanitized message.
+func (e *sanitizedError) Error() string { return e.message }
+
+// Unwrap returns the original wrapped error.
+func (e *sanitizedError) Unwrap() error { return e.original }
+
+// newSanitizedError wraps err with a human-readable prefix and redacted connection string.
+func newSanitizedError(err error, connectionString, prefix string) error {
+ return fmt.Errorf("%s: %w", prefix, &sanitizedError{
+ original: err,
+ message: sanitizeAMQPErr(err, connectionString),
+ })
+}
+
+func sanitizeAMQPErr(err error, connectionString string) string {
+ if err == nil {
+ return ""
+ }
+
+ errMsg := err.Error()
+
+ if connectionString == "" {
+ return redactURLCredentials(errMsg)
+ }
+
+ referenceURL, parseErr := url.Parse(connectionString)
+ if parseErr != nil {
+ return redactURLCredentials(errMsg)
+ }
+
+ redactedURL := referenceURL.Redacted()
+
+ if strings.Contains(errMsg, connectionString) {
+ errMsg = strings.ReplaceAll(errMsg, connectionString, redactedURL)
+ }
+ if strings.Contains(errMsg, referenceURL.String()) {
+ errMsg = strings.ReplaceAll(errMsg, referenceURL.String(), redactedURL)
+ }
+
+ // Redact decoded password individually — covers cases where the error message
+ // contains the password in decoded form (e.g., URL-encoded special characters).
+ if referenceURL.User != nil {
+ if pass, ok := referenceURL.User.Password(); ok && pass != "" {
+ errMsg = strings.ReplaceAll(errMsg, pass, redactedURLPassword)
+ }
+ }
+
+ return redactURLCredentials(errMsg)
+}
+
+func redactURLCredentials(message string) string {
+ if message == "" {
+ return ""
+ }
+
+ return urlPattern.ReplaceAllStringFunc(message, redactURLCredentialsCandidate)
+}
+
+func redactURLCredentialsCandidate(candidate string) string {
+ core, suffix := splitTrailingURLPunctuation(candidate)
+
+ return redactURLCredentialToken(core) + suffix
+}
+
+func splitTrailingURLPunctuation(candidate string) (string, string) {
+ end := len(candidate)
+
+ for end > 0 {
+ switch candidate[end-1] {
+ case '.', ',', ';', ')', ']', '}', '"', '\'':
+ end--
+ default:
+ return candidate[:end], candidate[end:]
+ }
+ }
+
+ return "", candidate
+}
+
+func redactURLCredentialToken(token string) string {
+ if token == "" {
+ return ""
+ }
+
+ parsedURL, err := url.Parse(token)
+ if err == nil && parsedURL != nil && parsedURL.User != nil {
+ username := parsedURL.User.Username()
+ if _, hasPassword := parsedURL.User.Password(); hasPassword {
+ parsedURL.User = url.UserPassword(username, redactedURLPassword)
+
+ return parsedURL.String()
+ }
+
+ return token
+ }
+
+ return redactURLCredentialsFallback(token)
+}
+
+func redactURLCredentialsFallback(token string) string {
+ schemeSeparator := strings.Index(token, "://")
+ if schemeSeparator == -1 {
+ return token
+ }
+
+ rest := token[schemeSeparator+3:]
+ authorityEnd := len(rest)
+
+ for i := 0; i < len(rest); i++ {
+ switch rest[i] {
+ case '/', '?', '#':
+ authorityEnd = i
+ i = len(rest)
+ }
+ }
+
+ atIndex := strings.LastIndex(rest[:authorityEnd], "@")
+ if atIndex == -1 && authorityEnd < len(rest) {
+ candidate := rest[:authorityEnd]
+ if separator := strings.LastIndex(candidate, ":"); separator > 0 {
+ tail := candidate[separator+1:]
+ if tail != "" && !allDigits(tail) {
+ atIndex = strings.LastIndex(rest, "@")
+ }
+ }
+ }
+
+ if atIndex == -1 {
+ return token
+ }
+
+ userinfo := rest[:atIndex]
+ hostAndSuffix := rest[atIndex+1:]
+
+ if hostAndSuffix == "" {
+ return token
+ }
+
+ username, _, found := strings.Cut(userinfo, ":")
+ if !found {
+ return token
+ }
+
+ return token[:schemeSeparator+3] + username + ":" + redactedURLPassword + "@" + hostAndSuffix
+}
+
+func allDigits(value string) bool {
+ if value == "" {
return false
}
- if status, ok := result["status"].(string); ok && status == "ok" {
- return true
+ for i := range len(value) {
+ if value[i] < '0' || value[i] > '9' {
+ return false
+ }
}
- rc.Logger.Error("rabbitmq unhealthy...")
+ return true
+}
- return false
+// recordConnectionFailure increments the rabbitmq connection failure counter.
+// No-op when MetricsFactory is nil.
+func (rc *RabbitMQConnection) recordConnectionFailure(operation string) {
+ if rc == nil || rc.MetricsFactory == nil {
+ return
+ }
+
+ counter, err := rc.MetricsFactory.Counter(connectionFailuresMetric)
+ if err != nil {
+ rc.logger().Log(context.Background(), log.LevelWarn, "failed to create rabbitmq metric counter", log.Err(err))
+ return
+ }
+
+ err = counter.
+ WithLabels(map[string]string{
+ "operation": constant.SanitizeMetricLabel(operation),
+ }).
+ AddOne(context.Background())
+ if err != nil {
+ rc.logger().Log(context.Background(), log.LevelWarn, "failed to record rabbitmq metric", log.Err(err))
+ }
}
// BuildRabbitMQConnectionString constructs an AMQP connection string.
@@ -363,4 +1434,4 @@ func BuildRabbitMQConnectionString(protocol, user, pass, host, port, vhost strin
}
return u.String()
-}
\ No newline at end of file
+}
diff --git a/commons/rabbitmq/rabbitmq_integration_test.go b/commons/rabbitmq/rabbitmq_integration_test.go
new file mode 100644
index 00000000..be5db124
--- /dev/null
+++ b/commons/rabbitmq/rabbitmq_integration_test.go
@@ -0,0 +1,235 @@
+//go:build integration
+
+package rabbitmq
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ amqp "github.com/rabbitmq/amqp091-go"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go"
+ tcrabbit "github.com/testcontainers/testcontainers-go/modules/rabbitmq"
+ "github.com/testcontainers/testcontainers-go/wait"
+)
+
+const (
+ testRabbitMQImage = "rabbitmq:3-management-alpine"
+ testRabbitMQUser = "guest"
+ testRabbitMQPass = "guest"
+ testStartupTimeout = 60 * time.Second
+ testConsumeDeadline = 10 * time.Second
+)
+
+// setupRabbitMQContainer starts a RabbitMQ testcontainer with the management plugin
+// and returns the AMQP URL, the management HTTP URL, and a cleanup function.
+func setupRabbitMQContainer(t *testing.T) (amqpURL string, mgmtURL string, cleanup func()) {
+ t.Helper()
+
+ ctx := context.Background()
+
+ container, err := tcrabbit.Run(ctx,
+ testRabbitMQImage,
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("Server startup complete").
+ WithStartupTimeout(testStartupTimeout),
+ ),
+ )
+ require.NoError(t, err, "failed to start RabbitMQ container")
+
+ amqpEndpoint, err := container.AmqpURL(ctx)
+ require.NoError(t, err, "failed to get AMQP URL from container")
+
+ httpEndpoint, err := container.HttpURL(ctx)
+ require.NoError(t, err, "failed to get HTTP management URL from container")
+
+ return amqpEndpoint, httpEndpoint, func() {
+ require.NoError(t, container.Terminate(ctx), "failed to terminate RabbitMQ container")
+ }
+}
+
+// newTestConnection creates a RabbitMQConnection configured for integration testing.
+func newTestConnection(amqpURL, mgmtURL string) *RabbitMQConnection {
+ return &RabbitMQConnection{
+ ConnectionStringSource: amqpURL,
+ HealthCheckURL: mgmtURL,
+ User: testRabbitMQUser,
+ Pass: testRabbitMQPass,
+ AllowInsecureHealthCheck: true,
+ Logger: log.NewNop(),
+ }
+}
+
+func TestIntegration_RabbitMQ_ConnectAndClose(t *testing.T) {
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ // Connect to the real RabbitMQ instance.
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed against a live broker")
+
+ assert.True(t, rc.Connected, "Connected flag should be true after successful connection")
+ assert.NotNil(t, rc.Connection, "Connection should be non-nil after connect")
+ assert.NotNil(t, rc.Channel, "Channel should be non-nil after connect")
+
+ // Close the connection and verify state is reset.
+ err = rc.CloseContext(ctx)
+ require.NoError(t, err, "CloseContext should succeed")
+
+ assert.False(t, rc.Connected, "Connected flag should be false after close")
+ assert.Nil(t, rc.Connection, "Connection should be nil after close")
+ assert.Nil(t, rc.Channel, "Channel should be nil after close")
+}
+
+func TestIntegration_RabbitMQ_HealthCheck(t *testing.T) {
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ // Connect first — health check needs a configured connection object.
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() {
+ _ = rc.CloseContext(ctx)
+ }()
+
+ // Run health check against the management API.
+ healthy, err := rc.HealthCheckContext(ctx)
+ require.NoError(t, err, "HealthCheckContext should not return an error for a healthy broker")
+ assert.True(t, healthy, "HealthCheckContext should report true for a running broker")
+}
+
+func TestIntegration_RabbitMQ_EnsureChannel(t *testing.T) {
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() {
+ _ = rc.CloseContext(ctx)
+ }()
+
+ // Close the channel explicitly to simulate a lost channel.
+ require.NotNil(t, rc.Channel, "Channel should exist after connect")
+
+ err = rc.Channel.Close()
+ require.NoError(t, err, "explicit channel close should succeed")
+
+ // EnsureChannelContext should detect the closed channel and recover it.
+ err = rc.EnsureChannelContext(ctx)
+ require.NoError(t, err, "EnsureChannelContext should recover a closed channel")
+
+ assert.True(t, rc.Connected, "Connected flag should be true after channel recovery")
+ assert.NotNil(t, rc.Channel, "Channel should be non-nil after recovery")
+ assert.False(t, rc.Channel.IsClosed(), "Recovered channel should not be closed")
+}
+
+func TestIntegration_RabbitMQ_GetNewConnect(t *testing.T) {
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() {
+ _ = rc.CloseContext(ctx)
+ }()
+
+ // GetNewConnectContext returns the active channel.
+ ch, err := rc.GetNewConnectContext(ctx)
+ require.NoError(t, err, "GetNewConnectContext should succeed on a connected instance")
+ assert.NotNil(t, ch, "returned channel should not be nil")
+ assert.False(t, ch.IsClosed(), "returned channel should be open")
+}
+
+func TestIntegration_RabbitMQ_PublishAndConsume(t *testing.T) {
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() {
+ _ = rc.CloseContext(ctx)
+ }()
+
+ ch, err := rc.GetNewConnectContext(ctx)
+ require.NoError(t, err, "GetNewConnectContext should succeed")
+
+ // Declare a test queue.
+ queueName := fmt.Sprintf("integration-test-queue-%d", time.Now().UnixNano())
+
+ q, err := ch.QueueDeclare(
+ queueName,
+ false, // durable
+ true, // autoDelete
+ false, // exclusive
+ false, // noWait
+ nil, // args
+ )
+ require.NoError(t, err, "QueueDeclare should succeed")
+
+ // Publish a message.
+ messageBody := []byte("hello from integration test")
+
+ publishCtx, publishCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer publishCancel()
+
+ err = ch.PublishWithContext(
+ publishCtx,
+ "", // exchange (default)
+ q.Name, // routing key = queue name
+ false, // mandatory
+ false, // immediate
+ amqp.Publishing{
+ ContentType: "text/plain",
+ Body: messageBody,
+ },
+ )
+ require.NoError(t, err, "PublishWithContext should succeed")
+
+ // Consume the message.
+ msgs, err := ch.Consume(
+ q.Name, // queue
+ "", // consumer tag (auto-generated)
+ true, // autoAck
+ false, // exclusive
+ false, // noLocal
+ false, // noWait
+ nil, // args
+ )
+ require.NoError(t, err, "Consume should succeed")
+
+ // Wait for the message with a deadline to avoid hanging forever.
+ consumeCtx, consumeCancel := context.WithTimeout(ctx, testConsumeDeadline)
+ defer consumeCancel()
+
+ select {
+ case msg, ok := <-msgs:
+ require.True(t, ok, "message channel should deliver a message")
+ assert.Equal(t, messageBody, msg.Body, "consumed message body should match published body")
+ assert.Equal(t, "text/plain", msg.ContentType, "content type should match")
+ case <-consumeCtx.Done():
+ t.Fatal("timed out waiting for message from RabbitMQ")
+ }
+}
diff --git a/commons/rabbitmq/rabbitmq_test.go b/commons/rabbitmq/rabbitmq_test.go
index a345781b..83961436 100644
--- a/commons/rabbitmq/rabbitmq_test.go
+++ b/commons/rabbitmq/rabbitmq_test.go
@@ -1,262 +1,1238 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package rabbitmq
import (
"context"
+ "crypto/tls"
+ "errors"
"net/http"
"net/http/httptest"
- "strings"
+ "net/url"
+ "sync"
+ "sync/atomic"
"testing"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ amqp "github.com/rabbitmq/amqp091-go"
"github.com/stretchr/testify/assert"
)
-// mockRabbitMQConnection extends RabbitMQConnection to allow mocking for tests
-type mockRabbitMQConnection struct {
- RabbitMQConnection
- connectError bool
- healthyResponse bool
- authFails bool
+func TestRabbitMQConnection_Connect(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var conn *RabbitMQConnection
+
+ err := conn.ConnectContext(context.Background())
+ assert.ErrorIs(t, err, ErrNilConnection)
+ })
+
+ t.Run("context canceled before connect", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@localhost:5672",
+ Logger: &log.NopLogger{},
+ dialerContext: func(context.Context, string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return &amqp.Connection{}, nil
+ },
+ }
+
+ err := conn.ConnectContext(ctx)
+
+ assert.ErrorIs(t, err, context.Canceled)
+ assert.Equal(t, 0, dialerCalls)
+ })
+
+ t.Run("dial error", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@localhost:5672",
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return nil, errors.New("dial failed")
+ },
+ }
+
+ err := conn.Connect()
+
+ assert.Error(t, err)
+ assert.False(t, conn.Connected)
+ assert.Nil(t, conn.Connection)
+ assert.Nil(t, conn.Channel)
+ assert.Equal(t, 1, dialerCalls)
+ assert.ErrorContains(t, err, "dial failed")
+ })
+
+ t.Run("channel error closes connection", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+ closeCalls := 0
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@localhost:5672",
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return &amqp.Connection{}, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return nil, errors.New("channel failed")
+ },
+ connectionCloser: func(*amqp.Connection) error {
+ closeCalls++
+
+ return nil
+ },
+ }
+
+ err := conn.Connect()
+
+ assert.Error(t, err)
+ assert.False(t, conn.Connected)
+ assert.Nil(t, conn.Connection)
+ assert.Nil(t, conn.Channel)
+ assert.Equal(t, 1, dialerCalls)
+ assert.Equal(t, 1, closeCalls)
+ })
+
+ t.Run("health check failure resets connection", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+ closeCalls := 0
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, err := w.Write([]byte(`{"status":"error"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@localhost:5672",
+ HealthCheckURL: healthServer.URL,
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return &amqp.Connection{}, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return &amqp.Channel{}, nil
+ },
+ connectionCloser: func(conn *amqp.Connection) error {
+ closeCalls++
+
+ return nil
+ },
+ }
+
+ err := conn.Connect()
+
+ assert.Error(t, err)
+ assert.False(t, conn.Connected)
+ assert.Nil(t, conn.Connection)
+ assert.Nil(t, conn.Channel)
+ assert.Equal(t, 1, dialerCalls)
+ assert.Equal(t, 1, closeCalls)
+ })
+
+ t.Run("healthy server creates connection", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@localhost:5672",
+ HealthCheckURL: healthServer.URL,
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return &amqp.Connection{}, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(*amqp.Connection) bool { return false },
+ channelClosedFn: func(*amqp.Channel) bool { return false },
+ }
+
+ err := conn.Connect()
+
+ assert.NoError(t, err)
+ assert.True(t, conn.Connected)
+ assert.NotNil(t, conn.Connection)
+ assert.NotNil(t, conn.Channel)
+ assert.Equal(t, 1, dialerCalls)
+ })
+
+ t.Run("does not hold lock while running health check", func(t *testing.T) {
+ healthStarted := make(chan struct{})
+ continueHealth := make(chan struct{})
+ dialerCalls := int32(0)
+
+ var once sync.Once
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ once.Do(func() { close(healthStarted) })
+
+ <-continueHealth
+
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@localhost:5672",
+ HealthCheckURL: healthServer.URL,
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ atomic.AddInt32(&dialerCalls, 1)
+
+ return &amqp.Connection{}, nil
+ },
+ connectionCloser: func(*amqp.Connection) error {
+ return nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(*amqp.Connection) bool { return false },
+ channelClosedFn: func(*amqp.Channel) bool { return false },
+ }
+
+ connectDone := make(chan error, 1)
+ go func() {
+ connectDone <- conn.Connect()
+ }()
+
+ select {
+ case <-healthStarted:
+ case err := <-connectDone:
+ t.Fatalf("connect completed before health check request started: %v", err)
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for health check request to start")
+ }
+
+ ensureDone := make(chan error, 1)
+ go func() {
+ ensureDone <- conn.EnsureChannel()
+ }()
+
+ assert.Eventually(t, func() bool {
+ return atomic.LoadInt32(&dialerCalls) >= 2
+ }, 200*time.Millisecond, 10*time.Millisecond)
+
+ close(continueHealth)
+
+ select {
+ case err := <-connectDone:
+ assert.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("connect did not complete")
+ }
+
+ select {
+ case err := <-ensureDone:
+ assert.NoError(t, err)
+ case <-time.After(time.Second):
+ t.Fatal("ensure channel did not complete")
+ }
+ })
+
+ t.Run("nil logger is safe", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@localhost:5672",
+ dialer: func(string) (*amqp.Connection, error) {
+ return nil, errors.New("dial failed")
+ },
+ }
+
+ assert.NotPanics(t, func() {
+ _ = conn.Connect()
+ })
+ })
}
-func (m *mockRabbitMQConnection) setupMockServer() *httptest.Server {
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Check basic auth
- username, password, ok := r.BasicAuth()
- if !ok || username != m.User || password != m.Pass {
- // When auth fails, return a 200 but with error status in JSON
- // This tests how the HealthCheck method parses the response
- w.Header().Set("Content-Type", "application/json")
- w.Write([]byte(`{"status":"not_authorized"}`))
- return
- }
+func TestRabbitMQConnection_EnsureChannel(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var conn *RabbitMQConnection
+
+ err := conn.EnsureChannelContext(context.Background())
+ assert.ErrorIs(t, err, ErrNilConnection)
+ })
+
+ t.Run("creates connection and channel when missing", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+ channelCalls := 0
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return &amqp.Connection{}, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ channelCalls++
+
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(connection *amqp.Connection) bool { return connection == nil },
+ channelClosedFn: func(ch *amqp.Channel) bool { return ch == nil },
+ }
+
+ err := conn.EnsureChannel()
+
+ assert.NoError(t, err)
+ assert.True(t, conn.Connected)
+ assert.NotNil(t, conn.Connection)
+ assert.NotNil(t, conn.Channel)
+ assert.Equal(t, 1, dialerCalls)
+ assert.Equal(t, 1, channelCalls)
+ })
+
+ t.Run("reuses open connection and channel", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+ channelCalls := 0
+
+ conn := &RabbitMQConnection{
+ Connection: &amqp.Connection{},
+ Channel: &amqp.Channel{},
+ Connected: true,
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return nil, errors.New("should not be called")
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ channelCalls++
+
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(*amqp.Connection) bool { return false },
+ channelClosedFn: func(*amqp.Channel) bool { return false },
+ }
+
+ err := conn.EnsureChannel()
+
+ assert.NoError(t, err)
+ assert.True(t, conn.Connected)
+ assert.Equal(t, 0, dialerCalls)
+ assert.Equal(t, 0, channelCalls)
+ })
+
+ t.Run("reopens channel when closed", func(t *testing.T) {
+ t.Parallel()
+
+ channelCalls := 0
+
+ conn := &RabbitMQConnection{
+ Connection: &amqp.Connection{},
+ Channel: &amqp.Channel{},
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ return nil, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ channelCalls++
+
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(*amqp.Connection) bool { return false },
+ channelClosedFn: func(ch *amqp.Channel) bool { return ch != nil },
+ }
+
+ err := conn.EnsureChannel()
+
+ assert.NoError(t, err)
+ assert.True(t, conn.Connected)
+ assert.Equal(t, 1, channelCalls)
+ assert.NotNil(t, conn.Channel)
+ })
+
+ t.Run("context canceled before ensure channel", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return &amqp.Connection{}, nil
+ },
+ }
+
+ err := conn.EnsureChannelContext(ctx)
+
+ assert.ErrorIs(t, err, context.Canceled)
+ assert.Equal(t, 0, dialerCalls)
+ })
+
+ t.Run("nil context defaults to background", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ Connection: &amqp.Connection{},
+ Channel: &amqp.Channel{},
+ Connected: true,
+ Logger: &log.NopLogger{},
+ connectionClosedFn: func(*amqp.Connection) bool { return false },
+ channelClosedFn: func(*amqp.Channel) bool { return false },
+ }
+
+ assert.NotPanics(t, func() {
+ //nolint:staticcheck // intentionally passing nil context
+ err := conn.EnsureChannelContext(nil)
+ assert.NoError(t, err)
+ })
+ })
+
+ t.Run("resets stale connection on channel failure", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+ closeCalls := 0
+
+ connection := &amqp.Connection{}
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return connection, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return nil, errors.New("failed to open")
+ },
+ connectionCloser: func(*amqp.Connection) error {
+ closeCalls++
+
+ return nil
+ },
+ connectionClosedFn: func(*amqp.Connection) bool { return true },
+ channelClosedFn: func(*amqp.Channel) bool { return true },
+ }
+
+ err := conn.EnsureChannel()
+
+ assert.Error(t, err)
+ assert.False(t, conn.Connected)
+ assert.Nil(t, conn.Connection)
+ assert.Nil(t, conn.Channel)
+ assert.Equal(t, 1, dialerCalls)
+ assert.Equal(t, 1, closeCalls)
+ })
+}
+
+func TestRabbitMQConnection_GetNewConnect(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var conn *RabbitMQConnection
+
+ ch, err := conn.GetNewConnectContext(context.Background())
+ assert.ErrorIs(t, err, ErrNilConnection)
+ assert.Nil(t, ch)
+ })
+
+ t.Run("context canceled before connect", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{}
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ got, err := conn.GetNewConnectContext(ctx)
+
+ assert.ErrorIs(t, err, context.Canceled)
+ assert.Nil(t, got)
+ })
+
+ t.Run("creates channel when not connected", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := int32(0)
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ atomic.AddInt32(&dialerCalls, 1)
+
+ return &amqp.Connection{}, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(connection *amqp.Connection) bool { return connection == nil },
+ channelClosedFn: func(ch *amqp.Channel) bool { return ch == nil },
+ }
+
+ channel, err := conn.GetNewConnect()
+
+ assert.NoError(t, err)
+ assert.NotNil(t, channel)
+ assert.Equal(t, int32(1), atomic.LoadInt32(&dialerCalls))
+ })
+
+ t.Run("reuses existing connected channel", func(t *testing.T) {
+ t.Parallel()
+
+ dialerCalls := 0
+ channelCalls := 0
+
+ existing := &amqp.Channel{}
+ conn := &RabbitMQConnection{
+ Connection: &amqp.Connection{},
+ Channel: existing,
+ Connected: true,
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ dialerCalls++
+
+ return nil, errors.New("should not be called")
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ channelCalls++
+
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(*amqp.Connection) bool { return false },
+ channelClosedFn: func(*amqp.Channel) bool { return false },
+ }
+
+ got, err := conn.GetNewConnect()
+
+ assert.NoError(t, err)
+ assert.Same(t, existing, got)
+ assert.Equal(t, 0, dialerCalls)
+ assert.Equal(t, 0, channelCalls)
+ })
+
+ t.Run("stale channel state returns error", func(t *testing.T) {
+ t.Parallel()
+
+ connection := &amqp.Connection{}
+ closeCalls := 0
+ conn := &RabbitMQConnection{
+ Connection: connection,
+ Channel: nil,
+ Connected: true,
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ return connection, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return nil, nil
+ },
+ connectionClosedFn: func(*amqp.Connection) bool { return true },
+ channelClosedFn: func(*amqp.Channel) bool { return true },
+ connectionCloser: func(*amqp.Connection) error {
+ closeCalls++
+
+ return nil
+ },
+ }
+
+ got, err := conn.GetNewConnect()
+
+ assert.Error(t, err)
+ assert.Nil(t, got)
+ assert.False(t, conn.Connected)
+ assert.Nil(t, conn.Connection)
+ assert.Nil(t, conn.Channel)
+ assert.Equal(t, 1, closeCalls)
+ })
+
+ t.Run("concurrent callers all succeed", func(t *testing.T) {
+ dialerCalls := int32(0)
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ dialer: func(string) (*amqp.Connection, error) {
+ atomic.AddInt32(&dialerCalls, 1)
+
+ return &amqp.Connection{}, nil
+ },
+ channelFactory: func(*amqp.Connection) (*amqp.Channel, error) {
+ return &amqp.Channel{}, nil
+ },
+ connectionClosedFn: func(connection *amqp.Connection) bool { return connection == nil },
+ channelClosedFn: func(ch *amqp.Channel) bool { return ch == nil },
+ }
+
+ const total = 10
+ results := make(chan error, total)
+
+ var wg sync.WaitGroup
+ wg.Add(total)
+ for i := 0; i < total; i++ {
+ go func() {
+ defer wg.Done()
+
+ _, err := conn.GetNewConnect()
+ results <- err
+ }()
+ }
+
+ wg.Wait()
+ close(results)
+
+ for err := range results {
+ assert.NoError(t, err)
+ }
+
+ // EnsureChannelContext releases the lock before dialing (to avoid holding it
+ // during I/O). Under contention, a small number of goroutines may race to dial
+ // before the first one finishes and updates the shared connection state. This is
+ // the expected trade-off — rare duplicate dials vs. convoy effect.
+ dials := atomic.LoadInt32(&dialerCalls)
+ assert.GreaterOrEqual(t, dials, int32(1))
+ assert.LessOrEqual(t, dials, int32(total))
+ assert.True(t, conn.Connected)
+ assert.NotNil(t, conn.Channel)
+ })
+}
+
+func TestRabbitMQConnection_HealthCheck(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil receiver", func(t *testing.T) {
+ t.Parallel()
+
+ var conn *RabbitMQConnection
+ healthy, err := conn.HealthCheckContext(context.Background())
+ assert.ErrorIs(t, err, ErrNilConnection)
+ assert.False(t, healthy)
+ })
+
+ t.Run("healthy response", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{
+ HealthCheckURL: healthServer.URL,
+ Logger: &log.NopLogger{},
+ }
+
+ healthy, err := conn.HealthCheck()
+ assert.NoError(t, err)
+ assert.True(t, healthy)
+ })
+
+ t.Run("returns defaults validation error", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ HealthCheckURL: "https://localhost:15672",
+ Logger: &log.NopLogger{},
+ healthHTTPClient: &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true, //nolint:gosec // intentional for validation test
+ },
+ },
+ },
+ }
+
+ healthy, err := conn.HealthCheckContext(context.Background())
+ assert.ErrorIs(t, err, ErrInsecureTLS)
+ assert.False(t, healthy)
+ })
+
+ t.Run("server returns error status", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, err := w.Write([]byte("err"))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{HealthCheckURL: healthServer.URL, Logger: &log.NopLogger{}}
+
+ healthy, err := conn.HealthCheck()
+ assert.Error(t, err)
+ assert.False(t, healthy)
+ })
+
+ t.Run("unhealthy response body", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte(`{"status":"error"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{HealthCheckURL: healthServer.URL, Logger: &log.NopLogger{}}
+
+ healthy, err := conn.HealthCheck()
+ assert.Error(t, err)
+ assert.False(t, healthy)
+ })
+
+ t.Run("malformed response", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte(`{"status":`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{HealthCheckURL: healthServer.URL, Logger: &log.NopLogger{}}
+
+ healthy, err := conn.HealthCheck()
+ assert.Error(t, err)
+ assert.False(t, healthy)
+ })
+
+ t.Run("null response", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte("null"))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ conn := &RabbitMQConnection{HealthCheckURL: healthServer.URL, Logger: &log.NopLogger{}}
+
+ healthy, err := conn.HealthCheck()
+ assert.Error(t, err)
+ assert.False(t, healthy)
+ })
+
+ t.Run("invalid URL returns false", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{HealthCheckURL: "http://[::1", Logger: &log.NopLogger{}}
+
+ healthy, err := conn.HealthCheck()
+ assert.Error(t, err)
+ assert.False(t, healthy)
+ })
+
+ t.Run("strict allowlist mode requires configured hosts", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ HealthCheckURL: "http://localhost:15672",
+ Logger: &log.NopLogger{},
+ RequireHealthCheckAllowedHosts: true,
+ }
+
+ healthy, err := conn.HealthCheck()
+ assert.ErrorIs(t, err, ErrHealthCheckAllowedHostsRequired)
+ assert.False(t, healthy)
+ })
+
+ t.Run("invalid URL scheme is rejected", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{HealthCheckURL: "ftp://localhost:15672", Logger: &log.NopLogger{}}
+
+ healthy, err := conn.HealthCheck()
+ assert.Error(t, err)
+ assert.False(t, healthy)
+ })
+
+ t.Run("context canceled before health check request", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ conn := &RabbitMQConnection{
+ HealthCheckURL: healthServer.URL,
+ Logger: &log.NopLogger{},
+ }
+
+ healthy, err := conn.HealthCheckContext(ctx)
+ assert.Error(t, err)
+ assert.False(t, healthy)
+ })
+
+ t.Run("authentication", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ username, password, ok := r.BasicAuth()
+ if !ok || username != "correct" || password != "correct" {
+ w.WriteHeader(http.StatusUnauthorized)
+
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ badAuth := &RabbitMQConnection{
+ HealthCheckURL: healthServer.URL,
+ User: "wrong",
+ Pass: "wrong",
+ Logger: &log.NopLogger{},
+ AllowInsecureHealthCheck: true,
+ }
+
+ goodAuth := &RabbitMQConnection{
+ HealthCheckURL: healthServer.URL,
+ User: "correct",
+ Pass: "correct",
+ Logger: &log.NopLogger{},
+ AllowInsecureHealthCheck: true,
+ }
+
+ badHealthy, badErr := badAuth.HealthCheck()
+ assert.Error(t, badErr)
+ assert.False(t, badHealthy)
+
+ goodHealthy, goodErr := goodAuth.HealthCheck()
+ assert.NoError(t, goodErr)
+ assert.True(t, goodHealthy)
+ })
+
+ t.Run("https basic auth without explicit allowlist derives host from connection string", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ username, password, ok := r.BasicAuth()
+ if !ok || username != "correct" || password != "correct" {
+ w.WriteHeader(http.StatusUnauthorized)
+
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ parsedURL, err := url.Parse(healthServer.URL)
+ assert.NoError(t, err)
+
+ conn := &RabbitMQConnection{
+ ConnectionStringSource: "amqp://guest:guest@" + parsedURL.Host,
+ HealthCheckURL: healthServer.URL,
+ User: "correct",
+ Pass: "correct",
+ Logger: &log.NopLogger{},
+ healthHTTPClient: healthServer.Client(),
+ AllowInsecureTLS: true,
+ }
+
+ healthy, healthErr := conn.HealthCheck()
+ assert.NoError(t, healthErr)
+ assert.True(t, healthy)
+ })
+
+ t.Run("healthCheck uses provided policy snapshot", func(t *testing.T) {
+ t.Parallel()
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
+ }))
+ defer healthServer.Close()
+
+ parsed, err := url.Parse(healthServer.URL)
+ assert.NoError(t, err)
+
+ conn := &RabbitMQConnection{
+ AllowInsecureHealthCheck: false,
+ HealthCheckAllowedHosts: []string{"blocked.example:15672"},
+ Logger: &log.NopLogger{},
+ }
+
+ err = conn.healthCheck(
+ context.Background(),
+ healthServer.URL,
+ "user",
+ "pass",
+ healthServer.Client(),
+ healthCheckURLConfig{
+ allowInsecure: true,
+ hasBasicAuth: true,
+ allowedHosts: []string{parsed.Host},
+ },
+ &log.NopLogger{},
+ )
+
+ assert.NoError(t, err)
+ })
+}
+
+func TestApplyDefaults_InsecureTLS(t *testing.T) {
+ t.Parallel()
+
+ t.Run("returns error when injected client disables TLS verification", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ healthHTTPClient: &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true, //nolint:gosec // intentional for test
+ },
+ },
+ },
+ }
+
+ conn.mu.Lock()
+ err := conn.applyDefaults()
+ conn.mu.Unlock()
+
+ assert.ErrorIs(t, err, ErrInsecureTLS)
+ })
+
+ t.Run("AllowInsecureTLS bypasses the check", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ healthHTTPClient: &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true, //nolint:gosec // intentional for test
+ },
+ },
+ },
+ AllowInsecureTLS: true,
+ }
+
+ conn.mu.Lock()
+ err := conn.applyDefaults()
+ conn.mu.Unlock()
+
+ assert.NoError(t, err)
+ })
+
+ t.Run("no error for default client", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ }
+
+ conn.mu.Lock()
+ err := conn.applyDefaults()
+ conn.mu.Unlock()
+
+ assert.NoError(t, err)
+ })
+
+ t.Run("no error for secure custom client", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ Logger: &log.NopLogger{},
+ healthHTTPClient: &http.Client{
+ Transport: &http.Transport{
+ TLSClientConfig: &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ },
+ },
+ },
+ }
+
+ conn.mu.Lock()
+ err := conn.applyDefaults()
+ conn.mu.Unlock()
+
+ assert.NoError(t, err)
+ })
+}
+
+func TestValidateHealthCheckURL(t *testing.T) {
+ t.Parallel()
+
+ t.Run("trims spaces and appends health path", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ HealthCheckURL: " http://localhost:15672 ",
+ Logger: &log.NopLogger{},
+ }
+
+ normalized, err := validateHealthCheckURLWithConfig(conn.HealthCheckURL, healthCheckURLConfig{})
+
+ assert.NoError(t, err)
+ assert.Equal(t, "http://localhost:15672/api/health/checks/alarms", normalized)
+ })
+
+ t.Run("preserves nested path and appends health endpoint", func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateHealthCheckURLWithConfig("http://localhost:15672/custom/alerts", healthCheckURLConfig{})
+
+ assert.NoError(t, err)
+ assert.Equal(t, "http://localhost:15672/custom/alerts/api/health/checks/alarms", normalized)
+ })
+
+ t.Run("normalizes path with trailing slash", func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateHealthCheckURLWithConfig("http://localhost:15672/custom/alerts/", healthCheckURLConfig{})
+
+ assert.NoError(t, err)
+ assert.Equal(t, "http://localhost:15672/custom/alerts/api/health/checks/alarms", normalized)
+ })
+
+ t.Run("requires host", func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateHealthCheckURLWithConfig("http:///api/health", healthCheckURLConfig{})
+
+ assert.Error(t, err)
+ assert.Empty(t, normalized)
+ })
+
+ t.Run("rejects unsupported scheme", func(t *testing.T) {
+ t.Parallel()
- // Set content type for JSON response
- w.Header().Set("Content-Type", "application/json")
+ normalized, err := validateHealthCheckURLWithConfig("ftp://localhost:15672", healthCheckURLConfig{})
- // Return appropriate status based on test case
- if m.healthyResponse {
- w.Write([]byte(`{"status":"ok"}`))
- } else {
- w.Write([]byte(`{"status":"error"}`))
- }
- }))
+ assert.Error(t, err)
+ assert.Empty(t, normalized)
+ })
- return server
-}
+ t.Run("rejects user credentials", func(t *testing.T) {
+ t.Parallel()
-func TestRabbitMQConnection_Connect(t *testing.T) {
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ normalized, err := validateHealthCheckURLWithConfig("http://user:pass@localhost:15672", healthCheckURLConfig{})
- // We can't easily test the actual connection in unit tests
- // So we'll focus on testing the error handling
+ assert.Error(t, err)
+ assert.Empty(t, normalized)
+ })
- tests := []struct {
- name string
- connectionString string
- expectError bool
- skipDetailedCheck bool
- }{
- {
- name: "invalid connection string",
- connectionString: "amqp://invalid-host:5672",
- expectError: true,
- skipDetailedCheck: true, // The detailed connection check would never be reached
- },
- {
- name: "valid format but unreachable",
- connectionString: "amqp://guest:guest@localhost:5999",
- expectError: true,
- skipDetailedCheck: true,
- },
- }
+ t.Run("rejects http with basic auth", func(t *testing.T) {
+ t.Parallel()
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- conn := &RabbitMQConnection{
- ConnectionStringSource: tt.connectionString,
- Logger: logger,
- }
+ _, err := validateHealthCheckURLWithConfig("http://localhost:15672", healthCheckURLConfig{
+ hasBasicAuth: true,
+ })
+ assert.ErrorIs(t, err, ErrInsecureHealthCheck)
+ })
- // This will always fail in a unit test environment without a real RabbitMQ
- // We're just testing the error handling
- err := conn.Connect()
-
- if tt.expectError {
- assert.Error(t, err)
- assert.False(t, conn.Connected)
- assert.Nil(t, conn.Channel)
- } else {
- // We don't expect this branch to be taken in unit tests
- assert.NoError(t, err)
- assert.True(t, conn.Connected)
- assert.NotNil(t, conn.Channel)
- }
+ t.Run("allows http with basic auth when opted in", func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateHealthCheckURLWithConfig("http://localhost:15672", healthCheckURLConfig{
+ hasBasicAuth: true,
+ allowInsecure: true,
})
- }
-}
+ assert.NoError(t, err)
+ assert.Contains(t, normalized, "/api/health/checks/alarms")
+ })
-func TestRabbitMQConnection_GetNewConnect(t *testing.T) {
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ t.Run("requires allowlist for https basic auth", func(t *testing.T) {
+ t.Parallel()
- t.Run("not connected - will try to connect", func(t *testing.T) {
- conn := &RabbitMQConnection{
- ConnectionStringSource: "amqp://guest:guest@localhost:5999", // Unreachable
- Logger: logger,
- Connected: false,
- }
+ _, err := validateHealthCheckURLWithConfig("https://rabbitmq:15671", healthCheckURLConfig{
+ hasBasicAuth: true,
+ })
+ assert.ErrorIs(t, err, ErrHealthCheckAllowedHostsRequired)
+ })
- ch, err := conn.GetNewConnect()
- assert.Error(t, err)
- assert.Nil(t, ch)
- assert.False(t, conn.Connected)
+ t.Run("allows https basic auth when host is derived from AMQP connection host", func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateHealthCheckURLWithConfig("https://rabbitmq:15671", healthCheckURLConfig{
+ hasBasicAuth: true,
+ allowedHosts: deriveAllowedHostsFromConnectionString("amqp://guest:guest@rabbitmq:5672"),
+ })
+ assert.NoError(t, err)
+ assert.Contains(t, normalized, "/api/health/checks/alarms")
})
- t.Run("already connected", func(t *testing.T) {
- // This test requires mocking the Channel which is difficult
- // since we can't create a real AMQP channel in a unit test
- t.Skip("Requires integration testing with a real RabbitMQ instance")
+ t.Run("strict allowlist mode still requires explicit configured list", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := validateHealthCheckURLWithConfig("https://rabbitmq:15671", healthCheckURLConfig{
+ hasBasicAuth: true,
+ derivedAllowedHosts: deriveAllowedHostsFromConnectionString("amqp://guest:guest@rabbitmq:5672"),
+ requireAllowedHosts: true,
+ })
+ assert.ErrorIs(t, err, ErrHealthCheckAllowedHostsRequired)
})
-}
-func TestRabbitMQConnection_HealthCheck(t *testing.T) {
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ t.Run("does not enforce derived hosts when basic auth is not used", func(t *testing.T) {
+ t.Parallel()
- tests := []struct {
- name string
- setupServer bool
- mockResponse string
- expectHealthy bool
- invalidRequest bool
- }{
- {
- name: "healthy server",
- setupServer: true,
- mockResponse: `{"status":"ok"}`,
- expectHealthy: true,
- },
- {
- name: "unhealthy server",
- setupServer: true,
- mockResponse: `{"status":"error"}`,
- expectHealthy: false,
- },
- {
- name: "invalid request",
- setupServer: false,
- invalidRequest: true,
- expectHealthy: false,
- },
- }
+ normalized, err := validateHealthCheckURLWithConfig("https://management.rabbitmq:15671", healthCheckURLConfig{
+ derivedAllowedHosts: deriveAllowedHostsFromConnectionString("amqp://guest:guest@rabbitmq:5672"),
+ })
+ assert.NoError(t, err)
+ assert.Contains(t, normalized, "/api/health/checks/alarms")
+ })
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- conn := &RabbitMQConnection{
- HealthCheckURL: "localhost",
- Host: "localhost",
- User: "worg",
- Pass: "pass",
- Logger: logger,
- }
+ t.Run("allows https basic auth without allowlist when explicitly insecure", func(t *testing.T) {
+ t.Parallel()
- if tt.invalidRequest {
- // Invalid host/port for request to fail
- conn.Host = "invalid::/host"
- conn.Port = "invalid"
+ normalized, err := validateHealthCheckURLWithConfig("https://rabbitmq:15671", healthCheckURLConfig{
+ hasBasicAuth: true,
+ allowInsecure: true,
+ })
+ assert.NoError(t, err)
+ assert.Contains(t, normalized, "/api/health/checks/alarms")
+ })
- isHealthy := conn.HealthCheck()
- assert.False(t, isHealthy)
- return
- }
+ t.Run("rejects host not in allowlist", func(t *testing.T) {
+ t.Parallel()
- if tt.setupServer {
- // Setup a test server that returns the mock response
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusOK)
- w.Write([]byte(tt.mockResponse))
- }))
- defer server.Close()
-
- // Parse the server URL to get host and port
- hostParts := strings.SplitN(server.URL, ":", 2)
- conn.Host = hostParts[0]
- if len(hostParts) > 1 {
- conn.Port = hostParts[1]
- }
- conn.HealthCheckURL = server.URL
-
- // Run the test
- isHealthy := conn.HealthCheck()
- assert.Equal(t, tt.expectHealthy, isHealthy)
- }
+ _, err := validateHealthCheckURLWithConfig("http://evil.example.com:15672", healthCheckURLConfig{
+ allowedHosts: []string{"localhost:15672", "rabbitmq:15672"},
})
- }
+ assert.ErrorIs(t, err, ErrHealthCheckHostNotAllowed)
+ })
+
+ t.Run("requires allowlist when strict mode enabled", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := validateHealthCheckURLWithConfig("http://localhost:15672", healthCheckURLConfig{
+ requireAllowedHosts: true,
+ })
+ assert.ErrorIs(t, err, ErrHealthCheckAllowedHostsRequired)
+ })
+
+ t.Run("allows host in allowlist", func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateHealthCheckURLWithConfig("http://rabbitmq:15672", healthCheckURLConfig{
+ allowedHosts: []string{"localhost:15672", "rabbitmq:15672"},
+ })
+ assert.NoError(t, err)
+ assert.Contains(t, normalized, "/api/health/checks/alarms")
+ })
+
+ t.Run("allows host-only allowlist entries", func(t *testing.T) {
+ t.Parallel()
+
+ normalized, err := validateHealthCheckURLWithConfig("http://rabbitmq:15672", healthCheckURLConfig{
+ allowedHosts: []string{"rabbitmq"},
+ })
+ assert.NoError(t, err)
+ assert.Contains(t, normalized, "/api/health/checks/alarms")
+ })
+
+ t.Run("enforces port when allowlist entry includes port", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := validateHealthCheckURLWithConfig("http://rabbitmq:5672", healthCheckURLConfig{
+ allowedHosts: []string{"rabbitmq:15672"},
+ })
+ assert.ErrorIs(t, err, ErrHealthCheckHostNotAllowed)
+ })
}
-func TestRabbitMQConnection_HealthCheck_Authentication(t *testing.T) {
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestRabbitMQConnection_HealthCheck_UsesConfiguredPath(t *testing.T) {
+ t.Parallel()
- // Create test server with authentication check
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Check basic auth
- username, password, ok := r.BasicAuth()
- if !ok || username != "correct" || password != "correct" {
- // Return unauthorized status
- w.WriteHeader(http.StatusUnauthorized)
- return
- }
+ gotPath := make(chan string, 1)
+
+ healthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotPath <- r.URL.Path
- // Valid auth, return healthy response
w.Header().Set("Content-Type", "application/json")
- w.Write([]byte(`{"status":"ok"}`))
+ w.WriteHeader(http.StatusOK)
+ _, err := w.Write([]byte(`{"status":"ok"}`))
+ assert.NoError(t, err)
}))
- defer server.Close()
-
- // Parse the server URL
- hostParts := strings.SplitN(server.URL, ":", 2)
- host := hostParts[0]
- var port string
- if len(hostParts) > 1 {
- port = hostParts[1]
- }
+ defer healthServer.Close()
- // Test with incorrect credentials
- badAuthConn := &RabbitMQConnection{
- Host: host,
- Port: port,
- User: "wrong",
- Pass: "wrong",
- Logger: logger,
+ conn := &RabbitMQConnection{
+ HealthCheckURL: healthServer.URL + "/custom/alerts",
+ Logger: &log.NopLogger{},
}
- isHealthy := badAuthConn.HealthCheck()
- assert.False(t, isHealthy, "HealthCheck should return false with invalid credentials")
-
- // Test with correct credentials
- goodAuthConn := &RabbitMQConnection{
- HealthCheckURL: server.URL,
- Host: host,
- Port: port,
- User: "correct",
- Pass: "correct",
- Logger: logger,
- }
+ healthy, err := conn.HealthCheck()
+ assert.NoError(t, err)
+ assert.True(t, healthy)
- isHealthy = goodAuthConn.HealthCheck()
- assert.True(t, isHealthy, "HealthCheck should return true with valid credentials")
+ select {
+ case p := <-gotPath:
+ assert.Equal(t, "/custom/alerts/api/health/checks/alarms", p)
+ case <-time.After(1 * time.Second):
+ t.Fatal("health check did not reach test server")
+ }
}
func TestBuildRabbitMQConnectionString(t *testing.T) {
+ t.Parallel()
+
tests := []struct {
name string
protocol string
@@ -268,17 +1244,16 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
expected string
}{
{
- name: "empty vhost - backward compatibility",
+ name: "empty vhost",
protocol: "amqp",
user: "guest",
pass: "guest",
host: "localhost",
port: "5672",
- vhost: "",
expected: "amqp://guest:guest@localhost:5672",
},
{
- name: "custom vhost - production",
+ name: "custom vhost",
protocol: "amqp",
user: "admin",
pass: "secret",
@@ -288,17 +1263,7 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
expected: "amqp://admin:secret@rabbitmq.example.com:5672/production",
},
{
- name: "custom vhost - staging",
- protocol: "amqps",
- user: "user",
- pass: "pass",
- host: "secure.rabbitmq.io",
- port: "5671",
- vhost: "staging",
- expected: "amqps://user:pass@secure.rabbitmq.io:5671/staging",
- },
- {
- name: "root vhost explicit - URL encoded as %2F",
+ name: "root vhost",
protocol: "amqp",
user: "guest",
pass: "guest",
@@ -308,7 +1273,7 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
expected: "amqp://guest:guest@localhost:5672/%2F",
},
{
- name: "vhost with special characters - spaces",
+ name: "vhost with spaces",
protocol: "amqp",
user: "guest",
pass: "guest",
@@ -318,7 +1283,7 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
expected: "amqp://guest:guest@localhost:5672/my%20vhost",
},
{
- name: "vhost with special characters - slashes",
+ name: "vhost with slash",
protocol: "amqp",
user: "guest",
pass: "guest",
@@ -328,7 +1293,7 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
expected: "amqp://guest:guest@localhost:5672/env%2Fprod%2Fregion1",
},
{
- name: "vhost with special characters - hash and ampersand",
+ name: "vhost with hash and ampersand",
protocol: "amqp",
user: "guest",
pass: "guest",
@@ -338,7 +1303,7 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
expected: "amqp://guest:guest@localhost:5672/test%231%262",
},
{
- name: "password with special characters",
+ name: "password with special chars",
protocol: "amqp",
user: "admin",
pass: "p@ss:word/123",
@@ -348,7 +1313,7 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
expected: "amqp://admin:p%40ss%3Aword%2F123@localhost:5672/production",
},
{
- name: "username with special characters",
+ name: "username with special chars",
protocol: "amqp",
user: "admin@domain:user",
pass: "secret",
@@ -357,176 +1322,475 @@ func TestBuildRabbitMQConnectionString(t *testing.T) {
vhost: "production",
expected: "amqp://admin%40domain%3Auser:secret@localhost:5672/production",
},
+ {
+ name: "ipv6 with port",
+ protocol: "amqp",
+ user: "guest",
+ pass: "guest",
+ host: "::1",
+ port: "5672",
+ expected: "amqp://guest:guest@[::1]:5672",
+ },
+ {
+ name: "ipv6 without port",
+ protocol: "amqp",
+ user: "guest",
+ pass: "guest",
+ host: "::1",
+ expected: "amqp://guest:guest@[::1]",
+ },
+ {
+ name: "hostname without port",
+ protocol: "amqp",
+ user: "guest",
+ pass: "guest",
+ host: "rabbitmq.local",
+ expected: "amqp://guest:guest@rabbitmq.local",
+ },
+ {
+ name: "empty credentials",
+ protocol: "amqp",
+ host: "localhost",
+ port: "5672",
+ expected: "amqp://localhost:5672",
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
result := BuildRabbitMQConnectionString(tt.protocol, tt.user, tt.pass, tt.host, tt.port, tt.vhost)
+
assert.Equal(t, tt.expected, result)
})
}
}
-// TestEnsureChannelWithContext_ReturnsErrorOnCancelledContext verifies that
-// EnsureChannelWithContext respects context cancellation.
-func TestEnsureChannelWithContext_ReturnsErrorOnCancelledContext(t *testing.T) {
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestRabbitMQConnection_ChannelSnapshot(t *testing.T) {
+ t.Parallel()
- conn := &RabbitMQConnection{
- ConnectionStringSource: "amqp://guest:guest@localhost:5999", // Unreachable
- Logger: logger,
- }
+ t.Run("nil receiver returns nil", func(t *testing.T) {
+ t.Parallel()
+
+ var conn *RabbitMQConnection
+
+ assert.Nil(t, conn.ChannelSnapshot())
+ })
+
+ t.Run("nil channel returns nil", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{}
+
+ assert.Nil(t, conn.ChannelSnapshot())
+ })
+
+ t.Run("returns current channel", func(t *testing.T) {
+ t.Parallel()
+
+ expected := &amqp.Channel{}
+ conn := &RabbitMQConnection{Channel: expected}
+
+ assert.Same(t, expected, conn.ChannelSnapshot())
+ })
+
+ t.Run("snapshot read is mutex protected", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{Channel: &amqp.Channel{}}
+ conn.mu.Lock()
- // Create already cancelled context
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
+ started := make(chan struct{}, 1)
+ readDone := make(chan struct{}, 1)
- err := conn.EnsureChannelWithContext(ctx)
+ go func() {
+ started <- struct{}{}
+ _ = conn.ChannelSnapshot()
+ readDone <- struct{}{}
+ }()
- // Should return context.Canceled error
- assert.ErrorIs(t, err, context.Canceled)
+ select {
+ case <-started:
+ case <-time.After(time.Second):
+ t.Fatal("ChannelSnapshot goroutine did not start")
+ }
+
+ select {
+ case <-readDone:
+ t.Fatal("ChannelSnapshot should block while the connection lock is held")
+ case <-time.After(250 * time.Millisecond):
+ }
+
+ conn.mu.Unlock()
+
+ select {
+ case <-readDone:
+ case <-time.After(time.Second):
+ t.Fatal("ChannelSnapshot did not resume after lock release")
+ }
+ })
}
-func TestEnsureChannelWithContext_ReturnsErrorOnDeadlineExceeded(t *testing.T) {
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestIsHostAllowed(t *testing.T) {
+ t.Parallel()
- conn := &RabbitMQConnection{
- ConnectionStringSource: "amqp://guest:guest@localhost:5999", // Unreachable
- Logger: logger,
- }
+ t.Run("allows CIDR ranges", func(t *testing.T) {
+ t.Parallel()
- // Create context with very short deadline that's already expired
- ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
- defer cancel()
- time.Sleep(10 * time.Millisecond) // Let deadline expire
+ assert.True(t, isHostAllowed("10.10.1.7:15672", []string{"10.10.0.0/16"}))
+ assert.False(t, isHostAllowed("10.11.1.7:15672", []string{"10.10.0.0/16"}))
+ })
- err := conn.EnsureChannelWithContext(ctx)
+ t.Run("normalizes ipv4 mapped ipv6", func(t *testing.T) {
+ t.Parallel()
- // Should return context.DeadlineExceeded error
- assert.ErrorIs(t, err, context.DeadlineExceeded)
+ assert.True(t, isHostAllowed("127.0.0.1:15672", []string{"::ffff:127.0.0.1"}))
+ })
}
-func TestEnsureChannelWithContext_TimeoutDuringDial(t *testing.T) {
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestDeriveAllowedHostsFromConnectionString(t *testing.T) {
+ t.Parallel()
- conn := &RabbitMQConnection{
- // Use a non-routable IP to ensure connection hangs (doesn't immediately fail)
- ConnectionStringSource: "amqp://guest:guest@10.255.255.1:5672",
- Logger: logger,
+ t.Run("derives host and host:port", func(t *testing.T) {
+ t.Parallel()
+
+ hosts := deriveAllowedHostsFromConnectionString("amqp://guest:guest@rabbitmq.internal:5672")
+ assert.Contains(t, hosts, "rabbitmq.internal:5672")
+ assert.Contains(t, hosts, "rabbitmq.internal")
+ })
+
+ t.Run("invalid connection string returns no hosts", func(t *testing.T) {
+ t.Parallel()
+
+ hosts := deriveAllowedHostsFromConnectionString("not-a-url")
+ assert.Empty(t, hosts)
+ })
+}
+
+func TestRedactURLCredentials(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ message string
+ expected string
+ expectedContain []string
+ notContain []string
+ }{
+ {
+ name: "amqps scheme is redacted",
+ message: "dial amqps://admin:s3cret@broker:5671/vhost failed",
+ expectedContain: []string{"amqps://admin:xxxxx@broker:5671/vhost"},
+ notContain: []string{"s3cret"},
+ },
+ {
+ name: "user-only URL remains unchanged",
+ message: "dial amqp://guest@localhost:5672 failed",
+ expected: "dial amqp://guest@localhost:5672 failed",
+ },
+ {
+ name: "url-encoded password is redacted",
+ message: "dial amqp://admin:p%40ss%3Aword%2F123@broker:5672 failed",
+ expectedContain: []string{"amqp://admin:xxxxx@broker:5672"},
+ notContain: []string{"p%40ss%3Aword%2F123"},
+ },
+ {
+ name: "password with slash is redacted",
+ message: "dial amqp://admin:pa/ss@broker:5672 failed",
+ expectedContain: []string{"amqp://admin:xxxxx@broker:5672"},
+ notContain: []string{"pa/ss"},
+ },
+ {
+ name: "password with literal at is redacted",
+ message: "dial amqp://admin:p@ss@broker:5672 failed",
+ expectedContain: []string{"amqp://admin:xxxxx@broker:5672"},
+ notContain: []string{"p@ss"},
+ },
+ {
+ name: "multiple URLs are redacted",
+ message: "upstream amqp://u1:p1@host1:5672 then amqps://u2:p2@host2:5671",
+ expectedContain: []string{"amqp://u1:xxxxx@host1:5672", "amqps://u2:xxxxx@host2:5671"},
+ notContain: []string{"u1:p1", "u2:p2"},
+ },
+ {
+ name: "ipv6 host is redacted",
+ message: "dial amqp://guest:guest@[::1]:5672 failed",
+ expectedContain: []string{"amqp://guest:xxxxx@[::1]:5672"},
+ notContain: []string{"guest:guest@[::1]"},
+ },
+ {
+ name: "empty password is normalized to redacted placeholder",
+ message: "dial amqp://user:@localhost:5672 failed",
+ expectedContain: []string{"amqp://user:xxxxx@localhost:5672"},
+ notContain: []string{"user:@localhost"},
+ },
+ {
+ name: "surrounding text and punctuation are preserved",
+ message: "error details (amqp://user:secret@localhost:5672), retry later",
+ expectedContain: []string{"error details (amqp://user:xxxxx@localhost:5672), retry later"},
+ notContain: []string{"user:secret@"},
+ },
+ {
+ name: "multiple colons in userinfo are fully redacted",
+ message: "dial amqp://user:name:secret@localhost:5672 failed",
+ expectedContain: []string{"amqp://user:xxxxx@localhost:5672"},
+ notContain: []string{"secret", "user:name:secret"},
+ },
}
- // Use short timeout - this should NOT take 30 seconds
- ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
- defer cancel()
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ t.Parallel()
+
+ got := redactURLCredentials(testCase.message)
- start := time.Now()
- err := conn.EnsureChannelWithContext(ctx)
- elapsed := time.Since(start)
+ if testCase.expected != "" {
+ assert.Equal(t, testCase.expected, got)
+ }
- // Should fail with context deadline exceeded or i/o timeout
- assert.Error(t, err)
+ for _, expected := range testCase.expectedContain {
+ assert.Contains(t, got, expected)
+ }
- // Should complete within reasonable time (not 30 seconds)
- assert.Less(t, elapsed, 500*time.Millisecond,
- "EnsureChannelWithContext should respect context timeout, took %v", elapsed)
+ for _, unwanted := range testCase.notContain {
+ assert.NotContains(t, got, unwanted)
+ }
+ })
+ }
}
-func TestEnsureChannelWithContext_UsesConnectionTimeoutField(t *testing.T) {
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestRedactURLCredentialsFallback(t *testing.T) {
+ t.Parallel()
- conn := &RabbitMQConnection{
- // Use non-routable IP to ensure connection hangs
- ConnectionStringSource: "amqp://guest:guest@10.255.255.1:5672",
- Logger: logger,
- ConnectionTimeout: 50 * time.Millisecond, // Short custom timeout
- }
+ t.Run("preserves at-sign in path while redacting userinfo", func(t *testing.T) {
+ t.Parallel()
+
+ token := "amqp://user:secret@host:5672/path@segment?key=value"
+
+ got := redactURLCredentialsFallback(token)
+
+ assert.Equal(t, "amqp://user:xxxxx@host:5672/path@segment?key=value", got)
+ })
- // Use context without deadline - should use ConnectionTimeout field
- ctx := context.Background()
+ t.Run("does not redact when at-sign appears only in path", func(t *testing.T) {
+ t.Parallel()
- start := time.Now()
- err := conn.EnsureChannelWithContext(ctx)
- elapsed := time.Since(start)
+ token := "amqp://host:5672/path@segment"
- // Should fail with connection error
- assert.Error(t, err)
+ got := redactURLCredentialsFallback(token)
- // Should complete around ConnectionTimeout duration (with some buffer)
- assert.Less(t, elapsed, 200*time.Millisecond,
- "Should respect ConnectionTimeout field, took %v", elapsed)
- assert.Greater(t, elapsed, 40*time.Millisecond,
- "Should take at least ConnectionTimeout duration, took %v", elapsed)
+ assert.Equal(t, token, got)
+ })
}
-func TestEnsureChannelWithContext_ChecksContextAfterLockAcquisition(t *testing.T) {
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestSanitizeAMQPErr(t *testing.T) {
+ t.Parallel()
- conn := &RabbitMQConnection{
- // Use non-routable IP so connection hangs until context is cancelled
- ConnectionStringSource: "amqp://guest:guest@10.255.255.1:5672",
- Logger: logger,
- }
+ t.Run("redacts credentials from connection string in error", func(t *testing.T) {
+ t.Parallel()
- // Create context that we'll cancel after a short delay
- ctx, cancel := context.WithCancel(context.Background())
+ err := errors.New("dial tcp: lookup amqp://admin:s3cretP@ss@broker:5672")
+ connectionString := "amqp://admin:s3cretP@ss@broker:5672"
- // Start goroutine that cancels context after a tiny delay
- go func() {
- time.Sleep(10 * time.Millisecond)
- cancel()
- }()
+ got := sanitizeAMQPErr(err, connectionString)
+
+ assert.NotContains(t, got, "s3cretP@ss")
+ assert.Contains(t, got, "xxxxx")
+ })
+
+ t.Run("nil error returns empty string", func(t *testing.T) {
+ t.Parallel()
+
+ got := sanitizeAMQPErr(nil, "amqp://guest:guest@localhost:5672")
+
+ assert.Equal(t, "", got)
+ })
+
+ t.Run("unparseable connection string uses fallback redaction pass", func(t *testing.T) {
+ t.Parallel()
- // Call should detect cancellation and return quickly
- start := time.Now()
- err := conn.EnsureChannelWithContext(ctx)
- elapsed := time.Since(start)
+ err := errors.New("something went wrong")
- // Should return an error (context.Canceled or connection error)
- assert.Error(t, err)
+ got := sanitizeAMQPErr(err, "://not-a-url")
- // Should complete quickly due to context cancellation (not 30 seconds)
- assert.Less(t, elapsed, 200*time.Millisecond,
- "Should detect context cancellation quickly, took %v", elapsed)
+ assert.Equal(t, "something went wrong", got)
+ })
+
+ t.Run("error without connection string returns original message", func(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New("timeout connecting to broker")
+
+ got := sanitizeAMQPErr(err, "amqp://admin:secret@broker:5672")
+
+ assert.Equal(t, "timeout connecting to broker", got)
+ assert.NotContains(t, got, "secret")
+ })
+
+ t.Run("redacts decoded password when embedded standalone in error", func(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New("authentication failed: password=s3cr3t")
+ connectionString := "amqp://admin:s3cr3t@broker:5672"
+
+ got := sanitizeAMQPErr(err, connectionString)
+
+ assert.NotContains(t, got, "s3cr3t")
+ assert.Contains(t, got, "xxxxx")
+ })
+
+ t.Run("redacts URL-encoded password in decoded form", func(t *testing.T) {
+ t.Parallel()
+
+ // Password with special chars: p@ss:word/123 → encoded as p%40ss%3Aword%2F123
+ err := errors.New("auth error for p@ss:word/123")
+ connectionString := "amqp://admin:p%40ss%3Aword%2F123@broker:5672"
+
+ got := sanitizeAMQPErr(err, connectionString)
+
+ assert.NotContains(t, got, "p@ss:word/123")
+ assert.Contains(t, got, "xxxxx")
+ })
+
+ t.Run("empty connection string without URL credentials returns unmodified error", func(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New("something failed")
+
+ got := sanitizeAMQPErr(err, "")
+
+ assert.Equal(t, "something failed", got)
+ })
+
+ t.Run("empty connection string still redacts URL credentials from error", func(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New("dial failed for amqp://guest:guest@localhost:5672")
+
+ got := sanitizeAMQPErr(err, "")
+
+ assert.NotContains(t, got, "guest:guest")
+ assert.Contains(t, got, "xxxxx")
+ })
+
+ t.Run("fallback redaction fully redacts multi-colon userinfo passwords", func(t *testing.T) {
+ t.Parallel()
+
+ err := errors.New("dial failed for amqp://user:name:secret@localhost:5672")
+
+ got := sanitizeAMQPErr(err, "")
+
+ assert.NotContains(t, got, "secret")
+ assert.Contains(t, got, "amqp://user:xxxxx@localhost:5672")
+ })
}
-// TestEnsureChannelWithContext_ChecksContextBeforeChannelCreation verifies that
-// context is checked before calling Channel() when connection already exists.
-// This test requires a real RabbitMQ connection to fully exercise the code path
-// where connection exists but channel needs to be created.
-func TestEnsureChannelWithContext_ChecksContextBeforeChannelCreation(t *testing.T) {
- t.Run("context_canceled_before_channel_with_nil_connection", func(t *testing.T) {
- // This test verifies that a pre-canceled context returns immediately
- // even when the connection would need to be established first.
- // The context check before Channel() provides defense-in-depth for cases
- // where an existing connection is reused but context was canceled.
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestRabbitMQConnection_Close(t *testing.T) {
+ t.Parallel()
+
+ t.Run("close releases resources", func(t *testing.T) {
+ t.Parallel()
+
+ channelCloseCalls := int32(0)
+ connectionCloseCalls := int32(0)
conn := &RabbitMQConnection{
- ConnectionStringSource: "amqp://guest:guest@localhost:5672",
- Logger: logger,
+ Connection: &amqp.Connection{},
+ Channel: &amqp.Channel{},
+ Connected: true,
+ channelCloser: func(*amqp.Channel) error {
+ atomic.AddInt32(&channelCloseCalls, 1)
+
+ return nil
+ },
+ connectionCloser: func(*amqp.Connection) error {
+ atomic.AddInt32(&connectionCloseCalls, 1)
+
+ return nil
+ },
+ Logger: &log.NopLogger{},
+ }
+
+ err := conn.Close()
+
+ assert.NoError(t, err)
+ assert.Equal(t, int32(1), atomic.LoadInt32(&channelCloseCalls))
+ assert.Equal(t, int32(1), atomic.LoadInt32(&connectionCloseCalls))
+ assert.False(t, conn.Connected)
+ assert.Nil(t, conn.Channel)
+ assert.Nil(t, conn.Connection)
+ })
+
+ t.Run("close aggregates channel and connection errors", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ Connection: &amqp.Connection{},
+ Channel: &amqp.Channel{},
+ Connected: true,
+ channelCloser: func(*amqp.Channel) error {
+ return errors.New("channel close failed")
+ },
+ connectionCloser: func(*amqp.Connection) error {
+ return errors.New("connection close failed")
+ },
+ Logger: &log.NopLogger{},
+ }
+
+ err := conn.Close()
+
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "channel close failed")
+ assert.Contains(t, err.Error(), "connection close failed")
+ assert.False(t, conn.Connected)
+ assert.Nil(t, conn.Channel)
+ assert.Nil(t, conn.Connection)
+ })
+
+ t.Run("close only connection error", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{
+ Connection: &amqp.Connection{},
+ Channel: &amqp.Channel{},
+ Connected: true,
+ channelCloser: func(*amqp.Channel) error {
+ return nil
+ },
+ connectionCloser: func(*amqp.Connection) error {
+ return errors.New("connection close failed")
+ },
+ Logger: &log.NopLogger{},
}
- // Pre-cancel context
+ err := conn.Close()
+
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "connection close failed")
+ })
+
+ t.Run("close on nil receiver is safe", func(t *testing.T) {
+ t.Parallel()
+
+ var rc *RabbitMQConnection
+
+ assert.NotPanics(t, func() {
+ err := rc.CloseContext(context.Background())
+ assert.ErrorIs(t, err, ErrNilConnection)
+ })
+ })
+
+ t.Run("close context canceled", func(t *testing.T) {
+ t.Parallel()
+
+ conn := &RabbitMQConnection{}
+
ctx, cancel := context.WithCancel(context.Background())
cancel()
- err := conn.EnsureChannelWithContext(ctx)
+ err := conn.CloseContext(ctx)
- // Should return context.Canceled from the first check (before lock)
assert.ErrorIs(t, err, context.Canceled)
})
-
- t.Run("integration_test_with_real_connection", func(t *testing.T) {
- // Skip in unit tests - this would require a real RabbitMQ instance
- // to establish a connection, then cancel context before Channel() call.
- //
- // To fully test the context check before Channel():
- // 1. Establish a real connection to RabbitMQ
- // 2. Set rc.Connection to the valid connection
- // 3. Ensure rc.Channel is nil (needs channel creation)
- // 4. Cancel context
- // 5. Call EnsureChannelWithContext
- // 6. Verify it returns context.Canceled without calling Channel()
- t.Skip("Requires integration testing with a real RabbitMQ instance")
- })
}
diff --git a/commons/rabbitmq/trace_propagation_integration_test.go b/commons/rabbitmq/trace_propagation_integration_test.go
new file mode 100644
index 00000000..a74cfece
--- /dev/null
+++ b/commons/rabbitmq/trace_propagation_integration_test.go
@@ -0,0 +1,486 @@
+//go:build integration
+
+package rabbitmq
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ libOtel "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ amqp "github.com/rabbitmq/amqp091-go"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/propagation"
+ sdktrace "go.opentelemetry.io/otel/sdk/trace"
+ "go.opentelemetry.io/otel/trace"
+)
+
+// containsKeyInsensitive checks if a string-keyed map contains a key (case-insensitive).
+// The W3C TraceContext propagator uses http.Header which canonicalizes keys to Pascal-case.
+func containsKeyInsensitive[V any](m map[string]V, key string) bool {
+ lower := strings.ToLower(key)
+ for k := range m {
+ if strings.ToLower(k) == lower {
+ return true
+ }
+ }
+
+ return false
+}
+
+// getValueInsensitive retrieves a string value by case-insensitive key lookup.
+func getValueInsensitive(m map[string]any, key string) (any, bool) {
+ lower := strings.ToLower(key)
+ for k, v := range m {
+ if strings.ToLower(k) == lower {
+ return v, true
+ }
+ }
+
+ return nil, false
+}
+
+// mapKeys returns the keys of a string-keyed map for diagnostic messages.
+func mapKeys[V any](m map[string]V) []string {
+ keys := make([]string, 0, len(m))
+ for k := range m {
+ keys = append(keys, k)
+ }
+
+ return keys
+}
+
+// saveAndRestoreOTELGlobals saves the current global tracer provider and propagator,
+// returning a restore function that resets them. Every test that configures OTEL
+// globals MUST defer this to avoid polluting sibling tests.
+func saveAndRestoreOTELGlobals(t *testing.T) func() {
+ t.Helper()
+
+ prevTP := otel.GetTracerProvider()
+ prevProp := otel.GetTextMapPropagator()
+
+ return func() {
+ otel.SetTracerProvider(prevTP)
+ otel.SetTextMapPropagator(prevProp)
+ }
+}
+
+// setupTestTracer creates a real SDK tracer provider and configures OTEL globals.
+// It returns a tracer and a cleanup function that shuts down the provider.
+func setupTestTracer(t *testing.T) (trace.Tracer, func()) {
+ t.Helper()
+
+ tp := sdktrace.NewTracerProvider()
+ otel.SetTracerProvider(tp)
+ otel.SetTextMapPropagator(
+ propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
+ )
+
+ tracer := tp.Tracer("trace-propagation-integration-test")
+
+ return tracer, func() {
+ _ = tp.Shutdown(context.Background())
+ }
+}
+
+// declareTestQueue declares an auto-delete queue with a unique name for test isolation.
+func declareTestQueue(t *testing.T, ch *amqp.Channel, prefix string) amqp.Queue {
+ t.Helper()
+
+ queueName := fmt.Sprintf("%s-%d", prefix, time.Now().UnixNano())
+
+ q, err := ch.QueueDeclare(
+ queueName,
+ false, // durable
+ true, // autoDelete
+ false, // exclusive
+ false, // noWait
+ nil, // args
+ )
+ require.NoError(t, err, "QueueDeclare should succeed")
+
+ return q
+}
+
+// consumeOne reads exactly one message from a queue within the test deadline.
+func consumeOne(t *testing.T, ch *amqp.Channel, queueName string) amqp.Delivery {
+ t.Helper()
+
+ msgs, err := ch.Consume(
+ queueName,
+ "", // consumer tag (auto-generated)
+ true, // autoAck
+ false, // exclusive
+ false, // noLocal
+ false, // noWait
+ nil, // args
+ )
+ require.NoError(t, err, "Consume should succeed")
+
+ ctx, cancel := context.WithTimeout(context.Background(), testConsumeDeadline)
+ defer cancel()
+
+ select {
+ case msg, ok := <-msgs:
+ require.True(t, ok, "message channel should deliver a message")
+ return msg
+ case <-ctx.Done():
+ t.Fatal("timed out waiting for message from RabbitMQ")
+ return amqp.Delivery{} // unreachable but satisfies compiler
+ }
+}
+
+// consumeN reads exactly n messages from a queue within the test deadline.
+func consumeN(t *testing.T, ch *amqp.Channel, queueName string, n int) []amqp.Delivery {
+ t.Helper()
+
+ msgs, err := ch.Consume(
+ queueName,
+ "", // consumer tag (auto-generated)
+ true, // autoAck
+ false, // exclusive
+ false, // noLocal
+ false, // noWait
+ nil, // args
+ )
+ require.NoError(t, err, "Consume should succeed")
+
+ ctx, cancel := context.WithTimeout(context.Background(), testConsumeDeadline)
+ defer cancel()
+
+ deliveries := make([]amqp.Delivery, 0, n)
+
+ for range n {
+ select {
+ case msg, ok := <-msgs:
+ require.True(t, ok, "message channel should deliver a message")
+ deliveries = append(deliveries, msg)
+ case <-ctx.Done():
+ t.Fatalf("timed out waiting for message %d/%d from RabbitMQ", len(deliveries)+1, n)
+ }
+ }
+
+ return deliveries
+}
+
+func TestIntegration_TraceContext_SurvivesPublishConsume(t *testing.T) {
+ // — Setup —
+ restoreGlobals := saveAndRestoreOTELGlobals(t)
+ defer restoreGlobals()
+
+ tracer, shutdownTP := setupTestTracer(t)
+ defer shutdownTP()
+
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ ctx := context.Background()
+
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() { _ = rc.CloseContext(ctx) }()
+
+ ch, err := rc.GetNewConnectContext(ctx)
+ require.NoError(t, err, "GetNewConnectContext should succeed")
+
+ q := declareTestQueue(t, ch, "trace-survive")
+
+ // — Produce: create a real span and inject its trace context into AMQP headers —
+ spanCtx, span := tracer.Start(ctx, "test-publish-operation")
+ originalTraceID := span.SpanContext().TraceID().String()
+ require.NotEmpty(t, originalTraceID, "span should have a valid trace ID")
+
+ traceHeaders := libOtel.InjectQueueTraceContext(spanCtx)
+ // The W3C propagator uses http.Header which canonicalizes keys to Pascal-case ("Traceparent").
+ require.True(t, containsKeyInsensitive(traceHeaders, "traceparent"),
+ "trace injection should produce a traceparent header, got keys: %v", mapKeys(traceHeaders))
+
+ amqpHeaders := amqp.Table{}
+ for k, v := range traceHeaders {
+ amqpHeaders[k] = v
+ }
+
+ publishCtx, publishCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer publishCancel()
+
+ err = ch.PublishWithContext(
+ publishCtx,
+ "", // default exchange
+ q.Name, // routing key
+ false, // mandatory
+ false, // immediate
+ amqp.Publishing{
+ ContentType: "application/json",
+ Body: []byte(`{"test":"trace_survives"}`),
+ Headers: amqpHeaders,
+ },
+ )
+ require.NoError(t, err, "PublishWithContext should succeed")
+
+ span.End()
+
+ // — Consume: extract trace context from the received message —
+ msg := consumeOne(t, ch, q.Name)
+
+ require.NotNil(t, msg.Headers, "consumed message should have headers")
+
+ // amqp.Table is map[string]interface{} — pass directly to ExtractTraceContextFromQueueHeaders.
+ extractedCtx := libOtel.ExtractTraceContextFromQueueHeaders(context.Background(), map[string]any(msg.Headers))
+ extractedTraceID := libOtel.GetTraceIDFromContext(extractedCtx)
+
+ assert.Equal(t, originalTraceID, extractedTraceID,
+ "trace ID extracted from consumed message must match the producer's trace ID")
+}
+
+func TestIntegration_TraceContext_PrepareQueueHeaders(t *testing.T) {
+ // — Setup —
+ restoreGlobals := saveAndRestoreOTELGlobals(t)
+ defer restoreGlobals()
+
+ tracer, shutdownTP := setupTestTracer(t)
+ defer shutdownTP()
+
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ ctx := context.Background()
+
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() { _ = rc.CloseContext(ctx) }()
+
+ ch, err := rc.GetNewConnectContext(ctx)
+ require.NoError(t, err, "GetNewConnectContext should succeed")
+
+ q := declareTestQueue(t, ch, "trace-prepare")
+
+ // — Build headers via PrepareQueueHeaders —
+ spanCtx, span := tracer.Start(ctx, "test-prepare-headers-operation")
+ defer span.End()
+
+ baseHeaders := map[string]any{
+ "correlation_id": "abc",
+ }
+
+ merged := libOtel.PrepareQueueHeaders(spanCtx, baseHeaders)
+
+ // Verify merge semantics: both base and trace keys must be present.
+ assert.Contains(t, merged, "correlation_id", "merged headers should preserve base header correlation_id")
+ assert.Equal(t, "abc", merged["correlation_id"], "correlation_id value should be unchanged")
+ assert.True(t, containsKeyInsensitive(merged, "traceparent"),
+ "merged headers should contain injected traceparent, got keys: %v", mapKeys(merged))
+
+ // Original baseHeaders must be unmodified (PrepareQueueHeaders creates a new map).
+ assert.False(t, containsKeyInsensitive(baseHeaders, "traceparent"),
+ "PrepareQueueHeaders should not mutate the original baseHeaders map")
+
+ // — Publish with merged headers and verify on the consumer side —
+ amqpHeaders := amqp.Table{}
+ for k, v := range merged {
+ amqpHeaders[k] = v
+ }
+
+ publishCtx, publishCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer publishCancel()
+
+ err = ch.PublishWithContext(
+ publishCtx,
+ "", // default exchange
+ q.Name, // routing key
+ false, // mandatory
+ false, // immediate
+ amqp.Publishing{
+ ContentType: "application/json",
+ Body: []byte(`{"test":"prepare_headers"}`),
+ Headers: amqpHeaders,
+ },
+ )
+ require.NoError(t, err, "PublishWithContext should succeed")
+
+ msg := consumeOne(t, ch, q.Name)
+
+ require.NotNil(t, msg.Headers, "consumed message should have headers")
+ assert.True(t, containsKeyInsensitive(map[string]any(msg.Headers), "traceparent"),
+ "consumed message headers should include traceparent from PrepareQueueHeaders")
+ assert.True(t, containsKeyInsensitive(map[string]any(msg.Headers), "correlation_id"),
+ "consumed message headers should include correlation_id from base headers")
+}
+
+func TestIntegration_TraceContext_NoTraceContext(t *testing.T) {
+ // — Setup —
+ restoreGlobals := saveAndRestoreOTELGlobals(t)
+ defer restoreGlobals()
+
+ // Deliberately set a real propagator so extraction doesn't panic,
+ // but do NOT create a span — the context carries no trace.
+ tp := sdktrace.NewTracerProvider()
+ defer func() { _ = tp.Shutdown(context.Background()) }()
+
+ otel.SetTracerProvider(tp)
+ otel.SetTextMapPropagator(
+ propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
+ )
+
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ ctx := context.Background()
+
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() { _ = rc.CloseContext(ctx) }()
+
+ ch, err := rc.GetNewConnectContext(ctx)
+ require.NoError(t, err, "GetNewConnectContext should succeed")
+
+ q := declareTestQueue(t, ch, "trace-none")
+
+ // — Publish WITHOUT any trace headers —
+ publishCtx, publishCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer publishCancel()
+
+ err = ch.PublishWithContext(
+ publishCtx,
+ "", // default exchange
+ q.Name, // routing key
+ false, // mandatory
+ false, // immediate
+ amqp.Publishing{
+ ContentType: "application/json",
+ Body: []byte(`{"test":"no_trace"}`),
+ // No Headers field — nil headers.
+ },
+ )
+ require.NoError(t, err, "PublishWithContext should succeed")
+
+ msg := consumeOne(t, ch, q.Name)
+
+ // — Extract from nil headers — must not panic and should return empty trace ID —
+ extractedFromNil := libOtel.ExtractTraceContextFromQueueHeaders(context.Background(), nil)
+ traceIDFromNil := libOtel.GetTraceIDFromContext(extractedFromNil)
+ assert.Empty(t, traceIDFromNil,
+ "extracting from nil headers should yield an empty/invalid trace ID")
+
+ // — Extract from empty headers map — same graceful degradation —
+ extractedFromEmpty := libOtel.ExtractTraceContextFromQueueHeaders(context.Background(), map[string]any{})
+ traceIDFromEmpty := libOtel.GetTraceIDFromContext(extractedFromEmpty)
+ assert.Empty(t, traceIDFromEmpty,
+ "extracting from empty headers should yield an empty/invalid trace ID")
+
+ // — Extract from the actual consumed message (which has nil or empty headers) —
+ consumedHeaders := map[string]any(msg.Headers) // amqp.Table -> map[string]any; may be nil
+ extractedFromMsg := libOtel.ExtractTraceContextFromQueueHeaders(context.Background(), consumedHeaders)
+ traceIDFromMsg := libOtel.GetTraceIDFromContext(extractedFromMsg)
+ assert.Empty(t, traceIDFromMsg,
+ "extracting from message published without trace headers should yield an empty trace ID")
+}
+
+func TestIntegration_TraceContext_MultipleMessages(t *testing.T) {
+ // — Setup —
+ restoreGlobals := saveAndRestoreOTELGlobals(t)
+ defer restoreGlobals()
+
+ tracer, shutdownTP := setupTestTracer(t)
+ defer shutdownTP()
+
+ amqpURL, mgmtURL, cleanup := setupRabbitMQContainer(t)
+ defer cleanup()
+
+ rc := newTestConnection(amqpURL, mgmtURL)
+
+ ctx := context.Background()
+
+ err := rc.ConnectContext(ctx)
+ require.NoError(t, err, "ConnectContext should succeed")
+
+ defer func() { _ = rc.CloseContext(ctx) }()
+
+ ch, err := rc.GetNewConnectContext(ctx)
+ require.NoError(t, err, "GetNewConnectContext should succeed")
+
+ q := declareTestQueue(t, ch, "trace-multi")
+
+ const messageCount = 3
+
+ // — Publish 3 messages, each under a distinct span (= distinct trace ID) —
+ publishedTraceIDs := make([]string, 0, messageCount)
+
+ for i := range messageCount {
+ spanCtx, span := tracer.Start(ctx, fmt.Sprintf("test-multi-operation-%d", i))
+ traceID := span.SpanContext().TraceID().String()
+ require.NotEmpty(t, traceID, "span %d should have a valid trace ID", i)
+
+ publishedTraceIDs = append(publishedTraceIDs, traceID)
+
+ traceHeaders := libOtel.InjectQueueTraceContext(spanCtx)
+
+ amqpHeaders := amqp.Table{}
+ for k, v := range traceHeaders {
+ amqpHeaders[k] = v
+ }
+
+ publishCtx, publishCancel := context.WithTimeout(ctx, 5*time.Second)
+
+ err = ch.PublishWithContext(
+ publishCtx,
+ "", // default exchange
+ q.Name, // routing key
+ false, // mandatory
+ false, // immediate
+ amqp.Publishing{
+ ContentType: "application/json",
+ Body: []byte(fmt.Sprintf(`{"msg":%d}`, i)),
+ Headers: amqpHeaders,
+ },
+ )
+ require.NoError(t, err, "PublishWithContext for message %d should succeed", i)
+
+ publishCancel()
+
+ span.End()
+ }
+
+ // Sanity check: all 3 published trace IDs must be unique.
+ uniqueIDs := make(map[string]struct{}, messageCount)
+ for _, id := range publishedTraceIDs {
+ uniqueIDs[id] = struct{}{}
+ }
+
+ require.Len(t, uniqueIDs, messageCount,
+ "each published message must carry a unique trace ID")
+
+ // — Consume all 3 and verify each extracts its own trace ID —
+ deliveries := consumeN(t, ch, q.Name, messageCount)
+
+ extractedTraceIDs := make([]string, 0, messageCount)
+
+ for i, msg := range deliveries {
+ require.NotNil(t, msg.Headers, "consumed message %d should have headers", i)
+
+ extractedCtx := libOtel.ExtractTraceContextFromQueueHeaders(
+ context.Background(), map[string]any(msg.Headers),
+ )
+ extractedID := libOtel.GetTraceIDFromContext(extractedCtx)
+ require.NotEmpty(t, extractedID, "consumed message %d should yield a valid trace ID", i)
+
+ extractedTraceIDs = append(extractedTraceIDs, extractedID)
+ }
+
+ // AMQP guarantees FIFO ordering on a single queue with a single publisher,
+ // so extracted order matches published order.
+ assert.Equal(t, publishedTraceIDs, extractedTraceIDs,
+ "extracted trace IDs must match published trace IDs in order")
+}
diff --git a/commons/redis/doc.go b/commons/redis/doc.go
new file mode 100644
index 00000000..d06b0359
--- /dev/null
+++ b/commons/redis/doc.go
@@ -0,0 +1,6 @@
+// Package redis provides Redis/Valkey client helpers with topology and IAM support.
+//
+// Supported deployment modes include standalone, sentinel, and cluster.
+// Authentication supports static passwords and short-lived GCP IAM tokens with
+// automatic refresh and reconnect.
+package redis
diff --git a/commons/redis/iam_example_test.go b/commons/redis/iam_example_test.go
new file mode 100644
index 00000000..cf18407f
--- /dev/null
+++ b/commons/redis/iam_example_test.go
@@ -0,0 +1,32 @@
+//go:build unit
+
+package redis_test
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/redis"
+)
+
+func ExampleConfig_gcpIAM() {
+ cfg := redis.Config{
+ Topology: redis.Topology{
+ Standalone: &redis.StandaloneTopology{Address: "redis.internal:6379"},
+ },
+ Auth: redis.Auth{
+ GCPIAM: &redis.GCPIAMAuth{
+ CredentialsBase64: "BASE64_JSON",
+ ServiceAccount: "svc-redis@project.iam.gserviceaccount.com",
+ RefreshEvery: 50 * time.Minute,
+ },
+ },
+ }
+
+ fmt.Println(cfg.Auth.GCPIAM != nil)
+ fmt.Println(cfg.Auth.GCPIAM.ServiceAccount)
+
+ // Output:
+ // true
+ // svc-redis@project.iam.gserviceaccount.com
+}
diff --git a/commons/redis/lock.go b/commons/redis/lock.go
index d69b6ce8..6de60ed7 100644
--- a/commons/redis/lock.go
+++ b/commons/redis/lock.go
@@ -1,23 +1,58 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package redis
import (
"context"
"errors"
"fmt"
+ "strconv"
"strings"
"time"
- libCommons "github.com/LerianStudio/lib-commons/v2/commons"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry"
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
"github.com/go-redsync/redsync/v4"
+ redsyncredis "github.com/go-redsync/redsync/v4/redis"
"github.com/go-redsync/redsync/v4/redis/goredis/v9"
)
-// DistributedLock provides distributed locking capabilities using Redis and the RedLock algorithm.
+const (
+ maxLockTries = 1000
+ // unlockTimeout is the maximum duration for an unlock operation using a
+ // detached context. This prevents unlock from failing silently when the
+ // caller's context has been cancelled.
+ unlockTimeout = 5 * time.Second
+)
+
+var (
+ // ErrNilLockHandle is returned when a nil or uninitialized lock handle is used.
+ ErrNilLockHandle = errors.New("lock handle is nil or not initialized")
+ // ErrLockNotHeld is returned when unlock is called on a lock that was not held or already expired.
+ ErrLockNotHeld = errors.New("lock was not held or already expired")
+ // ErrNilLockManager is returned when a method is called on a nil RedisLockManager.
+ ErrNilLockManager = errors.New("lock manager is nil")
+ // ErrLockNotInitialized is returned when the distributed lock's redsync is not initialized.
+ ErrLockNotInitialized = errors.New("distributed lock is not initialized")
+ // ErrNilLockFn is returned when a nil function is passed to WithLock.
+ ErrNilLockFn = errors.New("lock function is nil")
+ // ErrEmptyLockKey is returned when an empty lock key is provided.
+ ErrEmptyLockKey = errors.New("lock key cannot be empty")
+ // ErrLockExpiryInvalid is returned when lock expiry is not positive.
+ ErrLockExpiryInvalid = errors.New("lock expiry must be greater than 0")
+ // ErrLockTriesInvalid is returned when lock tries is less than 1.
+ ErrLockTriesInvalid = errors.New("lock tries must be at least 1")
+ // ErrLockTriesExceeded is returned when lock tries exceeds the maximum.
+ ErrLockTriesExceeded = errors.New("lock tries exceeds maximum")
+ // ErrLockRetryDelayNegative is returned when retry delay is negative.
+ ErrLockRetryDelayNegative = errors.New("lock retry delay cannot be negative")
+ // ErrLockDriftFactorInvalid is returned when drift factor is outside [0, 1).
+ ErrLockDriftFactorInvalid = errors.New("lock drift factor must be between 0 (inclusive) and 1 (exclusive)")
+ // ErrNilLockHandleOnUnlock is returned when Unlock is called with a nil handle.
+ ErrNilLockHandleOnUnlock = errors.New("lock handle is nil")
+)
+
+// RedisLockManager provides distributed locking capabilities using Redis and the RedLock algorithm.
// This implementation ensures mutual exclusion across multiple service instances, preventing race
// conditions in critical sections such as:
// - Password update operations
@@ -32,16 +67,16 @@ import (
//
// Example usage:
//
-// lock, err := redis.NewDistributedLock(redisConnection)
+// lock, err := redis.NewRedisLockManager(redisClient)
// if err != nil {
// return err
// }
//
-// err = lock.WithLock(ctx, "lock:user:123", func() error {
+// err = lock.WithLock(ctx, "lock:user:123", func(ctx context.Context) error {
// // Critical section - only one instance will execute this at a time
// return updateUser(123)
// })
-type DistributedLock struct {
+type RedisLockManager struct {
redsync *redsync.Redsync
}
@@ -53,7 +88,7 @@ type LockOptions struct {
Expiry time.Duration
// Tries is the number of attempts to acquire the lock before giving up
- // Default: 3
+ // Default: 3, Maximum: 1000
Tries int
// RetryDelay is the delay between retry attempts
@@ -93,29 +128,88 @@ func RateLimiterLockOptions() LockOptions {
}
}
-// NewDistributedLock creates a new distributed lock manager.
+// clientPool implements the redsync redis.Pool interface with lazy client resolution.
+// On each Get call it resolves the latest redis.UniversalClient from the Client wrapper,
+// ensuring the pool survives IAM token refresh reconnections.
+type clientPool struct {
+ conn *Client
+}
+
+func (p *clientPool) Get(ctx context.Context) (redsyncredis.Conn, error) {
+ rdb, err := p.conn.GetClient(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get redis client for lock pool: %w", err)
+ }
+
+ return goredis.NewPool(rdb).Get(ctx)
+}
+
+// lockHandle wraps a redsync.Mutex to implement LockHandle.
+// It is returned by TryLock and provides a self-contained Unlock method.
+type lockHandle struct {
+ mutex *redsync.Mutex
+ logger log.Logger
+}
+
+// Unlock releases the distributed lock.
+func (h *lockHandle) Unlock(ctx context.Context) error {
+ if h == nil || h.mutex == nil {
+ return ErrNilLockHandle
+ }
+
+ ok, err := h.mutex.UnlockContext(ctx)
+ if err != nil {
+ h.logger.Log(ctx, log.LevelError, "failed to release lock", log.Err(err))
+ return fmt.Errorf("distributed lock: unlock: %w", err)
+ }
+
+ if !ok {
+ h.logger.Log(ctx, log.LevelWarn, "lock was not held or already expired")
+ return ErrLockNotHeld
+ }
+
+ return nil
+}
+
+// nilLockAssert fires a nil-receiver assertion and returns an error.
+func nilLockAssert(ctx context.Context, operation string) error {
+ a := assert.New(ctx, resolvePackageLogger(), "redis.RedisLockManager", operation)
+ _ = a.Never(ctx, "nil receiver on *redis.RedisLockManager")
+
+ return ErrNilLockManager
+}
+
+// NewRedisLockManager creates a new distributed lock manager.
// The lock manager uses the RedLock algorithm for distributed consensus.
+// It uses a lazy pool that resolves the latest Redis client per operation,
+// surviving IAM token refresh reconnections.
//
-// Thread-safe: Yes - multiple goroutines can use the same DistributedLock instance.
+// Thread-safe: Yes - multiple goroutines can use the same RedisLockManager instance.
//
// Example:
//
-// lock, err := redis.NewDistributedLock(redisConnection)
+// lock, err := redis.NewRedisLockManager(redisClient)
// if err != nil {
// return fmt.Errorf("failed to initialize lock: %w", err)
// }
-func NewDistributedLock(conn *RedisConnection) (*DistributedLock, error) {
+func NewRedisLockManager(conn *Client) (*RedisLockManager, error) {
+ if conn == nil {
+ return nil, ErrNilClient
+ }
+
+ // Verify connectivity at construction time.
ctx := context.Background()
- client, err := conn.GetClient(ctx)
- if err != nil {
+ if _, err := conn.GetClient(ctx); err != nil {
return nil, fmt.Errorf("failed to get redis client: %w", err)
}
- pool := goredis.NewPool(client)
+ // Use a lazy pool that resolves the client per operation,
+ // surviving IAM token refresh reconnections.
+ pool := &clientPool{conn: conn}
rs := redsync.New(pool)
- return &DistributedLock{
+ return &RedisLockManager{
redsync: rs,
}, nil
}
@@ -133,10 +227,14 @@ func NewDistributedLock(conn *RedisConnection) (*DistributedLock, error) {
//
// Example:
//
-// err := lock.WithLock(ctx, "lock:user:password:123", func() error {
+// err := lock.WithLock(ctx, "lock:user:password:123", func(ctx context.Context) error {
// return updatePassword(123, newPassword)
// })
-func (dl *DistributedLock) WithLock(ctx context.Context, lockKey string, fn func() error) error {
+func (dl *RedisLockManager) WithLock(ctx context.Context, lockKey string, fn func(context.Context) error) error {
+ if dl == nil {
+ return nilLockAssert(ctx, "WithLock")
+ }
+
return dl.WithLockOptions(ctx, lockKey, DefaultLockOptions(), fn)
}
@@ -150,13 +248,34 @@ func (dl *DistributedLock) WithLock(ctx context.Context, lockKey string, fn func
// Tries: 5, // More aggressive retries
// RetryDelay: 1 * time.Second,
// }
-// err := lock.WithLockOptions(ctx, "lock:report:generation", opts, func() error {
+// err := lock.WithLockOptions(ctx, "lock:report:generation", opts, func(ctx context.Context) error {
// return generateReport()
// })
-func (dl *DistributedLock) WithLockOptions(ctx context.Context, lockKey string, opts LockOptions, fn func() error) error {
+func (dl *RedisLockManager) WithLockOptions(ctx context.Context, lockKey string, opts LockOptions, fn func(context.Context) error) error {
+ if dl == nil {
+ return nilLockAssert(ctx, "WithLockOptions")
+ }
+
+ if dl.redsync == nil {
+ return ErrLockNotInitialized
+ }
+
+ if fn == nil {
+ return ErrNilLockFn
+ }
+
+ if strings.TrimSpace(lockKey) == "" {
+ return ErrEmptyLockKey
+ }
+
+ if err := validateLockOptions(opts); err != nil {
+ return err
+ }
+
logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ safeLockKey := safeLockKeyForLogs(lockKey)
- ctx, span := tracer.Start(ctx, "distributed_lock.with_lock")
+ ctx, span := tracer.Start(ctx, "redis.lock.with_lock")
defer span.End()
// Create mutex with configured options
@@ -168,49 +287,55 @@ func (dl *DistributedLock) WithLockOptions(ctx context.Context, lockKey string,
redsync.WithDriftFactor(opts.DriftFactor),
)
- logger.Debugf("Attempting to acquire lock: %s", lockKey)
+ logger.Log(ctx, log.LevelDebug, "attempting to acquire lock", log.String("lock_key", safeLockKey))
// Try to acquire the lock
if err := mutex.LockContext(ctx); err != nil {
- logger.Errorf("Failed to acquire lock %s: %v", lockKey, err)
- opentelemetry.HandleSpanError(&span, "Failed to acquire lock", err)
+ logger.Log(ctx, log.LevelError, "failed to acquire lock", log.String("lock_key", safeLockKey), log.Err(err))
+ opentelemetry.HandleSpanError(span, "Failed to acquire lock", err)
- return fmt.Errorf("failed to acquire lock %s: %w", lockKey, err)
+ return fmt.Errorf("failed to acquire lock %s: %w", safeLockKey, err)
}
- logger.Debugf("Lock acquired: %s", lockKey)
+ logger.Log(ctx, log.LevelDebug, "lock acquired", log.String("lock_key", safeLockKey))
- // Ensure lock is released even if function panics
+ // Ensure lock is released even if function panics.
+ // Use a detached context with a timeout so that the unlock is not blocked
+ // by a cancelled/expired caller context — a failed unlock leaves a dangling
+ // lock until its expiry, which can stall other callers.
defer func() {
- if ok, err := mutex.UnlockContext(ctx); !ok || err != nil {
- logger.Errorf("Failed to release lock %s: ok=%v err=%v", lockKey, ok, err)
+ unlockCtx, unlockCancel := context.WithTimeout(context.Background(), unlockTimeout)
+ defer unlockCancel()
+
+ if ok, unlockErr := mutex.UnlockContext(unlockCtx); !ok || unlockErr != nil {
+ logger.Log(ctx, log.LevelError, "failed to release lock", log.String("lock_key", safeLockKey), log.Bool("unlock_ok", ok), log.Err(unlockErr))
} else {
- logger.Debugf("Lock released: %s", lockKey)
+ logger.Log(ctx, log.LevelDebug, "lock released", log.String("lock_key", safeLockKey))
}
}()
// Execute the function while holding the lock
- logger.Debugf("Executing function under lock: %s", lockKey)
+ logger.Log(ctx, log.LevelDebug, "executing function under lock", log.String("lock_key", safeLockKey))
- if err := fn(); err != nil {
- logger.Errorf("Function execution failed under lock %s: %v", lockKey, err)
- opentelemetry.HandleSpanError(&span, "Function execution failed", err)
+ if err := fn(ctx); err != nil {
+ logger.Log(ctx, log.LevelError, "function execution failed under lock", log.String("lock_key", safeLockKey), log.Err(err))
+ opentelemetry.HandleSpanError(span, "Function execution failed", err)
- return err
+ return fmt.Errorf("distributed lock: function execution: %w", err)
}
- logger.Debugf("Function completed successfully under lock: %s", lockKey)
+ logger.Log(ctx, log.LevelDebug, "function completed successfully under lock", log.String("lock_key", safeLockKey))
return nil
}
// TryLock attempts to acquire a lock without retrying.
-// Returns the mutex and true if lock was acquired, false if lock is busy.
+// Returns the handle and true if lock was acquired, nil and false if lock is busy.
// Returns an error for unexpected failures (network errors, context cancellation, etc.)
//
-// Use this when you want to skip the operation if the lock is busy:
+// Use LockHandle.Unlock to release the lock when done:
//
-// mutex, acquired, err := lock.TryLock(ctx, "lock:cache:refresh")
+// handle, acquired, err := lock.TryLock(ctx, "lock:cache:refresh")
// if err != nil {
// // Unexpected error (network, context cancellation, etc.) - should be propagated
// return fmt.Errorf("failed to attempt lock acquisition: %w", err)
@@ -219,67 +344,109 @@ func (dl *DistributedLock) WithLockOptions(ctx context.Context, lockKey string,
// logger.Info("Lock busy, skipping cache refresh")
// return nil
// }
-// defer lock.Unlock(ctx, mutex)
+// defer handle.Unlock(ctx)
// // Perform cache refresh...
-func (dl *DistributedLock) TryLock(ctx context.Context, lockKey string) (*redsync.Mutex, bool, error) {
+func (dl *RedisLockManager) TryLock(ctx context.Context, lockKey string) (LockHandle, bool, error) {
+ if dl == nil {
+ return nil, false, nilLockAssert(ctx, "TryLock")
+ }
+
+ if dl.redsync == nil {
+ return nil, false, ErrLockNotInitialized
+ }
+
+ if strings.TrimSpace(lockKey) == "" {
+ return nil, false, ErrEmptyLockKey
+ }
+
logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ safeLockKey := safeLockKeyForLogs(lockKey)
- ctx, span := tracer.Start(ctx, "distributed_lock.try_lock")
+ ctx, span := tracer.Start(ctx, "redis.lock.try_lock")
defer span.End()
+ defaultOpts := DefaultLockOptions()
+
mutex := dl.redsync.NewMutex(
lockKey,
- redsync.WithExpiry(10*time.Second),
+ redsync.WithExpiry(defaultOpts.Expiry),
redsync.WithTries(1), // Only try once
)
if err := mutex.LockContext(ctx); err != nil {
- // Check if this is a lock contention error (expected behavior)
- // redsync returns different error messages for lock contention:
- // - "lock already taken" when another process holds the lock
- // - "redsync: failed to acquire lock" as the base error
- errMsg := err.Error()
- isLockContention := errors.Is(err, redsync.ErrFailed) ||
- strings.Contains(errMsg, "lock already taken") ||
- strings.Contains(errMsg, "failed to acquire lock")
+ // Classify lock contention vs infrastructure faults using redsync's
+ // typed sentinels rather than string matching. ErrFailed is returned
+ // when all retries are exhausted and ErrTaken when the lock is held
+ // on a quorum of nodes — both indicate normal contention.
+ var errTaken *redsync.ErrTaken
+
+ isLockContention := errors.Is(err, redsync.ErrFailed) || errors.As(err, &errTaken)
if isLockContention {
- logger.Debugf("Could not acquire lock %s as it is already held by another process", lockKey)
+ logger.Log(ctx, log.LevelDebug, "lock already held by another process", log.String("lock_key", safeLockKey))
return nil, false, nil
}
- // Any other error (e.g., network, context cancellation) is an actual failure
- // and should be propagated to the caller.
- logger.Debugf("Could not acquire lock %s: %v", lockKey, err)
- opentelemetry.HandleSpanError(&span, "Failed to attempt lock acquisition", err)
+ // Any other error (e.g., network, context cancellation, RedisError)
+ // is an actual infrastructure fault and must be propagated.
+ logger.Log(ctx, log.LevelDebug, "could not acquire lock", log.String("lock_key", safeLockKey), log.Err(err))
+ opentelemetry.HandleSpanError(span, "Failed to attempt lock acquisition", err)
- return nil, false, fmt.Errorf("failed to attempt lock acquisition for %s: %w", lockKey, err)
+ return nil, false, fmt.Errorf("failed to attempt lock acquisition for %s: %w", safeLockKey, err)
}
- logger.Debugf("Lock acquired: %s", lockKey)
+ logger.Log(ctx, log.LevelDebug, "lock acquired", log.String("lock_key", safeLockKey))
- return mutex, true, nil
+ return &lockHandle{mutex: mutex, logger: logger}, true, nil
}
// Unlock releases a previously acquired lock.
-// This is only needed if you use TryLock(). WithLock() handles unlocking automatically.
-func (dl *DistributedLock) Unlock(ctx context.Context, mutex *redsync.Mutex) error {
- logger := libCommons.NewLoggerFromContext(ctx)
+//
+// Deprecated: Use LockHandle.Unlock() directly instead. This method is provided
+// for backward compatibility during migration from the old *redsync.Mutex-based API.
+func (dl *RedisLockManager) Unlock(ctx context.Context, handle LockHandle) error {
+ if dl == nil {
+ return nilLockAssert(ctx, "Unlock")
+ }
- if mutex == nil {
- return fmt.Errorf("mutex is nil")
+ if handle == nil {
+ return ErrNilLockHandleOnUnlock
}
- ok, err := mutex.UnlockContext(ctx)
- if err != nil {
- logger.Errorf("Failed to unlock mutex: %v", err)
- return err
+ return handle.Unlock(ctx)
+}
+
+func validateLockOptions(opts LockOptions) error {
+ if opts.Expiry <= 0 {
+ return ErrLockExpiryInvalid
}
- if !ok {
- logger.Warnf("Mutex was not locked or already expired")
- return fmt.Errorf("mutex was not locked")
+ if opts.Tries < 1 {
+ return ErrLockTriesInvalid
+ }
+
+ if opts.Tries > maxLockTries {
+ return ErrLockTriesExceeded
+ }
+
+ if opts.RetryDelay < 0 {
+ return ErrLockRetryDelayNegative
+ }
+
+ if opts.DriftFactor < 0 || opts.DriftFactor >= 1 {
+ return ErrLockDriftFactorInvalid
}
return nil
}
+
+func safeLockKeyForLogs(lockKey string) string {
+ const maxLockKeyLogLength = 128
+
+ safeLockKey := strconv.QuoteToASCII(lockKey)
+ if len(safeLockKey) <= maxLockKeyLogLength {
+ return safeLockKey
+ }
+
+ return safeLockKey[:maxLockKeyLogLength] + "...(truncated)"
+}
diff --git a/commons/redis/lock_integration_test.go b/commons/redis/lock_integration_test.go
new file mode 100644
index 00000000..3c198652
--- /dev/null
+++ b/commons/redis/lock_integration_test.go
@@ -0,0 +1,455 @@
+//go:build integration
+
+package redis
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestIntegration_Lock_MutualExclusion verifies that WithLockOptions enforces
+// mutual exclusion: 10 goroutines compete for the same lock key, but only one
+// at a time may enter the critical section. An atomic counter tracks the
+// maximum observed concurrency inside the lock—must be exactly 1—and total
+// completed executions—must be exactly 10.
+func TestIntegration_Lock_MutualExclusion(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, client.Close()) }()
+
+ lockMgr, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ const goroutines = 10
+ const lockKey = "integration:mutex:exclusion"
+
+ opts := LockOptions{
+ Expiry: 5 * time.Second,
+ Tries: 50,
+ RetryDelay: 50 * time.Millisecond,
+ DriftFactor: 0.01,
+ }
+
+ var (
+ totalExecutions atomic.Int64
+ maxConcurrent atomic.Int64
+ currentInside atomic.Int64
+ wg sync.WaitGroup
+ )
+
+ errs := make(chan error, goroutines)
+
+ wg.Add(goroutines)
+
+ for i := range goroutines {
+ go func(id int) {
+ defer wg.Done()
+
+ lockErr := lockMgr.WithLockOptions(ctx, lockKey, opts, func(_ context.Context) error {
+ // Track how many goroutines are inside the critical section right now.
+ cur := currentInside.Add(1)
+
+ // Atomically update the observed maximum.
+ for {
+ prev := maxConcurrent.Load()
+ if cur <= prev {
+ break
+ }
+
+ if maxConcurrent.CompareAndSwap(prev, cur) {
+ break
+ }
+ }
+
+ // Simulate work so goroutines overlap in wall-clock time.
+ time.Sleep(10 * time.Millisecond)
+
+ currentInside.Add(-1)
+ totalExecutions.Add(1)
+
+ return nil
+ })
+ if lockErr != nil {
+ errs <- fmt.Errorf("goroutine %d: WithLockOptions: %w", id, lockErr)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for e := range errs {
+ t.Error(e)
+ }
+
+ assert.Equal(t, int64(1), maxConcurrent.Load(), "at most 1 goroutine may be inside the critical section at any time")
+ assert.Equal(t, int64(goroutines), totalExecutions.Load(), "all goroutines must complete their execution")
+}
+
+// TestIntegration_Lock_TryLock_Contention verifies the non-blocking TryLock:
+// - Goroutine A acquires the lock; goroutine B's immediate TryLock must fail.
+// - After A unlocks, B retries and succeeds.
+func TestIntegration_Lock_TryLock_Contention(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, client.Close()) }()
+
+ lockMgr, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ const lockKey = "integration:trylock:contention"
+
+ // Goroutine A acquires the lock.
+ handleA, acquiredA, err := lockMgr.TryLock(ctx, lockKey)
+ require.NoError(t, err)
+ require.True(t, acquiredA, "A must acquire the lock")
+ require.NotNil(t, handleA)
+
+ // Goroutine B tries to acquire the same lock — must fail because A holds it.
+ _, acquiredB, err := lockMgr.TryLock(ctx, lockKey)
+ require.NoError(t, err)
+ assert.False(t, acquiredB, "B must NOT acquire the lock while A holds it")
+
+ // A releases the lock.
+ require.NoError(t, handleA.Unlock(ctx))
+
+ // B retries — should succeed now.
+ handleB, acquiredB2, err := lockMgr.TryLock(ctx, lockKey)
+ require.NoError(t, err)
+ assert.True(t, acquiredB2, "B must acquire the lock after A releases it")
+ require.NotNil(t, handleB)
+
+ require.NoError(t, handleB.Unlock(ctx))
+}
+
+// TestIntegration_Lock_Expiry tests two scenarios:
+// 1. WithLockOptions with short expiry: fn completes quickly, lock is released
+// explicitly → re-acquire must succeed immediately.
+// 2. TryLock without explicit unlock: wait beyond the TTL → re-acquire must
+// succeed because the lock auto-expired.
+func TestIntegration_Lock_Expiry(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, client.Close()) }()
+
+ lockMgr, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ // --- Scenario 1: WithLockOptions completes and releases, re-acquire succeeds ---
+ const lockKey1 = "integration:expiry:withopts"
+
+ opts := LockOptions{
+ Expiry: 2 * time.Second,
+ Tries: 1,
+ RetryDelay: 50 * time.Millisecond,
+ DriftFactor: 0.01,
+ }
+
+ err = lockMgr.WithLockOptions(ctx, lockKey1, opts, func(_ context.Context) error {
+ time.Sleep(100 * time.Millisecond)
+ return nil
+ })
+ require.NoError(t, err)
+
+ // Lock was released by WithLockOptions defer — re-acquire must succeed.
+ handle, acquired, err := lockMgr.TryLock(ctx, lockKey1)
+ require.NoError(t, err)
+ assert.True(t, acquired, "re-acquire after WithLockOptions must succeed")
+
+ if handle != nil {
+ require.NoError(t, handle.Unlock(ctx))
+ }
+
+ // --- Scenario 2: TryLock without explicit unlock, wait for TTL expiry ---
+ const lockKey2 = "integration:expiry:ttl"
+
+ handleTTL, acquired, err := lockMgr.TryLock(ctx, lockKey2)
+ require.NoError(t, err)
+ require.True(t, acquired, "first TryLock must succeed")
+ require.NotNil(t, handleTTL)
+
+ // Intentionally do NOT unlock. The default TryLock expiry is 10s.
+ // We need a shorter TTL, so we use WithLockOptions to acquire with 2s expiry,
+ // but TryLock uses defaults. Instead, acquire via WithLockOptions with 2s expiry
+ // and leak the lock by not calling the returned handle.
+ // Since TryLock uses DefaultLockOptions (10s), unlock first, then re-acquire
+ // with a short-lived custom approach.
+ require.NoError(t, handleTTL.Unlock(ctx))
+
+ // Acquire with short expiry via WithLockOptions, but simulate a crash by
+ // setting the fn to do nothing—the defer in WithLockOptions will unlock.
+ // Instead, use a direct TryLock approach: acquire, don't unlock, wait for
+ // the default 10s TTL. That's too long for a test. So we test the concept
+ // by using a raw redsync mutex with short expiry through the public API:
+ // acquire via WithLockOptions where we just sleep past the entire expiry.
+ // Actually, the cleanest approach: acquire via TryLock (10s default expiry),
+ // don't unlock, wait 11s. But that's slow. Let's verify the TTL concept
+ // with a 3-second key using the internal observation that TryLock uses
+ // DefaultLockOptions which has 10s expiry.
+ //
+ // Pragmatic approach: acquire via WithLockOptions with 2s expiry where the fn
+ // takes longer than 2s — the lock will auto-expire in Redis before fn returns.
+ // After fn returns, attempt to re-acquire the same key immediately.
+ const lockKey3 = "integration:expiry:auto"
+
+ shortOpts := LockOptions{
+ Expiry: 2 * time.Second,
+ Tries: 1,
+ RetryDelay: 50 * time.Millisecond,
+ DriftFactor: 0.01,
+ }
+
+ // Acquire with 2s TTL, then deliberately do NOT release (simulate the
+ // unlock failing because the TTL already expired).
+ // We use TryLock indirectly by locking with WithLockOptions where fn
+ // takes 3s — the lock expires after 2s while fn is still running.
+ // WithLockOptions' defer unlock will silently fail (lock expired), and
+ // the error from the fn (nil) propagates.
+ err = lockMgr.WithLockOptions(ctx, lockKey3, shortOpts, func(_ context.Context) error {
+ // Sleep past the 2s expiry — the Redis key will expire mid-fn.
+ time.Sleep(3 * time.Second)
+ return nil
+ })
+ // The fn itself returns nil, but the defer Unlock may log a warning
+ // (lock not held). WithLockOptions returns fn's error, which is nil.
+ require.NoError(t, err)
+
+ // The lock has already auto-expired — re-acquire must succeed.
+ handleAfterExpiry, acquired, err := lockMgr.TryLock(ctx, lockKey3)
+ require.NoError(t, err)
+ assert.True(t, acquired, "re-acquire after TTL expiry must succeed")
+
+ if handleAfterExpiry != nil {
+ require.NoError(t, handleAfterExpiry.Unlock(ctx))
+ }
+}
+
+// TestIntegration_Lock_RateLimiterPreset verifies that RateLimiterLockOptions()
+// produces a usable configuration against real Redis: acquire, execute, release,
+// then re-acquire (proving the short 2s expiry preset is functional).
+func TestIntegration_Lock_RateLimiterPreset(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, client.Close()) }()
+
+ lockMgr, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ const lockKey = "integration:ratelimiter:preset"
+ opts := RateLimiterLockOptions()
+
+ // First acquire + execute + auto-release.
+ executed := false
+
+ err = lockMgr.WithLockOptions(ctx, lockKey, opts, func(_ context.Context) error {
+ executed = true
+ return nil
+ })
+ require.NoError(t, err)
+ assert.True(t, executed, "fn must have been executed under the rate-limiter lock")
+
+ // Second acquire — must succeed because the first was properly released.
+ executed2 := false
+
+ err = lockMgr.WithLockOptions(ctx, lockKey, opts, func(_ context.Context) error {
+ executed2 = true
+ return nil
+ })
+ require.NoError(t, err)
+ assert.True(t, executed2, "second acquire must succeed after first release")
+}
+
+// TestIntegration_Lock_ConcurrentDifferentKeys verifies that locks on distinct
+// keys do not block each other: 5 goroutines, each locking a unique key, must
+// all complete within a tight timeout.
+func TestIntegration_Lock_ConcurrentDifferentKeys(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, client.Close()) }()
+
+ lockMgr, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ const goroutines = 5
+
+ var (
+ wg sync.WaitGroup
+ completions atomic.Int64
+ )
+
+ errs := make(chan error, goroutines)
+
+ wg.Add(goroutines)
+
+ start := time.Now()
+
+ for i := range goroutines {
+ go func(id int) {
+ defer wg.Done()
+
+ key := fmt.Sprintf("integration:concurrent:key:%d", id)
+
+ lockErr := lockMgr.WithLock(ctx, key, func(_ context.Context) error {
+ // Each goroutine does a small amount of work.
+ time.Sleep(50 * time.Millisecond)
+ completions.Add(1)
+
+ return nil
+ })
+ if lockErr != nil {
+ errs <- fmt.Errorf("goroutine %d: WithLock: %w", id, lockErr)
+ }
+ }(i)
+ }
+
+ wg.Wait()
+ close(errs)
+
+ elapsed := time.Since(start)
+
+ for e := range errs {
+ t.Error(e)
+ }
+
+ assert.Equal(t, int64(goroutines), completions.Load(), "all goroutines must complete")
+ assert.Less(t, elapsed, 2*time.Second, "concurrent different-key locks should complete well under 2s")
+}
+
+// TestIntegration_Lock_WithLock_ErrorPropagation verifies that:
+// 1. An error returned by fn propagates through WithLock.
+// 2. The lock is released even when fn returns an error (so another caller can
+// acquire the same key immediately).
+func TestIntegration_Lock_WithLock_ErrorPropagation(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, client.Close()) }()
+
+ lockMgr, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ const lockKey = "integration:errorprop:key"
+
+ sentinelErr := errors.New("business logic failed")
+
+ err = lockMgr.WithLock(ctx, lockKey, func(_ context.Context) error {
+ return sentinelErr
+ })
+ require.Error(t, err)
+ assert.ErrorIs(t, err, sentinelErr, "fn's error must propagate through WithLock")
+
+ // The lock must have been released by WithLockOptions' defer, so TryLock
+ // on the same key must succeed.
+ handle, acquired, err := lockMgr.TryLock(ctx, lockKey)
+ require.NoError(t, err)
+ assert.True(t, acquired, "lock must be released even when fn returns an error")
+
+ if handle != nil {
+ require.NoError(t, handle.Unlock(ctx))
+ }
+}
+
+// TestIntegration_Lock_ContextCancellation verifies that a waiting locker
+// respects context cancellation. Goroutine A holds the lock; goroutine B
+// attempts WithLockOptions with a short-lived context. B should fail with a
+// context-related error before exhausting its retries.
+func TestIntegration_Lock_ContextCancellation(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+ defer func() { require.NoError(t, client.Close()) }()
+
+ lockMgr, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ const lockKey = "integration:ctxcancel:key"
+
+ // Goroutine A: hold the lock for a long time.
+ aReady := make(chan struct{})
+ aDone := make(chan struct{})
+
+ go func() {
+ opts := LockOptions{
+ Expiry: 10 * time.Second,
+ Tries: 1,
+ RetryDelay: 50 * time.Millisecond,
+ DriftFactor: 0.01,
+ }
+
+ lockErr := lockMgr.WithLockOptions(ctx, lockKey, opts, func(_ context.Context) error {
+ close(aReady) // Signal that A holds the lock.
+ <-aDone // Wait until the test tells us to release.
+ return nil
+ })
+ // A might error if the test context is cancelled, which is fine.
+ _ = lockErr
+ }()
+
+ // Wait for A to acquire the lock.
+ select {
+ case <-aReady:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for goroutine A to acquire the lock")
+ }
+
+ // Goroutine B: attempt to acquire the same lock with a 200ms timeout context.
+ ctxB, cancelB := context.WithTimeout(ctx, 200*time.Millisecond)
+ defer cancelB()
+
+ bOpts := LockOptions{
+ Expiry: 5 * time.Second,
+ Tries: 100,
+ RetryDelay: 50 * time.Millisecond,
+ DriftFactor: 0.01,
+ }
+
+ err = lockMgr.WithLockOptions(ctxB, lockKey, bOpts, func(_ context.Context) error {
+ t.Error("B's fn must never execute — the lock should not be acquired")
+ return nil
+ })
+ require.Error(t, err, "B must fail because the context timed out")
+
+ // Release A so the goroutine can exit cleanly.
+ close(aDone)
+}
diff --git a/commons/redis/lock_interface.go b/commons/redis/lock_interface.go
index f6f4cceb..20f6a2f7 100644
--- a/commons/redis/lock_interface.go
+++ b/commons/redis/lock_interface.go
@@ -1,46 +1,61 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package redis
import (
"context"
-
- "github.com/go-redsync/redsync/v4"
)
-// DistributedLocker provides an interface for distributed locking operations.
+// LockHandle represents an acquired distributed lock.
+// It is obtained from TryLock and must be released via its Unlock method.
+//
+// Example usage:
+//
+// handle, acquired, err := locker.TryLock(ctx, "lock:resource:123")
+// if err != nil {
+// return err
+// }
+// if !acquired {
+// return nil // lock busy, skip
+// }
+// defer handle.Unlock(ctx)
+// // ... critical section ...
+type LockHandle interface {
+ // Unlock releases the distributed lock.
+ Unlock(ctx context.Context) error
+}
+
+// LockManager provides an interface for distributed locking operations.
// This interface allows for easy mocking in tests without requiring a real Redis instance.
//
// Example test implementation:
//
-// type MockDistributedLock struct{}
+// type MockLockManager struct{}
//
-// func (m *MockDistributedLock) WithLock(ctx context.Context, lockKey string, fn func() error) error {
+// func (m *MockLockManager) WithLock(ctx context.Context, lockKey string, fn func(context.Context) error) error {
// // In tests, just execute the function without actual locking
-// return fn()
+// return fn(ctx)
// }
//
-// func (m *MockDistributedLock) WithLockOptions(ctx context.Context, lockKey string, opts LockOptions, fn func() error) error {
-// return fn()
+// func (m *MockLockManager) WithLockOptions(ctx context.Context, lockKey string, opts LockOptions, fn func(context.Context) error) error {
+// return fn(ctx)
// }
-type DistributedLocker interface {
+//
+// func (m *MockLockManager) TryLock(ctx context.Context, lockKey string) (LockHandle, bool, error) {
+// return &mockHandle{}, true, nil
+// }
+type LockManager interface {
// WithLock executes a function while holding a distributed lock with default options.
// The lock is automatically released when the function returns.
- WithLock(ctx context.Context, lockKey string, fn func() error) error
+ WithLock(ctx context.Context, lockKey string, fn func(context.Context) error) error
// WithLockOptions executes a function while holding a distributed lock with custom options.
// Use this for fine-grained control over lock behavior.
- WithLockOptions(ctx context.Context, lockKey string, opts LockOptions, fn func() error) error
+ WithLockOptions(ctx context.Context, lockKey string, opts LockOptions, fn func(context.Context) error) error
// TryLock attempts to acquire a lock without retrying.
- // Returns the mutex and true if lock was acquired, nil and false otherwise.
- TryLock(ctx context.Context, lockKey string) (*redsync.Mutex, bool, error)
-
- // Unlock releases a previously acquired lock (used with TryLock).
- Unlock(ctx context.Context, mutex *redsync.Mutex) error
+ // Returns the handle and true if lock was acquired, nil and false otherwise.
+ // Use LockHandle.Unlock to release the lock when done.
+ TryLock(ctx context.Context, lockKey string) (LockHandle, bool, error)
}
-// Ensure DistributedLock implements DistributedLocker interface at compile time
-var _ DistributedLocker = (*DistributedLock)(nil)
+// Ensure RedisLockManager implements LockManager interface at compile time.
+var _ LockManager = (*RedisLockManager)(nil)
diff --git a/commons/redis/lock_test.go b/commons/redis/lock_test.go
index 9eac5dc4..060a7ded 100644
--- a/commons/redis/lock_test.go
+++ b/commons/redis/lock_test.go
@@ -1,482 +1,1054 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package redis
import (
"context"
- "fmt"
+ "errors"
+ "strings"
"sync"
"sync/atomic"
"testing"
"time"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
"github.com/alicebob/miniredis/v2"
- "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-// setupTestRedis creates a miniredis server for testing
-func setupTestRedis(t *testing.T) (*RedisConnection, func()) {
+func setupTestClient(t *testing.T) *Client {
+ t.Helper()
+
mr := miniredis.RunT(t)
- conn := &RedisConnection{
- Address: []string{mr.Addr()},
- DB: 0,
- }
+ client, err := New(context.Background(), Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: mr.Addr()},
+ },
+ Logger: &log.NopLogger{},
+ })
+ require.NoError(t, err)
- client := redis.NewClient(&redis.Options{
- Addr: mr.Addr(),
+ t.Cleanup(func() {
+ require.NoError(t, client.Close())
+ mr.Close()
})
- conn.Client = client
- conn.Connected = true
+ return client
+}
- cleanup := func() {
- client.Close()
- mr.Close()
- }
+// setupTestLock creates a Redis client and RedisLockManager for testing.
+func setupTestLock(t *testing.T) (*Client, *RedisLockManager) {
+ t.Helper()
- return conn, cleanup
+ client := setupTestClient(t)
+
+ lock, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ return client, lock
}
-// TestDistributedLock_WithLock tests basic locking functionality
-func TestDistributedLock_WithLock(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_WithLock(t *testing.T) {
+ client := setupTestClient(t)
- lock, err := NewDistributedLock(conn)
+ lock, err := NewRedisLockManager(client)
require.NoError(t, err)
- ctx := context.Background()
executed := false
-
- err = lock.WithLock(ctx, "test:lock", func() error {
+ err = lock.WithLock(context.Background(), "test:lock", func(context.Context) error {
executed = true
return nil
})
- assert.NoError(t, err)
- assert.True(t, executed, "function should have been executed")
+ require.NoError(t, err)
+ assert.True(t, executed)
}
-// TestDistributedLock_WithLock_Error tests error propagation
-func TestDistributedLock_WithLock_Error(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_WithLock_ErrorPropagation(t *testing.T) {
+ client := setupTestClient(t)
- lock, err := NewDistributedLock(conn)
+ lock, err := NewRedisLockManager(client)
require.NoError(t, err)
- ctx := context.Background()
- expectedErr := assert.AnError
-
- err = lock.WithLock(ctx, "test:lock", func() error {
+ expectedErr := errors.New("boom")
+ err = lock.WithLock(context.Background(), "test:lock", func(context.Context) error {
return expectedErr
})
- assert.Error(t, err)
- assert.Equal(t, expectedErr, err)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, expectedErr)
}
-// TestDistributedLock_ConcurrentExecution tests that locks prevent concurrent execution
-func TestDistributedLock_ConcurrentExecution(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_ConcurrentExecutionSingleKey(t *testing.T) {
+ client := setupTestClient(t)
- lock, err := NewDistributedLock(conn)
+ lock, err := NewRedisLockManager(client)
require.NoError(t, err)
ctx := context.Background()
- var counter int32
- var maxConcurrent int32
var currentConcurrent int32
+ var maxConcurrent int32
+ var total int32
- const numGoroutines = 10
-
- // Use more patient lock options for testing
opts := LockOptions{
Expiry: 5 * time.Second,
- Tries: 50, // Many retries to ensure all goroutines get a chance
- RetryDelay: 50 * time.Millisecond,
+ Tries: 50,
+ RetryDelay: 20 * time.Millisecond,
DriftFactor: 0.01,
}
+ const workers = 10
+ errCh := make(chan error, workers)
var wg sync.WaitGroup
- wg.Add(numGoroutines)
+ wg.Add(workers)
- for range numGoroutines {
+ for range workers {
go func() {
defer wg.Done()
- err := lock.WithLockOptions(ctx, "test:concurrent:lock", opts, func() error {
- // Track concurrent executions
- concurrent := atomic.AddInt32(¤tConcurrent, 1)
- if concurrent > atomic.LoadInt32(&maxConcurrent) {
- atomic.StoreInt32(&maxConcurrent, concurrent)
+ err := lock.WithLockOptions(ctx, "test:concurrent", opts, func(context.Context) error {
+ active := atomic.AddInt32(¤tConcurrent, 1)
+ if active > atomic.LoadInt32(&maxConcurrent) {
+ atomic.StoreInt32(&maxConcurrent, active)
}
- // Increment counter
- atomic.AddInt32(&counter, 1)
-
- // Simulate work
- time.Sleep(10 * time.Millisecond)
-
- // Decrement concurrent counter
+ atomic.AddInt32(&total, 1)
+ time.Sleep(5 * time.Millisecond)
atomic.AddInt32(¤tConcurrent, -1)
return nil
})
-
- assert.NoError(t, err)
+ if err != nil {
+ errCh <- err
+ }
}()
}
wg.Wait()
+ close(errCh)
+
+ for err := range errCh {
+ require.NoError(t, err)
+ }
- assert.Equal(t, int32(numGoroutines), counter, "all goroutines should have executed")
- assert.Equal(t, int32(1), maxConcurrent, "at most 1 goroutine should execute concurrently")
+ assert.Equal(t, int32(workers), total)
+ assert.Equal(t, int32(1), maxConcurrent)
}
-// TestDistributedLock_TryLock tests non-blocking lock acquisition
-func TestDistributedLock_TryLock(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_TryLock_Contention(t *testing.T) {
+ client := setupTestClient(t)
- lock, err := NewDistributedLock(conn)
+ lock, err := NewRedisLockManager(client)
require.NoError(t, err)
ctx := context.Background()
- // First lock should succeed
- mutex1, acquired1, err1 := lock.TryLock(ctx, "test:trylock")
- assert.NoError(t, err1)
- assert.True(t, acquired1, "first lock should be acquired")
- assert.NotNil(t, mutex1)
-
- if acquired1 {
- defer lock.Unlock(ctx, mutex1)
- }
+ handle1, acquired, err := lock.TryLock(ctx, "test:contention")
+ require.NoError(t, err)
+ require.True(t, acquired)
+ require.NotNil(t, handle1)
+ defer func() {
+ require.NoError(t, handle1.Unlock(ctx))
+ }()
- // Second lock should fail (already held)
- mutex2, acquired2, err2 := lock.TryLock(ctx, "test:trylock")
- assert.NoError(t, err2)
- assert.False(t, acquired2, "second lock should not be acquired")
- assert.Nil(t, mutex2)
+ handle2, acquired, err := lock.TryLock(ctx, "test:contention")
+ require.NoError(t, err)
+ assert.False(t, acquired)
+ assert.Nil(t, handle2)
}
-// TestDistributedLock_WithLockOptions tests custom lock options
-func TestDistributedLock_WithLockOptions(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_PanicRecovery(t *testing.T) {
+ client := setupTestClient(t)
- lock, err := NewDistributedLock(conn)
+ lock, err := NewRedisLockManager(client)
require.NoError(t, err)
ctx := context.Background()
+
+ require.Panics(t, func() {
+ _ = lock.WithLock(ctx, "test:panic", func(context.Context) error {
+ panic("panic inside lock")
+ })
+ })
+
executed := false
+ err = lock.WithLock(ctx, "test:panic", func(context.Context) error {
+ executed = true
+ return nil
+ })
+
+ require.NoError(t, err)
+ assert.True(t, executed)
+}
+
+func TestRedisLockManager_NilAndInitGuards(t *testing.T) {
+ t.Run("new lock with nil client", func(t *testing.T) {
+ lock, err := NewRedisLockManager(nil)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrNilClient)
+ assert.Nil(t, lock)
+ })
+
+ t.Run("nil receiver", func(t *testing.T) {
+ var dl *RedisLockManager
+ ctx := context.Background()
+
+ err := dl.WithLock(ctx, "test:key", func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "lock manager is nil")
+
+ err = dl.WithLockOptions(ctx, "test:key", DefaultLockOptions(), func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "lock manager is nil")
+
+ handle, acquired, err := dl.TryLock(ctx, "test:key")
+ assert.ErrorContains(t, err, "lock manager is nil")
+ assert.Nil(t, handle)
+ assert.False(t, acquired)
+
+ err = dl.Unlock(ctx, nil)
+ assert.ErrorContains(t, err, "lock manager is nil")
+ })
+
+ t.Run("zero value lock is rejected", func(t *testing.T) {
+ dl := &RedisLockManager{}
+ ctx := context.Background()
+
+ err := dl.WithLockOptions(ctx, "test:key", DefaultLockOptions(), func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "distributed lock is not initialized")
+
+ handle, acquired, err := dl.TryLock(ctx, "test:key")
+ assert.ErrorContains(t, err, "distributed lock is not initialized")
+ assert.Nil(t, handle)
+ assert.False(t, acquired)
+ })
+}
+
+func TestRedisLockManager_OptionValidation(t *testing.T) {
+ client := setupTestClient(t)
+
+ lock, err := NewRedisLockManager(client)
+ require.NoError(t, err)
+
+ err = lock.WithLockOptions(context.Background(), "", DefaultLockOptions(), func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "lock key cannot be empty")
+
+ err = lock.WithLockOptions(context.Background(), "test:key", DefaultLockOptions(), nil)
+ assert.ErrorIs(t, err, ErrNilLockFn)
+
+ err = lock.WithLockOptions(context.Background(), "test:key", LockOptions{
+ Expiry: 0,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ }, func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "lock expiry must be greater than 0")
+
+ err = lock.WithLockOptions(context.Background(), "test:key", LockOptions{
+ Expiry: time.Second,
+ Tries: 0,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ }, func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "lock tries must be at least 1")
+
+ err = lock.WithLockOptions(context.Background(), "test:key", LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: -time.Millisecond,
+ DriftFactor: 0.01,
+ }, func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "lock retry delay cannot be negative")
+
+ err = lock.WithLockOptions(context.Background(), "test:key", LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 1,
+ }, func(context.Context) error { return nil })
+ assert.ErrorContains(t, err, "lock drift factor")
+
+ // Tries exceeding max cap (1001 > maxLockTries=1000)
+ err = lock.WithLockOptions(context.Background(), "test:key", LockOptions{
+ Expiry: time.Second,
+ Tries: 1001,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ }, func(context.Context) error { return nil })
+ assert.ErrorIs(t, err, ErrLockTriesExceeded)
+}
+
+func TestLockOptionFactories(t *testing.T) {
+ defaultOpts := DefaultLockOptions()
+ assert.Equal(t, 10*time.Second, defaultOpts.Expiry)
+ assert.Equal(t, 3, defaultOpts.Tries)
+ assert.Equal(t, 500*time.Millisecond, defaultOpts.RetryDelay)
+ assert.Equal(t, 0.01, defaultOpts.DriftFactor)
+
+ rateLimiterOpts := RateLimiterLockOptions()
+ assert.Equal(t, 2*time.Second, rateLimiterOpts.Expiry)
+ assert.Equal(t, 2, rateLimiterOpts.Tries)
+ assert.Equal(t, 100*time.Millisecond, rateLimiterOpts.RetryDelay)
+ assert.Equal(t, 0.01, rateLimiterOpts.DriftFactor)
+}
+
+func TestSafeLockKeyForLogs(t *testing.T) {
+ safe := safeLockKeyForLogs("lock:tenant\n123")
+ assert.NotContains(t, safe, "\n")
+ assert.Contains(t, safe, "\\n")
+
+ longKey := strings.Repeat("a", 1024)
+ safeLong := safeLockKeyForLogs(longKey)
+ assert.Contains(t, safeLong, "...(truncated)")
+}
+
+// --- New comprehensive test coverage below ---
+
+func TestRedisLockManager_WithLock_ContextPassedToFn(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ type ctxKey string
+ key := ctxKey("trace-id")
+
+ ctx := context.WithValue(context.Background(), key, "abc-123")
+
+ err := lock.WithLock(ctx, "test:ctx-pass", func(ctx context.Context) error {
+ val, ok := ctx.Value(key).(string)
+ assert.True(t, ok)
+ assert.Equal(t, "abc-123", val)
+
+ return nil
+ })
+
+ require.NoError(t, err)
+}
+
+func TestRedisLockManager_WithLockOptions_CustomExpiry(t *testing.T) {
+ _, lock := setupTestLock(t)
opts := LockOptions{
- Expiry: 5 * time.Second,
- Tries: 5,
- RetryDelay: 100 * time.Millisecond,
+ Expiry: 30 * time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
DriftFactor: 0.01,
}
- err = lock.WithLockOptions(ctx, "test:lock:options", opts, func() error {
+ executed := false
+
+ err := lock.WithLockOptions(context.Background(), "test:custom-opts", opts, func(context.Context) error {
executed = true
return nil
})
- assert.NoError(t, err)
- assert.True(t, executed, "function should have been executed")
+ require.NoError(t, err)
+ assert.True(t, executed)
}
-// TestDistributedLock_DefaultLockOptions tests default options
-func TestDistributedLock_DefaultLockOptions(t *testing.T) {
- opts := DefaultLockOptions()
+func TestRedisLockManager_WithLock_WhitespaceOnlyKey(t *testing.T) {
+ _, lock := setupTestLock(t)
- assert.Equal(t, 10*time.Second, opts.Expiry)
- assert.Equal(t, 3, opts.Tries)
- assert.Equal(t, 500*time.Millisecond, opts.RetryDelay)
- assert.Equal(t, 0.01, opts.DriftFactor)
+ err := lock.WithLock(context.Background(), " ", func(context.Context) error {
+ return nil
+ })
+
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock key cannot be empty")
}
-// TestDistributedLock_Unlock tests explicit unlocking
-func TestDistributedLock_Unlock(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_WithLock_TabAndNewlineKey(t *testing.T) {
+ _, lock := setupTestLock(t)
- lock, err := NewDistributedLock(conn)
- require.NoError(t, err)
+ err := lock.WithLock(context.Background(), "\t\n", func(context.Context) error {
+ return nil
+ })
+
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock key cannot be empty")
+}
+
+func TestRedisLockManager_TryLock_EmptyKey(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ handle, acquired, err := lock.TryLock(context.Background(), "")
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock key cannot be empty")
+ assert.False(t, acquired)
+ assert.Nil(t, handle)
+}
+
+func TestRedisLockManager_TryLock_WhitespaceOnlyKey(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ handle, acquired, err := lock.TryLock(context.Background(), " ")
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock key cannot be empty")
+ assert.False(t, acquired)
+ assert.Nil(t, handle)
+}
+
+func TestRedisLockManager_TryLock_SuccessfulAcquireAndRelease(t *testing.T) {
+ _, lock := setupTestLock(t)
ctx := context.Background()
- mutex, acquired, err := lock.TryLock(ctx, "test:unlock")
+ handle, acquired, err := lock.TryLock(ctx, "test:try-success")
require.NoError(t, err)
require.True(t, acquired)
- require.NotNil(t, mutex)
+ require.NotNil(t, handle)
- // Unlock should succeed
- err = lock.Unlock(ctx, mutex)
- assert.NoError(t, err)
+ // Release the lock via LockHandle
+ err = handle.Unlock(ctx)
+ require.NoError(t, err)
- // After unlock, another lock should be acquirable
- mutex2, acquired2, err2 := lock.TryLock(ctx, "test:unlock")
- assert.NoError(t, err2)
+ // Lock should be available again
+ handle2, acquired2, err := lock.TryLock(ctx, "test:try-success")
+ require.NoError(t, err)
assert.True(t, acquired2)
- assert.NotNil(t, mutex2)
+ assert.NotNil(t, handle2)
- if acquired2 {
- lock.Unlock(ctx, mutex2)
- }
+ // Clean up
+ require.NoError(t, handle2.Unlock(ctx))
}
-// TestDistributedLock_NilMutexUnlock tests error handling for nil mutex
-func TestDistributedLock_NilMutexUnlock(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_TryLock_DifferentKeysNoContention(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ ctx := context.Background()
- lock, err := NewDistributedLock(conn)
+ handle1, acquired1, err := lock.TryLock(ctx, "test:key-a")
require.NoError(t, err)
+ require.True(t, acquired1)
+ require.NotNil(t, handle1)
+ defer func() { _ = handle1.Unlock(ctx) }()
- ctx := context.Background()
+ // Different key should not contend
+ handle2, acquired2, err := lock.TryLock(ctx, "test:key-b")
+ require.NoError(t, err)
+ assert.True(t, acquired2)
+ assert.NotNil(t, handle2)
+ defer func() { _ = handle2.Unlock(ctx) }()
+}
- err = lock.Unlock(ctx, nil)
- assert.Error(t, err)
- assert.Contains(t, err.Error(), "mutex is nil")
+func TestRedisLockManager_Unlock_NilMutex(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ err := lock.Unlock(context.Background(), nil)
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock handle is nil")
}
-// TestDistributedLock_ContextCancellation tests lock behavior with context cancellation
-func TestDistributedLock_ContextCancellation(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_ConcurrentTryLock(t *testing.T) {
+ _, lock := setupTestLock(t)
- lock, err := NewDistributedLock(conn)
- require.NoError(t, err)
+ ctx := context.Background()
- // Create a context that's already cancelled
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
+ const workers = 20
+ var acquired int32
- executed := false
- err = lock.WithLock(ctx, "test:cancelled", func() error {
- executed = true
- return nil
- })
+ var wg sync.WaitGroup
- assert.Error(t, err)
- assert.False(t, executed, "function should not execute with cancelled context")
-}
+ wg.Add(workers)
+
+ for range workers {
+ go func() {
+ defer wg.Done()
-// TestDistributedLock_MultipleLocksDifferentKeys tests multiple locks on different keys
-func TestDistributedLock_MultipleLocksDifferentKeys(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+ handle, ok, err := lock.TryLock(ctx, "test:concurrent-try")
+ if err != nil {
+ return
+ }
- lock, err := NewDistributedLock(conn)
- require.NoError(t, err)
+ if ok {
+ atomic.AddInt32(&acquired, 1)
+ // Hold lock briefly
+ time.Sleep(10 * time.Millisecond)
+ _ = handle.Unlock(ctx)
+ }
+ }()
+ }
+
+ wg.Wait()
+
+ // At least one goroutine must have acquired the lock
+ assert.GreaterOrEqual(t, atomic.LoadInt32(&acquired), int32(1))
+}
+
+func TestRedisLockManager_ConcurrentDifferentKeys(t *testing.T) {
+ _, lock := setupTestLock(t)
ctx := context.Background()
+ const workers = 5
var wg sync.WaitGroup
- var counter1, counter2 int32
- // Two different locks should not interfere with each other
- wg.Add(2)
+ wg.Add(workers)
- go func() {
- defer wg.Done()
- err := lock.WithLock(ctx, "test:lock:1", func() error {
- atomic.AddInt32(&counter1, 1)
- time.Sleep(50 * time.Millisecond)
- return nil
- })
- assert.NoError(t, err)
- }()
+ errCh := make(chan error, workers)
- go func() {
- defer wg.Done()
- err := lock.WithLock(ctx, "test:lock:2", func() error {
- atomic.AddInt32(&counter2, 1)
- time.Sleep(50 * time.Millisecond)
- return nil
- })
- assert.NoError(t, err)
- }()
+ for i := range workers {
+ go func(idx int) {
+ defer wg.Done()
+
+ key := "test:concurrent-diff:" + strings.Repeat("x", idx+1)
+
+ err := lock.WithLock(ctx, key, func(context.Context) error {
+ time.Sleep(5 * time.Millisecond)
+ return nil
+ })
+ if err != nil {
+ errCh <- err
+ }
+ }(i)
+ }
wg.Wait()
+ close(errCh)
- assert.Equal(t, int32(1), counter1)
- assert.Equal(t, int32(1), counter2)
+ for err := range errCh {
+ require.NoError(t, err)
+ }
}
-// TestDistributedLock_PanicRecovery tests that locks are released even on panic
-func TestDistributedLock_PanicRecovery(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_WithLockOptions_NegativeDriftFactor(t *testing.T) {
+ _, lock := setupTestLock(t)
- lock, err := NewDistributedLock(conn)
- require.NoError(t, err)
+ err := lock.WithLockOptions(context.Background(), "test:key", LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: -0.5,
+ }, func(context.Context) error { return nil })
- ctx := context.Background()
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock drift factor")
+}
- // First call panics
- func() {
- defer func() {
- if r := recover(); r != nil {
- // Panic recovered as expected
- }
- }()
+func TestRedisLockManager_WithLockOptions_NegativeExpiry(t *testing.T) {
+ _, lock := setupTestLock(t)
- lock.WithLock(ctx, "test:panic", func() error {
- panic("test panic")
- })
- }()
+ err := lock.WithLockOptions(context.Background(), "test:key", LockOptions{
+ Expiry: -time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ }, func(context.Context) error { return nil })
- // Second call should succeed (lock was released despite panic)
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock expiry must be greater than 0")
+}
+
+func TestRedisLockManager_WithLockOptions_ZeroRetryDelay(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ // Zero retry delay is valid (no delay between retries)
executed := false
- err = lock.WithLock(ctx, "test:panic", func() error {
+
+ err := lock.WithLockOptions(context.Background(), "test:zero-delay", LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: 0,
+ DriftFactor: 0.01,
+ }, func(context.Context) error {
executed = true
return nil
})
- assert.NoError(t, err)
- assert.True(t, executed, "lock should be available after panic")
+ require.NoError(t, err)
+ assert.True(t, executed)
}
-// TestDistributedLock_ConcurrentDifferentKeys tests high concurrency on different keys
-func TestDistributedLock_ConcurrentDifferentKeys(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_WithLockOptions_DriftFactorBoundary(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ // DriftFactor = 0 is valid (lower bound inclusive)
+ executed := false
+
+ err := lock.WithLockOptions(context.Background(), "test:drift-zero", LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0,
+ }, func(context.Context) error {
+ executed = true
+ return nil
+ })
- lock, err := NewDistributedLock(conn)
require.NoError(t, err)
+ assert.True(t, executed)
+
+ // DriftFactor = 0.99 is valid (just under 1)
+ executed = false
+
+ err = lock.WithLockOptions(context.Background(), "test:drift-high", LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.99,
+ }, func(context.Context) error {
+ executed = true
+ return nil
+ })
+
+ require.NoError(t, err)
+ assert.True(t, executed)
+}
+
+func TestRedisLockManager_WithLock_ContextionExhaustsRetries(t *testing.T) {
+ _, lock := setupTestLock(t)
ctx := context.Background()
- const numKeys = 5
- const numGoroutinesPerKey = 4
- counters := make([]int32, numKeys)
- var wg sync.WaitGroup
+ // Acquire lock and hold it
+ handle, acquired, err := lock.TryLock(ctx, "test:exhaust")
+ require.NoError(t, err)
+ require.True(t, acquired)
+ defer func() { _ = handle.Unlock(ctx) }()
- // Use patient lock options for concurrent scenario
- opts := LockOptions{
+ // Try to acquire the same key with limited retries - should fail
+ err = lock.WithLockOptions(ctx, "test:exhaust", LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ }, func(context.Context) error {
+ t.Fatal("function should not be executed when lock cannot be acquired")
+ return nil
+ })
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to acquire lock")
+}
+
+func TestRedisLockManager_ContextCancellation(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ // Acquire lock to create contention
+ bgCtx := context.Background()
+ handle, acquired, err := lock.TryLock(bgCtx, "test:cancel")
+ require.NoError(t, err)
+ require.True(t, acquired)
+ defer func() { _ = handle.Unlock(bgCtx) }()
+
+ // Create a context that will be cancelled quickly
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer cancel()
+
+ // Try to acquire the same lock with the cancellable context
+ err = lock.WithLockOptions(ctx, "test:cancel", LockOptions{
Expiry: 5 * time.Second,
- Tries: 50,
+ Tries: 100,
RetryDelay: 50 * time.Millisecond,
DriftFactor: 0.01,
+ }, func(context.Context) error {
+ t.Fatal("function should not execute when context is cancelled")
+ return nil
+ })
+
+ require.Error(t, err)
+}
+
+func TestRedisLockManager_WithLock_FnReceivesSpanContext(t *testing.T) {
+ // Verify the function receives a context (potentially enriched with span)
+ _, lock := setupTestLock(t)
+
+ err := lock.WithLock(context.Background(), "test:span-ctx", func(ctx context.Context) error {
+ // The context should be non-nil and usable
+ require.NotNil(t, ctx)
+
+ return nil
+ })
+
+ require.NoError(t, err)
+}
+
+func TestSafeLockKeyForLogs_ShortKey(t *testing.T) {
+ safe := safeLockKeyForLogs("lock:simple")
+ // Short keys should be returned as-is (quoted)
+ assert.NotContains(t, safe, "...(truncated)")
+ assert.Contains(t, safe, "lock:simple")
+}
+
+func TestSafeLockKeyForLogs_ExactBoundary(t *testing.T) {
+ // Key that produces a quoted string of exactly 128 characters
+ // QuoteToASCII adds 2 quote characters, so we need 126 inner chars
+ key := strings.Repeat("b", 126)
+ safe := safeLockKeyForLogs(key)
+ assert.NotContains(t, safe, "...(truncated)")
+}
+
+func TestSafeLockKeyForLogs_SpecialCharacters(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ contains string
+ }{
+ {
+ name: "tab character",
+ input: "lock:key\twith\ttabs",
+ contains: "\\t",
+ },
+ {
+ name: "null byte",
+ input: "lock:key\x00null",
+ contains: "\\x00",
+ },
+ {
+ name: "unicode",
+ input: "lock:key:emoji:😀",
+ contains: "lock:key:emoji:",
+ },
}
- // Channel to collect errors from goroutines
- errCh := make(chan error, numKeys*numGoroutinesPerKey)
-
- for keyIdx := range numKeys {
- for range numGoroutinesPerKey {
- wg.Add(1)
- go func(k int) {
- defer wg.Done()
-
- lockKey := fmt.Sprintf("test:concurrent:key:%d", k)
- err := lock.WithLockOptions(ctx, lockKey, opts, func() error {
- atomic.AddInt32(&counters[k], 1)
- time.Sleep(5 * time.Millisecond)
- return nil
- })
- if err != nil {
- errCh <- err
- }
- }(keyIdx)
- }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ safe := safeLockKeyForLogs(tt.input)
+ assert.Contains(t, safe, tt.contains)
+ })
}
+}
- wg.Wait()
- close(errCh)
+func TestSafeLockKeyForLogs_EmptyKey(t *testing.T) {
+ safe := safeLockKeyForLogs("")
+ // QuoteToASCII on empty string returns `""`
+ assert.Equal(t, `""`, safe)
+}
- // Assert errors in main goroutine
- for err := range errCh {
- assert.NoError(t, err)
+func TestValidateLockOptions_AllInvalid(t *testing.T) {
+ tests := []struct {
+ name string
+ opts LockOptions
+ errText string
+ }{
+ {
+ name: "zero expiry",
+ opts: LockOptions{
+ Expiry: 0,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ },
+ errText: "lock expiry must be greater than 0",
+ },
+ {
+ name: "negative expiry",
+ opts: LockOptions{
+ Expiry: -5 * time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ },
+ errText: "lock expiry must be greater than 0",
+ },
+ {
+ name: "zero tries",
+ opts: LockOptions{
+ Expiry: time.Second,
+ Tries: 0,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ },
+ errText: "lock tries must be at least 1",
+ },
+ {
+ name: "negative tries",
+ opts: LockOptions{
+ Expiry: time.Second,
+ Tries: -1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ },
+ errText: "lock tries must be at least 1",
+ },
+ {
+ name: "negative retry delay",
+ opts: LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: -time.Millisecond,
+ DriftFactor: 0.01,
+ },
+ errText: "lock retry delay cannot be negative",
+ },
+ {
+ name: "drift factor equals 1",
+ opts: LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 1.0,
+ },
+ errText: "lock drift factor must be between 0",
+ },
+ {
+ name: "drift factor greater than 1",
+ opts: LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 1.5,
+ },
+ errText: "lock drift factor must be between 0",
+ },
+ {
+ name: "negative drift factor",
+ opts: LockOptions{
+ Expiry: time.Second,
+ Tries: 1,
+ RetryDelay: time.Millisecond,
+ DriftFactor: -0.1,
+ },
+ errText: "lock drift factor must be between 0",
+ },
+ {
+ name: "tries exceeds max cap",
+ opts: LockOptions{
+ Expiry: time.Second,
+ Tries: 1001,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ },
+ errText: "lock tries exceeds maximum",
+ },
}
- // Each counter should have been incremented by numGoroutinesPerKey
- for i, count := range counters {
- assert.Equal(t, int32(numGoroutinesPerKey), count, "counter %d should be %d", i, numGoroutinesPerKey)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateLockOptions(tt.opts)
+ require.Error(t, err)
+ assert.ErrorContains(t, err, tt.errText)
+ })
}
}
-// TestDistributedLock_ReentrantNotSupported tests that re-entrant locking is not supported
-func TestDistributedLock_ReentrantNotSupported(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestValidateLockOptions_Valid(t *testing.T) {
+ tests := []struct {
+ name string
+ opts LockOptions
+ }{
+ {
+ name: "default options",
+ opts: DefaultLockOptions(),
+ },
+ {
+ name: "rate limiter options",
+ opts: RateLimiterLockOptions(),
+ },
+ {
+ name: "minimal valid",
+ opts: LockOptions{
+ Expiry: time.Millisecond,
+ Tries: 1,
+ RetryDelay: 0,
+ DriftFactor: 0,
+ },
+ },
+ {
+ name: "large values",
+ opts: LockOptions{
+ Expiry: time.Hour,
+ Tries: 1000,
+ RetryDelay: time.Minute,
+ DriftFactor: 0.99,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateLockOptions(tt.opts)
+ require.NoError(t, err)
+ })
+ }
+}
- lock, err := NewDistributedLock(conn)
- require.NoError(t, err)
+func TestRedisLockManager_WithLock_MultipleSequentialLocks(t *testing.T) {
+ _, lock := setupTestLock(t)
ctx := context.Background()
+ var order []int
- err = lock.WithLock(ctx, "test:reentrant", func() error {
- // Try to acquire the same lock again (this should fail/timeout)
- opts := LockOptions{
- Expiry: 1 * time.Second,
- Tries: 1, // Only try once
- RetryDelay: 100 * time.Millisecond,
- }
+ for i := range 5 {
+ idx := i
- err := lock.WithLockOptions(ctx, "test:reentrant", opts, func() error {
+ err := lock.WithLock(ctx, "test:sequential", func(context.Context) error {
+ order = append(order, idx)
return nil
})
+ require.NoError(t, err)
+ }
+
+ assert.Len(t, order, 5)
+ assert.Equal(t, []int{0, 1, 2, 3, 4}, order)
+}
+
+func TestRedisLockManager_WithLock_LongLockKey(t *testing.T) {
+ _, lock := setupTestLock(t)
- // This should fail because the lock is already held
- assert.Error(t, err)
+ // Redis supports keys up to 512MB; test with a reasonably long key
+ longKey := "test:" + strings.Repeat("x", 500)
+ executed := false
+
+ err := lock.WithLock(context.Background(), longKey, func(context.Context) error {
+ executed = true
return nil
})
- assert.NoError(t, err)
+ require.NoError(t, err)
+ assert.True(t, executed)
}
-// TestDistributedLock_ShortTimeout tests behavior with very short timeout
-func TestDistributedLock_ShortTimeout(t *testing.T) {
- conn, cleanup := setupTestRedis(t)
- defer cleanup()
+func TestRedisLockManager_InterfaceCompliance(t *testing.T) {
+ // Verify compile-time interface compliance
+ var _ LockManager = (*RedisLockManager)(nil)
+}
+
+// --- New tests for LockHandle API ---
+
+func TestRedisLockManager_Unlock_ExpiredMutex(t *testing.T) {
+ // Create miniredis directly so we can control time via FastForward.
+ mr := miniredis.RunT(t)
+
+ client, err := New(context.Background(), Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: mr.Addr()},
+ },
+ Logger: &log.NopLogger{},
+ })
+ require.NoError(t, err)
+
+ t.Cleanup(func() {
+ require.NoError(t, client.Close())
+ mr.Close()
+ })
- lock, err := NewDistributedLock(conn)
+ lock, err := NewRedisLockManager(client)
require.NoError(t, err)
ctx := context.Background()
- // First goroutine holds the lock
- var wg sync.WaitGroup
- wg.Add(2)
+ // Acquire a lock with a very short expiry via WithLockOptions + TryLock pattern.
+ // We use the low-level redsync through TryLock (which uses DefaultLockOptions expiry = 10s).
+ // Instead, acquire lock with WithLockOptions so we control expiry:
+ // Actually, TryLock uses DefaultLockOptions internally. We need the handle.
+ // We'll acquire and then fast-forward miniredis to expire the key.
+ handle, acquired, err := lock.TryLock(ctx, "test:expire")
+ require.NoError(t, err)
+ require.True(t, acquired)
+ require.NotNil(t, handle)
- go func() {
- defer wg.Done()
- lock.WithLock(ctx, "test:timeout", func() error {
- time.Sleep(200 * time.Millisecond) // Hold for 200ms
- return nil
- })
- }()
+ // Fast-forward time in miniredis to expire the lock key.
+ // DefaultLockOptions has 10s expiry; fast-forward past it.
+ mr.FastForward(15 * time.Second)
- time.Sleep(50 * time.Millisecond) // Ensure first goroutine has the lock
+ // Attempting to unlock an expired lock should return an error.
+ // redsync returns "failed to unlock, lock was already expired" when the key has expired.
+ err = handle.Unlock(ctx)
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "already expired")
+}
- // Second goroutine tries with short timeout
- go func() {
- defer wg.Done()
+func TestRedisLockManager_Unlock_CancelledContext(t *testing.T) {
+ _, lock := setupTestLock(t)
- opts := LockOptions{
- Expiry: 1 * time.Second,
- Tries: 1, // Give up quickly
- RetryDelay: 50 * time.Millisecond,
- }
+ ctx := context.Background()
- err := lock.WithLockOptions(ctx, "test:timeout", opts, func() error {
- return nil
+ handle, acquired, err := lock.TryLock(ctx, "test:cancel-unlock")
+ require.NoError(t, err)
+ require.True(t, acquired)
+ require.NotNil(t, handle)
+
+ // Cancel the context before unlocking.
+ cancelCtx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ // Unlock with cancelled context should fail with a context-related error.
+ err = handle.Unlock(cancelCtx)
+ require.Error(t, err)
+ assert.ErrorContains(t, err, context.Canceled.Error())
+
+ // The lock should still be releasable with a valid context.
+ require.NoError(t, handle.Unlock(context.Background()))
+}
+
+func TestRedisLockManager_LockHandle_NilHandle(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ ctx := context.Background()
+
+ // Test that calling Unlock with a nil LockHandle returns an error.
+ err := lock.Unlock(ctx, nil)
+ require.Error(t, err)
+ assert.ErrorContains(t, err, "lock handle is nil")
+}
+
+func TestRedisLockManager_ValidateLockOptions_TriesCap(t *testing.T) {
+ tests := []struct {
+ name string
+ tries int
+ wantErr bool
+ errText string
+ }{
+ {
+ name: "at max cap (1000) is valid",
+ tries: 1000,
+ wantErr: false,
+ },
+ {
+ name: "exceeds max cap (1001) is rejected",
+ tries: 1001,
+ wantErr: true,
+ errText: "lock tries exceeds maximum",
+ },
+ {
+ name: "way above max cap",
+ tries: 10000,
+ wantErr: true,
+ errText: "lock tries exceeds maximum",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateLockOptions(LockOptions{
+ Expiry: time.Second,
+ Tries: tt.tries,
+ RetryDelay: time.Millisecond,
+ DriftFactor: 0.01,
+ })
+
+ if tt.wantErr {
+ require.Error(t, err)
+ assert.ErrorContains(t, err, tt.errText)
+ } else {
+ require.NoError(t, err)
+ }
})
+ }
+}
- // Should fail to acquire
- assert.Error(t, err)
- }()
+func TestRedisLockManager_NewRedisLockManager_ErrNilClient(t *testing.T) {
+ lock, err := NewRedisLockManager(nil)
+ require.Error(t, err)
+ assert.Nil(t, lock)
- wg.Wait()
+ // Verify the sentinel error supports errors.Is.
+ assert.ErrorIs(t, err, ErrNilClient)
+ assert.Equal(t, "redis client is nil", err.Error())
+}
+
+func TestRedisLockManager_LockHandle_Interface(t *testing.T) {
+ _, lock := setupTestLock(t)
+
+ ctx := context.Background()
+
+ handle, acquired, err := lock.TryLock(ctx, "test:interface-check")
+ require.NoError(t, err)
+ require.True(t, acquired)
+ require.NotNil(t, handle)
+
+ // Verify that the returned handle satisfies the LockHandle interface.
+ var _ LockHandle = handle
+
+ // Clean up.
+ require.NoError(t, handle.Unlock(ctx))
}
diff --git a/commons/redis/redis.go b/commons/redis/redis.go
index 4133b787..e0872722 100644
--- a/commons/redis/redis.go
+++ b/commons/redis/redis.go
@@ -1,7 +1,3 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package redis
import (
@@ -11,231 +7,594 @@ import (
"encoding/base64"
"errors"
"fmt"
+ "strings"
"sync"
+ "sync/atomic"
"time"
iamcredentials "cloud.google.com/go/iam/credentials/apiv1"
iamcredentialspb "cloud.google.com/go/iam/credentials/apiv1/credentialspb"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/assert"
+ "github.com/LerianStudio/lib-commons/v4/commons/backoff"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
"github.com/redis/go-redis/v9"
- "go.uber.org/zap"
+ "go.opentelemetry.io/otel"
+ "go.opentelemetry.io/otel/attribute"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/durationpb"
)
-// Mode define the Redis connection mode supported
-type Mode string
-
const (
- TTL int = 300
- Scope string = "https://www.googleapis.com/auth/cloud-platform"
- PrefixServicesAccounts string = "projects/-/serviceAccounts/"
- ModeStandalone Mode = "standalone"
- ModeSentinel Mode = "sentinel"
- ModeCluster Mode = "cluster"
+ gcpScope = "https://www.googleapis.com/auth/cloud-platform"
+ gcpServiceAccountPrefix = "projects/-/serviceAccounts/"
+
+ defaultTokenLifetime = 1 * time.Hour
+ defaultRefreshEvery = 50 * time.Minute
+ defaultRefreshCheckInterval = 10 * time.Second
+ defaultRefreshOperationTimeout = 15 * time.Second
)
-// RedisConnection represents a Redis connection hub
-type RedisConnection struct {
- Mode Mode
- Address []string
- DB int
- MasterName string
- Password string //#nosec G117 -- Credential field required for Redis connection config
- Protocol int
- UseTLS bool
- Logger log.Logger
- Connected bool
- Client redis.UniversalClient
- CACert string
- UseGCPIAMAuth bool
- GoogleApplicationCredentials string
- ServiceAccount string
- TokenLifeTime time.Duration
- RefreshDuration time.Duration
- token string
- lastRefreshInstant time.Time
- errLastSeen error
- mu sync.RWMutex
- PoolSize int
- MinIdleConns int
- ReadTimeout time.Duration
- WriteTimeout time.Duration
- DialTimeout time.Duration
- PoolTimeout time.Duration
- MaxRetries int
- MinRetryBackoff time.Duration
- MaxRetryBackoff time.Duration
-}
-
-// Connect initializes a Redis connection
-func (rc *RedisConnection) Connect(ctx context.Context) error {
- rc.mu.Lock()
- defer rc.mu.Unlock()
-
- return rc.connectLocked(ctx)
-}
-
-func (rc *RedisConnection) connectLocked(ctx context.Context) error {
- rc.Logger.Info("Connecting to Redis/Valkey...")
-
- rc.InitVariables()
-
- var err error
- if rc.UseGCPIAMAuth {
- rc.token, err = rc.retrieveToken(ctx)
- if err != nil {
- rc.Logger.Infof("initial token retrieval failed: %v", zap.Error(err))
- return err
+var (
+ // ErrNilClient is returned when a redis client receiver is nil.
+ ErrNilClient = errors.New("redis client is nil")
+ // ErrInvalidConfig indicates the provided redis configuration is invalid.
+ ErrInvalidConfig = errors.New("invalid redis config")
+
+ // pkgLogger holds the package-level logger for nil-receiver diagnostics.
+ // Defaults to NopLogger; consumers can override via SetPackageLogger.
+ pkgLogger atomic.Value // stores log.Logger
+)
+
+func init() {
+ pkgLogger.Store(log.Logger(&log.NopLogger{}))
+}
+
+// SetPackageLogger configures a package-level logger used for nil-receiver
+// assertion diagnostics and telemetry reporting. This is typically called
+// once during application bootstrap. If l is nil, a NopLogger is used.
+func SetPackageLogger(l log.Logger) {
+ if l == nil {
+ l = &log.NopLogger{}
+ }
+
+ pkgLogger.Store(l)
+}
+
+func resolvePackageLogger() log.Logger {
+ if v := pkgLogger.Load(); v != nil {
+ if l, ok := v.(log.Logger); ok {
+ return l
}
+ }
+
+ return &log.NopLogger{}
+}
+
+// nilClientAssert fires a nil-receiver assertion and returns ErrNilClient.
+func nilClientAssert(ctx context.Context, operation string) error {
+ a := assert.New(ctx, resolvePackageLogger(), "redis.Client", operation)
+ _ = a.Never(ctx, "nil receiver on *redis.Client")
+
+ return ErrNilClient
+}
+
+// Config defines Redis client topology, auth, TLS, and connection settings.
+type Config struct {
+ Topology Topology
+ TLS *TLSConfig
+ Auth Auth
+ Options ConnectionOptions
+ Logger log.Logger
+ MetricsFactory *metrics.MetricsFactory
+}
+
+// Topology selects exactly one Redis deployment mode.
+type Topology struct {
+ Standalone *StandaloneTopology
+ Sentinel *SentinelTopology
+ Cluster *ClusterTopology
+}
+
+// StandaloneTopology configures single-node Redis access.
+type StandaloneTopology struct {
+ Address string
+}
+
+// SentinelTopology configures Redis Sentinel access.
+type SentinelTopology struct {
+ Addresses []string
+ MasterName string
+}
+
+// ClusterTopology configures Redis cluster access.
+type ClusterTopology struct {
+ Addresses []string
+}
+
+// TLSConfig configures TLS validation for Redis connections.
+type TLSConfig struct {
+ CACertBase64 string
+ MinVersion uint16
+ AllowLegacyMinVersion bool
+}
+
+// Auth selects one Redis authentication strategy.
+type Auth struct {
+ StaticPassword *StaticPasswordAuth
+ GCPIAM *GCPIAMAuth
+}
+
+// StaticPasswordAuth authenticates using a static password.
+type StaticPasswordAuth struct {
+ Password string // #nosec G117 -- field is redacted via String() and GoString() methods
+}
+
+// String returns a redacted representation to prevent accidental credential logging.
+func (StaticPasswordAuth) String() string { return "StaticPasswordAuth{Password:REDACTED}" }
+
+// GoString returns a redacted representation for fmt %#v.
+func (a StaticPasswordAuth) GoString() string { return a.String() }
+
+// GCPIAMAuth authenticates with short-lived GCP IAM access tokens.
+type GCPIAMAuth struct {
+ CredentialsBase64 string
+ ServiceAccount string
+ TokenLifetime time.Duration
+ RefreshEvery time.Duration
+ RefreshCheckInterval time.Duration
+ RefreshOperationTimeout time.Duration
+}
+
+// String returns a redacted representation to prevent accidental credential logging.
+func (a GCPIAMAuth) String() string {
+ return fmt.Sprintf("GCPIAMAuth{ServiceAccount:%s, CredentialsBase64:REDACTED}", a.ServiceAccount)
+}
+
+// GoString returns a redacted representation for fmt %#v.
+func (a GCPIAMAuth) GoString() string { return a.String() }
+
+// ConnectionOptions configures protocol, timeouts, pools, and retries.
+type ConnectionOptions struct {
+ DB int
+ Protocol int
+ PoolSize int
+ MinIdleConns int
+ ReadTimeout time.Duration
+ WriteTimeout time.Duration
+ DialTimeout time.Duration
+ PoolTimeout time.Duration
+ MaxRetries int
+ MinRetryBackoff time.Duration
+ MaxRetryBackoff time.Duration
+}
+
+// Status reports the last known client connectivity and IAM refresh loop health.
+// Fields reflect cached state updated during connect/reconnect/refresh operations,
+// not a live probe of the underlying connection. Use a Redis PING for liveness checks.
+type Status struct {
+ Connected bool
+ LastRefreshError error
+ LastRefreshAt time.Time
+ RefreshLoopRunning bool
+}
+
+// connectionFailuresMetric defines the counter for redis connection failures.
+var connectionFailuresMetric = metrics.Metric{
+ Name: "redis_connection_failures_total",
+ Unit: "1",
+ Description: "Total number of redis connection failures",
+}
- rc.lastRefreshInstant = time.Now()
+// reconnectionsMetric defines the counter for redis reconnection attempts.
+var reconnectionsMetric = metrics.Metric{
+ Name: "redis_reconnections_total",
+ Unit: "1",
+ Description: "Total number of redis reconnection attempts",
+}
+
+// Client wraps a redis.UniversalClient with reconnection and IAM token refresh logic.
+type Client struct {
+ mu sync.RWMutex
+ cfg Config
+ logger log.Logger
+ metricsFactory *metrics.MetricsFactory
+ client redis.UniversalClient
+ connected bool
+ token string
+ lastRefresh time.Time
+ refreshErr error
+
+ refreshCancel context.CancelFunc
+ refreshLoopRunning bool
+ refreshGeneration uint64
+
+ // Reconnect rate-limiting: prevents thundering-herd reconnect storms
+ // when the server is down by enforcing exponential backoff between attempts.
+ lastReconnectAttempt time.Time
+ reconnectAttempts int
+
+ // test hooks
+ tokenRetriever func(ctx context.Context) (string, error)
+ reconnectFn func(ctx context.Context) error
+}
- go rc.refreshTokenLoop(ctx)
+// New validates config, connects to Redis, and returns a ready client.
+func New(ctx context.Context, cfg Config) (*Client, error) {
+ normalized, err := normalizeConfig(cfg)
+ if err != nil {
+ return nil, err
}
- opts := &redis.UniversalOptions{
- Addrs: rc.Address,
- MasterName: rc.MasterName,
- DB: rc.DB,
- Protocol: rc.Protocol,
- PoolSize: rc.PoolSize,
- MinIdleConns: rc.MinIdleConns,
- ReadTimeout: rc.ReadTimeout,
- WriteTimeout: rc.WriteTimeout,
- DialTimeout: rc.DialTimeout,
- PoolTimeout: rc.PoolTimeout,
- MaxRetries: rc.MaxRetries,
- MinRetryBackoff: rc.MinRetryBackoff,
- MaxRetryBackoff: rc.MaxRetryBackoff,
- }
-
- if rc.UseGCPIAMAuth {
- opts.Password = rc.token
- opts.Username = "default"
- } else {
- opts.Password = rc.Password
+ c := &Client{
+ cfg: normalized,
+ logger: normalized.Logger,
+ metricsFactory: normalized.MetricsFactory,
}
- if rc.UseTLS {
- tlsConfig, err := rc.BuildTLSConfig()
- if err != nil {
- rc.Logger.Infof("BuildTLSConfig error: %v", zap.Error(err))
+ if err := c.Connect(ctx); err != nil {
+ return nil, err
+ }
- return err
- }
+ return c, nil
+}
- opts.TLSConfig = tlsConfig
+// Connect establishes a Redis connection using the current client configuration.
+func (c *Client) Connect(ctx context.Context) error {
+ if c == nil {
+ return nilClientAssert(ctx, "Connect")
}
- rdb := redis.NewUniversalClient(opts)
- if _, err := rdb.Ping(ctx).Result(); err != nil {
- rc.Logger.Infof("Ping error: %v", zap.Error(err))
- return err
+ tracer := otel.Tracer("redis")
+
+ ctx, span := tracer.Start(ctx, "redis.connect")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.logger == nil {
+ c.logger = &log.NopLogger{}
}
- rc.Client = rdb
- rc.Connected = true
+ if err := c.connectLocked(ctx); err != nil {
+ c.recordConnectionFailure("connect")
- switch rdb.(type) {
- case *redis.ClusterClient:
- rc.Logger.Info("Connected to Redis/Valkey in CLUSTER mode ✅ \n")
- case *redis.Client:
- rc.Logger.Info("Connected to Redis/Valkey in STANDALONE mode ✅ \n")
- case *redis.Ring:
- rc.Logger.Info("Connected to Redis/Valkey in SENTINEL mode ✅ \n")
- default:
- rc.Logger.Warn("Unknown Redis/Valkey mode ⚠️ \n")
+ libOpentelemetry.HandleSpanError(span, "Failed to connect to redis", err)
+
+ return err
}
return nil
}
-// GetClient always returns a pointer to a Redis client
-func (rc *RedisConnection) GetClient(ctx context.Context) (redis.UniversalClient, error) {
- rc.mu.RLock()
+// reconnectBackoffCap is the maximum delay between reconnect attempts.
+const reconnectBackoffCap = 30 * time.Second
+
+// GetClient returns a connected redis client, reconnecting on demand if needed.
+func (c *Client) GetClient(ctx context.Context) (redis.UniversalClient, error) {
+ if c == nil {
+ return nil, nilClientAssert(ctx, "GetClient")
+ }
+
+ c.mu.RLock()
- if rc.Client != nil {
- client := rc.Client
- rc.mu.RUnlock()
+ if c.client != nil {
+ client := c.client
+ c.mu.RUnlock()
return client, nil
}
- rc.mu.RUnlock()
+ c.mu.RUnlock()
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.logger == nil {
+ c.logger = &log.NopLogger{}
+ }
+
+ if c.client != nil {
+ return c.client, nil
+ }
- rc.mu.Lock()
- defer rc.mu.Unlock()
+ // Rate-limit reconnect attempts: if we've failed recently, enforce a
+ // minimum delay before the next attempt to avoid hammering the server.
+ if c.reconnectAttempts > 0 {
+ delay := min(backoff.ExponentialWithJitter(500*time.Millisecond, c.reconnectAttempts), reconnectBackoffCap)
- if rc.Client != nil {
- return rc.Client, nil
+ if elapsed := time.Since(c.lastReconnectAttempt); elapsed < delay {
+ return nil, fmt.Errorf("redis reconnect: rate-limited (next attempt in %s)", delay-elapsed)
+ }
}
- if err := rc.connectLocked(ctx); err != nil {
- rc.Logger.Infof("Get client connect error %v", zap.Error(err))
+ c.lastReconnectAttempt = time.Now()
+
+ // Only trace when actually reconnecting.
+ tracer := otel.Tracer("redis")
+
+ ctx, span := tracer.Start(ctx, "redis.reconnect")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ if err := c.connectLocked(ctx); err != nil {
+ c.reconnectAttempts++
+ c.recordConnectionFailure("reconnect")
+ c.recordReconnection("failure")
+
+ libOpentelemetry.HandleSpanError(span, "Failed to reconnect redis", err)
+
return nil, err
}
- return rc.Client, nil
+ c.reconnectAttempts = 0
+ c.recordReconnection("success")
+
+ return c.client, nil
}
-// Close closes the Redis connection
-func (rc *RedisConnection) Close() error {
- rc.mu.Lock()
- defer rc.mu.Unlock()
+// Close stops background refresh and closes the underlying Redis client.
+func (c *Client) Close() error {
+ if c == nil {
+ return nilClientAssert(context.Background(), "Close")
+ }
+
+ tracer := otel.Tracer("redis")
+
+ _, span := tracer.Start(context.Background(), "redis.close")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
- return rc.closeLocked()
+ c.stopRefreshLoopLocked()
+
+ if err := c.closeClientLocked(); err != nil {
+ libOpentelemetry.HandleSpanError(span, "Failed to close redis client", err)
+
+ return err
+ }
+
+ return nil
+}
+
+// Status returns a snapshot of the last known connectivity and token refresh state.
+// The Connected field is updated during connect/reconnect/close operations and does
+// not probe the server. For a live liveness check, issue a Redis PING via GetClient.
+func (c *Client) Status() (Status, error) {
+ if c == nil {
+ return Status{}, nilClientAssert(context.Background(), "Status")
+ }
+
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return Status{
+ Connected: c.connected,
+ LastRefreshError: c.refreshErr,
+ LastRefreshAt: c.lastRefresh,
+ RefreshLoopRunning: c.refreshLoopRunning,
+ }, nil
+}
+
+// IsConnected reports the last known connection state. It does not probe
+// the server — the value is updated during connect/reconnect/close operations.
+// For a live liveness check, issue a Redis PING via GetClient.
+func (c *Client) IsConnected() (bool, error) {
+ status, err := c.Status()
+ if err != nil {
+ return false, err
+ }
+
+ return status.Connected, nil
}
-// closeLocked closes the Redis connection without acquiring the lock.
-// Caller must hold rc.mu write lock.
-func (rc *RedisConnection) closeLocked() error {
- if rc.Client != nil {
- err := rc.Client.Close()
- rc.Client = nil
- rc.Connected = false
+// LastRefreshError returns the latest IAM refresh/reconnect error.
+func (c *Client) LastRefreshError() error {
+ if c == nil {
+ return nilClientAssert(context.Background(), "LastRefreshError")
+ }
+ status, err := c.Status()
+ if err != nil {
return err
}
+ return status.LastRefreshError
+}
+
+func (c *Client) connectLocked(ctx context.Context) error {
+ // Config validation is performed by New/normalizeConfig at construction time.
+ // Direct Connect() callers should only use properly-constructed Clients.
+ c.logger.Log(ctx, log.LevelInfo, "connecting to Redis/Valkey")
+
+ if c.usesGCPIAM() && c.token == "" {
+ token, err := c.retrieveToken(ctx)
+ if err != nil {
+ c.logger.Log(ctx, log.LevelError, "initial token retrieval failed", log.Err(err))
+
+ return fmt.Errorf("redis connect: token retrieval: %w", err)
+ }
+
+ c.token = token
+ }
+
+ // Create and verify the new client BEFORE touching the old one.
+ // This follows the same create-ping-swap pattern used by reconnectLocked,
+ // preventing a window where a healthy client is closed before its replacement
+ // is confirmed working.
+ if err := c.connectClientLocked(ctx); err != nil {
+ return err
+ }
+
+ if c.usesGCPIAM() {
+ c.lastRefresh = time.Now()
+ c.startRefreshLoopLocked()
+ }
+
return nil
}
-// BuildTLSConfig generates a *tls.Config configuration using ca cert on base64
-func (rc *RedisConnection) BuildTLSConfig() (*tls.Config, error) {
- caCert, err := base64.StdEncoding.DecodeString(rc.CACert)
+func (c *Client) connectClientLocked(ctx context.Context) error {
+ opts, err := c.buildUniversalOptionsLocked()
if err != nil {
- rc.Logger.Infof("Base64 caceret error to decode error: %v", zap.Error(err))
+ return fmt.Errorf("redis connect: build options: %w", err)
+ }
- return nil, err
+ rdb := redis.NewUniversalClient(opts)
+ if _, err := rdb.Ping(ctx).Result(); err != nil {
+ _ = rdb.Close()
+
+ c.logger.Log(ctx, log.LevelError, "redis ping failed", log.Err(err))
+ c.connected = false
+
+ return fmt.Errorf("redis connect: ping: %w", err)
}
- caCertPool := x509.NewCertPool()
- if !caCertPool.AppendCertsFromPEM(caCert) {
- return nil, errors.New("adding CA cert failed")
+ // New client verified. Close old client (if any) AFTER new one is confirmed healthy.
+ oldClient := c.client
+
+ c.client = rdb
+ c.connected = true
+ c.refreshErr = nil
+
+ if oldClient != nil {
+ if err := oldClient.Close(); err != nil {
+ c.logger.Log(ctx, log.LevelWarn, "failed to close previous client after successful connect", log.Err(err))
+ }
}
- tlsCfg := &tls.Config{
- RootCAs: caCertPool,
- MinVersion: tls.VersionTLS12,
+ switch rdb.(type) {
+ case *redis.ClusterClient:
+ c.logger.Log(ctx, log.LevelInfo, "connected to Redis/Valkey in cluster mode")
+ case *redis.Client:
+ c.logger.Log(ctx, log.LevelInfo, "connected to Redis/Valkey in standalone mode")
+ case *redis.Ring:
+ c.logger.Log(ctx, log.LevelInfo, "connected to Redis/Valkey in ring mode")
+ default:
+ c.logger.Log(ctx, log.LevelWarn, "connected to Redis/Valkey in unknown mode")
+ }
+
+ if c.cfg.TLS == nil {
+ c.logger.Log(ctx, log.LevelWarn, "redis connection established without TLS; consider configuring TLS for production use")
+ }
+
+ return nil
+}
+
+func (c *Client) closeClientLocked() error {
+ if c.client == nil {
+ return nil
+ }
+
+ err := c.client.Close()
+ c.client = nil
+ c.connected = false
+
+ return err
+}
+
+func (c *Client) buildUniversalOptionsLocked() (*redis.UniversalOptions, error) {
+ o := c.cfg.Options
+ opts := &redis.UniversalOptions{
+ DB: o.DB,
+ Protocol: o.Protocol,
+ PoolSize: o.PoolSize,
+ MinIdleConns: o.MinIdleConns,
+ ReadTimeout: o.ReadTimeout,
+ WriteTimeout: o.WriteTimeout,
+ DialTimeout: o.DialTimeout,
+ PoolTimeout: o.PoolTimeout,
+ MaxRetries: o.MaxRetries,
+ MinRetryBackoff: o.MinRetryBackoff,
+ MaxRetryBackoff: o.MaxRetryBackoff,
+ }
+
+ if c.cfg.Topology.Standalone != nil {
+ opts.Addrs = []string{c.cfg.Topology.Standalone.Address}
+ }
+
+ if c.cfg.Topology.Sentinel != nil {
+ opts.Addrs = c.cfg.Topology.Sentinel.Addresses
+ opts.MasterName = c.cfg.Topology.Sentinel.MasterName
+ }
+
+ if c.cfg.Topology.Cluster != nil {
+ opts.Addrs = c.cfg.Topology.Cluster.Addresses
+ }
+
+ // Guard against zero-value Config producing Addrs: nil, which causes
+ // go-redis to silently default to localhost:6379. This can happen when
+ // GetClient triggers a reconnect on a Client not created via New().
+ if len(opts.Addrs) == 0 {
+ return nil, configError("no topology configured: at least one address is required")
+ }
+
+ if c.cfg.Auth.StaticPassword != nil {
+ opts.Password = c.cfg.Auth.StaticPassword.Password
+ }
+
+ if c.usesGCPIAM() {
+ opts.Username = "default"
+ opts.Password = c.token
}
- return tlsCfg, nil
+ if c.cfg.TLS != nil {
+ tlsCfg, err := buildTLSConfig(*c.cfg.TLS)
+ if err != nil {
+ return nil, fmt.Errorf("redis: TLS config: %w", err)
+ }
+
+ opts.TLSConfig = tlsCfg
+ }
+
+ return opts, nil
}
-// retrieveToken generates a new GCP IAM token
-func (rc *RedisConnection) retrieveToken(ctx context.Context) (string, error) {
- credentialsJSON, err := base64.StdEncoding.DecodeString(rc.GoogleApplicationCredentials)
+func (c *Client) retrieveToken(ctx context.Context) (string, error) {
+ if c == nil {
+ return "", nilClientAssert(ctx, "retrieveToken")
+ }
+
+ if c.tokenRetriever != nil {
+ return c.tokenRetriever(ctx)
+ }
+
+ auth := c.cfg.Auth.GCPIAM
+ if auth == nil {
+ return "", errors.New("GCP IAM auth is not configured")
+ }
+
+ credentialsJSON, err := base64.StdEncoding.DecodeString(auth.CredentialsBase64)
if err != nil {
- rc.Logger.Infof("Base64 credentials error to decode error: %v", zap.Error(err))
+ c.logger.Log(ctx, log.LevelError, "failed to decode base64 credentials", log.Err(err))
- return "", err
+ return "", fmt.Errorf("redis: generate IAM token: %w", err)
}
+ // Defense-in-depth: zero decoded credentials when done to reduce memory exposure window.
+ defer func() {
+ for i := range credentialsJSON {
+ credentialsJSON[i] = 0
+ }
+ }()
+
creds, err := google.CredentialsFromJSONWithType(ctx, credentialsJSON, google.ServiceAccount)
if err != nil {
- return "", fmt.Errorf("parsing credentials JSON: %w", err)
+ // Wrap error to prevent potential credential fragments in the original error message
+ // from leaking into logs or upstream callers.
+ return "", fmt.Errorf("parsing credentials JSON failed (content redacted): %w",
+ errors.New("invalid service account credentials format"))
}
client, err := iamcredentials.NewIamCredentialsClient(ctx, option.WithCredentials(creds))
@@ -244,56 +603,59 @@ func (rc *RedisConnection) retrieveToken(ctx context.Context) (string, error) {
}
defer client.Close()
- req := &iamcredentialspb.GenerateAccessTokenRequest{
- Name: PrefixServicesAccounts + rc.ServiceAccount,
- Scope: []string{Scope},
- Lifetime: durationpb.New(rc.TokenLifeTime),
+ resp, err := client.GenerateAccessToken(ctx, &iamcredentialspb.GenerateAccessTokenRequest{
+ Name: gcpServiceAccountPrefix + auth.ServiceAccount,
+ Scope: []string{gcpScope},
+ Lifetime: durationpb.New(auth.TokenLifetime),
+ })
+ if err != nil {
+ return "", fmt.Errorf("problem generating access token: %w", err)
}
- resp, err := client.GenerateAccessToken(ctx, req)
- if err != nil {
- return "", fmt.Errorf("problem to generate access token: %w", err)
+ if resp == nil {
+ return "", errors.New("generate access token returned nil response")
}
return resp.AccessToken, nil
}
-// refreshTokenLoop periodically refreshes the GCP IAM token
-func (rc *RedisConnection) refreshTokenLoop(ctx context.Context) {
- ticker := time.NewTicker(10 * time.Second)
+func (c *Client) refreshTokenLoop(ctx context.Context) {
+ if c == nil {
+ return
+ }
+
+ auth := c.cfg.Auth.GCPIAM
+ if auth == nil {
+ // Should never happen in production (startRefreshLoopLocked checks usesGCPIAM()),
+ // but guard defensively against direct invocations.
+ return
+ }
+
+ ticker := time.NewTicker(auth.RefreshCheckInterval)
defer ticker.Stop()
+ var consecutiveFailures int
+
for {
select {
case <-ticker.C:
- rc.mu.RLock()
- last := rc.lastRefreshInstant
- rc.mu.RUnlock()
-
- if time.Now().After(last.Add(rc.RefreshDuration)) {
- token, err := rc.retrieveToken(ctx)
- rc.mu.Lock()
-
- if err != nil {
- rc.errLastSeen = err
- rc.Logger.Infof("IAM token refresh failed: %v", zap.Error(err))
- } else {
- rc.token = token
- rc.lastRefreshInstant = time.Now()
- rc.Logger.Info("IAM token refreshed...")
-
- if closeErr := rc.closeLocked(); closeErr != nil {
- rc.Logger.Infof("warning: close before reconnect failed: %v", closeErr)
- }
-
- if connErr := rc.connectLocked(ctx); connErr != nil {
- rc.errLastSeen = connErr
- rc.Connected = false
- rc.Logger.Errorf("failed to reconnect after IAM token refresh: %v", zap.Error(connErr))
- }
- }
-
- rc.mu.Unlock()
+ if c.refreshTick(ctx, auth) {
+ consecutiveFailures = 0
+
+ continue
+ }
+
+ // On failure, apply exponential backoff before the next attempt.
+ // The ticker continues to fire, but we wait an additional delay
+ // proportional to the number of consecutive failures. The base
+ // derives from the configured check interval so that test configs
+ // with sub-millisecond intervals produce proportionally small delays.
+ consecutiveFailures++
+
+ delay := min(backoff.ExponentialWithJitter(auth.RefreshCheckInterval, consecutiveFailures), reconnectBackoffCap)
+
+ if err := backoff.WaitContext(ctx, delay); err != nil {
+ return
}
case <-ctx.Done():
@@ -302,41 +664,460 @@ func (rc *RedisConnection) refreshTokenLoop(ctx context.Context) {
}
}
-// InitVariables sets default values for RedisConnection
-func (rc *RedisConnection) InitVariables() {
- if rc.PoolSize == 0 {
- rc.PoolSize = 10
+// refreshTick handles a single tick of the IAM token refresh cycle.
+// Returns true if the tick completed successfully (including when no refresh
+// was needed), false if a token retrieval or reconnect failed.
+func (c *Client) refreshTick(ctx context.Context, auth *GCPIAMAuth) bool {
+ c.mu.RLock()
+ lastRefresh := c.lastRefresh
+ c.mu.RUnlock()
+
+ if !time.Now().After(lastRefresh.Add(auth.RefreshEvery)) {
+ return true
+ }
+
+ tracer := otel.Tracer("redis")
+
+ refreshCtx, cancel := context.WithTimeout(ctx, auth.RefreshOperationTimeout)
+ defer cancel()
+
+ refreshCtx, span := tracer.Start(refreshCtx, "redis.iam_refresh")
+ defer span.End()
+
+ span.SetAttributes(attribute.String(constant.AttrDBSystem, constant.DBSystemRedis))
+
+ token, err := c.retrieveToken(refreshCtx)
+ if err != nil {
+ c.mu.Lock()
+ c.refreshErr = err
+ c.logger.Log(refreshCtx, log.LevelWarn, "IAM token refresh failed", log.Err(err))
+ c.mu.Unlock()
+
+ libOpentelemetry.HandleSpanError(span, "IAM token refresh failed", err)
+
+ return false
+ }
+
+ return c.applyTokenAndReconnect(refreshCtx, token)
+}
+
+// applyTokenAndReconnect sets the new token and reconnects the client.
+// On reconnect failure, the old token is restored to keep the existing client usable.
+func (c *Client) applyTokenAndReconnect(ctx context.Context, token string) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ oldToken := c.token
+ c.token = token
+
+ reconnectFn := c.reconnectFn
+ if reconnectFn == nil {
+ reconnectFn = c.reconnectLocked
+ }
+
+ if err := reconnectFn(ctx); err != nil {
+ c.refreshErr = err
+ // Restore old token: reconnect failed, so the new token is useless
+ // and the old client (if any) is still using the previous token.
+ c.token = oldToken
+ c.logger.Log(ctx, log.LevelError, "failed to reconnect after IAM token refresh, keeping existing client", log.Err(err))
+
+ return false
+ }
+
+ c.lastRefresh = time.Now()
+ c.refreshErr = nil
+ c.logger.Log(ctx, log.LevelInfo, "IAM token refreshed")
+
+ return true
+}
+
+func (c *Client) reconnectLocked(ctx context.Context) error {
+ // Build new client options with the refreshed token.
+ opts, err := c.buildUniversalOptionsLocked()
+ if err != nil {
+ c.logger.Log(ctx, log.LevelError, "failed to build options for reconnect", log.Err(err))
+
+ return err
+ }
+
+ // Create and verify the new client BEFORE touching the old one.
+ newClient := redis.NewUniversalClient(opts)
+
+ if _, err := newClient.Ping(ctx).Result(); err != nil {
+ _ = newClient.Close()
+
+ c.logger.Log(ctx, log.LevelError, "new client ping failed during reconnect, keeping existing client", log.Err(err))
+
+ return err
+ }
+
+ // New client is verified. Swap atomically: close old, assign new.
+ oldClient := c.client
+
+ c.client = newClient
+ c.connected = true
+ c.refreshErr = nil
+
+ if oldClient != nil {
+ if err := oldClient.Close(); err != nil {
+ c.logger.Log(ctx, log.LevelWarn, "failed to close previous client after successful reconnect", log.Err(err))
+ }
+ }
+
+ return nil
+}
+
+func (c *Client) startRefreshLoopLocked() {
+ if !c.usesGCPIAM() || c.refreshLoopRunning {
+ return
+ }
+
+ refreshCtx, cancel := context.WithCancel(context.Background())
+ c.refreshGeneration++
+ generation := c.refreshGeneration
+ c.refreshCancel = cancel
+ c.refreshLoopRunning = true
+
+ runtime.SafeGoWithContextAndComponent(
+ refreshCtx,
+ c.logger,
+ "redis",
+ "iam_refresh_loop",
+ runtime.KeepRunning,
+ func(_ context.Context) {
+ c.refreshTokenLoop(refreshCtx)
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.refreshGeneration == generation {
+ c.refreshCancel = nil
+ c.refreshLoopRunning = false
+ }
+ },
+ )
+}
+
+func (c *Client) stopRefreshLoopLocked() {
+ if c.refreshCancel != nil {
+ c.refreshCancel()
+ c.refreshCancel = nil
+ }
+
+ c.refreshLoopRunning = false
+}
+
+func (c *Client) usesGCPIAM() bool {
+ return c.cfg.Auth.GCPIAM != nil
+}
+
+func normalizeConfig(cfg Config) (Config, error) {
+ normalizeLoggerDefault(&cfg)
+ normalizeConnectionOptionsDefaults(&cfg.Options)
+
+ originalTLSMinVersion := uint16(0)
+ if cfg.TLS != nil {
+ originalTLSMinVersion = cfg.TLS.MinVersion
+ }
+
+ tlsMinVersionUpgraded, legacyTLSAllowed := normalizeTLSDefaults(cfg.TLS)
+ normalizeGCPIAMDefaults(cfg.Auth.GCPIAM)
+
+ if tlsMinVersionUpgraded {
+ if originalTLSMinVersion == 0 {
+ cfg.Logger.Log(
+ context.Background(),
+ log.LevelInfo,
+ "redis TLS MinVersion was not set and has been defaulted to tls.VersionTLS12",
+ )
+ } else {
+ cfg.Logger.Log(
+ context.Background(),
+ log.LevelWarn,
+ "redis TLS MinVersion was below TLS1.2 and has been upgraded to tls.VersionTLS12",
+ )
+ }
+ }
+
+ if legacyTLSAllowed {
+ cfg.Logger.Log(
+ context.Background(),
+ log.LevelWarn,
+ "redis TLS MinVersion below TLS1.2 retained because AllowLegacyMinVersion=true; this is insecure and should be temporary",
+ )
+ }
+
+ if err := validateConfig(cfg); err != nil {
+ return Config{}, err
+ }
+
+ return cfg, nil
+}
+
+func normalizeLoggerDefault(cfg *Config) {
+ if cfg.Logger == nil {
+ cfg.Logger = &log.NopLogger{}
+ }
+}
+
+const (
+ maxPoolSize = 1000
+)
+
+func normalizeConnectionOptionsDefaults(options *ConnectionOptions) {
+ if options.PoolSize == 0 {
+ options.PoolSize = 10
+ }
+
+ if options.PoolSize > maxPoolSize {
+ options.PoolSize = maxPoolSize
+ }
+
+ if options.ReadTimeout == 0 {
+ options.ReadTimeout = 3 * time.Second
+ }
+
+ if options.WriteTimeout == 0 {
+ options.WriteTimeout = 3 * time.Second
+ }
+
+ if options.DialTimeout == 0 {
+ options.DialTimeout = 5 * time.Second
+ }
+
+ if options.PoolTimeout == 0 {
+ options.PoolTimeout = 2 * time.Second
+ }
+
+ if options.MaxRetries == 0 {
+ options.MaxRetries = 3
+ }
+
+ if options.MinRetryBackoff == 0 {
+ options.MinRetryBackoff = 8 * time.Millisecond
+ }
+
+ if options.MaxRetryBackoff == 0 {
+ options.MaxRetryBackoff = 1 * time.Second
+ }
+}
+
+// normalizeTLSDefaults enforces a TLS 1.2 minimum floor. Versions below TLS 1.2
+// (including TLS 1.0 and 1.1) have known vulnerabilities and are rejected by
+// most compliance frameworks. If MinVersion is unset, it is upgraded. If
+// MinVersion is below tls.VersionTLS12, it is upgraded unless
+// AllowLegacyMinVersion is set explicitly.
+//
+// Returns (upgraded, legacyAllowed).
+func normalizeTLSDefaults(tlsCfg *TLSConfig) (bool, bool) {
+ if tlsCfg == nil {
+ return false, false
+ }
+
+ if tlsCfg.MinVersion == 0 {
+ tlsCfg.MinVersion = tls.VersionTLS12
+
+ return true, false
+ }
+
+ if tlsCfg.MinVersion < tls.VersionTLS12 {
+ if tlsCfg.AllowLegacyMinVersion {
+ return false, true
+ }
+
+ tlsCfg.MinVersion = tls.VersionTLS12
+
+ return true, false
+ }
+
+ return false, false
+}
+
+func normalizeGCPIAMDefaults(auth *GCPIAMAuth) {
+ if auth == nil {
+ return
}
- if rc.MinIdleConns == 0 {
- rc.MinIdleConns = 0
+ if auth.TokenLifetime == 0 {
+ auth.TokenLifetime = defaultTokenLifetime
}
- if rc.ReadTimeout == 0 {
- rc.ReadTimeout = 3 * time.Second
+ if auth.RefreshEvery == 0 {
+ auth.RefreshEvery = defaultRefreshEvery
}
- if rc.WriteTimeout == 0 {
- rc.WriteTimeout = 3 * time.Second
+ if auth.RefreshCheckInterval == 0 {
+ auth.RefreshCheckInterval = defaultRefreshCheckInterval
+ }
+
+ if auth.RefreshOperationTimeout == 0 {
+ auth.RefreshOperationTimeout = defaultRefreshOperationTimeout
+ }
+}
+
+func validateConfig(cfg Config) error {
+ if err := validateTopology(cfg.Topology); err != nil {
+ return err
}
- if rc.DialTimeout == 0 {
- rc.DialTimeout = 5 * time.Second
+ if cfg.Auth.StaticPassword != nil && cfg.Auth.GCPIAM != nil {
+ return configError("only one auth strategy can be configured")
}
- if rc.PoolTimeout == 0 {
- rc.PoolTimeout = 2 * time.Second
+ if cfg.TLS != nil && strings.TrimSpace(cfg.TLS.CACertBase64) == "" {
+ return configError("TLS CA cert is required when TLS is configured")
}
- if rc.MaxRetries == 0 {
- rc.MaxRetries = 3
+ if cfg.Auth.GCPIAM == nil {
+ return nil
}
- if rc.MinRetryBackoff == 0 {
- rc.MinRetryBackoff = 8 * time.Millisecond
+ if cfg.TLS == nil {
+ return configError("TLS must be configured when GCP IAM auth is enabled")
}
- if rc.MaxRetryBackoff == 0 {
- rc.MaxRetryBackoff = 1 * time.Second
+ if strings.TrimSpace(cfg.Auth.GCPIAM.ServiceAccount) == "" {
+ return configError("service account is required for GCP IAM auth")
+ }
+
+ if strings.Contains(cfg.Auth.GCPIAM.ServiceAccount, "/") {
+ return configError("service account cannot contain '/' characters")
+ }
+
+ if strings.TrimSpace(cfg.Auth.GCPIAM.CredentialsBase64) == "" {
+ return configError("credentials are required for GCP IAM auth")
+ }
+
+ if cfg.Auth.GCPIAM.RefreshEvery >= cfg.Auth.GCPIAM.TokenLifetime {
+ return configError("RefreshEvery must be less than TokenLifetime to prevent token expiry before refresh")
+ }
+
+ return nil
+}
+
+func validateTopology(topology Topology) error {
+ count := 0
+
+ if topology.Standalone != nil {
+ count++
+
+ if strings.TrimSpace(topology.Standalone.Address) == "" {
+ return configError("standalone address is required")
+ }
+ }
+
+ if topology.Sentinel != nil {
+ count++
+
+ if len(topology.Sentinel.Addresses) == 0 {
+ return configError("sentinel addresses are required")
+ }
+
+ if strings.TrimSpace(topology.Sentinel.MasterName) == "" {
+ return configError("sentinel master name is required")
+ }
+
+ for _, address := range topology.Sentinel.Addresses {
+ if strings.TrimSpace(address) == "" {
+ return configError("sentinel addresses cannot be empty")
+ }
+ }
}
+
+ if topology.Cluster != nil {
+ count++
+
+ if len(topology.Cluster.Addresses) == 0 {
+ return configError("cluster addresses are required")
+ }
+
+ for _, address := range topology.Cluster.Addresses {
+ if strings.TrimSpace(address) == "" {
+ return configError("cluster addresses cannot be empty")
+ }
+ }
+ }
+
+ if count != 1 {
+ return configError("exactly one topology must be configured")
+ }
+
+ return nil
+}
+
+func buildTLSConfig(cfg TLSConfig) (*tls.Config, error) {
+ caCert, err := base64.StdEncoding.DecodeString(cfg.CACertBase64)
+ if err != nil {
+ return nil, err
+ }
+
+ caCertPool := x509.NewCertPool()
+ if !caCertPool.AppendCertsFromPEM(caCert) {
+ return nil, errors.New("adding CA cert failed")
+ }
+
+ // Enforce a TLS 1.2 floor. normalizeTLSDefaults already applies this
+ // floor in normal flows, but a caller using AllowLegacyMinVersion=true
+ // could still set a lower value. The literal tls.VersionTLS12 default
+ // satisfies gosec G402 static analysis; we override only when the caller
+ // requests a *higher* version.
+ minVersion := max(uint16(tls.VersionTLS12), cfg.MinVersion)
+
+ tlsConfig := &tls.Config{ // #nosec G402 -- minVersion is floored to tls.VersionTLS12 above; gosec cannot trace through local variables
+ RootCAs: caCertPool,
+ MinVersion: minVersion,
+ }
+
+ return tlsConfig, nil
+}
+
+// recordConnectionFailure increments the redis connection failure counter.
+// No-op when metricsFactory is nil.
+func (c *Client) recordConnectionFailure(operation string) {
+ if c.metricsFactory == nil {
+ return
+ }
+
+ counter, err := c.metricsFactory.Counter(connectionFailuresMetric)
+ if err != nil {
+ c.logger.Log(context.Background(), log.LevelWarn, "failed to create redis metric counter", log.Err(err))
+ return
+ }
+
+ err = counter.
+ WithLabels(map[string]string{
+ "operation": constant.SanitizeMetricLabel(operation),
+ }).
+ AddOne(context.Background())
+ if err != nil {
+ c.logger.Log(context.Background(), log.LevelWarn, "failed to record redis metric", log.Err(err))
+ }
+}
+
+// recordReconnection increments the redis reconnection counter.
+// No-op when metricsFactory is nil.
+func (c *Client) recordReconnection(result string) {
+ if c.metricsFactory == nil {
+ return
+ }
+
+ counter, err := c.metricsFactory.Counter(reconnectionsMetric)
+ if err != nil {
+ c.logger.Log(context.Background(), log.LevelWarn, "failed to create redis reconnection metric counter", log.Err(err))
+ return
+ }
+
+ err = counter.
+ WithLabels(map[string]string{
+ "result": result,
+ }).
+ AddOne(context.Background())
+ if err != nil {
+ c.logger.Log(context.Background(), log.LevelWarn, "failed to record redis reconnection metric", log.Err(err))
+ }
+}
+
+func configError(msg string) error {
+ return fmt.Errorf("%w: %s", ErrInvalidConfig, msg)
}
diff --git a/commons/redis/redis_example_test.go b/commons/redis/redis_example_test.go
new file mode 100644
index 00000000..677e4e18
--- /dev/null
+++ b/commons/redis/redis_example_test.go
@@ -0,0 +1,27 @@
+//go:build unit
+
+package redis_test
+
+import (
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/redis"
+)
+
+func ExampleConfig() {
+ cfg := redis.Config{
+ Topology: redis.Topology{
+ Standalone: &redis.StandaloneTopology{Address: "redis.internal:6379"},
+ },
+ Auth: redis.Auth{
+ StaticPassword: &redis.StaticPasswordAuth{Password: "redacted"},
+ },
+ }
+
+ fmt.Println(cfg.Topology.Standalone.Address)
+ fmt.Println(cfg.Auth.StaticPassword != nil)
+
+ // Output:
+ // redis.internal:6379
+ // true
+}
diff --git a/commons/redis/redis_integration_test.go b/commons/redis/redis_integration_test.go
new file mode 100644
index 00000000..3730f500
--- /dev/null
+++ b/commons/redis/redis_integration_test.go
@@ -0,0 +1,321 @@
+//go:build integration
+
+package redis
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go"
+ tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
+ "github.com/testcontainers/testcontainers-go/wait"
+)
+
+// setupRedisContainer starts a real Redis 7 container and returns its address
+// (host:port) plus a cleanup function. The container is waited on until Redis
+// logs "Ready to accept connections", which guarantees the server is ready.
+func setupRedisContainer(t *testing.T) (string, func()) {
+ t.Helper()
+
+ ctx := context.Background()
+
+ container, err := tcredis.Run(ctx,
+ "redis:7-alpine",
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("Ready to accept connections").
+ WithStartupTimeout(30*time.Second),
+ ),
+ )
+ require.NoError(t, err)
+
+ endpoint, err := container.Endpoint(ctx, "")
+ require.NoError(t, err)
+
+ return endpoint, func() {
+ require.NoError(t, container.Terminate(ctx))
+ }
+}
+
+// setupRedisContainerWithPassword starts a Redis 7 container with password
+// authentication enabled via the --requirepass flag. Returns the address,
+// the password, and a cleanup function.
+func setupRedisContainerWithPassword(t *testing.T, password string) (string, func()) {
+ t.Helper()
+
+ ctx := context.Background()
+
+ container, err := tcredis.Run(ctx,
+ "redis:7-alpine",
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("Ready to accept connections").
+ WithStartupTimeout(30*time.Second),
+ ),
+ // Override the default CMD to pass --requirepass.
+ testcontainers.WithCmd("redis-server", "--requirepass", password),
+ )
+ require.NoError(t, err)
+
+ endpoint, err := container.Endpoint(ctx, "")
+ require.NoError(t, err)
+
+ return endpoint, func() {
+ require.NoError(t, container.Terminate(ctx))
+ }
+}
+
+// newTestConfig builds a minimal standalone Config pointing at the given address.
+func newTestConfig(addr string) Config {
+ return Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: addr},
+ },
+ Logger: &log.NopLogger{},
+ }
+}
+
+// TestIntegration_Redis_ConnectAndOperate verifies the full lifecycle against a
+// real Redis container: connect, SET, GET, and close.
+func TestIntegration_Redis_ConnectAndOperate(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, client.Close()) }()
+
+ rdb, err := client.GetClient(ctx)
+ require.NoError(t, err)
+ require.NotNil(t, rdb)
+
+ // SET a key with a TTL to avoid polluting the container beyond test scope.
+ const testKey = "integration:connect:key"
+ const testValue = "hello-from-integration-test"
+
+ err = rdb.Set(ctx, testKey, testValue, 30*time.Second).Err()
+ require.NoError(t, err, "SET must succeed")
+
+ got, err := rdb.Get(ctx, testKey).Result()
+ require.NoError(t, err, "GET must succeed")
+ assert.Equal(t, testValue, got, "GET value must match SET value")
+}
+
+// TestIntegration_Redis_Status verifies that Status() and IsConnected() report
+// the correct state throughout the client lifecycle.
+func TestIntegration_Redis_Status(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ // After New(), the client must be connected.
+ status, err := client.Status()
+ require.NoError(t, err)
+ assert.True(t, status.Connected, "status.Connected must be true after New()")
+
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "IsConnected must be true after New()")
+
+ // After Close(), the client must report disconnected.
+ require.NoError(t, client.Close())
+
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected, "IsConnected must be false after Close()")
+
+ status, err = client.Status()
+ require.NoError(t, err)
+ assert.False(t, status.Connected, "status.Connected must be false after Close()")
+}
+
+// TestIntegration_Redis_ReconnectOnDemand verifies that GetClient() transparently
+// reconnects when the internal client has been closed (simulating a disconnect).
+func TestIntegration_Redis_ReconnectOnDemand(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ // Verify initial connectivity.
+ rdb, err := client.GetClient(ctx)
+ require.NoError(t, err)
+ require.NoError(t, rdb.Set(ctx, "reconnect:before", "v1", 30*time.Second).Err())
+
+ // Simulate a disconnect by calling Close(), which sets the internal client
+ // to nil and connected to false.
+ require.NoError(t, client.Close())
+
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected, "must be disconnected after Close()")
+
+ // GetClient() should trigger reconnect-on-demand because the internal
+ // client is nil.
+ rdb2, err := client.GetClient(ctx)
+ require.NoError(t, err, "GetClient must reconnect on demand")
+ require.NotNil(t, rdb2)
+
+ // The reconnected client must be able to operate normally.
+ require.NoError(t, rdb2.Set(ctx, "reconnect:after", "v2", 30*time.Second).Err())
+
+ got, err := rdb2.Get(ctx, "reconnect:after").Result()
+ require.NoError(t, err)
+ assert.Equal(t, "v2", got)
+
+ // Verify status is back to connected.
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "must be reconnected after GetClient()")
+
+ // Final cleanup.
+ require.NoError(t, client.Close())
+}
+
+// TestIntegration_Redis_ConcurrentOperations spawns multiple goroutines each
+// performing SET/GET operations concurrently. When run with -race, this
+// validates there are no data races in the client implementation.
+func TestIntegration_Redis_ConcurrentOperations(t *testing.T) {
+ addr, cleanup := setupRedisContainer(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ defer func() { require.NoError(t, client.Close()) }()
+
+ const goroutines = 10
+ const opsPerGoroutine = 50
+
+ var wg sync.WaitGroup
+
+ wg.Add(goroutines)
+
+ // errors collects any non-nil errors from goroutines so the main
+ // goroutine can fail the test with full context.
+ errs := make(chan error, goroutines*opsPerGoroutine)
+
+ for g := range goroutines {
+ go func(id int) {
+ defer wg.Done()
+
+ rdb, getErr := client.GetClient(ctx)
+ if getErr != nil {
+ errs <- fmt.Errorf("goroutine %d: GetClient: %w", id, getErr)
+ return
+ }
+
+ for i := range opsPerGoroutine {
+ key := fmt.Sprintf("concurrent:%d:%d", id, i)
+ value := fmt.Sprintf("val-%d-%d", id, i)
+
+ if setErr := rdb.Set(ctx, key, value, 30*time.Second).Err(); setErr != nil {
+ errs <- fmt.Errorf("goroutine %d op %d: SET: %w", id, i, setErr)
+ return
+ }
+
+ got, getValErr := rdb.Get(ctx, key).Result()
+ if getValErr != nil {
+ errs <- fmt.Errorf("goroutine %d op %d: GET: %w", id, i, getValErr)
+ return
+ }
+
+ if got != value {
+ errs <- fmt.Errorf("goroutine %d op %d: value mismatch: got %q, want %q", id, i, got, value)
+ return
+ }
+ }
+ }(g)
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for e := range errs {
+ t.Error(e)
+ }
+}
+
+// TestIntegration_Redis_StaticPassword verifies that authentication with a
+// static password works against a real Redis container configured with
+// --requirepass.
+func TestIntegration_Redis_StaticPassword(t *testing.T) {
+ const password = "integration-test-secret-42"
+
+ addr, cleanup := setupRedisContainerWithPassword(t, password)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Connect with the correct password.
+ cfg := Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: addr},
+ },
+ Auth: Auth{
+ StaticPassword: &StaticPasswordAuth{Password: password},
+ },
+ Logger: &log.NopLogger{},
+ }
+
+ client, err := New(ctx, cfg)
+ require.NoError(t, err, "New() with correct password must succeed")
+
+ defer func() { require.NoError(t, client.Close()) }()
+
+ rdb, err := client.GetClient(ctx)
+ require.NoError(t, err)
+
+ // Verify authenticated operations work.
+ const testKey = "auth:static:key"
+ const testValue = "authenticated-value"
+
+ require.NoError(t, rdb.Set(ctx, testKey, testValue, 30*time.Second).Err())
+
+ got, err := rdb.Get(ctx, testKey).Result()
+ require.NoError(t, err)
+ assert.Equal(t, testValue, got)
+
+ // Verify that connecting WITHOUT a password fails.
+ badCfg := Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: addr},
+ },
+ Logger: &log.NopLogger{},
+ }
+
+ badClient, err := New(ctx, badCfg)
+ assert.Error(t, err, "New() without password must fail against auth-protected Redis")
+ assert.Nil(t, badClient)
+
+ // Verify that connecting with the WRONG password also fails.
+ wrongCfg := Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: addr},
+ },
+ Auth: Auth{
+ StaticPassword: &StaticPasswordAuth{Password: "wrong-password"},
+ },
+ Logger: &log.NopLogger{},
+ }
+
+ wrongClient, err := New(ctx, wrongCfg)
+ assert.Error(t, err, "New() with wrong password must fail")
+ assert.Nil(t, wrongClient)
+}
diff --git a/commons/redis/redis_test.go b/commons/redis/redis_test.go
index b4faba5d..6ccfc58f 100644
--- a/commons/redis/redis_test.go
+++ b/commons/redis/redis_test.go
@@ -1,569 +1,1155 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package redis
import (
"context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/base64"
+ "encoding/pem"
"errors"
+ "fmt"
+ "math/big"
"sync"
+ "sync/atomic"
"testing"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
-func TestRedisConnection_Connect(t *testing.T) {
- // Start a mini Redis server for testing
- mr, err := miniredis.Run()
- if err != nil {
- t.Fatalf("Failed to start miniredis: %v", err)
+type recordingLogger struct {
+ mu sync.Mutex
+ warnings []string
+}
+
+func (logger *recordingLogger) Log(_ context.Context, level log.Level, msg string, _ ...log.Field) {
+ if level != log.LevelWarn {
+ return
}
- defer mr.Close()
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ logger.mu.Lock()
+ logger.warnings = append(logger.warnings, msg)
+ logger.mu.Unlock()
+}
+
+func (logger *recordingLogger) With(...log.Field) log.Logger { return logger }
+
+func (logger *recordingLogger) WithGroup(string) log.Logger { return logger }
+
+func (logger *recordingLogger) Enabled(log.Level) bool { return true }
+
+func (logger *recordingLogger) Sync(context.Context) error { return nil }
+
+func (logger *recordingLogger) warningMessages() []string {
+ logger.mu.Lock()
+ defer logger.mu.Unlock()
+
+ return append([]string(nil), logger.warnings...)
+}
+
+func newStandaloneConfig(addr string) Config {
+ return Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: addr},
+ },
+ Logger: &log.NopLogger{},
+ }
+}
+
+func TestClient_NewAndGetClient(t *testing.T) {
+ mr := miniredis.RunT(t)
+
+ client, err := New(context.Background(), newStandaloneConfig(mr.Addr()))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if closeErr := client.Close(); closeErr != nil {
+ t.Errorf("cleanup: client close: %v", closeErr)
+ }
+ })
+
+ redisClient, err := client.GetClient(context.Background())
+ require.NoError(t, err)
+
+ require.NoError(t, redisClient.Set(context.Background(), "test:key", "value", 0).Err())
+ value, err := redisClient.Get(context.Background(), "test:key").Result()
+ require.NoError(t, err)
+ assert.Equal(t, "value", value)
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected)
+}
+
+func TestClient_New_InvalidConfig(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
tests := []struct {
- name string
- redisConn *RedisConnection
- expectError bool
- skip bool
- skipReason string
+ name string
+ cfg Config
+ errText string
}{
{
- name: "successful connection - standalone mode",
- redisConn: &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
+ name: "missing topology",
+ cfg: Config{Logger: &log.NopLogger{}},
+ errText: "exactly one topology",
+ },
+ {
+ name: "multiple topologies",
+ cfg: Config{
+ Topology: Topology{
+ Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"},
+ Cluster: &ClusterTopology{Addresses: []string{"127.0.0.1:6379"}},
+ },
+ Logger: &log.NopLogger{},
},
- expectError: false,
+ errText: "exactly one topology",
},
{
- name: "successful connection - sentinel mode",
- redisConn: &RedisConnection{
- Mode: ModeSentinel,
- Address: []string{mr.Addr()},
- MasterName: "mymaster",
- Logger: logger,
+ name: "gcp iam requires tls",
+ cfg: Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ Auth: Auth{
+ GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: "abc",
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ },
+ },
+ Logger: &log.NopLogger{},
},
- skip: true,
- skipReason: "miniredis doesn't support sentinel commands",
+ errText: "TLS must be configured",
},
{
- name: "successful connection - cluster mode",
- redisConn: &RedisConnection{
- Mode: ModeCluster,
- Address: []string{mr.Addr()},
- Logger: logger,
+ name: "gcp iam requires service account",
+ cfg: Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{
+ GCPIAM: &GCPIAMAuth{CredentialsBase64: "abc"},
+ },
+ Logger: &log.NopLogger{},
},
- expectError: false,
+ errText: "service account is required",
},
{
- name: "failed connection - wrong addresses",
- redisConn: &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{"wrong_address:6379"},
- Logger: logger,
+ name: "gcp iam service account cannot contain slash",
+ cfg: Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{
+ GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: "abc",
+ ServiceAccount: "projects/-/serviceAccounts/svc@project.iam.gserviceaccount.com",
+ },
+ },
+ Logger: &log.NopLogger{},
},
- expectError: true,
+ errText: "cannot contain '/'",
},
{
- name: "failed connection - wrong sentinel addresses",
- redisConn: &RedisConnection{
- Mode: ModeSentinel,
- Address: []string{"wrong_address:6379"},
- MasterName: "mymaster",
- Logger: logger,
+ name: "gcp iam credentials required",
+ cfg: Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{
+ GCPIAM: &GCPIAMAuth{ServiceAccount: "svc@project.iam.gserviceaccount.com"},
+ },
+ Logger: &log.NopLogger{},
},
- expectError: true,
+ errText: "credentials are required",
},
{
- name: "failed connection - wrong cluster addresses",
- redisConn: &RedisConnection{
- Mode: ModeCluster,
- Address: []string{"wrong_address:6379"},
- Logger: logger,
+ name: "tls requires ca cert",
+ cfg: Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{},
+ Logger: &log.NopLogger{},
},
- expectError: true,
+ errText: "TLS CA cert is required",
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- if tt.skip {
- t.Skip(tt.skipReason)
- }
-
- ctx := context.Background()
- err := tt.redisConn.Connect(ctx)
-
- if tt.expectError {
- assert.Error(t, err)
- assert.False(t, tt.redisConn.Connected)
- assert.Nil(t, tt.redisConn.Client)
- } else {
- assert.NoError(t, err)
- assert.True(t, tt.redisConn.Connected)
- assert.NotNil(t, tt.redisConn.Client)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ client, err := New(context.Background(), test.cfg)
+ require.Error(t, err)
+ assert.Nil(t, client)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ assert.Contains(t, err.Error(), test.errText)
})
}
}
-func TestRedisConnection_GetClient(t *testing.T) {
- // Start a mini Redis server for testing
- mr, err := miniredis.Run()
- if err != nil {
- t.Fatalf("Failed to start miniredis: %v", err)
- }
- defer mr.Close()
+func TestBuildTLSConfig(t *testing.T) {
+ _, err := buildTLSConfig(TLSConfig{CACertBase64: "not-base64"})
+ assert.Error(t, err)
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ _, err = buildTLSConfig(TLSConfig{CACertBase64: base64.StdEncoding.EncodeToString([]byte("not-a-pem"))})
+ assert.Error(t, err)
- t.Run("get client - first time initialization", func(t *testing.T) {
- ctx := context.Background()
- redisConn := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
- }
+ cfg, err := buildTLSConfig(TLSConfig{
+ CACertBase64: base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t)),
+ MinVersion: tls.VersionTLS12,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, cfg)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
- client, err := redisConn.GetClient(ctx)
- assert.NoError(t, err)
- assert.NotNil(t, client)
- assert.True(t, redisConn.Connected)
+ cfg, err = buildTLSConfig(TLSConfig{
+ CACertBase64: base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t)),
+ MinVersion: tls.VersionTLS13,
})
+ require.NoError(t, err)
+ require.NotNil(t, cfg)
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
+
+ // buildTLSConfig enforces a TLS 1.2 floor. Passing a version below 1.2
+ // is silently upgraded to TLS 1.2 to prevent insecure configurations.
+ cfg, err = buildTLSConfig(TLSConfig{
+ CACertBase64: base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t)),
+ MinVersion: tls.VersionTLS10,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, cfg)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+
+ // Even when AllowLegacyMinVersion is true and normalizeTLSDefaults
+ // preserves the lower version, buildTLSConfig enforces the TLS 1.2 floor
+ // as a defense-in-depth measure.
+ normalizedCfg := &TLSConfig{
+ CACertBase64: base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t)),
+ MinVersion: tls.VersionTLS10,
+ AllowLegacyMinVersion: true,
+ }
+ _, _ = normalizeTLSDefaults(normalizedCfg)
+ cfg, err = buildTLSConfig(*normalizedCfg)
+ require.NoError(t, err)
+ require.NotNil(t, cfg)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+}
- t.Run("get client - already initialized", func(t *testing.T) {
- ctx := context.Background()
- redisConn := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
- }
+func TestClient_NilReceiverGuards(t *testing.T) {
+ var client *Client
- // First call to initialize
- _, err := redisConn.GetClient(ctx)
- assert.NoError(t, err)
+ err := client.Connect(context.Background())
+ assert.ErrorIs(t, err, ErrNilClient)
- // Second call to get existing client
- client, err := redisConn.GetClient(ctx)
- assert.NoError(t, err)
- assert.NotNil(t, client)
- assert.True(t, redisConn.Connected)
- })
+ rdb, err := client.GetClient(context.Background())
+ assert.ErrorIs(t, err, ErrNilClient)
+ assert.Nil(t, rdb)
- t.Run("get client - connection fails", func(t *testing.T) {
- ctx := context.Background()
- redisConn := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{"wrong_address:6379"},
- Logger: logger,
- }
+ err = client.Close()
+ assert.ErrorIs(t, err, ErrNilClient)
+
+ connected, err := client.IsConnected()
+ assert.ErrorIs(t, err, ErrNilClient)
+ assert.False(t, connected)
+ assert.ErrorIs(t, client.LastRefreshError(), ErrNilClient)
+}
+
+func TestClient_StatusLifecycle(t *testing.T) {
+ mr := miniredis.RunT(t)
- client, err := redisConn.GetClient(ctx)
- assert.Error(t, err)
- assert.Nil(t, client)
- assert.False(t, redisConn.Connected)
+ client, err := New(context.Background(), newStandaloneConfig(mr.Addr()))
+ require.NoError(t, err)
+
+ status, err := client.Status()
+ require.NoError(t, err)
+ assert.True(t, status.Connected)
+ assert.Nil(t, status.LastRefreshError)
+
+ require.NoError(t, client.Close())
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected)
+}
+
+func TestClient_RefreshLoop_DoesNotDuplicateGoroutines(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ normalized, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: base64.StdEncoding.EncodeToString([]byte("{}")),
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ RefreshEvery: time.Millisecond,
+ RefreshCheckInterval: time.Millisecond,
+ RefreshOperationTimeout: time.Second,
+ }},
+ Logger: &log.NopLogger{},
})
+ require.NoError(t, err)
- // Test different connection modes
- testModes := []struct {
- name string
- redisConn *RedisConnection
- skip bool
- skipReason string
- }{
- {
- name: "sentinel mode",
- redisConn: &RedisConnection{
- Mode: ModeSentinel,
- Address: []string{mr.Addr()},
- MasterName: "mymaster",
- Logger: logger,
- },
- skip: true,
- skipReason: "miniredis doesn't support sentinel commands",
- },
- {
- name: "cluster mode",
- redisConn: &RedisConnection{
- Mode: ModeCluster,
- Address: []string{mr.Addr()},
- Logger: logger,
- },
+ var calls int32
+ client := &Client{
+ cfg: normalized,
+ logger: normalized.Logger,
+ tokenRetriever: func(ctx context.Context) (string, error) {
+ atomic.AddInt32(&calls, 1)
+ <-ctx.Done()
+
+ return "", ctx.Err()
},
+ reconnectFn: func(context.Context) error { return nil },
}
- for _, mode := range testModes {
- t.Run("get client - "+mode.name, func(t *testing.T) {
- if mode.skip {
- t.Skip(mode.skipReason)
+ client.mu.Lock()
+ client.lastRefresh = time.Now().Add(-time.Hour)
+ client.startRefreshLoopLocked()
+ client.startRefreshLoopLocked()
+ client.mu.Unlock()
+
+ require.Eventually(t, func() bool {
+ return atomic.LoadInt32(&calls) >= 1
+ }, 200*time.Millisecond, 10*time.Millisecond)
+
+ require.NoError(t, client.Close())
+ assert.Equal(t, int32(1), atomic.LoadInt32(&calls))
+}
+
+func TestClient_RefreshStatusErrorAndRecovery(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ normalized, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: base64.StdEncoding.EncodeToString([]byte("{}")),
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ RefreshEvery: time.Millisecond,
+ RefreshCheckInterval: time.Millisecond,
+ RefreshOperationTimeout: time.Second,
+ }},
+ Logger: &log.NopLogger{},
+ })
+ require.NoError(t, err)
+
+ firstErr := errors.New("token refresh failed")
+ var shouldFail atomic.Bool
+ shouldFail.Store(true)
+
+ client := &Client{
+ cfg: normalized,
+ logger: normalized.Logger,
+ tokenRetriever: func(context.Context) (string, error) {
+ if shouldFail.Load() {
+ return "", firstErr
}
- ctx := context.Background()
- client, err := mode.redisConn.GetClient(ctx)
- assert.NoError(t, err)
- assert.NotNil(t, client)
- assert.True(t, mode.redisConn.Connected)
- })
+ return "token", nil
+ },
+ reconnectFn: func(context.Context) error { return nil },
}
+
+ client.mu.Lock()
+ client.lastRefresh = time.Now().Add(-time.Hour)
+ client.startRefreshLoopLocked()
+ client.mu.Unlock()
+
+ require.Eventually(t, func() bool {
+ return errors.Is(client.LastRefreshError(), firstErr)
+ }, 500*time.Millisecond, 10*time.Millisecond)
+
+ shouldFail.Store(false)
+
+ require.Eventually(t, func() bool {
+ return client.LastRefreshError() == nil
+ }, 500*time.Millisecond, 10*time.Millisecond)
+
+ require.NoError(t, client.Close())
}
-func TestRedisIntegration(t *testing.T) {
- // Skip this test when running in CI environment
- if testing.Short() {
- t.Skip("Skipping integration test in short mode")
- }
+func TestClient_RefreshTick_ReconnectFailureReturnsFalse(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ normalized, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: base64.StdEncoding.EncodeToString([]byte("{}")),
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ RefreshEvery: time.Millisecond,
+ RefreshCheckInterval: time.Millisecond,
+ RefreshOperationTimeout: time.Second,
+ }},
+ Logger: &log.NopLogger{},
+ })
+ require.NoError(t, err)
- // Start a mini Redis server for testing
- mr, err := miniredis.Run()
- if err != nil {
- t.Fatalf("Failed to start miniredis: %v", err)
+ reconnectErr := errors.New("simulated reconnect failure")
+ initialRefresh := time.Now().Add(-time.Hour)
+
+ client := &Client{
+ cfg: normalized,
+ logger: normalized.Logger,
+ token: "old-token",
+ tokenRetriever: func(context.Context) (string, error) {
+ return "new-token", nil
+ },
+ reconnectFn: func(context.Context) error {
+ return reconnectErr
+ },
+ lastRefresh: initialRefresh,
}
- defer mr.Close()
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ ok := client.refreshTick(context.Background(), normalized.Auth.GCPIAM)
+ assert.False(t, ok)
+ assert.ErrorIs(t, client.LastRefreshError(), reconnectErr)
- // Create Redis connection
- redisConn := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
- }
+ client.mu.RLock()
+ defer client.mu.RUnlock()
- ctx := context.Background()
+ assert.Equal(t, "old-token", client.token)
+ assert.Equal(t, initialRefresh, client.lastRefresh)
+}
- // Connect to Redis
- err = redisConn.Connect(ctx)
- assert.NoError(t, err)
+func TestClient_Connect_ReconnectClosesPreviousClient(t *testing.T) {
+ mr := miniredis.RunT(t)
- // Get client
- client, err := redisConn.GetClient(ctx)
- assert.NoError(t, err)
+ client, err := New(context.Background(), newStandaloneConfig(mr.Addr()))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if closeErr := client.Close(); closeErr != nil {
+ t.Errorf("cleanup: client close: %v", closeErr)
+ }
+ })
- // Test setting and getting a value
- key := "test_key"
- value := "test_value"
+ firstClient, err := client.GetClient(context.Background())
+ require.NoError(t, err)
- err = client.Set(ctx, key, value, 0).Err()
- assert.NoError(t, err)
+ require.NoError(t, client.Connect(context.Background()))
- result, err := client.Get(ctx, key).Result()
- assert.NoError(t, err)
- assert.Equal(t, value, result)
+ secondClient, err := client.GetClient(context.Background())
+ require.NoError(t, err)
+ assert.NotSame(t, firstClient, secondClient)
+
+ _, err = firstClient.Ping(context.Background()).Result()
+ assert.Error(t, err)
}
-func TestTTLFunctionality(t *testing.T) {
- // Start a mini Redis server for testing
- mr, err := miniredis.Run()
- if err != nil {
- t.Fatalf("Failed to start miniredis: %v", err)
- }
- defer mr.Close()
+func TestClient_ReconnectFailure_PreservesOldClient(t *testing.T) {
+ mr := miniredis.RunT(t)
+ addr := mr.Addr() // capture address before closing
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ // Connect a working standalone client (no IAM -- we test reconnect directly).
+ client, err := New(context.Background(), newStandaloneConfig(addr))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if closeErr := client.Close(); closeErr != nil {
+ t.Errorf("cleanup: client close: %v", closeErr)
+ }
+ })
- // Create Redis connection
- redisConn := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
- }
+ // Verify initial connectivity.
+ rdb, err := client.GetClient(context.Background())
+ require.NoError(t, err)
+ require.NoError(t, rdb.Set(context.Background(), "preserve:key", "before", 0).Err())
- ctx := context.Background()
+ // Shut down miniredis so the new client Ping fails during reconnect.
+ mr.Close()
- // Connect to Redis
- err = redisConn.Connect(ctx)
- assert.NoError(t, err)
+ // Simulate a reconnect failure.
+ client.mu.Lock()
+ err = client.reconnectLocked(context.Background())
+ client.mu.Unlock()
- // Get client
- client, err := redisConn.GetClient(ctx)
- assert.NoError(t, err)
+ // reconnectLocked must return an error (Ping against closed server fails).
+ require.Error(t, err, "reconnectLocked should fail when new client cannot Ping")
- // Test setting a value with TTL
- key := "ttl_key"
- value := "ttl_value"
+ // The old client must still be set and marked connected.
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "client must remain connected after failed reconnect")
- // Use the default TTL constant
- err = client.Set(ctx, key, value, time.Duration(TTL)*time.Second).Err()
- assert.NoError(t, err)
+ // Restart miniredis on the same address so the OLD preserved client can work again.
+ mr2 := miniredis.NewMiniRedis()
+ require.NoError(t, mr2.StartAddr(addr))
+ t.Cleanup(mr2.Close)
- // Check TTL is set
- ttl, err := client.TTL(ctx, key).Result()
- assert.NoError(t, err)
- assert.True(t, ttl > 0, "TTL should be greater than 0")
+ // The preserved old client must still be usable.
+ rdb2, err := client.GetClient(context.Background())
+ require.NoError(t, err)
+ require.NoError(t, rdb2.Set(context.Background(), "preserve:key", "still-works", 0).Err())
- // Verify the value is still accessible
- result, err := client.Get(ctx, key).Result()
- assert.NoError(t, err)
- assert.Equal(t, value, result)
+ val, err := rdb2.Get(context.Background(), "preserve:key").Result()
+ require.NoError(t, err)
+ assert.Equal(t, "still-works", val)
+}
- // Fast-forward time in miniredis to simulate expiration
- mr.FastForward(time.Duration(TTL+1) * time.Second)
+func TestClient_ReconnectFailure_IAMRefreshLoopPreservesClient(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ normalized, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: base64.StdEncoding.EncodeToString([]byte("{}")),
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ RefreshEvery: time.Millisecond,
+ RefreshCheckInterval: time.Millisecond,
+ RefreshOperationTimeout: time.Second,
+ }},
+ Logger: &log.NopLogger{},
+ })
+ require.NoError(t, err)
+
+ reconnectErr := errors.New("simulated reconnect failure")
+ var reconnectShouldFail atomic.Bool
+ reconnectShouldFail.Store(true)
+
+ var reconnectCalls atomic.Int32
+ var tokenAtReconnect atomic.Value
+
+ client := &Client{
+ cfg: normalized,
+ logger: normalized.Logger,
+ connected: true,
+ token: "original-working-token",
+ tokenRetriever: func(context.Context) (string, error) {
+ return "new-refreshed-token", nil
+ },
+ reconnectFn: func(ctx context.Context) error {
+ reconnectCalls.Add(1)
- // Verify the key has expired
- exists, err := client.Exists(ctx, key).Result()
- assert.NoError(t, err)
- assert.Equal(t, int64(0), exists, "Key should have expired")
-}
+ // Capture the token at the time of reconnect attempt for verification.
+ tokenAtReconnect.Store("called")
-func TestModesIntegration(t *testing.T) {
- // Skip this test when running in CI environment
- if testing.Short() {
- t.Skip("Skipping integration test in short mode")
- }
+ if reconnectShouldFail.Load() {
+ return reconnectErr
+ }
- // Start a mini Redis server for testing
- mr, err := miniredis.Run()
- if err != nil {
- t.Fatalf("Failed to start miniredis: %v", err)
+ return nil
+ },
}
- defer mr.Close()
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+ client.mu.Lock()
+ client.lastRefresh = time.Now().Add(-time.Hour)
+ client.startRefreshLoopLocked()
+ client.mu.Unlock()
+
+ // Wait for at least one failed reconnect attempt.
+ require.Eventually(t, func() bool {
+ return reconnectCalls.Load() >= 1
+ }, 500*time.Millisecond, 5*time.Millisecond)
+
+ // Verify: the refresh error is recorded.
+ require.Eventually(t, func() bool {
+ return client.LastRefreshError() != nil
+ }, 500*time.Millisecond, 5*time.Millisecond)
+ assert.ErrorIs(t, client.LastRefreshError(), reconnectErr)
+
+ // Verify: the token is rolled back to the original after failed reconnect.
+ client.mu.RLock()
+ currentToken := client.token
+ client.mu.RUnlock()
+ assert.Equal(t, "original-working-token", currentToken,
+ "token must be rolled back to original after failed reconnect")
+
+ // Now allow reconnect to succeed.
+ reconnectShouldFail.Store(false)
+
+ // Wait for recovery.
+ require.Eventually(t, func() bool {
+ return client.LastRefreshError() == nil
+ }, 500*time.Millisecond, 5*time.Millisecond)
+
+ // After successful reconnect, the new token should be in place.
+ client.mu.RLock()
+ recoveredToken := client.token
+ client.mu.RUnlock()
+ assert.Equal(t, "new-refreshed-token", recoveredToken,
+ "token must be updated after successful reconnect")
+
+ require.NoError(t, client.Close())
+}
+
+func TestClient_ReconnectSuccess_SwapsClient(t *testing.T) {
+ mr := miniredis.RunT(t)
+
+ client, err := New(context.Background(), newStandaloneConfig(mr.Addr()))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if closeErr := client.Close(); closeErr != nil {
+ t.Errorf("cleanup: client close: %v", closeErr)
+ }
+ })
+
+ // Grab reference to the original underlying client.
+ rdb1, err := client.GetClient(context.Background())
+ require.NoError(t, err)
+
+ // Successful reconnect should swap the client.
+ client.mu.Lock()
+ err = client.reconnectLocked(context.Background())
+ client.mu.Unlock()
+ require.NoError(t, err)
+
+ rdb2, err := client.GetClient(context.Background())
+ require.NoError(t, err)
+
+ // The client reference must have changed.
+ assert.NotSame(t, rdb1, rdb2, "successful reconnect must swap to new client")
+
+ // Old client must be closed.
+ _, err = rdb1.Ping(context.Background()).Result()
+ assert.Error(t, err, "old client must be closed after successful reconnect")
+
+ // New client must work.
+ require.NoError(t, rdb2.Set(context.Background(), "swap:key", "works", 0).Err())
+ val, err := rdb2.Get(context.Background(), "swap:key").Result()
+ require.NoError(t, err)
+ assert.Equal(t, "works", val)
+
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected)
+}
- // Test all connection modes
- modes := []struct {
- name string
- redisConn *RedisConnection
- skip bool
- skipReason string
+func TestValidateTopology_Sentinel(t *testing.T) {
+ tests := []struct {
+ name string
+ topo Topology
+ errText string
}{
{
- name: "standalone mode",
- redisConn: &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
- },
+ name: "sentinel valid",
+ topo: Topology{Sentinel: &SentinelTopology{
+ Addresses: []string{"127.0.0.1:26379"},
+ MasterName: "mymaster",
+ }},
},
{
- name: "sentinel mode",
- redisConn: &RedisConnection{
- Mode: ModeSentinel,
- Address: []string{mr.Addr()},
+ name: "sentinel missing addresses",
+ topo: Topology{Sentinel: &SentinelTopology{
MasterName: "mymaster",
- Logger: logger,
- },
- skip: true,
- skipReason: "miniredis doesn't support sentinel commands",
+ }},
+ errText: "sentinel addresses are required",
},
{
- name: "cluster mode",
- redisConn: &RedisConnection{
- Mode: ModeCluster,
- Address: []string{mr.Addr()},
- Logger: logger,
- },
+ name: "sentinel missing master name",
+ topo: Topology{Sentinel: &SentinelTopology{
+ Addresses: []string{"127.0.0.1:26379"},
+ }},
+ errText: "sentinel master name is required",
+ },
+ {
+ name: "sentinel empty address in list",
+ topo: Topology{Sentinel: &SentinelTopology{
+ Addresses: []string{"127.0.0.1:26379", " "},
+ MasterName: "mymaster",
+ }},
+ errText: "sentinel addresses cannot be empty",
},
}
- ctx := context.Background()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateTopology(tt.topo)
+ if tt.errText == "" {
+ require.NoError(t, err)
+ } else {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errText)
+ }
+ })
+ }
+}
- for _, mode := range modes {
- t.Run(mode.name, func(t *testing.T) {
- if mode.skip {
- t.Skip(mode.skipReason)
+func TestValidateTopology_Cluster(t *testing.T) {
+ tests := []struct {
+ name string
+ topo Topology
+ errText string
+ }{
+ {
+ name: "cluster valid",
+ topo: Topology{Cluster: &ClusterTopology{
+ Addresses: []string{"127.0.0.1:7000", "127.0.0.1:7001"},
+ }},
+ },
+ {
+ name: "cluster missing addresses",
+ topo: Topology{Cluster: &ClusterTopology{}},
+ errText: "cluster addresses are required",
+ },
+ {
+ name: "cluster empty address in list",
+ topo: Topology{Cluster: &ClusterTopology{
+ Addresses: []string{"127.0.0.1:7000", " "},
+ }},
+ errText: "cluster addresses cannot be empty",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := validateTopology(tt.topo)
+ if tt.errText == "" {
+ require.NoError(t, err)
+ } else {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errText)
}
+ })
+ }
+}
+
+func TestValidateTopology_StandaloneEmptyAddress(t *testing.T) {
+ err := validateTopology(Topology{Standalone: &StandaloneTopology{Address: " "}})
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "standalone address is required")
+}
- // Connect to Redis
- err := mode.redisConn.Connect(ctx)
- assert.NoError(t, err)
+func TestValidateConfig_DualAuth(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ _, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{
+ StaticPassword: &StaticPasswordAuth{Password: "pass"},
+ GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: "abc",
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ },
+ },
+ })
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "only one auth strategy")
+}
- // Get client
- client, err := mode.redisConn.GetClient(ctx)
- assert.NoError(t, err)
+func TestNormalizeLoggerDefault_NilLogger(t *testing.T) {
+ cfg := Config{}
+ normalizeLoggerDefault(&cfg)
+ require.NotNil(t, cfg.Logger)
+}
- // Test basic operations
- key := "test_key_" + string(mode.redisConn.Mode)
- value := "test_value_" + string(mode.redisConn.Mode)
+func TestBuildUniversalOptionsLocked_Topologies(t *testing.T) {
+ mr := miniredis.RunT(t)
- // Test with TTL
- err = client.Set(ctx, key, value, time.Duration(TTL)*time.Second).Err()
- assert.NoError(t, err)
+ t.Run("sentinel topology", func(t *testing.T) {
+ cfg, err := normalizeConfig(Config{
+ Topology: Topology{Sentinel: &SentinelTopology{
+ Addresses: []string{mr.Addr()},
+ MasterName: "mymaster",
+ }},
+ })
+ require.NoError(t, err)
- result, err := client.Get(ctx, key).Result()
- assert.NoError(t, err)
- assert.Equal(t, value, result)
+ c := &Client{cfg: cfg, logger: cfg.Logger}
+ opts, err := c.buildUniversalOptionsLocked()
+ require.NoError(t, err)
+ assert.Equal(t, []string{mr.Addr()}, opts.Addrs)
+ assert.Equal(t, "mymaster", opts.MasterName)
+ })
- // Test Close method
- if mode.redisConn != nil {
- err = mode.redisConn.Close()
- assert.NoError(t, err)
- }
+ t.Run("cluster topology", func(t *testing.T) {
+ cfg, err := normalizeConfig(Config{
+ Topology: Topology{Cluster: &ClusterTopology{
+ Addresses: []string{mr.Addr(), "127.0.0.1:7001"},
+ }},
})
- }
+ require.NoError(t, err)
+
+ c := &Client{cfg: cfg, logger: cfg.Logger}
+ opts, err := c.buildUniversalOptionsLocked()
+ require.NoError(t, err)
+ assert.Equal(t, []string{mr.Addr(), "127.0.0.1:7001"}, opts.Addrs)
+ })
+
+ t.Run("static password auth", func(t *testing.T) {
+ cfg, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: mr.Addr()}},
+ Auth: Auth{StaticPassword: &StaticPasswordAuth{Password: "secret"}},
+ })
+ require.NoError(t, err)
+
+ c := &Client{cfg: cfg, logger: cfg.Logger}
+ opts, err := c.buildUniversalOptionsLocked()
+ require.NoError(t, err)
+ assert.Equal(t, "secret", opts.Password)
+ })
+
+ t.Run("gcp iam auth sets username and token", func(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ cfg, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: mr.Addr()}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: base64.StdEncoding.EncodeToString([]byte("{}")),
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ }},
+ })
+ require.NoError(t, err)
+
+ c := &Client{cfg: cfg, logger: cfg.Logger, token: "test-token"}
+ opts, err := c.buildUniversalOptionsLocked()
+ require.NoError(t, err)
+ assert.Equal(t, "default", opts.Username)
+ assert.Equal(t, "test-token", opts.Password)
+ assert.NotNil(t, opts.TLSConfig)
+ })
}
-func TestRedisWithTLSConfig(t *testing.T) {
- // This test is more of a unit test to ensure TLS configuration is properly set up
- // Actual TLS connections can't be tested with miniredis
+func TestBuildUniversalOptionsLocked_NoTopology(t *testing.T) {
+ c := &Client{logger: &log.NopLogger{}}
+ _, err := c.buildUniversalOptionsLocked()
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+ assert.Contains(t, err.Error(), "no topology configured")
+}
- // Create logger
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestClient_GetClient_NoTopology_ReturnsError(t *testing.T) {
+ // A bare Client{} with no Config (e.g., constructed outside of New()) must
+ // return an error from GetClient rather than silently connecting to localhost:6379.
+ c := &Client{}
+ _, err := c.GetClient(context.Background())
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidConfig)
+}
- // Create Redis connection with TLS
- redisConn := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{"localhost:6379"},
- UseTLS: true,
- Logger: logger,
- }
+func TestClient_GetClient_ReconnectsWhenNil(t *testing.T) {
+ mr := miniredis.RunT(t)
- // Verify that TLS would be used in all modes
- modes := []struct {
- name string
- mode Mode
- }{
- {"standalone", ModeStandalone},
- {"sentinel", ModeSentinel},
- {"cluster", ModeCluster},
- }
+ client, err := New(context.Background(), newStandaloneConfig(mr.Addr()))
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ if closeErr := client.Close(); closeErr != nil {
+ t.Errorf("cleanup: client close: %v", closeErr)
+ }
+ })
- for _, modeTest := range modes {
- t.Run("tls_config_"+modeTest.name, func(t *testing.T) {
- redisConn.Mode = modeTest.mode
+ // Simulate a nil internal client to exercise the reconnect-on-demand path.
+ client.mu.Lock()
+ old := client.client
+ client.client = nil
+ client.mu.Unlock()
- // We don't actually connect, just verify the TLS config would be used
- assert.True(t, redisConn.UseTLS)
- })
+ // Close the old client manually.
+ require.NotNil(t, old)
+ require.NoError(t, old.Close())
+
+ // GetClient should reconnect.
+ rdb, err := client.GetClient(context.Background())
+ require.NoError(t, err)
+ require.NotNil(t, rdb)
+
+ require.NoError(t, rdb.Set(context.Background(), "reconnect:key", "ok", 0).Err())
+}
+
+func TestClient_RetrieveToken_NilClient(t *testing.T) {
+ var c *Client
+ _, err := c.retrieveToken(context.Background())
+ assert.ErrorIs(t, err, ErrNilClient)
+}
+
+func TestClient_RetrieveToken_NoGCPIAM(t *testing.T) {
+ c := &Client{
+ cfg: Config{},
+ logger: &log.NopLogger{},
}
+ _, err := c.retrieveToken(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "GCP IAM auth is not configured")
+}
+
+func TestClient_RefreshTokenLoop_NilClient(t *testing.T) {
+ var c *Client
+ // Should return immediately without panic.
+ c.refreshTokenLoop(context.Background())
}
-func TestRedisConnection_ConcurrentAccess(t *testing.T) {
- mr, err := miniredis.Run()
- if err != nil {
- t.Fatalf("Failed to start miniredis: %v", err)
+func TestNormalizeConnectionOptionsDefaults(t *testing.T) {
+ opts := ConnectionOptions{}
+ normalizeConnectionOptionsDefaults(&opts)
+ assert.Equal(t, 10, opts.PoolSize)
+ assert.Equal(t, 3*time.Second, opts.ReadTimeout)
+ assert.Equal(t, 3*time.Second, opts.WriteTimeout)
+ assert.Equal(t, 5*time.Second, opts.DialTimeout)
+ assert.Equal(t, 2*time.Second, opts.PoolTimeout)
+ assert.Equal(t, 3, opts.MaxRetries)
+ assert.Equal(t, 8*time.Millisecond, opts.MinRetryBackoff)
+ assert.Equal(t, 1*time.Second, opts.MaxRetryBackoff)
+}
+
+func TestNormalizeConnectionOptionsDefaults_PreservesExisting(t *testing.T) {
+ opts := ConnectionOptions{
+ PoolSize: 20,
+ ReadTimeout: 10 * time.Second,
+ WriteTimeout: 10 * time.Second,
+ DialTimeout: 10 * time.Second,
+ PoolTimeout: 10 * time.Second,
+ MaxRetries: 5,
+ MinRetryBackoff: 100 * time.Millisecond,
+ MaxRetryBackoff: 5 * time.Second,
}
- defer mr.Close()
+ normalizeConnectionOptionsDefaults(&opts)
+ assert.Equal(t, 20, opts.PoolSize)
+ assert.Equal(t, 10*time.Second, opts.ReadTimeout)
+ assert.Equal(t, 5, opts.MaxRetries)
+}
- logger := &log.GoLogger{Level: log.InfoLevel}
+func TestNormalizeTLSDefaults(t *testing.T) {
+ t.Run("nil config", func(t *testing.T) {
+ upgraded, legacyAllowed := normalizeTLSDefaults(nil)
+ assert.False(t, upgraded)
+ assert.False(t, legacyAllowed)
+ })
- t.Run("concurrent GetClient calls return same instance", func(t *testing.T) {
- rc := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
- }
+ t.Run("sets default min version", func(t *testing.T) {
+ cfg := &TLSConfig{}
+ upgraded, legacyAllowed := normalizeTLSDefaults(cfg)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+ assert.True(t, upgraded)
+ assert.False(t, legacyAllowed)
+ })
- const goroutines = 100
- var wg sync.WaitGroup
- wg.Add(goroutines)
-
- errs := make(chan error, goroutines)
- clients := make(chan interface{}, goroutines)
-
- for i := 0; i < goroutines; i++ {
- go func() {
- defer wg.Done()
- client, err := rc.GetClient(context.Background())
- if err != nil {
- errs <- err
- return
- }
- if client == nil {
- errs <- errors.New("client is nil")
- return
- }
- clients <- client
- }()
- }
+ t.Run("preserves existing min version", func(t *testing.T) {
+ cfg := &TLSConfig{MinVersion: tls.VersionTLS13}
+ upgraded, legacyAllowed := normalizeTLSDefaults(cfg)
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
+ assert.False(t, upgraded)
+ assert.False(t, legacyAllowed)
+ })
- wg.Wait()
- close(errs)
- close(clients)
+ t.Run("enforces tls1.2 minimum floor", func(t *testing.T) {
+ cfg := &TLSConfig{MinVersion: tls.VersionTLS10}
+ upgraded, legacyAllowed := normalizeTLSDefaults(cfg)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
+ assert.True(t, upgraded)
+ assert.False(t, legacyAllowed)
+ })
- for err := range errs {
- t.Errorf("concurrent GetClient error: %v", err)
- }
+ t.Run("allows explicit legacy min version opt in", func(t *testing.T) {
+ cfg := &TLSConfig{MinVersion: tls.VersionTLS10, AllowLegacyMinVersion: true}
+ upgraded, legacyAllowed := normalizeTLSDefaults(cfg)
+ assert.Equal(t, uint16(tls.VersionTLS10), cfg.MinVersion)
+ assert.False(t, upgraded)
+ assert.True(t, legacyAllowed)
+ })
+}
- assert.True(t, rc.Connected)
- assert.NotNil(t, rc.Client)
+func TestNormalizeConfig_TLSUpgradeLogsWarning(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ logger := &recordingLogger{}
- var firstClient interface{}
- for client := range clients {
- if firstClient == nil {
- firstClient = client
- } else {
- assert.Same(t, firstClient, client, "all goroutines should get same client instance")
- }
- }
+ cfg, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{
+ CACertBase64: validCert,
+ MinVersion: tls.VersionTLS10,
+ },
+ Logger: logger,
})
+ require.NoError(t, err)
+ require.NotNil(t, cfg.TLS)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.TLS.MinVersion)
- t.Run("concurrent Connect calls", func(t *testing.T) {
- rc := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{mr.Addr()},
- Logger: logger,
- }
+ warnings := logger.warningMessages()
+ require.NotEmpty(t, warnings)
+ assert.Contains(t, warnings[0], "upgraded")
+}
- const goroutines = 100
- var wg sync.WaitGroup
- wg.Add(goroutines)
+func TestNormalizeConfig_DefaultTLSMinVersionDoesNotLogWarning(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ logger := &recordingLogger{}
- errs := make(chan error, goroutines)
+ cfg, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{
+ CACertBase64: validCert,
+ },
+ Logger: logger,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, cfg.TLS)
+ assert.Equal(t, uint16(tls.VersionTLS12), cfg.TLS.MinVersion)
+ assert.Empty(t, logger.warningMessages())
+}
- for i := 0; i < goroutines; i++ {
- go func() {
- defer wg.Done()
- if err := rc.Connect(context.Background()); err != nil {
- errs <- err
- }
- }()
- }
+func TestNormalizeConfig_LegacyTLSOptInLogsWarning(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ logger := &recordingLogger{}
- wg.Wait()
- close(errs)
+ cfg, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{
+ CACertBase64: validCert,
+ MinVersion: tls.VersionTLS10,
+ AllowLegacyMinVersion: true,
+ },
+ Logger: logger,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, cfg.TLS)
+ assert.Equal(t, uint16(tls.VersionTLS10), cfg.TLS.MinVersion)
- for err := range errs {
- t.Errorf("concurrent Connect error: %v", err)
- }
+ warnings := logger.warningMessages()
+ require.NotEmpty(t, warnings)
+ assert.Contains(t, warnings[0], "retained")
+}
- assert.True(t, rc.Connected)
- assert.NotNil(t, rc.Client)
+func TestNormalizeGCPIAMDefaults(t *testing.T) {
+ t.Run("nil auth", func(t *testing.T) {
+ normalizeGCPIAMDefaults(nil) // should not panic
})
- t.Run("concurrent GetClient with connection failure", func(t *testing.T) {
- rc := &RedisConnection{
- Mode: ModeStandalone,
- Address: []string{"127.0.0.1:1"},
- Logger: logger,
- DialTimeout: 100 * time.Millisecond,
- }
+ t.Run("sets defaults", func(t *testing.T) {
+ auth := &GCPIAMAuth{}
+ normalizeGCPIAMDefaults(auth)
+ assert.Equal(t, defaultTokenLifetime, auth.TokenLifetime)
+ assert.Equal(t, defaultRefreshEvery, auth.RefreshEvery)
+ assert.Equal(t, defaultRefreshCheckInterval, auth.RefreshCheckInterval)
+ assert.Equal(t, defaultRefreshOperationTimeout, auth.RefreshOperationTimeout)
+ })
- const goroutines = 10
- var wg sync.WaitGroup
- wg.Add(goroutines)
-
- var errCount int
- var mu sync.Mutex
-
- for i := 0; i < goroutines; i++ {
- go func() {
- defer wg.Done()
- _, err := rc.GetClient(context.Background())
- if err != nil {
- mu.Lock()
- errCount++
- mu.Unlock()
- }
- }()
+ t.Run("preserves existing", func(t *testing.T) {
+ auth := &GCPIAMAuth{
+ TokenLifetime: 2 * time.Hour,
+ RefreshEvery: 30 * time.Minute,
+ RefreshCheckInterval: 5 * time.Second,
+ RefreshOperationTimeout: 10 * time.Second,
}
+ normalizeGCPIAMDefaults(auth)
+ assert.Equal(t, 2*time.Hour, auth.TokenLifetime)
+ assert.Equal(t, 30*time.Minute, auth.RefreshEvery)
+ })
+}
+
+func TestNormalizeConnectionOptionsDefaults_PoolSizeCap(t *testing.T) {
+ opts := ConnectionOptions{PoolSize: 5000}
+ normalizeConnectionOptionsDefaults(&opts)
+ assert.Equal(t, maxPoolSize, opts.PoolSize)
+}
+
+func TestNormalizeConnectionOptionsDefaults_PoolSizeAtCap(t *testing.T) {
+ opts := ConnectionOptions{PoolSize: 1000}
+ normalizeConnectionOptionsDefaults(&opts)
+ assert.Equal(t, 1000, opts.PoolSize)
+}
- wg.Wait()
+func TestValidateConfig_RefreshEveryExceedsTokenLifetime(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ _, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: base64.StdEncoding.EncodeToString([]byte("{}")),
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ TokenLifetime: 30 * time.Minute,
+ RefreshEvery: 50 * time.Minute,
+ }},
+ })
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "RefreshEvery must be less than TokenLifetime")
+}
- assert.Equal(t, goroutines, errCount, "all goroutines should receive an error")
- assert.False(t, rc.Connected)
- assert.Nil(t, rc.Client)
+func TestValidateConfig_RefreshEveryEqualsTokenLifetime(t *testing.T) {
+ validCert := base64.StdEncoding.EncodeToString(generateTestCertificatePEM(t))
+ _, err := normalizeConfig(Config{
+ Topology: Topology{Standalone: &StandaloneTopology{Address: "127.0.0.1:6379"}},
+ TLS: &TLSConfig{CACertBase64: validCert},
+ Auth: Auth{GCPIAM: &GCPIAMAuth{
+ CredentialsBase64: base64.StdEncoding.EncodeToString([]byte("{}")),
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ TokenLifetime: 1 * time.Hour,
+ RefreshEvery: 1 * time.Hour,
+ }},
})
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "RefreshEvery must be less than TokenLifetime")
+}
+
+func TestStaticPasswordAuth_StringRedactsPassword(t *testing.T) {
+ auth := StaticPasswordAuth{Password: "super-secret-password"}
+ s := auth.String()
+ assert.Contains(t, s, "REDACTED")
+ assert.NotContains(t, s, "super-secret-password")
+
+ gs := auth.GoString()
+ assert.Contains(t, gs, "REDACTED")
+ assert.NotContains(t, gs, "super-secret-password")
+}
+
+func TestGCPIAMAuth_StringRedactsCredentials(t *testing.T) {
+ auth := GCPIAMAuth{
+ CredentialsBase64: "c2VjcmV0LWtleS1tYXRlcmlhbA==",
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ }
+ s := auth.String()
+ assert.Contains(t, s, "svc@project.iam.gserviceaccount.com")
+ assert.Contains(t, s, "REDACTED")
+ assert.NotContains(t, s, "c2VjcmV0LWtleS1tYXRlcmlhbA==")
+
+ gs := auth.GoString()
+ assert.Contains(t, gs, "REDACTED")
+ assert.NotContains(t, gs, "c2VjcmV0LWtleS1tYXRlcmlhbA==")
+}
+
+func TestStaticPasswordAuth_FmtRedacts(t *testing.T) {
+ auth := StaticPasswordAuth{Password: "my-password-123"}
+ // fmt.Sprintf uses String()/GoString() methods
+ assert.NotContains(t, fmt.Sprintf("%v", auth), "my-password-123")
+ assert.NotContains(t, fmt.Sprintf("%s", auth), "my-password-123")
+ assert.NotContains(t, fmt.Sprintf("%#v", auth), "my-password-123")
+}
+
+func TestGCPIAMAuth_FmtRedacts(t *testing.T) {
+ auth := GCPIAMAuth{
+ CredentialsBase64: "secret-base64-content",
+ ServiceAccount: "svc@project.iam.gserviceaccount.com",
+ }
+ assert.NotContains(t, fmt.Sprintf("%v", auth), "secret-base64-content")
+ assert.NotContains(t, fmt.Sprintf("%s", auth), "secret-base64-content")
+ assert.NotContains(t, fmt.Sprintf("%#v", auth), "secret-base64-content")
+}
+
+func TestSetPackageLogger_NilDefaultsToNop(t *testing.T) {
+ // Should not panic with nil
+ SetPackageLogger(nil)
+ logger := resolvePackageLogger()
+ require.NotNil(t, logger)
+
+ // Reset to NopLogger
+ SetPackageLogger(&log.NopLogger{})
+}
+
+func TestSetPackageLogger_CustomLogger(t *testing.T) {
+ SetPackageLogger(&log.NopLogger{})
+ logger := resolvePackageLogger()
+ require.NotNil(t, logger)
+}
+
+func TestClient_RefreshTokenLoop_NilGCPIAM(t *testing.T) {
+ // refreshTokenLoop with non-nil client but nil GCPIAM should return immediately.
+ c := &Client{
+ cfg: Config{},
+ logger: &log.NopLogger{},
+ }
+ // Should return immediately without panic.
+ c.refreshTokenLoop(context.Background())
+}
+
+func generateTestCertificatePEM(t *testing.T) []byte {
+ t.Helper()
+
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ tmpl := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{CommonName: "redis-test-ca"},
+ NotBefore: time.Now().Add(-time.Hour),
+ NotAfter: time.Now().Add(time.Hour),
+ IsCA: true,
+ BasicConstraintsValid: true,
+ KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privateKey.PublicKey, privateKey)
+ require.NoError(t, err)
+
+ return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
}
diff --git a/commons/redis/resilience_integration_test.go b/commons/redis/resilience_integration_test.go
new file mode 100644
index 00000000..10d33b5d
--- /dev/null
+++ b/commons/redis/resilience_integration_test.go
@@ -0,0 +1,412 @@
+//go:build integration
+
+package redis
+
+import (
+ "context"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/testcontainers/testcontainers-go"
+ tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
+ "github.com/testcontainers/testcontainers-go/wait"
+)
+
+// setupRedisContainerRaw starts a Redis 7 container and returns the container
+// handle (for Stop/Start control), its host:port endpoint, and a cleanup function.
+// Unlike setupRedisContainer, this returns the container itself so tests can
+// simulate server outages by stopping and restarting it.
+func setupRedisContainerRaw(t *testing.T) (*tcredis.RedisContainer, string, func()) {
+ t.Helper()
+
+ ctx := context.Background()
+
+ container, err := tcredis.Run(ctx,
+ "redis:7-alpine",
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("Ready to accept connections").
+ WithStartupTimeout(30*time.Second),
+ ),
+ )
+ require.NoError(t, err)
+
+ endpoint, err := container.Endpoint(ctx, "")
+ require.NoError(t, err)
+
+ return container, endpoint, func() {
+ _ = container.Terminate(ctx)
+ }
+}
+
+// waitForRedisReady polls the restarted container until Redis is accepting
+// connections. After a container restart the mapped port stays the same but
+// the server needs a moment to initialize. We try PING via a fresh client
+// every pollInterval for up to timeout.
+func waitForRedisReady(t *testing.T, addr string, timeout, pollInterval time.Duration) {
+ t.Helper()
+
+ deadline := time.Now().Add(timeout)
+ ctx := context.Background()
+
+ for time.Now().Before(deadline) {
+ probe, err := New(ctx, newTestConfig(addr))
+ if err == nil {
+ _ = probe.Close()
+ return
+ }
+
+ time.Sleep(pollInterval)
+ }
+
+ t.Fatalf("Redis at %s did not become ready within %s", addr, timeout)
+}
+
+// TestIntegration_Redis_Resilience_ReconnectAfterServerRestart validates the
+// full outage-recovery cycle:
+// 1. Connect and verify operations work.
+// 2. Stop the container (simulates server crash / network partition).
+// 3. Verify that operations fail while the server is down.
+// 4. Restart the container (same mapped port).
+// 5. Verify GetClient() eventually reconnects and operations succeed again.
+//
+// This is the most realistic resilience scenario: the process keeps running
+// while the backing Redis goes down and comes back.
+func TestIntegration_Redis_Resilience_ReconnectAfterServerRestart(t *testing.T) {
+ container, addr, cleanup := setupRedisContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ // Phase 1: establish a healthy connection and verify operations.
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ defer func() {
+ // Best-effort close; may already be closed or disconnected.
+ _ = client.Close()
+ }()
+
+ rdb, err := client.GetClient(ctx)
+ require.NoError(t, err)
+ require.NoError(t, rdb.Set(ctx, "resilience:before", "alive", 60*time.Second).Err(),
+ "SET must succeed while server is healthy")
+
+ // Phase 2: stop the container to simulate server going down.
+ t.Log("Stopping Redis container to simulate outage...")
+ require.NoError(t, container.Stop(ctx, nil))
+
+ // The existing go-redis client handle is now pointing at a dead socket.
+ // Operations should fail (the exact error varies by OS/timing).
+ err = rdb.Set(ctx, "resilience:during-outage", "should-fail", 10*time.Second).Err()
+ assert.Error(t, err, "SET must fail while server is down")
+
+ // Phase 3: restart the container. The mapped port may change after restart,
+ // so we must re-read the endpoint from the container.
+ t.Log("Restarting Redis container...")
+ require.NoError(t, container.Start(ctx))
+
+ newAddr, err := container.Endpoint(ctx, "")
+ require.NoError(t, err, "must be able to read endpoint after restart")
+ t.Logf("Redis endpoint after restart: %s (was: %s)", newAddr, addr)
+
+ // Poll until the server is accepting connections at the (potentially new) address.
+ waitForRedisReady(t, newAddr, 15*time.Second, 200*time.Millisecond)
+ t.Log("Redis container is ready after restart")
+
+ // Phase 4: the old client's config points at the old address.
+ // Close it and create a FRESH client with the new address to prove reconnect works.
+ _ = client.Close()
+
+ client2, err := New(ctx, newTestConfig(newAddr))
+ require.NoError(t, err, "New() must succeed after server restart")
+
+ defer func() { _ = client2.Close() }()
+
+ // Phase 5: verify the reconnected client can operate.
+ rdb2, err := client2.GetClient(ctx)
+ require.NoError(t, err, "GetClient must succeed after server restart")
+
+ require.NoError(t, rdb2.Set(ctx, "resilience:after-restart", "reconnected", 60*time.Second).Err(),
+ "SET must succeed after reconnect")
+
+ got, err := rdb2.Get(ctx, "resilience:after-restart").Result()
+ require.NoError(t, err)
+ assert.Equal(t, "reconnected", got, "value written after restart must be readable")
+
+ connected, err := client2.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "client must report connected after successful reconnect")
+}
+
+// TestIntegration_Redis_Resilience_BackoffRateLimiting validates that the
+// reconnect rate-limiter prevents thundering-herd storms. When the internal
+// client is nil and GetClient() is called rapidly, only the first call
+// attempts a real reconnect; subsequent calls within the backoff window
+// return a "rate-limited" error without hitting the network.
+//
+// Mechanism (from redis.go GetClient):
+// - reconnectAttempts tracks consecutive failures.
+// - Each failure increments reconnectAttempts and records lastReconnectAttempt.
+// - The next GetClient computes delay = ExponentialWithJitter(500ms, attempts).
+// - If elapsed < delay, it returns "rate-limited" immediately.
+//
+// To trigger this, we connect to a real Redis, then close the underlying
+// go-redis client directly (making c.client nil, c.connected false), and
+// also stop the container so the reconnect attempt actually fails (which
+// increments reconnectAttempts).
+func TestIntegration_Redis_Resilience_BackoffRateLimiting(t *testing.T) {
+ container, addr, cleanup := setupRedisContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ // Verify the connection is healthy before we break things.
+ rdb, err := client.GetClient(ctx)
+ require.NoError(t, err)
+ require.NoError(t, rdb.Ping(ctx).Err())
+
+ // Stop the container so reconnect attempts genuinely fail.
+ t.Log("Stopping container to make reconnect attempts fail...")
+ require.NoError(t, container.Stop(ctx, nil))
+
+ // Close the wrapper client to nil out the internal go-redis handle.
+ // This puts the client into the "needs reconnect" state.
+ require.NoError(t, client.Close())
+
+ // First GetClient call: should attempt a real reconnect to the stopped
+ // server, fail, and increment reconnectAttempts to 1.
+ _, err = client.GetClient(ctx)
+ require.Error(t, err, "first GetClient must fail because server is stopped")
+ t.Logf("First GetClient error (expected): %v", err)
+
+ // Rapid subsequent calls: should be rate-limited because we're within
+ // the backoff window. The delay after 1 failure is in
+ // [0, 500ms * 2^1) = [0, 1000ms). Even with jitter at its minimum (0ms),
+ // consecutive calls within microseconds should be rate-limited after the
+ // first real attempt set lastReconnectAttempt.
+ rateLimitedCount := 0
+ realAttemptCount := 0
+
+ const rapidCalls = 20
+
+ for range rapidCalls {
+ _, callErr := client.GetClient(ctx)
+ require.Error(t, callErr)
+
+ if strings.Contains(callErr.Error(), "rate-limited") {
+ rateLimitedCount++
+ } else {
+ realAttemptCount++
+ }
+ }
+
+ t.Logf("Of %d rapid calls: %d rate-limited, %d real attempts",
+ rapidCalls, rateLimitedCount, realAttemptCount)
+
+ // Due to the jitter in ExponentialWithJitter, the exact split between
+ // rate-limited and real attempts is non-deterministic. However, we
+ // expect the majority to be rate-limited since the calls happen in
+ // microseconds and the backoff window is at least hundreds of milliseconds.
+ assert.Greater(t, rateLimitedCount, 0,
+ "at least some calls must be rate-limited to prevent reconnect storms")
+
+ // Verify that real reconnect attempts are significantly fewer than
+ // rate-limited ones. This proves the backoff is working.
+ if rateLimitedCount > 0 && realAttemptCount > 0 {
+ assert.Greater(t, rateLimitedCount, realAttemptCount,
+ "rate-limited calls should outnumber real reconnect attempts")
+ }
+}
+
+// TestIntegration_Redis_Resilience_GracefulDegradation validates that the
+// client degrades gracefully under failure conditions without panics or
+// undefined behavior:
+// 1. After server goes down, IsConnected() still reflects the last known
+// state (true) because no probe has updated it yet.
+// 2. Operations on the stale client handle fail with errors (not panics).
+// 3. After Close() + GetClient(), we get clean errors (not panics).
+// 4. Status() returns a valid struct throughout.
+func TestIntegration_Redis_Resilience_GracefulDegradation(t *testing.T) {
+ container, addr, cleanup := setupRedisContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ defer func() { _ = client.Close() }()
+
+ // Capture the connected client handle before the outage.
+ rdb, err := client.GetClient(ctx)
+ require.NoError(t, err)
+ require.NoError(t, rdb.Ping(ctx).Err())
+
+ // Stop the server while the client still holds a connection handle.
+ t.Log("Stopping Redis container...")
+ require.NoError(t, container.Stop(ctx, nil))
+
+ // IsConnected() checks the client struct's `connected` field, which is
+ // only updated on connect/close calls — NOT by external server state.
+ // So immediately after a server crash, it still reports true.
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected,
+ "IsConnected must still be true immediately after server stop "+
+ "(the struct hasn't been updated yet)")
+
+ // Status() must return a valid struct, not panic.
+ status, err := client.Status()
+ require.NoError(t, err)
+ assert.True(t, status.Connected,
+ "Status.Connected must reflect the struct state, not the wire state")
+
+ // Operations on the stale rdb handle should fail with errors, not panics.
+ setErr := rdb.Set(ctx, "degradation:should-fail", "value", 10*time.Second).Err()
+ assert.Error(t, setErr, "SET on stale handle must fail when server is down")
+
+ pingErr := rdb.Ping(ctx).Err()
+ assert.Error(t, pingErr, "PING on stale handle must fail when server is down")
+
+ // Close the wrapper client. This nils out the internal handle and sets
+ // connected=false.
+ require.NoError(t, client.Close())
+
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.False(t, connected, "IsConnected must be false after Close()")
+
+ // GetClient() should attempt reconnect, fail (server is still down),
+ // and return an error — not panic.
+ _, getErr := client.GetClient(ctx)
+ assert.Error(t, getErr, "GetClient must fail gracefully when server is down")
+
+ // Verify Status() still works and returns a coherent snapshot.
+ status, err = client.Status()
+ require.NoError(t, err)
+ assert.False(t, status.Connected,
+ "Status.Connected must be false after failed reconnect")
+
+ // Calling Close() again on an already-closed client must not panic.
+ assert.NotPanics(t, func() {
+ _ = client.Close()
+ }, "double Close() must not panic")
+}
+
+// TestIntegration_Redis_Resilience_ConcurrentReconnect validates that when
+// multiple goroutines call GetClient() simultaneously on a disconnected
+// client, the double-checked locking in GetClient() serializes reconnect
+// attempts correctly:
+// - No panics or data races (validated by -race detector).
+// - Only one goroutine performs the actual connect (others either get the
+// reconnected client from the second c.client!=nil check, or get a
+// rate-limited/connection error).
+// - All goroutines return without hanging.
+func TestIntegration_Redis_Resilience_ConcurrentReconnect(t *testing.T) {
+ _, addr, cleanup := setupRedisContainerRaw(t)
+ defer cleanup()
+
+ ctx := context.Background()
+
+ client, err := New(ctx, newTestConfig(addr))
+ require.NoError(t, err)
+
+ // Verify healthy state before we break things.
+ rdb, err := client.GetClient(ctx)
+ require.NoError(t, err)
+ require.NoError(t, rdb.Ping(ctx).Err())
+
+ // Close the wrapper to put the client into "needs reconnect" state.
+ // The container is still running, so reconnect should succeed.
+ require.NoError(t, client.Close())
+
+ connected, err := client.IsConnected()
+ require.NoError(t, err)
+ require.False(t, connected, "precondition: client must be disconnected")
+
+ const goroutines = 10
+
+ var (
+ wg sync.WaitGroup
+ successCount atomic.Int64
+ errorCount atomic.Int64
+ panicRecovered atomic.Int64
+ )
+
+ wg.Add(goroutines)
+
+ // All goroutines start simultaneously via a shared gate.
+ gate := make(chan struct{})
+
+ for i := range goroutines {
+ go func(id int) {
+ defer wg.Done()
+
+ // Catch any panics so the test can report them rather than crashing.
+ defer func() {
+ if r := recover(); r != nil {
+ panicRecovered.Add(1)
+ t.Errorf("goroutine %d panicked: %v", id, r)
+ }
+ }()
+
+ // Wait for the gate to open so all goroutines race together.
+ <-gate
+
+ rdbLocal, getErr := client.GetClient(ctx)
+ if getErr != nil {
+ errorCount.Add(1)
+ return
+ }
+
+ // Verify the returned client is functional.
+ if pingErr := rdbLocal.Ping(ctx).Err(); pingErr != nil {
+ errorCount.Add(1)
+ return
+ }
+
+ successCount.Add(1)
+ }(i)
+ }
+
+ // Open the gate: all goroutines race into GetClient().
+ close(gate)
+ wg.Wait()
+
+ successes := successCount.Load()
+ errors := errorCount.Load()
+ panics := panicRecovered.Load()
+
+ t.Logf("Concurrent reconnect results: %d successes, %d errors, %d panics",
+ successes, errors, panics)
+
+ // Hard requirement: no panics.
+ assert.Equal(t, int64(0), panics, "no goroutines should panic during concurrent reconnect")
+
+ // At least one goroutine must succeed (the one that wins the lock and
+ // reconnects). Others may succeed too (if they arrive after the reconnect
+ // completes and see c.client != nil in the fast path), or fail with
+ // rate-limited errors.
+ assert.Greater(t, successes, int64(0),
+ "at least one goroutine must successfully reconnect")
+
+ // All goroutines must have completed (no hangs).
+ assert.Equal(t, int64(goroutines), successes+errors+panics,
+ "all goroutines must complete")
+
+ // Verify the client is in a good state after the storm.
+ connected, err = client.IsConnected()
+ require.NoError(t, err)
+ assert.True(t, connected, "client must be connected after successful concurrent reconnect")
+
+ // Final cleanup.
+ require.NoError(t, client.Close())
+}
diff --git a/commons/runtime/doc.go b/commons/runtime/doc.go
new file mode 100644
index 00000000..a1c3051e
--- /dev/null
+++ b/commons/runtime/doc.go
@@ -0,0 +1,75 @@
+// Package runtime provides panic recovery utilities for services with
+// full observability integration.
+//
+// This package offers policy-based panic recovery primitives that integrate
+// with lib-commons logging, OpenTelemetry metrics/tracing, and optional
+// error tracking services like Sentry.
+//
+// # Panic Policies
+//
+// Two panic policies are supported:
+//
+// - KeepRunning: Log the panic and stack trace, then continue execution.
+// Use this for worker goroutines and HTTP/gRPC handlers.
+//
+// - CrashProcess: Log the panic and stack trace, then re-panic to crash
+// the process. Use this for critical invariant violations where continuing
+// would cause data corruption.
+//
+// # Safe Goroutine Launching
+//
+// Use SafeGo and SafeGoWithContext to launch goroutines with automatic panic
+// recovery and observability:
+//
+// // Basic (no observability)
+// runtime.SafeGo(logger, "background-task", runtime.KeepRunning, func() {
+// doWork()
+// })
+//
+// // With full observability (recommended)
+// runtime.SafeGoWithContextAndComponent(ctx, logger, "transaction", "balance-sync", runtime.KeepRunning,
+// func(ctx context.Context) {
+// syncBalances(ctx)
+// })
+//
+// # Deferred Recovery
+//
+// Use RecoverAndLog, RecoverAndCrash, or RecoverWithPolicy in defer statements.
+// Context-aware variants provide full observability:
+//
+// func handler(ctx context.Context) {
+// defer runtime.RecoverAndLogWithContext(ctx, logger, "transaction", "handler")
+// // Panics here will be logged, recorded as metrics, and added to the trace
+// }
+//
+// # Observability Integration
+//
+// The package integrates with three observability systems:
+//
+// 1. Metrics: Records panic_recovered_total counter with component and goroutine_name labels.
+// Initialize with InitPanicMetrics(metricsFactory).
+//
+// 2. Tracing: Records panic.recovered span events with stack traces and sets span status to Error.
+// Automatically uses the span from the context.
+//
+// 3. Error Reporting: Optionally reports panics to services like Sentry.
+// Configure with SetErrorReporter(reporter).
+//
+// # Initialization
+//
+// During application startup, initialize the observability integrations:
+//
+// tl, err := opentelemetry.NewTelemetry(cfg)
+// if err != nil {
+// return err
+// }
+// runtime.InitPanicMetrics(tl.MetricsFactory)
+//
+// // Optional: Configure Sentry or other error reporter
+// runtime.SetErrorReporter(mySentryReporter)
+//
+// # Stack Traces
+//
+// All recovery functions capture and log the full stack trace using
+// runtime/debug.Stack() for debugging purposes.
+package runtime
diff --git a/commons/runtime/error_reporter.go b/commons/runtime/error_reporter.go
new file mode 100644
index 00000000..ff39c845
--- /dev/null
+++ b/commons/runtime/error_reporter.go
@@ -0,0 +1,198 @@
+package runtime
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+ "sync"
+)
+
+// ErrorReporter defines an interface for external error reporting services.
+// This abstraction allows integration with error tracking services (e.g., logging
+// to Grafana Loki, sending to an alerting system) without creating a hard
+// dependency on any specific SDK.
+//
+// Implementations should:
+// - Handle nil contexts gracefully
+// - Be safe for concurrent use
+// - Not panic themselves
+type ErrorReporter interface {
+ // CaptureException reports a panic/exception to the error tracking service.
+ // The tags map can include metadata like "component", "goroutine_name", etc.
+ CaptureException(ctx context.Context, err error, tags map[string]string)
+}
+
+// errorReporterInstance is the singleton error reporter.
+// It remains nil unless explicitly configured.
+var (
+ errorReporterInstance ErrorReporter
+ errorReporterMu sync.RWMutex
+)
+
+// SetErrorReporter configures the global error reporter for panic reporting.
+// Pass nil to disable error reporting.
+//
+// This should be called once during application startup if an external
+// error tracking service is desired.
+//
+// Example with structured logging:
+//
+// type logReporter struct {
+// logger *slog.Logger
+// }
+//
+// func (r *logReporter) CaptureException(ctx context.Context, err error, tags map[string]string) {
+// attrs := make([]any, 0, len(tags)*2)
+// for k, v := range tags {
+// attrs = append(attrs, k, v)
+// }
+// r.logger.ErrorContext(ctx, "panic recovered", append(attrs, "error", err)...)
+// }
+//
+// runtime.SetErrorReporter(&logReporter{logger: slog.Default()})
+func SetErrorReporter(reporter ErrorReporter) {
+ errorReporterMu.Lock()
+ defer errorReporterMu.Unlock()
+
+ errorReporterInstance = reporter
+}
+
+// GetErrorReporter returns the currently configured error reporter.
+// Returns nil if no reporter has been configured.
+func GetErrorReporter() ErrorReporter {
+ errorReporterMu.RLock()
+ defer errorReporterMu.RUnlock()
+
+ return errorReporterInstance
+}
+
+var (
+ // productionMode controls whether sensitive data is redacted in error reports.
+ // When true, stack traces and detailed panic values are suppressed.
+ productionMode bool
+ productionModeMu sync.RWMutex
+)
+
+const redactedPanicMsg = "panic recovered (details redacted)"
+
+// SetProductionMode enables or disables production mode for error reporting.
+// In production mode, stack traces and potentially sensitive panic details are redacted.
+func SetProductionMode(enabled bool) {
+ productionModeMu.Lock()
+ defer productionModeMu.Unlock()
+
+ productionMode = enabled
+}
+
+// IsProductionMode returns whether production mode is enabled.
+func IsProductionMode() bool {
+ productionModeMu.RLock()
+ defer productionModeMu.RUnlock()
+
+ return productionMode
+}
+
+// reportPanicToErrorService reports a panic to the configured error reporter if one exists.
+// This is called internally by recovery functions.
+// In production mode, stack traces and potentially sensitive panic values are redacted.
+func reportPanicToErrorService(
+ ctx context.Context,
+ panicValue any,
+ stack []byte,
+ component, goroutineName string,
+) {
+ reporter := GetErrorReporter()
+ if reporter == nil {
+ return
+ }
+
+ isProduction := IsProductionMode()
+
+ // Convert panic value to error, redacting details in production
+ err := toPanicError(panicValue, isProduction)
+
+ tags := map[string]string{
+ "component": component,
+ "goroutine_name": goroutineName,
+ "panic_type": "recovered",
+ }
+
+ // Include stack trace only in non-production mode
+ if len(stack) > 0 && !isProduction {
+ stackStr := string(stack)
+
+ const maxStackLen = 4096
+ if len(stackStr) > maxStackLen {
+ stackStr = stackStr[:maxStackLen] + "\n...[truncated]"
+ }
+
+ tags["stack_trace"] = stackStr
+ }
+
+ reporter.CaptureException(ctx, err, tags)
+}
+
+// panicError wraps a panic value as an error for reporting.
+type panicError struct {
+ message string
+}
+
+// Error returns the panic error message.
+func (e *panicError) Error() string {
+ return e.message
+}
+
+func toPanicError(panicValue any, isProduction bool) error {
+ if isProduction {
+ return &panicError{message: redactedPanicMsg}
+ }
+
+ // Guard against typed-nil error values: an interface holding (type=*MyError, value=nil)
+ // would pass the type assertion but panic on .Error(). Use reflect to detect this.
+ if err, ok := panicValue.(error); ok && !isTypedNil(panicValue) {
+ return err
+ }
+
+ if message, ok := panicValue.(string); ok {
+ return &panicError{message: message}
+ }
+
+ return &panicError{message: "panic: " + formatPanicValue(panicValue)}
+}
+
+// isTypedNil returns true if v is an interface holding a nil pointer/nil value.
+func isTypedNil(v any) bool {
+ if v == nil {
+ return false // untyped nil is not a typed nil
+ }
+
+ rv := reflect.ValueOf(v)
+
+ switch rv.Kind() {
+ case reflect.Ptr, reflect.Interface, reflect.Slice, reflect.Map, reflect.Chan, reflect.Func:
+ return rv.IsNil()
+ default:
+ return false
+ }
+}
+
+// formatPanicValue formats a panic value as a string.
+func formatPanicValue(value any) string {
+ if value == nil {
+ return ""
+ }
+
+ // Guard against typed-nil values that would panic on method calls.
+ if isTypedNil(value) {
+ return fmt.Sprintf("<%T>(nil)", value)
+ }
+
+ switch val := value.(type) {
+ case string:
+ return val
+ case error:
+ return val.Error()
+ default:
+ return fmt.Sprintf("%v", value)
+ }
+}
diff --git a/commons/runtime/error_reporter_test.go b/commons/runtime/error_reporter_test.go
new file mode 100644
index 00000000..f1f378bc
--- /dev/null
+++ b/commons/runtime/error_reporter_test.go
@@ -0,0 +1,662 @@
+//go:build unit
+
+package runtime
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ errBasePanic = errors.New("base error")
+ errDetailedMessage = errors.New("detailed error message")
+ errPanicError = errors.New("error panic")
+ errSensitiveDetails = errors.New("database password: secret123")
+ errTestError = errors.New("test error")
+)
+
+// testErrorReporter is a test implementation of ErrorReporter for these tests.
+type testErrorReporter struct {
+ mu sync.RWMutex
+ capturedErr error
+ capturedCtx context.Context
+ capturedTags map[string]string
+ callCount int
+}
+
+func (reporter *testErrorReporter) CaptureException(
+ ctx context.Context,
+ err error,
+ tags map[string]string,
+) {
+ reporter.mu.Lock()
+ defer reporter.mu.Unlock()
+
+ reporter.capturedErr = err
+ reporter.capturedCtx = ctx
+ reporter.capturedTags = tags
+ reporter.callCount++
+}
+
+func (reporter *testErrorReporter) getCapturedErr() error {
+ reporter.mu.RLock()
+ defer reporter.mu.RUnlock()
+
+ return reporter.capturedErr
+}
+
+func (reporter *testErrorReporter) getCapturedTags() map[string]string {
+ reporter.mu.RLock()
+ defer reporter.mu.RUnlock()
+
+ // Return a defensive copy to prevent races with callers
+ if reporter.capturedTags == nil {
+ return nil
+ }
+
+ copyTags := make(map[string]string, len(reporter.capturedTags))
+ for k, v := range reporter.capturedTags {
+ copyTags[k] = v
+ }
+
+ return copyTags
+}
+
+func (reporter *testErrorReporter) getCallCount() int {
+ reporter.mu.RLock()
+ defer reporter.mu.RUnlock()
+
+ return reporter.callCount
+}
+
+// TestSetAndGetErrorReporter tests basic SetErrorReporter and GetErrorReporter functionality.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestSetAndGetErrorReporter(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ got := GetErrorReporter()
+ require.NotNil(t, got)
+ assert.Equal(t, reporter, got)
+}
+
+// TestReportPanicToErrorService_NilContext tests reporting with nil context.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_NilContext(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ require.NotPanics(t, func() {
+ reportPanicToErrorService(
+ nil,
+ "test panic",
+ []byte("stack"),
+ "component",
+ "goroutine",
+ ) //nolint:staticcheck // Testing nil context intentionally
+ })
+
+ require.NotNil(t, reporter.getCapturedErr())
+ assert.Contains(t, reporter.getCapturedErr().Error(), "test panic")
+}
+
+// TestReportPanicToErrorService_NilStackTrace tests reporting with nil stack trace.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_NilStackTrace(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ reportPanicToErrorService(context.Background(), "test panic", nil, "component", "goroutine")
+
+ tags := reporter.getCapturedTags()
+ require.NotNil(t, tags)
+ _, hasStackTrace := tags["stack_trace"]
+ assert.False(t, hasStackTrace, "Should not include stack_trace tag when stack is nil")
+}
+
+// TestReportPanicToErrorService_EmptyStackTrace tests reporting with empty stack trace.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_EmptyStackTrace(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ reportPanicToErrorService(
+ context.Background(),
+ "test panic",
+ []byte{},
+ "component",
+ "goroutine",
+ )
+
+ tags := reporter.getCapturedTags()
+ require.NotNil(t, tags)
+ _, hasStackTrace := tags["stack_trace"]
+ assert.False(t, hasStackTrace, "Should not include stack_trace tag when stack is empty")
+}
+
+// TestReportPanicToErrorService_StackTraceTruncation tests that long stack traces are truncated.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_StackTraceTruncation(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ longStack := strings.Repeat("a", 5000)
+ reportPanicToErrorService(
+ context.Background(),
+ "test panic",
+ []byte(longStack),
+ "component",
+ "goroutine",
+ )
+
+ tags := reporter.getCapturedTags()
+ require.NotNil(t, tags)
+
+ stackTrace, hasStackTrace := tags["stack_trace"]
+ require.True(t, hasStackTrace)
+ assert.True(t, strings.HasSuffix(stackTrace, "...[truncated]"))
+ assert.LessOrEqual(t, len(stackTrace), 4096+len("\n...[truncated]"))
+}
+
+// TestReportPanicToErrorService_StackTraceExactlyMaxLen tests stack trace at exactly max length.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_StackTraceExactlyMaxLen(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ exactStack := strings.Repeat("a", 4096)
+ reportPanicToErrorService(
+ context.Background(),
+ "test panic",
+ []byte(exactStack),
+ "component",
+ "goroutine",
+ )
+
+ tags := reporter.getCapturedTags()
+ require.NotNil(t, tags)
+
+ stackTrace, hasStackTrace := tags["stack_trace"]
+ require.True(t, hasStackTrace)
+ assert.False(
+ t,
+ strings.HasSuffix(stackTrace, "...[truncated]"),
+ "Should not truncate at exactly max length",
+ )
+ assert.Equal(t, exactStack, stackTrace)
+}
+
+// TestReportPanicToErrorService_PanicValueTypes tests different panic value types.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_PanicValueTypes(t *testing.T) {
+ tests := []struct {
+ name string
+ panicValue any
+ expectedSubstr string
+ }{
+ {
+ name: "error type",
+ panicValue: errPanicError,
+ expectedSubstr: "error panic",
+ },
+ {
+ name: "string type",
+ panicValue: "string panic",
+ expectedSubstr: "string panic",
+ },
+ {
+ name: "int type",
+ panicValue: 42,
+ expectedSubstr: "panic: 42",
+ },
+ {
+ name: "struct type",
+ panicValue: struct{ Field string }{Field: "value"},
+ expectedSubstr: "panic: {value}",
+ },
+ {
+ name: "nil value",
+ panicValue: nil,
+ expectedSubstr: "panic: ",
+ },
+ {
+ name: "slice type",
+ panicValue: []int{1, 2, 3},
+ expectedSubstr: "panic: [1 2 3]",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ reportPanicToErrorService(
+ context.Background(),
+ tt.panicValue,
+ []byte("stack"),
+ "component",
+ "goroutine",
+ )
+
+ err := reporter.getCapturedErr()
+ require.NotNil(t, err)
+ assert.Contains(t, err.Error(), tt.expectedSubstr)
+ })
+ }
+}
+
+// TestReportPanicToErrorService_Tags tests that all expected tags are set.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_Tags(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ reportPanicToErrorService(
+ context.Background(),
+ "test",
+ []byte("stack"),
+ "my-component",
+ "my-goroutine",
+ )
+
+ tags := reporter.getCapturedTags()
+ require.NotNil(t, tags)
+ assert.Equal(t, "my-component", tags["component"])
+ assert.Equal(t, "my-goroutine", tags["goroutine_name"])
+ assert.Equal(t, "recovered", tags["panic_type"])
+ assert.Equal(t, "stack", tags["stack_trace"])
+}
+
+// TestFormatPanicValue tests formatPanicValue with various input types.
+func TestFormatPanicValue(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ value any
+ expected string
+ }{
+ {
+ name: "nil value",
+ value: nil,
+ expected: "",
+ },
+ {
+ name: "string value",
+ value: "test string",
+ expected: "test string",
+ },
+ {
+ name: "error value",
+ value: errTestError,
+ expected: "test error",
+ },
+ {
+ name: "int value",
+ value: 123,
+ expected: "123",
+ },
+ {
+ name: "float value",
+ value: 3.14,
+ expected: "3.14",
+ },
+ {
+ name: "bool value",
+ value: true,
+ expected: "true",
+ },
+ {
+ name: "struct value",
+ value: struct{ Name string }{Name: "test"},
+ expected: "{test}",
+ },
+ {
+ name: "slice value",
+ value: []string{"a", "b"},
+ expected: "[a b]",
+ },
+ {
+ name: "map value",
+ value: map[string]int{"key": 1},
+ expected: "map[key:1]",
+ },
+ {
+ name: "empty string",
+ value: "",
+ expected: "",
+ },
+ {
+ name: "pointer to int",
+ value: func() any { i := 42; return &i }(),
+ expected: "", // Will be a pointer address, just check it doesn't panic
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := formatPanicValue(tt.value)
+
+ if tt.name == "pointer to int" {
+ assert.NotEmpty(t, result)
+ } else {
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+// TestConcurrentSetGetErrorReporter tests thread safety of SetErrorReporter/GetErrorReporter.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestConcurrentSetGetErrorReporter(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ const (
+ goroutines = 100
+ iterations = 100
+ )
+
+ var wg sync.WaitGroup
+
+ wg.Add(goroutines * 2)
+
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+
+ for j := 0; j < iterations; j++ {
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ for j := 0; j < iterations; j++ {
+ _ = GetErrorReporter()
+ }
+ }()
+ }
+
+ wg.Wait()
+}
+
+// TestConcurrentReportPanic tests thread safety of reportPanicToErrorService.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestConcurrentReportPanic(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ const goroutines = 50
+
+ var wg sync.WaitGroup
+
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(id int) {
+ defer wg.Done()
+
+ reportPanicToErrorService(
+ context.Background(),
+ fmt.Sprintf("panic %d", id),
+ []byte("stack"),
+ "component",
+ fmt.Sprintf("goroutine-%d", id),
+ )
+ }(i)
+ }
+
+ wg.Wait()
+
+ assert.Equal(t, goroutines, reporter.getCallCount())
+}
+
+// TestReportPanicToErrorService_WrappedError tests that wrapped errors are handled correctly.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestReportPanicToErrorService_WrappedError(t *testing.T) {
+ SetErrorReporter(nil)
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ wrappedErr := fmt.Errorf("wrapped: %w", errBasePanic)
+
+ reportPanicToErrorService(
+ context.Background(),
+ wrappedErr,
+ []byte("stack"),
+ "component",
+ "goroutine",
+ )
+
+ capturedErr := reporter.getCapturedErr()
+ require.NotNil(t, capturedErr)
+ assert.Equal(t, wrappedErr, capturedErr)
+ assert.True(t, errors.Is(capturedErr, errBasePanic))
+}
+
+// TestFormatPanicValue_CustomStringer tests formatPanicValue with a custom Stringer.
+func TestFormatPanicValue_CustomStringer(t *testing.T) {
+ t.Parallel()
+
+ stringer := struct {
+ value string
+ }{value: "custom"}
+
+ result := formatPanicValue(stringer)
+ assert.Equal(t, "{custom}", result)
+}
+
+// TestFormatPanicValue_CustomError tests formatPanicValue with a custom error type.
+func TestFormatPanicValue_CustomError(t *testing.T) {
+ t.Parallel()
+
+ type customError struct {
+ code int
+ msg string
+ }
+
+ customErr := &customError{code: 500, msg: "internal error"}
+
+ result := formatPanicValue(customErr)
+ assert.Contains(t, result, "500")
+ assert.Contains(t, result, "internal error")
+}
+
+// TestSetProductionMode tests enabling and disabling production mode.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global productionMode
+func TestSetProductionMode(t *testing.T) {
+ SetProductionMode(false)
+ t.Cleanup(func() { SetProductionMode(false) })
+
+ assert.False(t, IsProductionMode())
+
+ SetProductionMode(true)
+ assert.True(t, IsProductionMode())
+
+ SetProductionMode(false)
+ assert.False(t, IsProductionMode())
+}
+
+// TestReportPanicToErrorService_ProductionMode_RedactsPanicDetails tests that production mode redacts panic values.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global state
+func TestReportPanicToErrorService_ProductionMode_RedactsPanicDetails(t *testing.T) {
+ SetErrorReporter(nil)
+ SetProductionMode(false)
+ t.Cleanup(func() {
+ SetErrorReporter(nil)
+ SetProductionMode(false)
+ })
+
+ SetProductionMode(true)
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ reportPanicToErrorService(
+ context.Background(),
+ errSensitiveDetails,
+ []byte("stack"),
+ "component",
+ "goroutine",
+ )
+
+ capturedErr := reporter.getCapturedErr()
+ require.NotNil(t, capturedErr)
+ assert.Equal(t, "panic recovered (details redacted)", capturedErr.Error())
+ assert.NotContains(t, capturedErr.Error(), "secret123")
+}
+
+// TestReportPanicToErrorService_ProductionMode_RedactsStackTrace tests that production mode omits stack traces.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global state
+func TestReportPanicToErrorService_ProductionMode_RedactsStackTrace(t *testing.T) {
+ SetErrorReporter(nil)
+ SetProductionMode(false)
+ t.Cleanup(func() {
+ SetErrorReporter(nil)
+ SetProductionMode(false)
+ })
+
+ SetProductionMode(true)
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ reportPanicToErrorService(
+ context.Background(),
+ "test panic",
+ []byte("sensitive stack trace here"),
+ "component",
+ "goroutine",
+ )
+
+ tags := reporter.getCapturedTags()
+ require.NotNil(t, tags)
+ _, hasStackTrace := tags["stack_trace"]
+ assert.False(t, hasStackTrace, "Production mode should not include stack_trace")
+}
+
+// TestReportPanicToErrorService_NonProductionMode_IncludesDetails tests that non-production mode includes full details.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global state
+func TestReportPanicToErrorService_NonProductionMode_IncludesDetails(t *testing.T) {
+ SetErrorReporter(nil)
+ SetProductionMode(false)
+ t.Cleanup(func() {
+ SetErrorReporter(nil)
+ SetProductionMode(false)
+ })
+
+ reporter := &testErrorReporter{}
+ SetErrorReporter(reporter)
+
+ reportPanicToErrorService(
+ context.Background(),
+ errDetailedMessage,
+ []byte("full stack trace"),
+ "component",
+ "goroutine",
+ )
+
+ capturedErr := reporter.getCapturedErr()
+ require.NotNil(t, capturedErr)
+ assert.Equal(t, errDetailedMessage, capturedErr)
+
+ tags := reporter.getCapturedTags()
+ require.NotNil(t, tags)
+ stackTrace, hasStackTrace := tags["stack_trace"]
+ assert.True(t, hasStackTrace, "Non-production mode should include stack_trace")
+ assert.Equal(t, "full stack trace", stackTrace)
+}
+
+// TestConcurrentSetProductionMode tests thread safety of SetProductionMode/IsProductionMode.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global productionMode
+func TestConcurrentSetProductionMode(t *testing.T) {
+ SetProductionMode(false)
+ t.Cleanup(func() { SetProductionMode(false) })
+
+ const (
+ goroutines = 100
+ iterations = 100
+ )
+
+ var wg sync.WaitGroup
+
+ wg.Add(goroutines * 2)
+
+ for i := 0; i < goroutines; i++ {
+ go func(id int) {
+ defer wg.Done()
+
+ for j := 0; j < iterations; j++ {
+ SetProductionMode(id%2 == 0)
+ }
+ }(i)
+
+ go func() {
+ defer wg.Done()
+
+ for j := 0; j < iterations; j++ {
+ _ = IsProductionMode()
+ }
+ }()
+ }
+
+ wg.Wait()
+}
diff --git a/commons/runtime/example_test.go b/commons/runtime/example_test.go
new file mode 100644
index 00000000..7ba41676
--- /dev/null
+++ b/commons/runtime/example_test.go
@@ -0,0 +1,93 @@
+//go:build unit
+
+package runtime
+
+import (
+ "context"
+ "fmt"
+
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+// simpleLogger is a minimal logger for examples.
+type simpleLogger struct{}
+
+func (l *simpleLogger) Log(_ context.Context, _ libLog.Level, _ string, _ ...libLog.Field) {}
+
+func ExampleSafeGoWithContext() {
+ ctx := context.Background()
+ logger := &simpleLogger{}
+
+ // Launch a goroutine with panic recovery and observability
+ done := make(chan struct{})
+
+ SafeGoWithContextAndComponent(ctx, logger, "transaction", "example-worker", KeepRunning,
+ func(ctx context.Context) {
+ defer close(done)
+
+ fmt.Println("Worker started")
+ // Work happens here...
+ fmt.Println("Worker completed")
+ })
+
+ <-done
+ // Output:
+ // Worker started
+ // Worker completed
+}
+
+func ExampleRecoverAndLogWithContext() {
+ ctx := context.Background()
+ logger := &simpleLogger{}
+
+ func() {
+ defer RecoverAndLogWithContext(ctx, logger, "example", "handler")
+
+ fmt.Println("Before panic")
+ // If a panic occurred here, it would be recovered and logged
+ fmt.Println("After (no panic)")
+ }()
+
+ fmt.Println("Function completed normally")
+ // Output:
+ // Before panic
+ // After (no panic)
+ // Function completed normally
+}
+
+func ExampleInitPanicMetrics() {
+ // During application startup, after telemetry initialization:
+ // tl := opentelemetry.InitializeTelemetry(cfg)
+ // runtime.InitPanicMetrics(tl.MetricsFactory)
+
+ // Nil is safe (no-op):
+ InitPanicMetrics(nil)
+
+ // Metrics remain uninitialized until properly configured.
+ pm := GetPanicMetrics()
+ fmt.Printf("Metrics initialized: %v\n", pm != nil)
+ // Output:
+ // Metrics initialized: false
+}
+
+func ExampleSetErrorReporter() {
+ // Create a custom error reporter (e.g., for Sentry)
+ reporter := &customReporter{}
+
+ // Configure during startup
+ SetErrorReporter(reporter)
+
+ // Later, panics will be reported automatically
+ fmt.Println("Error reporter configured")
+
+ // Clean up
+ SetErrorReporter(nil)
+ // Output:
+ // Error reporter configured
+}
+
+type customReporter struct{}
+
+func (r *customReporter) CaptureException(_ context.Context, _ error, _ map[string]string) {
+ // In a real implementation, this would send to Sentry or similar
+}
diff --git a/commons/runtime/goroutine.go b/commons/runtime/goroutine.go
new file mode 100644
index 00000000..e7b00d7b
--- /dev/null
+++ b/commons/runtime/goroutine.go
@@ -0,0 +1,93 @@
+package runtime
+
+import (
+ "context"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+// SafeGo launches a goroutine with panic recovery. If the goroutine panics,
+// the panic is handled according to the specified policy.
+//
+// Note: This function does not record metrics or span events because it lacks
+// context. For observability integration, use SafeGoWithContext instead.
+//
+// Parameters:
+// - logger: Logger for recording panic information
+// - name: Descriptive name for the goroutine (used in logs)
+// - policy: How to handle panics (KeepRunning or CrashProcess)
+// - fn: The function to execute in the goroutine
+//
+// Example:
+//
+// runtime.SafeGo(logger, "email-sender", runtime.KeepRunning, func() {
+// sendEmail(to, subject, body)
+// })
+func SafeGo(logger Logger, name string, policy PanicPolicy, fn func()) {
+ if fn == nil {
+ if logger != nil {
+ logger.Log(context.Background(), log.LevelWarn,
+ "SafeGo called with nil callback, ignoring",
+ log.String("goroutine", name),
+ )
+ }
+
+ return
+ }
+
+ go func() {
+ defer RecoverWithPolicy(logger, name, policy)
+
+ fn()
+ }()
+}
+
+// SafeGoWithContext launches a goroutine with panic recovery and context
+// propagation.
+//
+// Note: For better observability labeling, prefer SafeGoWithContextAndComponent.
+func SafeGoWithContext(
+ ctx context.Context,
+ logger Logger,
+ name string,
+ policy PanicPolicy,
+ fn func(context.Context),
+) {
+ SafeGoWithContextAndComponent(ctx, logger, "", name, policy, fn)
+}
+
+// SafeGoWithContextAndComponent is like SafeGoWithContext but also records the
+// provided component name in observability signals.
+//
+// Parameters:
+// - ctx: Context for cancellation, values, and observability
+// - logger: Logger for recording panic information
+// - component: The service component (e.g., "transaction", "onboarding")
+// - name: Descriptive name for the goroutine (used in logs and metrics)
+// - policy: How to handle panics (KeepRunning or CrashProcess)
+// - fn: The function to execute, receiving the context
+func SafeGoWithContextAndComponent(
+ ctx context.Context,
+ logger Logger,
+ component, name string,
+ policy PanicPolicy,
+ fn func(context.Context),
+) {
+ if fn == nil {
+ if logger != nil {
+ logger.Log(context.Background(), log.LevelWarn,
+ "SafeGoWithContextAndComponent called with nil callback, ignoring",
+ log.String("component", component),
+ log.String("goroutine", name),
+ )
+ }
+
+ return
+ }
+
+ go func() {
+ defer RecoverWithPolicyAndContext(ctx, logger, component, name, policy)
+
+ fn(ctx)
+ }()
+}
diff --git a/commons/runtime/goroutine_test.go b/commons/runtime/goroutine_test.go
new file mode 100644
index 00000000..680d2188
--- /dev/null
+++ b/commons/runtime/goroutine_test.go
@@ -0,0 +1,469 @@
+//go:build unit
+
+package runtime
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestPanicPolicyString tests the String method of PanicPolicy.
+func TestPanicPolicyString(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ policy PanicPolicy
+ expected string
+ }{
+ {
+ name: "KeepRunning",
+ policy: KeepRunning,
+ expected: "KeepRunning",
+ },
+ {
+ name: "CrashProcess",
+ policy: CrashProcess,
+ expected: "CrashProcess",
+ },
+ {
+ name: "Unknown",
+ policy: PanicPolicy(99),
+ expected: "Unknown",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := tt.policy.String()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestRecoverAndLog_NoPanic tests that RecoverAndLog does nothing when no panic occurs.
+func TestRecoverAndLog_NoPanic(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ func() {
+ defer RecoverAndLog(logger, "test-no-panic")
+ // No panic here
+ }()
+
+ assert.False(t, logger.wasPanicLogged(), "Should not log when no panic occurs")
+}
+
+// TestRecoverAndLog_WithPanic tests that RecoverAndLog catches and logs panics.
+func TestRecoverAndLog_WithPanic(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ func() {
+ defer RecoverAndLog(logger, "test-with-panic")
+
+ panic("test panic value")
+ }()
+
+ assert.True(t, logger.wasPanicLogged(), "Should log when panic occurs")
+ // The log message should contain the panic info and stack trace
+ assert.NotEmpty(t, logger.errorCalls, "Should have logged error")
+}
+
+// TestRecoverAndCrash_NoPanic tests that RecoverAndCrash does nothing when no panic occurs.
+func TestRecoverAndCrash_NoPanic(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ func() {
+ defer RecoverAndCrash(logger, "test-no-panic")
+ // No panic here
+ }()
+
+ assert.False(t, logger.wasPanicLogged(), "Should not log when no panic occurs")
+}
+
+// TestRecoverAndCrash_WithPanic tests that RecoverAndCrash catches, logs, and re-panics.
+func TestRecoverAndCrash_WithPanic(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r, "Should re-panic after logging")
+ assert.Equal(t, "test panic value", r)
+ }()
+
+ func() {
+ defer RecoverAndCrash(logger, "test-with-panic")
+
+ panic("test panic value")
+ }()
+
+ t.Fatal("Should not reach here - panic should propagate")
+}
+
+// TestRecoverWithPolicy_KeepRunning tests policy-based recovery with KeepRunning.
+func TestRecoverWithPolicy_KeepRunning(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ func() {
+ defer RecoverWithPolicy(logger, "test-keep-running", KeepRunning)
+
+ panic("test panic")
+ }()
+
+ assert.True(t, logger.wasPanicLogged(), "Should log the panic")
+ // If we get here, the panic was swallowed (KeepRunning behavior)
+}
+
+// TestRecoverWithPolicy_CrashProcess tests policy-based recovery with CrashProcess.
+func TestRecoverWithPolicy_CrashProcess(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r, "Should re-panic with CrashProcess policy")
+ }()
+
+ func() {
+ defer RecoverWithPolicy(logger, "test-crash", CrashProcess)
+
+ panic("test panic")
+ }()
+
+ t.Fatal("Should not reach here")
+}
+
+// TestSafeGo_NoPanic tests SafeGo with a function that doesn't panic.
+func TestSafeGo_NoPanic(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ done := make(chan struct{})
+
+ SafeGo(logger, "test-no-panic", KeepRunning, func() {
+ close(done)
+ })
+
+ select {
+ case <-done:
+ // Success - goroutine completed
+ case <-time.After(time.Second):
+ t.Fatal("Goroutine did not complete in time")
+ }
+
+ // No sleep needed - if no panic occurred, logger won't be called
+ assert.False(t, logger.wasPanicLogged(), "Should not log when no panic occurs")
+}
+
+// TestSafeGo_WithPanic_KeepRunning tests SafeGo catching panics with KeepRunning policy.
+func TestSafeGo_WithPanic_KeepRunning(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ done := make(chan struct{})
+
+ SafeGo(logger, "test-panic-keep-running", KeepRunning, func() {
+ defer close(done)
+
+ panic("goroutine panic")
+ })
+
+ select {
+ case <-done:
+ // Success - goroutine completed (panic was caught)
+ case <-time.After(time.Second):
+ t.Fatal("Goroutine did not complete in time")
+ }
+
+ // Wait for logging via channel instead of arbitrary sleep
+ require.True(t, logger.waitForPanicLog(time.Second), "Should log the panic")
+}
+
+// TestSafeGoWithContext_NoPanic tests SafeGoWithContext with no panic.
+func TestSafeGoWithContext_NoPanic(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+ done := make(chan struct{})
+
+ SafeGoWithContext(ctx, logger, "test-ctx-no-panic", KeepRunning, func(ctx context.Context) {
+ close(done)
+ })
+
+ select {
+ case <-done:
+ // Success
+ case <-time.After(time.Second):
+ t.Fatal("Goroutine did not complete in time")
+ }
+}
+
+// TestSafeGoWithContext_WithCancellation tests context cancellation propagation.
+func TestSafeGoWithContext_WithCancellation(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx, cancel := context.WithCancel(context.Background())
+ done := make(chan struct{})
+
+ SafeGoWithContext(ctx, logger, "test-ctx-cancel", KeepRunning, func(ctx context.Context) {
+ <-ctx.Done()
+ close(done)
+ })
+
+ cancel()
+
+ select {
+ case <-done:
+ // Success - context cancellation was received
+ case <-time.After(time.Second):
+ t.Fatal("Goroutine did not receive cancellation in time")
+ }
+}
+
+// TestSafeGoWithContext_WithPanic tests SafeGoWithContext catching panics.
+func TestSafeGoWithContext_WithPanic(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+ done := make(chan struct{})
+
+ SafeGoWithContext(ctx, logger, "test-ctx-panic", KeepRunning, func(ctx context.Context) {
+ defer close(done)
+
+ panic("context goroutine panic")
+ })
+
+ select {
+ case <-done:
+ // Success - panic was caught
+ case <-time.After(time.Second):
+ t.Fatal("Goroutine did not complete in time")
+ }
+
+ // Wait for logging via channel instead of arbitrary sleep
+ require.True(t, logger.waitForPanicLog(time.Second), "Should log the panic")
+}
+
+// Note: SafeGo with CrashProcess policy is not directly tested because the re-panic
+// would crash the test process. The underlying RecoverWithPolicy is tested with
+// CrashProcess policy in TestRecoverWithPolicy_CrashProcess, which verifies the
+// re-panic behavior. In production, CrashProcess is intended to terminate the
+// process, which is the expected and correct behavior.
+
+// TestSafeGoWithContext_WithComponent tests SafeGoWithContextAndComponent.
+func TestSafeGoWithContext_WithComponent(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+ done := make(chan struct{})
+
+ SafeGoWithContextAndComponent(
+ ctx,
+ logger,
+ "transaction",
+ "test-component",
+ KeepRunning,
+ func(ctx context.Context) {
+ defer close(done)
+
+ panic("component panic")
+ },
+ )
+
+ select {
+ case <-done:
+ // Success - panic was caught
+ case <-time.After(time.Second):
+ t.Fatal("Goroutine did not complete in time")
+ }
+
+ // Wait for logging via channel
+ require.True(t, logger.waitForPanicLog(time.Second), "Should log the panic")
+}
+
+// TestRecoverWithPolicyAndContext_KeepRunning tests context-aware recovery with KeepRunning.
+func TestRecoverWithPolicyAndContext_KeepRunning(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+
+ func() {
+ defer RecoverWithPolicyAndContext(
+ ctx,
+ logger,
+ "test-component",
+ "test-handler",
+ KeepRunning,
+ )
+
+ panic("context panic")
+ }()
+
+ assert.True(t, logger.wasPanicLogged(), "Should log the panic")
+}
+
+// TestRecoverWithPolicyAndContext_CrashProcess tests context-aware recovery with CrashProcess.
+func TestRecoverWithPolicyAndContext_CrashProcess(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r, "Should re-panic with CrashProcess policy")
+ }()
+
+ func() {
+ defer RecoverWithPolicyAndContext(ctx, logger, "test-component", "test-crash", CrashProcess)
+
+ panic("crash panic")
+ }()
+
+ t.Fatal("Should not reach here")
+}
+
+// TestRecoverAndLogWithContext tests RecoverAndLogWithContext.
+func TestRecoverAndLogWithContext(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+
+ func() {
+ defer RecoverAndLogWithContext(ctx, logger, "test-component", "test-handler")
+
+ panic("log context panic")
+ }()
+
+ assert.True(t, logger.wasPanicLogged(), "Should log the panic")
+}
+
+// TestRecoverAndCrashWithContext tests RecoverAndCrashWithContext.
+func TestRecoverAndCrashWithContext(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r, "Should re-panic after logging")
+ assert.Equal(t, "crash context panic", r)
+ }()
+
+ func() {
+ defer RecoverAndCrashWithContext(ctx, logger, "test-component", "test-crash")
+
+ panic("crash context panic")
+ }()
+
+ t.Fatal("Should not reach here - panic should propagate")
+}
+
+// TestPanicMetrics_NilFactory tests that nil factory doesn't cause panic.
+func TestPanicMetrics_NilFactory(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ // Should not panic even with nil metrics
+ var pm *PanicMetrics
+ pm.RecordPanicRecovered(ctx, "test", "test")
+
+ // Also test the package-level function with no initialization
+ recordPanicMetric(ctx, "test", "test")
+}
+
+// TestErrorReporter_NilReporter tests that nil reporter doesn't cause panic.
+func TestErrorReporter_NilReporter(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ // Ensure no reporter is set
+ SetErrorReporter(nil)
+
+ // Should not panic
+ reportPanicToErrorService(ctx, "test panic", nil, "test", "test")
+
+ assert.Nil(t, GetErrorReporter())
+}
+
+// TestErrorReporter_CustomReporter tests custom error reporter integration.
+// Note: This test cannot run in parallel because it modifies the global error reporter.
+//
+//nolint:paralleltest // Cannot use t.Parallel() - modifies global errorReporterInstance
+func TestErrorReporter_CustomReporter(t *testing.T) {
+ ctx := context.Background()
+
+ var capturedErr error
+
+ var capturedTags map[string]string
+
+ // Create a mock reporter
+ mockReporter := &mockErrorReporter{
+ captureFunc: func(ctx context.Context, err error, tags map[string]string) {
+ capturedErr = err
+ capturedTags = tags
+ },
+ }
+
+ // Clear any existing reporter first, then set our mock
+ SetErrorReporter(nil)
+ SetErrorReporter(mockReporter)
+
+ // Ensure cleanup happens after the test
+ t.Cleanup(func() { SetErrorReporter(nil) })
+
+ // Report a panic
+ reportPanicToErrorService(
+ ctx,
+ "test panic",
+ []byte("test stack trace"),
+ "transaction",
+ "worker",
+ )
+
+ require.NotNil(t, capturedErr)
+ assert.Contains(t, capturedErr.Error(), "test panic")
+ assert.Equal(t, "transaction", capturedTags["component"])
+ assert.Equal(t, "worker", capturedTags["goroutine_name"])
+ assert.Equal(t, "test stack trace", capturedTags["stack_trace"])
+}
+
+// mockErrorReporter is a test implementation of ErrorReporter.
+type mockErrorReporter struct {
+ captureFunc func(ctx context.Context, err error, tags map[string]string)
+}
+
+func (m *mockErrorReporter) CaptureException(
+ ctx context.Context,
+ err error,
+ tags map[string]string,
+) {
+ if m.captureFunc != nil {
+ m.captureFunc(ctx, err, tags)
+ }
+}
diff --git a/commons/runtime/helpers_test.go b/commons/runtime/helpers_test.go
new file mode 100644
index 00000000..e81232b6
--- /dev/null
+++ b/commons/runtime/helpers_test.go
@@ -0,0 +1,56 @@
+//go:build unit
+
+package runtime
+
+import (
+ "context"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+// testLogger is a test logger that captures log calls.
+// It is shared across all runtime test files.
+type testLogger struct {
+ mu sync.Mutex
+ errorCalls []string
+ lastMessage string
+ panicLogged atomic.Bool
+ logged chan struct{} // Signals when a panic was logged
+}
+
+func newTestLogger() *testLogger {
+ return &testLogger{
+ logged: make(chan struct{}, 1), // Buffered to avoid blocking
+ }
+}
+
+func (logger *testLogger) Log(_ context.Context, _ log.Level, msg string, _ ...log.Field) {
+ logger.mu.Lock()
+ defer logger.mu.Unlock()
+
+ logger.errorCalls = append(logger.errorCalls, msg)
+ logger.lastMessage = msg
+ logger.panicLogged.Store(true)
+
+ // Signal that logging occurred (non-blocking)
+ select {
+ case logger.logged <- struct{}{}:
+ default:
+ }
+}
+
+func (logger *testLogger) wasPanicLogged() bool {
+ return logger.panicLogged.Load()
+}
+
+func (logger *testLogger) waitForPanicLog(timeout time.Duration) bool {
+ select {
+ case <-logger.logged:
+ return true
+ case <-time.After(timeout):
+ return false
+ }
+}
diff --git a/commons/runtime/log_mode_link_test.go b/commons/runtime/log_mode_link_test.go
new file mode 100644
index 00000000..609ba830
--- /dev/null
+++ b/commons/runtime/log_mode_link_test.go
@@ -0,0 +1,48 @@
+//go:build unit
+
+package runtime
+
+import (
+ "bytes"
+ "context"
+ slog "log"
+ "sync"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+)
+
+var runtimeLoggerOutputMu sync.Mutex
+
+func withRuntimeLoggerOutput(t *testing.T, output *bytes.Buffer) {
+ t.Helper()
+
+ runtimeLoggerOutputMu.Lock()
+ defer t.Cleanup(func() {
+ runtimeLoggerOutputMu.Unlock()
+ })
+
+ originalOutput := slog.Writer()
+ slog.SetOutput(output)
+ t.Cleanup(func() { slog.SetOutput(originalOutput) })
+}
+
+func TestLogProductionModeResolverRegistration(t *testing.T) {
+ var buf bytes.Buffer
+ withRuntimeLoggerOutput(t, &buf)
+
+ logger := &log.GoLogger{Level: log.LevelInfo}
+ initialMode := IsProductionMode()
+ t.Cleanup(func() { SetProductionMode(initialMode) })
+
+ SetProductionMode(false)
+ log.SafeError(logger, context.Background(), "runtime integration", assert.AnError, IsProductionMode())
+ assert.Contains(t, buf.String(), "general error")
+
+ buf.Reset()
+ SetProductionMode(true)
+ log.SafeError(logger, context.Background(), "runtime integration", assert.AnError, IsProductionMode())
+ assert.Contains(t, buf.String(), "error_type=*errors.errorString")
+ assert.NotContains(t, buf.String(), "general error")
+}
diff --git a/commons/runtime/metrics.go b/commons/runtime/metrics.go
new file mode 100644
index 00000000..9d5e4f08
--- /dev/null
+++ b/commons/runtime/metrics.go
@@ -0,0 +1,136 @@
+package runtime
+
+import (
+ "context"
+ "sync"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
+)
+
+// PanicMetrics provides panic-related metrics using OpenTelemetry.
+// It wraps lib-commons' MetricsFactory for consistent metric handling.
+type PanicMetrics struct {
+ factory *metrics.MetricsFactory
+ logger Logger
+}
+
+// panicRecoveredMetric defines the metric for counting recovered panics.
+var panicRecoveredMetric = metrics.Metric{
+ Name: constant.MetricPanicRecoveredTotal,
+ Unit: "1",
+ Description: "Total number of recovered panics",
+}
+
+// panicMetricsInstance is the singleton instance for panic metrics.
+// It is initialized lazily via InitPanicMetrics.
+var (
+ panicMetricsInstance *PanicMetrics
+ panicMetricsMu sync.RWMutex
+)
+
+// InitPanicMetrics initializes panic metrics with the provided MetricsFactory.
+//
+// Backward compatibility:
+// - InitPanicMetrics(factory)
+// - InitPanicMetrics(factory, logger)
+//
+// The logger is optional and used only for metric recording diagnostics.
+// This should be called once during application startup after telemetry is initialized.
+// It is safe to call multiple times; subsequent calls are no-ops.
+//
+// Example:
+//
+// tl, err := opentelemetry.NewTelemetry(cfg)
+// if err != nil {
+// log.Fatalf("failed to init telemetry: %v", err)
+// }
+// tl.ApplyGlobals()
+// runtime.InitPanicMetrics(tl.MetricsFactory)
+func InitPanicMetrics(factory *metrics.MetricsFactory, logger ...Logger) {
+ panicMetricsMu.Lock()
+ defer panicMetricsMu.Unlock()
+
+ if factory == nil {
+ return
+ }
+
+ if panicMetricsInstance != nil {
+ return // Already initialized
+ }
+
+ var l Logger
+ if len(logger) > 0 {
+ l = logger[0]
+ }
+
+ panicMetricsInstance = &PanicMetrics{
+ factory: factory,
+ logger: l,
+ }
+}
+
+// GetPanicMetrics returns the singleton PanicMetrics instance.
+// Returns nil if InitPanicMetrics has not been called.
+func GetPanicMetrics() *PanicMetrics {
+ panicMetricsMu.RLock()
+ defer panicMetricsMu.RUnlock()
+
+ return panicMetricsInstance
+}
+
+// ResetPanicMetrics clears the panic metrics singleton.
+// This is primarily intended for testing to ensure test isolation.
+// In production, this should generally not be called.
+func ResetPanicMetrics() {
+ panicMetricsMu.Lock()
+ defer panicMetricsMu.Unlock()
+
+ panicMetricsInstance = nil
+}
+
+// RecordPanicRecovered increments the panic_recovered_total counter with the given labels.
+// If metrics are not initialized, this is a no-op.
+//
+// Parameters:
+// - ctx: Context for metric recording (may contain trace correlation)
+// - component: The component where the panic occurred (e.g., "transaction", "onboarding", "crm")
+// - goroutineName: The name of the goroutine or handler (e.g., "http_handler", "rabbitmq_worker")
+func (pm *PanicMetrics) RecordPanicRecovered(ctx context.Context, component, goroutineName string) {
+ if pm == nil || pm.factory == nil {
+ return
+ }
+
+ counter, err := pm.factory.Counter(panicRecoveredMetric)
+ if err != nil {
+ if pm.logger != nil {
+ pm.logger.Log(ctx, log.LevelWarn, "failed to create panic metric counter", log.Err(err))
+ }
+
+ return
+ }
+
+ err = counter.
+ WithLabels(map[string]string{
+ "component": constant.SanitizeMetricLabel(component),
+ "goroutine_name": constant.SanitizeMetricLabel(goroutineName),
+ }).
+ AddOne(ctx)
+ if err != nil {
+ if pm.logger != nil {
+ pm.logger.Log(ctx, log.LevelWarn, "failed to record panic metric", log.Err(err))
+ }
+
+ return
+ }
+}
+
+// recordPanicMetric is a package-level helper that records a panic metric if metrics are initialized.
+// This is called internally by recovery functions.
+func recordPanicMetric(ctx context.Context, component, goroutineName string) {
+ pm := GetPanicMetrics()
+ if pm != nil {
+ pm.RecordPanicRecovered(ctx, component, goroutineName)
+ }
+}
diff --git a/commons/runtime/metrics_test.go b/commons/runtime/metrics_test.go
new file mode 100644
index 00000000..701928cd
--- /dev/null
+++ b/commons/runtime/metrics_test.go
@@ -0,0 +1,58 @@
+//go:build unit
+
+package runtime
+
+import (
+ "strings"
+ "testing"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/stretchr/testify/assert"
+)
+
+// TestSanitizeMetricLabel tests the shared constant.SanitizeMetricLabel function.
+func TestSanitizeMetricLabel(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {
+ name: "empty string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "short string",
+ input: "component",
+ expected: "component",
+ },
+ {
+ name: "exactly max length",
+ input: strings.Repeat("a", constant.MaxMetricLabelLength),
+ expected: strings.Repeat("a", constant.MaxMetricLabelLength),
+ },
+ {
+ name: "exceeds max length",
+ input: strings.Repeat("b", constant.MaxMetricLabelLength+10),
+ expected: strings.Repeat("b", constant.MaxMetricLabelLength),
+ },
+ {
+ name: "much longer than max",
+ input: strings.Repeat("c", 200),
+ expected: strings.Repeat("c", constant.MaxMetricLabelLength),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := constant.SanitizeMetricLabel(tt.input)
+ assert.Equal(t, tt.expected, result)
+ assert.LessOrEqual(t, len(result), constant.MaxMetricLabelLength)
+ })
+ }
+}
diff --git a/commons/runtime/policy.go b/commons/runtime/policy.go
new file mode 100644
index 00000000..6d8dd1fd
--- /dev/null
+++ b/commons/runtime/policy.go
@@ -0,0 +1,28 @@
+package runtime
+
+// PanicPolicy determines how a recovered panic should be handled.
+type PanicPolicy int
+
+const (
+ // KeepRunning logs the panic and stack trace, then continues execution.
+ // Use for HTTP/gRPC handlers and worker goroutines where crashing would
+ // affect other requests or tasks.
+ KeepRunning PanicPolicy = iota
+
+ // CrashProcess logs the panic and stack trace, then re-panics to crash
+ // the process. Use for critical invariant violations where continuing
+ // would cause data corruption or undefined behavior.
+ CrashProcess
+)
+
+// String returns the string representation of the PanicPolicy.
+func (p PanicPolicy) String() string {
+ switch p {
+ case KeepRunning:
+ return "KeepRunning"
+ case CrashProcess:
+ return "CrashProcess"
+ default:
+ return "Unknown"
+ }
+}
diff --git a/commons/runtime/policy_test.go b/commons/runtime/policy_test.go
new file mode 100644
index 00000000..64833143
--- /dev/null
+++ b/commons/runtime/policy_test.go
@@ -0,0 +1,110 @@
+//go:build unit
+
+package runtime
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// TestPanicPolicy_String tests the String method for all PanicPolicy values.
+func TestPanicPolicy_String(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ policy PanicPolicy
+ expected string
+ }{
+ {
+ name: "KeepRunning returns correct string",
+ policy: KeepRunning,
+ expected: "KeepRunning",
+ },
+ {
+ name: "CrashProcess returns correct string",
+ policy: CrashProcess,
+ expected: "CrashProcess",
+ },
+ {
+ name: "Unknown positive value returns Unknown",
+ policy: PanicPolicy(99),
+ expected: "Unknown",
+ },
+ {
+ name: "Negative value returns Unknown",
+ policy: PanicPolicy(-1),
+ expected: "Unknown",
+ },
+ {
+ name: "Large value returns Unknown",
+ policy: PanicPolicy(1000),
+ expected: "Unknown",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := tt.policy.String()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestPanicPolicy_IotaOrdering verifies the iota constant ordering.
+func TestPanicPolicy_IotaOrdering(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ policy PanicPolicy
+ expectedValue int
+ }{
+ {
+ name: "KeepRunning is 0 (first iota)",
+ policy: KeepRunning,
+ expectedValue: 0,
+ },
+ {
+ name: "CrashProcess is 1 (second iota)",
+ policy: CrashProcess,
+ expectedValue: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tt.expectedValue, int(tt.policy))
+ })
+ }
+}
+
+// TestPanicPolicy_TypeSafety verifies type conversion behavior.
+func TestPanicPolicy_TypeSafety(t *testing.T) {
+ t.Parallel()
+
+ t.Run("explicit int conversion works", func(t *testing.T) {
+ t.Parallel()
+
+ p := KeepRunning
+ assert.Equal(t, 0, int(p))
+
+ p = CrashProcess
+ assert.Equal(t, 1, int(p))
+ })
+
+ t.Run("policy from int conversion", func(t *testing.T) {
+ t.Parallel()
+
+ p := PanicPolicy(0)
+ assert.Equal(t, KeepRunning, p)
+
+ p = PanicPolicy(1)
+ assert.Equal(t, CrashProcess, p)
+ })
+}
diff --git a/commons/runtime/recover.go b/commons/runtime/recover.go
new file mode 100644
index 00000000..b3ed6473
--- /dev/null
+++ b/commons/runtime/recover.go
@@ -0,0 +1,229 @@
+package runtime
+
+import (
+ "context"
+ "runtime/debug"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+// Logger defines the minimal logging interface required by runtime.
+// This interface is satisfied by github.com/LerianStudio/lib-commons/v4/commons/log.Logger.
+type Logger interface {
+ Log(ctx context.Context, level log.Level, msg string, fields ...log.Field)
+}
+
+// RecoverAndLog recovers from a panic, logs it with the stack trace, and
+// continues execution. Use this in defer statements for handlers and workers
+// where you want to prevent crashes.
+//
+// Note: This function does not record metrics or span events because it lacks
+// context. For observability integration, use RecoverAndLogWithContext instead.
+//
+// Example:
+//
+// func worker() {
+// defer runtime.RecoverAndLog(logger, "worker")
+// // ...
+// }
+func RecoverAndLog(logger Logger, name string) {
+ if r := recover(); r != nil {
+ logPanic(logger, name, r)
+ }
+}
+
+// RecoverAndLogWithContext is like RecoverAndLog but with full observability integration.
+// It records metrics, span events, and reports to error tracking services.
+//
+// Parameters:
+// - ctx: Context for observability (metrics, tracing, error reporting)
+// - logger: Logger for structured logging
+// - component: The service component (e.g., "transaction", "onboarding")
+// - name: Descriptive name for the goroutine or handler
+//
+// Example:
+//
+// func handler(ctx context.Context) {
+// defer runtime.RecoverAndLogWithContext(ctx, logger, "transaction", "create_handler")
+// // ...
+// }
+func RecoverAndLogWithContext(ctx context.Context, logger Logger, component, name string) {
+ if r := recover(); r != nil {
+ stack := debug.Stack()
+ logPanicWithStack(logger, name, r, stack)
+ recordPanicObservability(ctx, r, stack, component, name)
+ }
+}
+
+// RecoverAndCrash recovers from a panic, logs it with the stack trace, and
+// re-panics to crash the process. Use this in defer statements for critical
+// operations where continuing after a panic would be dangerous.
+//
+// Example:
+//
+// func criticalOperation() {
+// defer runtime.RecoverAndCrash(logger, "critical-op")
+// // ...
+// }
+func RecoverAndCrash(logger Logger, name string) {
+ if r := recover(); r != nil {
+ logPanic(logger, name, r)
+ panic(r)
+ }
+}
+
+// RecoverAndCrashWithContext is like RecoverAndCrash but with full observability integration.
+// It records metrics and span events before re-panicking.
+//
+// Parameters:
+// - ctx: Context for observability (metrics, tracing, error reporting)
+// - logger: Logger for structured logging
+// - component: The service component (e.g., "transaction", "onboarding")
+// - name: Descriptive name for the goroutine or handler
+func RecoverAndCrashWithContext(ctx context.Context, logger Logger, component, name string) {
+ if r := recover(); r != nil {
+ stack := debug.Stack()
+ logPanicWithStack(logger, name, r, stack)
+ recordPanicObservability(ctx, r, stack, component, name)
+ panic(r)
+ }
+}
+
+// RecoverWithPolicy recovers from a panic and handles it according to the
+// specified policy. Use this when the recovery behavior needs to be determined
+// at runtime.
+//
+// Note: This function does not record metrics or span events because it lacks
+// context. For observability integration, use RecoverWithPolicyAndContext instead.
+//
+// Example:
+//
+// func flexibleHandler(policy runtime.PanicPolicy) {
+// defer runtime.RecoverWithPolicy(logger, "handler", policy)
+// // ...
+// }
+func RecoverWithPolicy(logger Logger, name string, policy PanicPolicy) {
+ if r := recover(); r != nil {
+ logPanic(logger, name, r)
+
+ if policy == CrashProcess {
+ panic(r)
+ }
+ }
+}
+
+// RecoverWithPolicyAndContext is like RecoverWithPolicy but with full observability integration.
+// It records metrics, span events, and reports to error tracking services.
+//
+// Parameters:
+// - ctx: Context for observability (metrics, tracing, error reporting)
+// - logger: Logger for structured logging
+// - component: The service component (e.g., "transaction", "onboarding")
+// - name: Descriptive name for the goroutine or handler
+// - policy: How to handle the panic after logging/recording
+//
+// Example:
+//
+// func worker(ctx context.Context, policy runtime.PanicPolicy) {
+// defer runtime.RecoverWithPolicyAndContext(ctx, logger, "transaction", "balance_worker", policy)
+// // ...
+// }
+func RecoverWithPolicyAndContext(
+ ctx context.Context,
+ logger Logger,
+ component, name string,
+ policy PanicPolicy,
+) {
+ if recovered := recover(); recovered != nil {
+ stack := debug.Stack()
+ logPanicWithStack(logger, name, recovered, stack)
+ recordPanicObservability(ctx, recovered, stack, component, name)
+
+ if policy == CrashProcess {
+ panic(recovered)
+ }
+ }
+}
+
+// logPanic logs the panic value and stack trace using the provided logger.
+// This is the legacy function that captures stack internally.
+func logPanic(logger Logger, name string, panicValue any) {
+ stack := debug.Stack()
+ logPanicWithStack(logger, name, panicValue, stack)
+}
+
+// logPanicWithStack logs the panic with a pre-captured stack trace.
+// In production mode, panic values are redacted to prevent leaking sensitive data.
+func logPanicWithStack(logger Logger, name string, panicValue any, stack []byte) {
+ if logger == nil {
+ // Last resort fallback - should never happen in production
+ return
+ }
+
+ if IsProductionMode() {
+ logger.Log(context.Background(), log.LevelError,
+ "panic recovered",
+ log.String("source", name),
+ log.String("value", redactedPanicMsg),
+ )
+
+ return
+ }
+
+ logger.Log(context.Background(), log.LevelError,
+ "panic recovered",
+ log.String("source", name),
+ log.Any("value", panicValue),
+ log.String("stack_trace", string(stack)),
+ )
+}
+
+// recordPanicObservability records panic information to all configured observability systems.
+// This includes metrics, distributed tracing, and error reporting services.
+func recordPanicObservability(
+ ctx context.Context,
+ panicValue any,
+ stack []byte,
+ component, name string,
+) {
+ // Record metric
+ recordPanicMetric(ctx, component, name)
+
+ // Record span event
+ RecordPanicToSpanWithComponent(ctx, panicValue, stack, component, name)
+
+ // Report to error tracking service (e.g., Sentry) if configured
+ reportPanicToErrorService(ctx, panicValue, stack, component, name)
+}
+
+// HandlePanicValue processes a panic value that was already recovered by an external
+// mechanism (e.g., Fiber's recover middleware). This function logs and records
+// observability data without calling recover() itself.
+//
+// Use this when integrating with frameworks that provide their own panic recovery
+// but still need our observability pipeline.
+//
+// Parameters:
+// - ctx: Context for observability (metrics, tracing, error reporting)
+// - logger: Logger for structured logging
+// - panicValue: The panic value recovered by the external mechanism
+// - component: The service component (e.g., "matcher", "ingestion")
+// - name: Descriptive name for the handler (e.g., "http_handler")
+//
+// Example (Fiber middleware):
+//
+// recover.New(recover.Config{
+// StackTraceHandler: func(c *fiber.Ctx, panicValue any) {
+// ctx := extractContext(c)
+// runtime.HandlePanicValue(ctx, logger, panicValue, "matcher", "http_handler")
+// },
+// })
+func HandlePanicValue(ctx context.Context, logger Logger, panicValue any, component, name string) {
+ if panicValue == nil {
+ return
+ }
+
+ stack := debug.Stack()
+ logPanicWithStack(logger, name, panicValue, stack)
+ recordPanicObservability(ctx, panicValue, stack, component, name)
+}
diff --git a/commons/runtime/recover_test.go b/commons/runtime/recover_test.go
new file mode 100644
index 00000000..6ac714d8
--- /dev/null
+++ b/commons/runtime/recover_test.go
@@ -0,0 +1,490 @@
+//go:build unit
+
+package runtime
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+var (
+ errTestPanicRecover = errors.New("test error")
+ errOriginalPanicRecover = errors.New("original error")
+)
+
+// TestLogPanicWithStack_NilLogger tests that nil logger doesn't cause panic.
+func TestLogPanicWithStack_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ require.NotPanics(t, func() {
+ logPanicWithStack(nil, "test", "panic value", []byte("stack trace"))
+ })
+}
+
+// TestLogPanicWithStack_ValidLogger tests logging with a valid logger.
+func TestLogPanicWithStack_ValidLogger(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ stack := []byte("goroutine 1 [running]:\nmain.main()\n\t/path/to/file.go:10")
+
+ logPanicWithStack(logger, "test-handler", "test panic", stack)
+
+ assert.True(t, logger.wasPanicLogged())
+ assert.NotEmpty(t, logger.errorCalls)
+}
+
+// TestLogPanicWithStack_DifferentPanicTypes tests various panic value types.
+func TestLogPanicWithStack_DifferentPanicTypes(t *testing.T) {
+ t.Parallel()
+
+ type customStruct struct {
+ Field string
+ Code int
+ }
+
+ tests := []struct {
+ name string
+ panicValue any
+ }{
+ {
+ name: "string panic value",
+ panicValue: "something went wrong",
+ },
+ {
+ name: "error panic value",
+ panicValue: errTestPanicRecover,
+ },
+ {
+ name: "int panic value",
+ panicValue: 42,
+ },
+ {
+ name: "struct panic value",
+ panicValue: customStruct{Field: "test", Code: 500},
+ },
+ {
+ name: "nil panic value",
+ panicValue: nil,
+ },
+ {
+ name: "bool panic value",
+ panicValue: true,
+ },
+ {
+ name: "float panic value",
+ panicValue: 3.14159,
+ },
+ {
+ name: "slice panic value",
+ panicValue: []string{"a", "b", "c"},
+ },
+ {
+ name: "map panic value",
+ panicValue: map[string]int{"key": 123},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ stack := []byte("test stack")
+
+ require.NotPanics(t, func() {
+ logPanicWithStack(logger, "test", tt.panicValue, stack)
+ })
+
+ assert.True(t, logger.wasPanicLogged())
+ })
+ }
+}
+
+// TestRecoverAndLog_NilLogger tests RecoverAndLog with nil logger.
+func TestRecoverAndLog_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ require.NotPanics(t, func() {
+ func() {
+ defer RecoverAndLog(nil, "test-nil-logger")
+
+ panic("test panic")
+ }()
+ })
+}
+
+// TestRecoverAndLogWithContext_NilLogger tests RecoverAndLogWithContext with nil logger.
+func TestRecoverAndLogWithContext_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ require.NotPanics(t, func() {
+ func() {
+ defer RecoverAndLogWithContext(ctx, nil, "component", "test-nil-logger")
+
+ panic("test panic")
+ }()
+ })
+}
+
+// TestRecoverAndCrash_NilLogger tests RecoverAndCrash with nil logger still re-panics.
+func TestRecoverAndCrash_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r, "Should re-panic even with nil logger")
+ assert.Equal(t, "test panic", r)
+ }()
+
+ func() {
+ defer RecoverAndCrash(nil, "test-nil-logger")
+
+ panic("test panic")
+ }()
+
+ t.Fatal("Should not reach here")
+}
+
+// TestRecoverWithPolicy_NilLogger tests RecoverWithPolicy with nil logger.
+func TestRecoverWithPolicy_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ t.Run("KeepRunning with nil logger", func(t *testing.T) {
+ t.Parallel()
+
+ require.NotPanics(t, func() {
+ func() {
+ defer RecoverWithPolicy(nil, "test", KeepRunning)
+
+ panic("test panic")
+ }()
+ })
+ })
+
+ t.Run("CrashProcess with nil logger still re-panics", func(t *testing.T) {
+ t.Parallel()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r, "Should re-panic with CrashProcess")
+ }()
+
+ func() {
+ defer RecoverWithPolicy(nil, "test", CrashProcess)
+
+ panic("test panic")
+ }()
+
+ t.Fatal("Should not reach here")
+ })
+}
+
+// TestRecoverWithPolicyAndContext_NilLogger tests context variant with nil logger.
+func TestRecoverWithPolicyAndContext_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ t.Run("KeepRunning with nil logger", func(t *testing.T) {
+ t.Parallel()
+
+ require.NotPanics(t, func() {
+ func() {
+ defer RecoverWithPolicyAndContext(ctx, nil, "component", "test", KeepRunning)
+
+ panic("test panic")
+ }()
+ })
+ })
+
+ t.Run("CrashProcess with nil logger still re-panics", func(t *testing.T) {
+ t.Parallel()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r, "Should re-panic with CrashProcess")
+ }()
+
+ func() {
+ defer RecoverWithPolicyAndContext(ctx, nil, "component", "test", CrashProcess)
+
+ panic("test panic")
+ }()
+
+ t.Fatal("Should not reach here")
+ })
+}
+
+// TestLogPanic_CallsLogPanicWithStack tests that logPanic delegates correctly.
+func TestLogPanic_CallsLogPanicWithStack(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ logPanic(logger, "test-handler", "panic value")
+
+ assert.True(t, logger.wasPanicLogged())
+ assert.NotEmpty(t, logger.errorCalls)
+}
+
+// TestRecoverAndLog_PreservesPanicValue tests panic value is correctly captured.
+func TestRecoverAndLog_PreservesPanicValue(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ panicValue any
+ }{
+ {
+ name: "string value",
+ panicValue: "panic message",
+ },
+ {
+ name: "error value",
+ panicValue: errOriginalPanicRecover,
+ },
+
+ {
+ name: "integer value",
+ panicValue: 12345,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ func() {
+ defer RecoverAndLog(logger, "test")
+
+ panic(tt.panicValue)
+ }()
+
+ assert.True(t, logger.wasPanicLogged())
+ })
+ }
+}
+
+// TestRecoverAndCrash_PreservesPanicValue tests re-panicked value is preserved.
+func TestRecoverAndCrash_PreservesPanicValue(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ panicValue any
+ }{
+ {
+ name: "string value",
+ panicValue: "original panic",
+ },
+ {
+ name: "error value",
+ panicValue: errOriginalPanicRecover,
+ },
+ {
+ name: "integer value",
+ panicValue: 99999,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r)
+ assert.Equal(t, tt.panicValue, r)
+ }()
+
+ func() {
+ defer RecoverAndCrash(logger, "test")
+
+ panic(tt.panicValue)
+ }()
+
+ t.Fatal("Should not reach here")
+ })
+ }
+}
+
+// TestRecoverAndCrashWithContext_PreservesPanicValue tests context variant preserves value.
+func TestRecoverAndCrashWithContext_PreservesPanicValue(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ logger := newTestLogger()
+ expectedValue := "context panic value"
+
+ defer func() {
+ r := recover()
+ require.NotNil(t, r)
+ assert.Equal(t, expectedValue, r)
+ }()
+
+ func() {
+ defer RecoverAndCrashWithContext(ctx, logger, "component", "handler")
+
+ panic(expectedValue)
+ }()
+
+ t.Fatal("Should not reach here")
+}
+
+// TestRecoverFunctions_NoPanic tests all recover functions when no panic occurs.
+func TestRecoverFunctions_NoPanic(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ logger := newTestLogger()
+
+ t.Run("RecoverAndLog no panic", func(t *testing.T) {
+ t.Parallel()
+
+ testLogger := newTestLogger()
+
+ func() {
+ defer RecoverAndLog(testLogger, "test")
+ }()
+
+ assert.False(t, testLogger.wasPanicLogged())
+ })
+
+ t.Run("RecoverAndLogWithContext no panic", func(t *testing.T) {
+ t.Parallel()
+
+ testLogger := newTestLogger()
+
+ func() {
+ defer RecoverAndLogWithContext(ctx, testLogger, "component", "test")
+ }()
+
+ assert.False(t, testLogger.wasPanicLogged())
+ })
+
+ t.Run("RecoverAndCrash no panic", func(t *testing.T) {
+ t.Parallel()
+
+ func() {
+ defer RecoverAndCrash(logger, "test")
+ }()
+
+ assert.False(t, logger.wasPanicLogged())
+ })
+
+ t.Run("RecoverAndCrashWithContext no panic", func(t *testing.T) {
+ t.Parallel()
+
+ testLogger := newTestLogger()
+
+ func() {
+ defer RecoverAndCrashWithContext(ctx, testLogger, "component", "test")
+ }()
+
+ assert.False(t, testLogger.wasPanicLogged())
+ })
+
+ t.Run("RecoverWithPolicy no panic", func(t *testing.T) {
+ t.Parallel()
+
+ testLogger := newTestLogger()
+
+ func() {
+ defer RecoverWithPolicy(testLogger, "test", KeepRunning)
+ }()
+
+ assert.False(t, testLogger.wasPanicLogged())
+ })
+
+ t.Run("RecoverWithPolicyAndContext no panic", func(t *testing.T) {
+ t.Parallel()
+
+ testLogger := newTestLogger()
+
+ func() {
+ defer RecoverWithPolicyAndContext(ctx, testLogger, "component", "test", KeepRunning)
+ }()
+
+ assert.False(t, testLogger.wasPanicLogged())
+ })
+}
+
+// TestHandlePanicValue tests the HandlePanicValue function for external recovery integration.
+func TestHandlePanicValue(t *testing.T) {
+ t.Parallel()
+
+ t.Run("logs and records observability for panic value", func(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+
+ HandlePanicValue(ctx, logger, "test panic", "matcher", "http_handler")
+
+ assert.True(t, logger.wasPanicLogged())
+ assert.NotEmpty(t, logger.errorCalls)
+ })
+
+ t.Run("handles nil panic value gracefully", func(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+
+ require.NotPanics(t, func() {
+ HandlePanicValue(ctx, logger, nil, "matcher", "http_handler")
+ })
+
+ assert.False(t, logger.wasPanicLogged())
+ })
+
+ t.Run("handles nil logger gracefully", func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ require.NotPanics(t, func() {
+ HandlePanicValue(ctx, nil, "test panic", "matcher", "http_handler")
+ })
+ })
+
+ t.Run("handles various panic value types", func(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ panicValue any
+ }{
+ {"string", "panic message"},
+ {"error", errTestPanicRecover},
+ {"integer", 42},
+ {"struct", struct{ Code int }{Code: 500}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ logger := newTestLogger()
+ ctx := context.Background()
+
+ require.NotPanics(t, func() {
+ HandlePanicValue(ctx, logger, tt.panicValue, "matcher", "handler")
+ })
+
+ assert.True(t, logger.wasPanicLogged())
+ })
+ }
+ })
+}
diff --git a/commons/runtime/tracing.go b/commons/runtime/tracing.go
new file mode 100644
index 00000000..6d3b7c68
--- /dev/null
+++ b/commons/runtime/tracing.go
@@ -0,0 +1,137 @@
+package runtime
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "regexp"
+
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "go.opentelemetry.io/otel/attribute"
+ "go.opentelemetry.io/otel/codes"
+ "go.opentelemetry.io/otel/trace"
+)
+
+// maxPanicValueLen is the maximum length for a panic value string exported to spans.
+const maxPanicValueLen = 1024
+
+// maxStackTraceLen is the maximum length for a stack trace string exported to spans.
+const maxStackTraceLen = 4096
+
+// sensitivePattern matches common sensitive data patterns for redaction in span attributes.
+// Covers passwords, tokens, secrets, API keys, credentials, and connection strings.
+var sensitivePattern = regexp.MustCompile(
+ `(?i)(password|passwd|pwd|token|secret|api[_-]?key|credential|bearer|authorization)[=:]\s*\S+`,
+)
+
+// sensitiveRedaction is the replacement string for redacted sensitive data.
+const sensitiveRedaction = "[REDACTED]"
+
+// sanitizePanicValue truncates and redacts sensitive patterns from a panic value string.
+func sanitizePanicValue(raw string) string {
+ sanitized := sensitivePattern.ReplaceAllString(raw, sensitiveRedaction)
+
+ if len(sanitized) > maxPanicValueLen {
+ return sanitized[:maxPanicValueLen] + "...[truncated]"
+ }
+
+ return sanitized
+}
+
+// sanitizeStackTrace truncates a stack trace for safe span export.
+func sanitizeStackTrace(stack []byte) string {
+ s := string(stack)
+
+ if len(s) > maxStackTraceLen {
+ return s[:maxStackTraceLen] + "\n...[truncated]"
+ }
+
+ return s
+}
+
+// ErrPanic is the sentinel error for recovered panics recorded to spans.
+var ErrPanic = errors.New("panic")
+
+// PanicSpanEventName is the event name used when recording panic events on spans.
+const PanicSpanEventName = constant.EventPanicRecovered
+
+// RecordPanicToSpan records a recovered panic as an error event on the current span.
+// This enriches distributed traces with panic information for debugging.
+//
+// The function:
+// - Adds a "panic.recovered" event with panic value, stack trace, and goroutine name
+// - Records the panic as an error using span.RecordError
+// - Sets the span status to Error with a descriptive message
+//
+// Parameters:
+// - ctx: Context containing the active span
+// - panicValue: The value passed to panic()
+// - stack: The stack trace captured via debug.Stack()
+// - goroutineName: The name of the goroutine where the panic occurred
+//
+// If there is no active span in the context, this function is a no-op.
+func RecordPanicToSpan(ctx context.Context, panicValue any, stack []byte, goroutineName string) {
+ recordPanicToSpanInternal(ctx, panicValue, stack, "", goroutineName)
+}
+
+// RecordPanicToSpanWithComponent is like RecordPanicToSpan but also includes the component name.
+// This is useful for HTTP/gRPC handlers where both component and handler name are relevant.
+//
+// Parameters:
+// - ctx: Context containing the active span
+// - panicValue: The value passed to panic()
+// - stack: The stack trace captured via debug.Stack()
+// - component: The service component (e.g., "transaction", "onboarding")
+// - goroutineName: The name of the handler or goroutine
+func RecordPanicToSpanWithComponent(
+ ctx context.Context,
+ panicValue any,
+ stack []byte,
+ component, goroutineName string,
+) {
+ recordPanicToSpanInternal(ctx, panicValue, stack, component, goroutineName)
+}
+
+// recordPanicToSpanInternal is the shared implementation for recording panic events.
+// Panic values and stack traces are sanitized to prevent leaking sensitive data
+// into distributed tracing backends.
+func recordPanicToSpanInternal(
+ ctx context.Context,
+ panicValue any,
+ stack []byte,
+ component, goroutineName string,
+) {
+ span := trace.SpanFromContext(ctx)
+ if !span.IsRecording() {
+ return
+ }
+
+ panicStr := sanitizePanicValue(fmt.Sprintf("%v", panicValue))
+ stackStr := sanitizeStackTrace(stack)
+
+ // Build attributes list
+ attrs := []attribute.KeyValue{
+ attribute.String("panic.value", panicStr),
+ attribute.String("panic.stack", stackStr),
+ attribute.String("panic.goroutine_name", goroutineName),
+ }
+
+ // Add component if provided
+ if component != "" {
+ attrs = append(attrs, attribute.String("panic.component", component))
+ }
+
+ // Add detailed event with all panic information
+ span.AddEvent(PanicSpanEventName, trace.WithAttributes(attrs...))
+
+ // Record sanitized error for error-tracking integrations
+ span.RecordError(fmt.Errorf("%w: %s", ErrPanic, panicStr))
+
+ // Set span status to Error
+ statusMsg := "panic recovered in " + goroutineName
+ if component != "" {
+ statusMsg = fmt.Sprintf("panic recovered in %s/%s", component, goroutineName)
+ }
+
+ span.SetStatus(codes.Error, statusMsg)
+}
diff --git a/commons/runtime/tracing_test.go b/commons/runtime/tracing_test.go
new file mode 100644
index 00000000..3eb9042c
--- /dev/null
+++ b/commons/runtime/tracing_test.go
@@ -0,0 +1,541 @@
+//go:build unit
+
+package runtime
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/codes"
+ sdktrace "go.opentelemetry.io/otel/sdk/trace"
+ "go.opentelemetry.io/otel/sdk/trace/tracetest"
+ "go.opentelemetry.io/otel/trace"
+)
+
+func newTestTracerProvider(t *testing.T) (*sdktrace.TracerProvider, *tracetest.SpanRecorder) {
+ t.Helper()
+
+ recorder := tracetest.NewSpanRecorder()
+ provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
+
+ t.Cleanup(func() {
+ _ = provider.Shutdown(context.Background())
+ })
+
+ return provider, recorder
+}
+
+func TestErrPanic(t *testing.T) {
+ t.Parallel()
+
+ assert.NotNil(t, ErrPanic)
+ assert.Equal(t, "panic", ErrPanic.Error())
+}
+
+func TestPanicSpanEventName(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, "panic.recovered", PanicSpanEventName)
+}
+
+func TestRecordPanicToSpan(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ panicValue any
+ stack []byte
+ goroutineName string
+ wantEvent bool
+ wantStatus codes.Code
+ wantMessage string
+ }{
+ {
+ name: "string panic value",
+ panicValue: "something went wrong",
+ stack: []byte("goroutine 1 [running]:\nmain.main()"),
+ goroutineName: "worker-1",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in worker-1",
+ },
+ {
+ name: "error panic value",
+ panicValue: assert.AnError,
+ stack: []byte("stack trace here"),
+ goroutineName: "handler",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in handler",
+ },
+ {
+ name: "integer panic value",
+ panicValue: 42,
+ stack: []byte(""),
+ goroutineName: "processor",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in processor",
+ },
+ {
+ name: "nil panic value",
+ panicValue: nil,
+ stack: []byte("some stack"),
+ goroutineName: "main",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in main",
+ },
+ {
+ name: "empty goroutine name",
+ panicValue: "panic!",
+ stack: []byte("trace"),
+ goroutineName: "",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in ",
+ },
+ {
+ name: "empty stack trace",
+ panicValue: "error",
+ stack: nil,
+ goroutineName: "worker",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in worker",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ provider, recorder := newTestTracerProvider(t)
+ tracer := provider.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+
+ RecordPanicToSpan(ctx, tt.panicValue, tt.stack, tt.goroutineName)
+ span.End()
+
+ spans := recorder.Ended()
+ require.Len(t, spans, 1)
+
+ recordedSpan := spans[0]
+
+ if tt.wantEvent {
+ require.NotEmpty(t, recordedSpan.Events(), "expected panic event to be recorded")
+
+ var foundPanicEvent bool
+
+ for _, event := range recordedSpan.Events() {
+ if event.Name == PanicSpanEventName {
+ foundPanicEvent = true
+
+ attrMap := make(map[string]string)
+ for _, attr := range event.Attributes {
+ attrMap[string(attr.Key)] = attr.Value.AsString()
+ }
+
+ assert.Contains(t, attrMap, "panic.value")
+ assert.Contains(t, attrMap, "panic.stack")
+ assert.Contains(t, attrMap, "panic.goroutine_name")
+ assert.Equal(t, tt.goroutineName, attrMap["panic.goroutine_name"])
+ assert.NotContains(
+ t,
+ attrMap,
+ "panic.component",
+ "component should not be present for RecordPanicToSpan",
+ )
+ }
+ }
+
+ assert.True(t, foundPanicEvent, "panic.recovered event not found")
+ }
+
+ assert.Equal(t, tt.wantStatus, recordedSpan.Status().Code)
+ assert.Equal(t, tt.wantMessage, recordedSpan.Status().Description)
+ })
+ }
+}
+
+func TestRecordPanicToSpanWithComponent(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ panicValue any
+ stack []byte
+ component string
+ goroutineName string
+ wantEvent bool
+ wantStatus codes.Code
+ wantMessage string
+ }{
+ {
+ name: "with component",
+ panicValue: "panic error",
+ stack: []byte("stack trace"),
+ component: "transaction",
+ goroutineName: "CreateTransaction",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in transaction/CreateTransaction",
+ },
+ {
+ name: "empty component",
+ panicValue: "error",
+ stack: []byte("trace"),
+ component: "",
+ goroutineName: "handler",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in handler",
+ },
+ {
+ name: "empty goroutine name with component",
+ panicValue: "error",
+ stack: []byte("trace"),
+ component: "auth",
+ goroutineName: "",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in auth/",
+ },
+ {
+ name: "both empty",
+ panicValue: "panic",
+ stack: []byte(""),
+ component: "",
+ goroutineName: "",
+ wantEvent: true,
+ wantStatus: codes.Error,
+ wantMessage: "panic recovered in ",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ provider, recorder := newTestTracerProvider(t)
+ tracer := provider.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+
+ RecordPanicToSpanWithComponent(
+ ctx,
+ tt.panicValue,
+ tt.stack,
+ tt.component,
+ tt.goroutineName,
+ )
+ span.End()
+
+ spans := recorder.Ended()
+ require.Len(t, spans, 1)
+
+ recordedSpan := spans[0]
+
+ if tt.wantEvent {
+ require.NotEmpty(t, recordedSpan.Events(), "expected panic event to be recorded")
+
+ var foundPanicEvent bool
+
+ for _, event := range recordedSpan.Events() {
+ if event.Name == PanicSpanEventName {
+ foundPanicEvent = true
+
+ attrMap := make(map[string]string)
+ for _, attr := range event.Attributes {
+ attrMap[string(attr.Key)] = attr.Value.AsString()
+ }
+
+ assert.Contains(t, attrMap, "panic.value")
+ assert.Contains(t, attrMap, "panic.stack")
+ assert.Contains(t, attrMap, "panic.goroutine_name")
+ assert.Equal(t, tt.goroutineName, attrMap["panic.goroutine_name"])
+
+ if tt.component != "" {
+ assert.Contains(t, attrMap, "panic.component")
+ assert.Equal(t, tt.component, attrMap["panic.component"])
+ }
+ }
+ }
+
+ assert.True(t, foundPanicEvent, "panic.recovered event not found")
+ }
+
+ assert.Equal(t, tt.wantStatus, recordedSpan.Status().Code)
+ assert.Equal(t, tt.wantMessage, recordedSpan.Status().Description)
+ })
+ }
+}
+
+func TestRecordPanicToSpan_NoActiveSpan(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ require.NotPanics(t, func() {
+ RecordPanicToSpan(ctx, "panic value", []byte("stack"), "goroutine")
+ })
+}
+
+func TestRecordPanicToSpanWithComponent_NoActiveSpan(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+
+ require.NotPanics(t, func() {
+ RecordPanicToSpanWithComponent(
+ ctx,
+ "panic value",
+ []byte("stack"),
+ "component",
+ "goroutine",
+ )
+ })
+}
+
+func TestRecordPanicToSpan_NonRecordingSpan(t *testing.T) {
+ t.Parallel()
+
+ provider := sdktrace.NewTracerProvider()
+
+ t.Cleanup(func() {
+ _ = provider.Shutdown(context.Background())
+ })
+
+ tracer := provider.Tracer("test")
+ _, nonRecordingSpan := tracer.Start(
+ context.Background(),
+ "test-span",
+ trace.WithSpanKind(trace.SpanKindInternal),
+ )
+ nonRecordingSpan.End()
+
+ ctx := trace.ContextWithSpan(context.Background(), nonRecordingSpan)
+
+ require.NotPanics(t, func() {
+ RecordPanicToSpan(ctx, "panic value", []byte("stack"), "goroutine")
+ })
+}
+
+func TestRecordPanicToSpan_NilContext(t *testing.T) {
+ t.Parallel()
+
+ require.NotPanics(t, func() {
+ RecordPanicToSpan(context.TODO(), "panic value", []byte("stack"), "goroutine")
+ })
+}
+
+func TestRecordPanicToSpanWithComponent_NilContext(t *testing.T) {
+ t.Parallel()
+
+ require.NotPanics(t, func() {
+ RecordPanicToSpanWithComponent(
+ context.TODO(),
+ "panic value",
+ []byte("stack"),
+ "component",
+ "goroutine",
+ )
+ })
+}
+
+func TestRecordPanicToSpan_VerifyErrorRecorded(t *testing.T) {
+ t.Parallel()
+
+ provider, recorder := newTestTracerProvider(t)
+ tracer := provider.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+
+ panicValue := "test panic"
+ RecordPanicToSpan(ctx, panicValue, []byte("stack trace"), "worker")
+ span.End()
+
+ spans := recorder.Ended()
+ require.Len(t, spans, 1)
+
+ recordedSpan := spans[0]
+ events := recordedSpan.Events()
+
+ var (
+ hasExceptionEvent bool
+ hasPanicEvent bool
+ )
+
+ for _, event := range events {
+ if event.Name == "exception" {
+ hasExceptionEvent = true
+
+ attrMap := make(map[string]string)
+ for _, attr := range event.Attributes {
+ attrMap[string(attr.Key)] = attr.Value.AsString()
+ }
+
+ assert.Contains(t, attrMap["exception.message"], "panic")
+ assert.Contains(t, attrMap["exception.message"], panicValue)
+ }
+
+ if event.Name == PanicSpanEventName {
+ hasPanicEvent = true
+ }
+ }
+
+ assert.True(t, hasExceptionEvent, "expected exception event from RecordError")
+ assert.True(t, hasPanicEvent, "expected panic.recovered event")
+}
+
+func TestRecordPanicToSpan_VerifySpanAttributes(t *testing.T) {
+ t.Parallel()
+
+ provider, recorder := newTestTracerProvider(t)
+ tracer := provider.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+
+ panicValue := "detailed panic message"
+ stackTrace := []byte("goroutine 1 [running]:\nmain.main()\n\t/path/to/file.go:42")
+ goroutineName := "main-worker"
+
+ RecordPanicToSpan(ctx, panicValue, stackTrace, goroutineName)
+ span.End()
+
+ spans := recorder.Ended()
+ require.Len(t, spans, 1)
+
+ recordedSpan := spans[0]
+
+ var panicEvent *sdktrace.Event
+
+ for i := range recordedSpan.Events() {
+ if recordedSpan.Events()[i].Name == PanicSpanEventName {
+ panicEvent = &recordedSpan.Events()[i]
+
+ break
+ }
+ }
+
+ require.NotNil(t, panicEvent, "panic event not found")
+
+ attrMap := make(map[string]string)
+ for _, attr := range panicEvent.Attributes {
+ attrMap[string(attr.Key)] = attr.Value.AsString()
+ }
+
+ assert.Equal(t, panicValue, attrMap["panic.value"])
+ assert.Equal(t, string(stackTrace), attrMap["panic.stack"])
+ assert.Equal(t, goroutineName, attrMap["panic.goroutine_name"])
+}
+
+func TestRecordPanicToSpanWithComponent_VerifyComponentAttribute(t *testing.T) {
+ t.Parallel()
+
+ provider, recorder := newTestTracerProvider(t)
+ tracer := provider.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+
+ panicValue := "component panic"
+ stackTrace := []byte("stack")
+ component := "reconciliation"
+ goroutineName := "ProcessBatch"
+
+ RecordPanicToSpanWithComponent(ctx, panicValue, stackTrace, component, goroutineName)
+ span.End()
+
+ spans := recorder.Ended()
+ require.Len(t, spans, 1)
+
+ recordedSpan := spans[0]
+
+ var panicEvent *sdktrace.Event
+
+ for i := range recordedSpan.Events() {
+ if recordedSpan.Events()[i].Name == PanicSpanEventName {
+ panicEvent = &recordedSpan.Events()[i]
+
+ break
+ }
+ }
+
+ require.NotNil(t, panicEvent, "panic event not found")
+
+ attrMap := make(map[string]string)
+ for _, attr := range panicEvent.Attributes {
+ attrMap[string(attr.Key)] = attr.Value.AsString()
+ }
+
+ assert.Equal(t, panicValue, attrMap["panic.value"])
+ assert.Equal(t, string(stackTrace), attrMap["panic.stack"])
+ assert.Equal(t, goroutineName, attrMap["panic.goroutine_name"])
+ assert.Equal(t, component, attrMap["panic.component"])
+}
+
+func TestRecordPanicToSpan_ComplexPanicValues(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ panicValue any
+ wantValue string
+ }{
+ {
+ name: "struct panic value",
+ panicValue: struct{ Message string }{Message: "error"},
+ wantValue: "{error}",
+ },
+ {
+ name: "slice panic value",
+ panicValue: []string{"a", "b", "c"},
+ wantValue: "[a b c]",
+ },
+ {
+ name: "map panic value",
+ panicValue: map[string]int{"key": 1},
+ wantValue: "map[key:1]",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ provider, recorder := newTestTracerProvider(t)
+ tracer := provider.Tracer("test")
+ ctx, span := tracer.Start(context.Background(), "test-span")
+
+ RecordPanicToSpan(ctx, tt.panicValue, []byte("stack"), "goroutine")
+ span.End()
+
+ spans := recorder.Ended()
+ require.Len(t, spans, 1)
+
+ recordedSpan := spans[0]
+
+ var panicEvent *sdktrace.Event
+
+ for i := range recordedSpan.Events() {
+ if recordedSpan.Events()[i].Name == PanicSpanEventName {
+ panicEvent = &recordedSpan.Events()[i]
+
+ break
+ }
+ }
+
+ require.NotNil(t, panicEvent)
+
+ var panicValueAttr string
+
+ for _, attr := range panicEvent.Attributes {
+ if string(attr.Key) == "panic.value" {
+ panicValueAttr = attr.Value.AsString()
+
+ break
+ }
+ }
+
+ assert.Equal(t, tt.wantValue, panicValueAttr)
+ })
+ }
+}
diff --git a/commons/safe/doc.go b/commons/safe/doc.go
new file mode 100644
index 00000000..df658cb5
--- /dev/null
+++ b/commons/safe/doc.go
@@ -0,0 +1,8 @@
+// Package safe provides panic-free helpers for math, slices, and regex operations.
+//
+// Core APIs include decimal division helpers (Divide, Percentage), bounds-checked
+// slice accessors (First, Last, At), and regex compilation/matching with caching.
+//
+// Functions that can fail return explicit errors instead of panicking, so callers
+// can handle failures predictably in production paths.
+package safe
diff --git a/commons/safe/math.go b/commons/safe/math.go
new file mode 100644
index 00000000..d0d70489
--- /dev/null
+++ b/commons/safe/math.go
@@ -0,0 +1,140 @@
+package safe
+
+import (
+ "errors"
+
+ "github.com/shopspring/decimal"
+)
+
+// ErrDivisionByZero is returned when attempting to divide by zero.
+var ErrDivisionByZero = errors.New("division by zero")
+
+// percentageMultiplier is the multiplier for percentage calculations.
+const percentageMultiplier = 100
+
+// hundredDecimal is the pre-allocated decimal multiplier for percentage calculations.
+var hundredDecimal = decimal.NewFromInt(percentageMultiplier)
+
+// Divide performs decimal division with zero check.
+// Returns ErrDivisionByZero if denominator is zero.
+//
+// Example:
+//
+// result, err := safe.Divide(numerator, denominator)
+// if err != nil {
+// return fmt.Errorf("calculate ratio: %w", err)
+// }
+func Divide(numerator, denominator decimal.Decimal) (decimal.Decimal, error) {
+ if denominator.IsZero() {
+ return decimal.Zero, ErrDivisionByZero
+ }
+
+ return numerator.Div(denominator), nil
+}
+
+// DivideRound performs decimal division with rounding and zero check.
+// Returns ErrDivisionByZero if denominator is zero.
+//
+// Example:
+//
+// result, err := safe.DivideRound(numerator, denominator, 2)
+// if err != nil {
+// return fmt.Errorf("calculate percentage: %w", err)
+// }
+func DivideRound(numerator, denominator decimal.Decimal, places int32) (decimal.Decimal, error) {
+ if denominator.IsZero() {
+ return decimal.Zero, ErrDivisionByZero
+ }
+
+ return numerator.DivRound(denominator, places), nil
+}
+
+// DivideOrZero performs decimal division, returning zero if denominator is zero.
+// Use when zero is an acceptable fallback (e.g., percentage calculations where
+// zero total means zero percentage).
+//
+// Example:
+//
+// percentage := safe.DivideOrZero(matched, total).Mul(hundred)
+func DivideOrZero(numerator, denominator decimal.Decimal) decimal.Decimal {
+ if denominator.IsZero() {
+ return decimal.Zero
+ }
+
+ return numerator.Div(denominator)
+}
+
+// DivideOrDefault performs decimal division, returning defaultValue if denominator is zero.
+// Use when a specific fallback value is needed.
+//
+// Example:
+//
+// rate := safe.DivideOrDefault(resolved, total, decimal.NewFromInt(100))
+func DivideOrDefault(numerator, denominator, defaultValue decimal.Decimal) decimal.Decimal {
+ if denominator.IsZero() {
+ return defaultValue
+ }
+
+ return numerator.Div(denominator)
+}
+
+// Percentage calculates (numerator / denominator) * 100 with zero check.
+// Returns ErrDivisionByZero if denominator is zero.
+//
+// Example:
+//
+// pct, err := safe.Percentage(matched, total)
+// if err != nil {
+// return fmt.Errorf("calculate match rate: %w", err)
+// }
+func Percentage(numerator, denominator decimal.Decimal) (decimal.Decimal, error) {
+ if denominator.IsZero() {
+ return decimal.Zero, ErrDivisionByZero
+ }
+
+ return numerator.Div(denominator).Mul(hundredDecimal), nil
+}
+
+// PercentageOrZero calculates (numerator / denominator) * 100, returning zero if
+// denominator is zero. This is the common pattern for rate calculations.
+//
+// Example:
+//
+// matchRate := safe.PercentageOrZero(matched, total)
+func PercentageOrZero(numerator, denominator decimal.Decimal) decimal.Decimal {
+ if denominator.IsZero() {
+ return decimal.Zero
+ }
+
+ return numerator.Div(denominator).Mul(hundredDecimal)
+}
+
+// DivideFloat64 performs float64 division with zero check.
+// Returns ErrDivisionByZero if denominator is zero.
+//
+// Example:
+//
+// ratio, err := safe.DivideFloat64(failures, total)
+// if err != nil {
+// return fmt.Errorf("calculate failure ratio: %w", err)
+// }
+func DivideFloat64(numerator, denominator float64) (float64, error) {
+ if denominator == 0 {
+ return 0, ErrDivisionByZero
+ }
+
+ return numerator / denominator, nil
+}
+
+// DivideFloat64OrZero performs float64 division, returning zero if denominator is zero.
+//
+// Example:
+//
+// ratio := safe.DivideFloat64OrZero(failures, total)
+func DivideFloat64OrZero(numerator, denominator float64) float64 {
+ if denominator == 0 {
+ return 0
+ }
+
+ return numerator / denominator
+}
diff --git a/commons/safe/math_test.go b/commons/safe/math_test.go
new file mode 100644
index 00000000..ecaffe4e
--- /dev/null
+++ b/commons/safe/math_test.go
@@ -0,0 +1,326 @@
+//go:build unit
+
+package safe
+
+import (
+ "testing"
+
+ "github.com/shopspring/decimal"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestDivide(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ numerator decimal.Decimal
+ denominator decimal.Decimal
+ want decimal.Decimal
+ wantErr error
+ }{
+ {
+ name: "success",
+ numerator: decimal.NewFromInt(100),
+ denominator: decimal.NewFromInt(4),
+ want: decimal.NewFromInt(25),
+ wantErr: nil,
+ },
+ {
+ name: "zero denominator",
+ numerator: decimal.NewFromInt(100),
+ denominator: decimal.Zero,
+ want: decimal.Zero,
+ wantErr: ErrDivisionByZero,
+ },
+ {
+ name: "zero numerator",
+ numerator: decimal.Zero,
+ denominator: decimal.NewFromInt(4),
+ want: decimal.Zero,
+ wantErr: nil,
+ },
+ {
+ name: "negative numbers",
+ numerator: decimal.NewFromInt(-100),
+ denominator: decimal.NewFromInt(4),
+ want: decimal.NewFromInt(-25),
+ wantErr: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := Divide(tt.numerator, tt.denominator)
+
+ if tt.wantErr != nil {
+ assert.ErrorIs(t, err, tt.wantErr)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.True(t, result.Equal(tt.want), "expected %s, got %s", tt.want, result)
+ })
+ }
+}
+
+func TestDivideRound_Success(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(100)
+ denominator := decimal.NewFromInt(3)
+
+ result, err := DivideRound(numerator, denominator, 2)
+
+ assert.NoError(t, err)
+ expected := decimal.NewFromFloat(33.33)
+ assert.True(t, result.Equal(expected), "expected %s, got %s", expected, result)
+}
+
+func TestDivideRound_ZeroDenominator(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(100)
+ denominator := decimal.Zero
+
+ result, err := DivideRound(numerator, denominator, 2)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrDivisionByZero)
+ assert.True(t, result.IsZero())
+}
+
+func TestDivideOrZero_Success(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(100)
+ denominator := decimal.NewFromInt(4)
+
+ result := DivideOrZero(numerator, denominator)
+
+ assert.True(t, result.Equal(decimal.NewFromInt(25)))
+}
+
+func TestDivideOrZero_ZeroDenominator(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(100)
+ denominator := decimal.Zero
+
+ result := DivideOrZero(numerator, denominator)
+
+ assert.True(t, result.IsZero())
+}
+
+func TestDivideOrDefault_Success(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(100)
+ denominator := decimal.NewFromInt(4)
+ defaultVal := decimal.NewFromInt(999)
+
+ result := DivideOrDefault(numerator, denominator, defaultVal)
+
+ assert.True(t, result.Equal(decimal.NewFromInt(25)))
+}
+
+func TestDivideOrDefault_ZeroDenominator(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(100)
+ denominator := decimal.Zero
+ defaultVal := decimal.NewFromInt(999)
+
+ result := DivideOrDefault(numerator, denominator, defaultVal)
+
+ assert.True(t, result.Equal(defaultVal))
+}
+
+func TestPercentage_Success(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(25)
+ denominator := decimal.NewFromInt(100)
+
+ result, err := Percentage(numerator, denominator)
+
+ assert.NoError(t, err)
+ assert.True(t, result.Equal(decimal.NewFromInt(25)))
+}
+
+func TestPercentage_ZeroDenominator(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(25)
+ denominator := decimal.Zero
+
+ result, err := Percentage(numerator, denominator)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrDivisionByZero)
+ assert.True(t, result.IsZero())
+}
+
+func TestPercentage_FullPercentage(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(100)
+ denominator := decimal.NewFromInt(100)
+
+ result, err := Percentage(numerator, denominator)
+
+ assert.NoError(t, err)
+ assert.True(t, result.Equal(decimal.NewFromInt(100)))
+}
+
+func TestPercentageOrZero_Success(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(50)
+ denominator := decimal.NewFromInt(100)
+
+ result := PercentageOrZero(numerator, denominator)
+
+ assert.True(t, result.Equal(decimal.NewFromInt(50)))
+}
+
+func TestPercentageOrZero_ZeroDenominator(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.NewFromInt(50)
+ denominator := decimal.Zero
+
+ result := PercentageOrZero(numerator, denominator)
+
+ assert.True(t, result.IsZero())
+}
+
+func TestPercentageOrZero_ZeroNumerator(t *testing.T) {
+ t.Parallel()
+
+ numerator := decimal.Zero
+ denominator := decimal.NewFromInt(100)
+
+ result := PercentageOrZero(numerator, denominator)
+
+ assert.True(t, result.IsZero())
+}
+
+func TestDivideFloat64(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ numerator float64
+ denominator float64
+ want float64
+ wantErr error
+ }{
+ {
+ name: "success",
+ numerator: 100,
+ denominator: 4,
+ want: 25,
+ wantErr: nil,
+ },
+ {
+ name: "zero denominator",
+ numerator: 100,
+ denominator: 0,
+ want: 0,
+ wantErr: ErrDivisionByZero,
+ },
+ {
+ name: "zero numerator",
+ numerator: 0,
+ denominator: 4,
+ want: 0,
+ wantErr: nil,
+ },
+ {
+ name: "negative numerator",
+ numerator: -100,
+ denominator: 4,
+ want: -25,
+ wantErr: nil,
+ },
+ {
+ name: "negative denominator",
+ numerator: 100,
+ denominator: -4,
+ want: -25,
+ wantErr: nil,
+ },
+ {
+ name: "both negative",
+ numerator: -100,
+ denominator: -4,
+ want: 25,
+ wantErr: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := DivideFloat64(tt.numerator, tt.denominator)
+
+ if tt.wantErr != nil {
+ assert.ErrorIs(t, err, tt.wantErr)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ assert.InDelta(t, tt.want, result, 1e-9)
+ })
+ }
+}
+
+func TestDivideFloat64OrZero(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ numerator float64
+ denominator float64
+ want float64
+ }{
+ {
+ name: "success",
+ numerator: 100,
+ denominator: 4,
+ want: 25,
+ },
+ {
+ name: "zero denominator",
+ numerator: 100,
+ denominator: 0,
+ want: 0,
+ },
+ {
+ name: "zero numerator",
+ numerator: 0,
+ denominator: 4,
+ want: 0,
+ },
+ {
+ name: "negative numerator",
+ numerator: -100,
+ denominator: 4,
+ want: -25,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := DivideFloat64OrZero(tt.numerator, tt.denominator)
+
+ assert.InDelta(t, tt.want, result, 1e-9)
+ })
+ }
+}
diff --git a/commons/safe/regex.go b/commons/safe/regex.go
new file mode 100644
index 00000000..e90db7fc
--- /dev/null
+++ b/commons/safe/regex.go
@@ -0,0 +1,163 @@
+package safe
+
+import (
+ "errors"
+ "fmt"
+ "regexp"
+ "sync"
+)
+
+// ErrInvalidRegex is returned when a regex pattern cannot be compiled.
+var ErrInvalidRegex = errors.New("invalid regular expression")
+
+// maxCacheSize is the upper bound for cached compiled regex patterns.
+// When this limit is reached, the entire cache is cleared to prevent
+// unbounded memory growth from dynamic user-provided patterns.
+const maxCacheSize = 1024
+
+// regexCache caches compiled regex patterns for performance.
+// Protected by regexMu; bounded to maxCacheSize entries.
+var (
+ regexMu sync.RWMutex
+ regexCache = make(map[string]*regexp.Regexp)
+)
+
+// cacheLoad returns a cached regex and true if it exists, or nil and false.
+func cacheLoad(key string) (*regexp.Regexp, bool) {
+ regexMu.RLock()
+ defer regexMu.RUnlock()
+
+ re, ok := regexCache[key]
+
+ return re, ok
+}
+
+// evictionFraction is the proportion of entries to evict when the cache is full.
+// 25% eviction provides a balance between reclaiming space and preserving hot entries.
+const evictionFraction = 4 // 1/4 = 25%
+
+// cacheStore stores a compiled regex, evicting a random subset if the cache is full.
+// When at capacity, approximately 25% of entries are evicted (random map iteration order).
+func cacheStore(key string, re *regexp.Regexp) {
+ regexMu.Lock()
+ defer regexMu.Unlock()
+
+ if len(regexCache) >= maxCacheSize {
+ evictCount := len(regexCache) / evictionFraction
+ if evictCount == 0 {
+ evictCount = 1
+ }
+
+ evicted := 0
+
+ for k := range regexCache {
+ delete(regexCache, k)
+
+ evicted++
+
+ if evicted >= evictCount {
+ break
+ }
+ }
+ }
+
+ regexCache[key] = re
+}
+
+// Compile compiles a regex pattern with error return instead of panic.
+// Compiled patterns are cached for performance.
+//
+// Use this for dynamic patterns (e.g., user-provided patterns).
+// For static compile-time patterns, use regexp.MustCompile directly.
+//
+// Example:
+//
+// re, err := safe.Compile(userPattern)
+// if err != nil {
+// return fmt.Errorf("invalid pattern: %w", err)
+// }
+// matches := re.FindAllString(input, -1)
+func Compile(pattern string) (*regexp.Regexp, error) {
+ if cached, ok := cacheLoad(pattern); ok {
+ return cached, nil
+ }
+
+ re, err := regexp.Compile(pattern)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrInvalidRegex, err)
+ }
+
+ cacheStore(pattern, re)
+
+ return re, nil
+}
+
+// CompilePOSIX compiles a POSIX ERE regex pattern with error return.
+// Compiled patterns are cached for performance.
+//
+// Example:
+//
+// re, err := safe.CompilePOSIX(userPattern)
+// if err != nil {
+// return fmt.Errorf("invalid POSIX pattern: %w", err)
+// }
+func CompilePOSIX(pattern string) (*regexp.Regexp, error) {
+ cacheKey := "posix:" + pattern
+
+ if cached, ok := cacheLoad(cacheKey); ok {
+ return cached, nil
+ }
+
+ re, err := regexp.CompilePOSIX(pattern)
+ if err != nil {
+ return nil, fmt.Errorf("%w: %w", ErrInvalidRegex, err)
+ }
+
+ cacheStore(cacheKey, re)
+
+ return re, nil
+}
+
+// MatchString compiles and matches a pattern against input in one call.
+// Returns false if the pattern is invalid.
+//
+// Example:
+//
+// matched, err := safe.MatchString(`^\d{4}-\d{2}-\d{2}$`, dateStr)
+// if err != nil {
+// return fmt.Errorf("invalid date pattern: %w", err)
+// }
+func MatchString(pattern, input string) (bool, error) {
+ re, err := Compile(pattern)
+ if err != nil {
+ return false, err
+ }
+
+ return re.MatchString(input), nil
+}
+
+// FindString compiles and finds the first match.
+// Returns ("", error) if the pattern is invalid, or ("", nil) if no match is found.
+//
+// Example:
+//
+// match, err := safe.FindString(`[a-z]+`, input)
+// if err != nil {
+// return fmt.Errorf("invalid pattern: %w", err)
+// }
+func FindString(pattern, input string) (string, error) {
+ re, err := Compile(pattern)
+ if err != nil {
+ return "", err
+ }
+
+ return re.FindString(input), nil
+}
+
+// ClearCache clears the regex cache. Useful for testing.
+func ClearCache() {
+ regexMu.Lock()
+ defer regexMu.Unlock()
+
+ regexCache = make(map[string]*regexp.Regexp)
+}
diff --git a/commons/safe/regex_example_test.go b/commons/safe/regex_example_test.go
new file mode 100644
index 00000000..ac9dae9e
--- /dev/null
+++ b/commons/safe/regex_example_test.go
@@ -0,0 +1,19 @@
+//go:build unit
+
+package safe_test
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/safe"
+)
+
+func ExampleCompile_errorHandling() {
+ _, err := safe.Compile("[")
+
+ fmt.Println(errors.Is(err, safe.ErrInvalidRegex))
+
+ // Output:
+ // true
+}
diff --git a/commons/safe/regex_test.go b/commons/safe/regex_test.go
new file mode 100644
index 00000000..d4e8edc9
--- /dev/null
+++ b/commons/safe/regex_test.go
@@ -0,0 +1,243 @@
+//go:build unit
+
+package safe
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// testCacheLen returns the current number of entries in the regex cache.
+// This is a test-only helper to verify cache behavior without exporting
+// the function from the production code.
+func testCacheLen() int {
+ regexMu.RLock()
+ defer regexMu.RUnlock()
+
+ return len(regexCache)
+}
+
+// TestCompile verifies safe regex compilation and caching behavior.
+// t.Parallel() is intentionally omitted because this test mutates the
+// package-level regexCache via ClearCache, which would race with other
+// cache-dependent tests running concurrently.
+func TestCompile(t *testing.T) {
+ ClearCache()
+
+ t.Run("valid pattern", func(t *testing.T) {
+ re, err := Compile(`^\d{4}-\d{2}-\d{2}$`)
+
+ assert.NoError(t, err)
+ assert.NotNil(t, re)
+ assert.True(t, re.MatchString("2026-01-27"))
+ assert.False(t, re.MatchString("invalid"))
+ })
+
+ t.Run("invalid pattern", func(t *testing.T) {
+ re, err := Compile(`[invalid(`)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidRegex)
+ assert.Nil(t, re)
+ })
+
+ t.Run("caching", func(t *testing.T) {
+ ClearCache()
+
+ pattern := `^\d+$`
+
+ re1, err1 := Compile(pattern)
+ re2, err2 := Compile(pattern)
+
+ assert.NoError(t, err1)
+ assert.NoError(t, err2)
+ assert.Same(t, re1, re2)
+ })
+
+ t.Run("empty pattern", func(t *testing.T) {
+ re, err := Compile("")
+
+ assert.NoError(t, err)
+ assert.NotNil(t, re)
+ assert.True(t, re.MatchString("anything"))
+ })
+}
+
+// TestCompilePOSIX verifies POSIX regex compilation and caching.
+// t.Parallel() is intentionally omitted because this test mutates the
+// package-level regexCache via ClearCache, which would race with other
+// cache-dependent tests running concurrently.
+func TestCompilePOSIX(t *testing.T) {
+ ClearCache()
+
+ t.Run("valid pattern", func(t *testing.T) {
+ re, err := CompilePOSIX(`^[0-9]+$`)
+
+ assert.NoError(t, err)
+ assert.NotNil(t, re)
+ assert.True(t, re.MatchString("12345"))
+ assert.False(t, re.MatchString("abc"))
+ })
+
+ t.Run("invalid pattern", func(t *testing.T) {
+ re, err := CompilePOSIX(`[invalid(`)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidRegex)
+ assert.Nil(t, re)
+ })
+
+ t.Run("caching", func(t *testing.T) {
+ ClearCache()
+
+ pattern := `^[a-z]+$`
+
+ re1, err1 := CompilePOSIX(pattern)
+ re2, err2 := CompilePOSIX(pattern)
+
+ assert.NoError(t, err1)
+ assert.NoError(t, err2)
+ assert.Same(t, re1, re2)
+ })
+}
+
+// TestMatchString verifies the convenience MatchString wrapper.
+// t.Parallel() is intentionally omitted because this test mutates the
+// package-level regexCache via ClearCache, which would race with other
+// cache-dependent tests running concurrently.
+func TestMatchString(t *testing.T) {
+ ClearCache()
+
+ t.Run("valid pattern match", func(t *testing.T) {
+ matched, err := MatchString(`^\d{4}-\d{2}-\d{2}$`, "2026-01-27")
+
+ assert.NoError(t, err)
+ assert.True(t, matched)
+ })
+
+ t.Run("valid pattern no match", func(t *testing.T) {
+ matched, err := MatchString(`^\d{4}-\d{2}-\d{2}$`, "invalid-date")
+
+ assert.NoError(t, err)
+ assert.False(t, matched)
+ })
+
+ t.Run("invalid pattern", func(t *testing.T) {
+ matched, err := MatchString(`[invalid(`, "test")
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidRegex)
+ assert.False(t, matched)
+ })
+}
+
+// TestFindString verifies the convenience FindString wrapper.
+// t.Parallel() is intentionally omitted because this test mutates the
+// package-level regexCache via ClearCache, which would race with other
+// cache-dependent tests running concurrently.
+func TestFindString(t *testing.T) {
+ ClearCache()
+
+ t.Run("valid pattern match", func(t *testing.T) {
+ match, err := FindString(`[a-z]+`, "123abc456")
+
+ assert.NoError(t, err)
+ assert.Equal(t, "abc", match)
+ })
+
+ t.Run("valid pattern no match", func(t *testing.T) {
+ match, err := FindString(`[a-z]+`, "123456")
+
+ assert.NoError(t, err)
+ assert.Empty(t, match)
+ })
+
+ t.Run("invalid pattern", func(t *testing.T) {
+ match, err := FindString(`[invalid(`, "test")
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrInvalidRegex)
+ assert.Empty(t, match)
+ })
+}
+
+// TestClearCache verifies that ClearCache removes all cached entries and
+// subsequent compilations produce new regex instances.
+// t.Parallel() is intentionally omitted because this test mutates the
+// package-level regexCache via ClearCache, which would race with other
+// cache-dependent tests running concurrently.
+func TestClearCache(t *testing.T) {
+ pattern := `^test$`
+
+ re1, _ := Compile(pattern)
+ ClearCache()
+
+ re2, _ := Compile(pattern)
+
+ assert.NotSame(t, re1, re2)
+}
+
+// TestCacheBoundedSize verifies that the regex cache does not grow beyond
+// maxCacheSize entries. When the cache is full, storing a new entry evicts
+// approximately 25% of entries to reclaim space while preserving hot entries.
+// t.Parallel() is intentionally omitted because this test mutates the
+// package-level regexCache via ClearCache, which would race with other
+// cache-dependent tests running concurrently.
+func TestCacheBoundedSize(t *testing.T) {
+ ClearCache()
+
+ // Fill the cache to maxCacheSize.
+ for i := range maxCacheSize {
+ pattern := fmt.Sprintf(`^pattern_%d$`, i)
+
+ _, err := Compile(pattern)
+ require.NoError(t, err)
+ }
+
+ require.Equal(t, maxCacheSize, testCacheLen(), "cache should be full at maxCacheSize")
+
+ // One more entry should trigger ~25% eviction + store the new entry.
+ _, err := Compile(`^overflow_pattern$`)
+ require.NoError(t, err)
+
+ cacheLen := testCacheLen()
+ // After evicting ~256 entries (25% of 1024) and adding 1, we expect ~769.
+ // Allow a range to account for the exact eviction count.
+ require.Less(t, cacheLen, maxCacheSize, "cache should be smaller than maxCacheSize after eviction")
+ require.Greater(t, cacheLen, maxCacheSize/2, "cache should retain majority of entries (random 25%% eviction)")
+
+ // Re-compiled patterns should still be cached after compilation.
+ re1, _ := Compile(`^pattern_fresh$`)
+ re2, _ := Compile(`^pattern_fresh$`)
+ assert.Same(t, re1, re2, "same pattern compiled twice should return cached instance")
+}
+
+// TestCacheBoundedSizePOSIX verifies the same bounded cache behavior for
+// POSIX patterns, which share the same cache with a "posix:" key prefix.
+// t.Parallel() is intentionally omitted because this test mutates the
+// package-level regexCache via ClearCache, which would race with other
+// cache-dependent tests running concurrently.
+func TestCacheBoundedSizePOSIX(t *testing.T) {
+ ClearCache()
+
+ // Fill the cache to maxCacheSize with POSIX patterns.
+ for i := range maxCacheSize {
+ pattern := fmt.Sprintf(`^posix_%d$`, i)
+
+ _, err := CompilePOSIX(pattern)
+ require.NoError(t, err)
+ }
+
+ require.Equal(t, maxCacheSize, testCacheLen(), "cache should be full at maxCacheSize")
+
+ // One more POSIX entry should trigger ~25% eviction.
+ _, err := CompilePOSIX(`^posix_overflow$`)
+ require.NoError(t, err)
+
+ cacheLen := testCacheLen()
+ require.Less(t, cacheLen, maxCacheSize, "cache should be smaller than maxCacheSize after eviction")
+ require.Greater(t, cacheLen, maxCacheSize/2, "cache should retain majority of entries")
+}
diff --git a/commons/safe/safe_example_test.go b/commons/safe/safe_example_test.go
new file mode 100644
index 00000000..68d333e4
--- /dev/null
+++ b/commons/safe/safe_example_test.go
@@ -0,0 +1,21 @@
+//go:build unit
+
+package safe_test
+
+import (
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/safe"
+ "github.com/shopspring/decimal"
+)
+
+func ExampleDivide() {
+ result, err := safe.Divide(decimal.NewFromInt(25), decimal.NewFromInt(5))
+
+ fmt.Println(err == nil)
+ fmt.Println(result.String())
+
+ // Output:
+ // true
+ // 5
+}
diff --git a/commons/safe/slice.go b/commons/safe/slice.go
new file mode 100644
index 00000000..c49c710e
--- /dev/null
+++ b/commons/safe/slice.go
@@ -0,0 +1,108 @@
+package safe
+
+import (
+ "errors"
+ "fmt"
+)
+
+// ErrEmptySlice is returned when attempting to access elements of an empty slice.
+var ErrEmptySlice = errors.New("empty slice")
+
+// ErrIndexOutOfBounds is returned when an index is outside the valid range.
+var ErrIndexOutOfBounds = errors.New("index out of bounds")
+
+// First returns the first element of a slice.
+// Returns ErrEmptySlice if the slice is empty.
+//
+// Example:
+//
+// first, err := safe.First(items)
+// if err != nil {
+// return fmt.Errorf("get first item: %w", err)
+// }
+func First[T any](slice []T) (T, error) {
+ var zero T
+
+ if len(slice) == 0 {
+ return zero, ErrEmptySlice
+ }
+
+ return slice[0], nil
+}
+
+// Last returns the last element of a slice.
+// Returns ErrEmptySlice if the slice is empty.
+//
+// Example:
+//
+// last, err := safe.Last(items)
+// if err != nil {
+// return fmt.Errorf("get last item: %w", err)
+// }
+func Last[T any](slice []T) (T, error) {
+ var zero T
+
+ if len(slice) == 0 {
+ return zero, ErrEmptySlice
+ }
+
+ return slice[len(slice)-1], nil
+}
+
+// At returns the element at the specified index.
+// Returns ErrIndexOutOfBounds if the index is out of range.
+//
+// Example:
+//
+// item, err := safe.At(items, 5)
+// if err != nil {
+// return fmt.Errorf("get item at index 5: %w", err)
+// }
+func At[T any](slice []T, index int) (T, error) {
+ var zero T
+
+ if index < 0 || index >= len(slice) {
+ return zero, fmt.Errorf("%w: index %d, length %d", ErrIndexOutOfBounds, index, len(slice))
+ }
+
+ return slice[index], nil
+}
+
+// FirstOrDefault returns the first element of a slice, or defaultValue if empty.
+//
+// Example:
+//
+// first := safe.FirstOrDefault(items, defaultItem)
+func FirstOrDefault[T any](slice []T, defaultValue T) T {
+ if len(slice) == 0 {
+ return defaultValue
+ }
+
+ return slice[0]
+}
+
+// LastOrDefault returns the last element of a slice, or defaultValue if empty.
+//
+// Example:
+//
+// last := safe.LastOrDefault(items, defaultItem)
+func LastOrDefault[T any](slice []T, defaultValue T) T {
+ if len(slice) == 0 {
+ return defaultValue
+ }
+
+ return slice[len(slice)-1]
+}
+
+// AtOrDefault returns the element at index, or defaultValue if out of bounds.
+//
+// Example:
+//
+// item := safe.AtOrDefault(items, 5, defaultItem)
+func AtOrDefault[T any](slice []T, index int, defaultValue T) T {
+ if index < 0 || index >= len(slice) {
+ return defaultValue
+ }
+
+ return slice[index]
+}
diff --git a/commons/safe/slice_test.go b/commons/safe/slice_test.go
new file mode 100644
index 00000000..c13a5aeb
--- /dev/null
+++ b/commons/safe/slice_test.go
@@ -0,0 +1,216 @@
+//go:build unit
+
+package safe
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestFirst_Success(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{1, 2, 3}
+
+ result, err := First(slice)
+
+ assert.NoError(t, err)
+ assert.Equal(t, 1, result)
+}
+
+func TestFirst_EmptySlice(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{}
+
+ result, err := First(slice)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptySlice)
+ assert.Equal(t, 0, result)
+}
+
+func TestFirst_SingleElement(t *testing.T) {
+ t.Parallel()
+
+ slice := []string{"only"}
+
+ result, err := First(slice)
+
+ assert.NoError(t, err)
+ assert.Equal(t, "only", result)
+}
+
+func TestLast_Success(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{1, 2, 3}
+
+ result, err := Last(slice)
+
+ assert.NoError(t, err)
+ assert.Equal(t, 3, result)
+}
+
+func TestLast_EmptySlice(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{}
+
+ result, err := Last(slice)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrEmptySlice)
+ assert.Equal(t, 0, result)
+}
+
+func TestLast_SingleElement(t *testing.T) {
+ t.Parallel()
+
+ slice := []string{"only"}
+
+ result, err := Last(slice)
+
+ assert.NoError(t, err)
+ assert.Equal(t, "only", result)
+}
+
+func TestAt_Success(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result, err := At(slice, 1)
+
+ assert.NoError(t, err)
+ assert.Equal(t, 20, result)
+}
+
+func TestAt_FirstIndex(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result, err := At(slice, 0)
+
+ assert.NoError(t, err)
+ assert.Equal(t, 10, result)
+}
+
+func TestAt_LastIndex(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result, err := At(slice, 2)
+
+ assert.NoError(t, err)
+ assert.Equal(t, 30, result)
+}
+
+func TestAt_NegativeIndex(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result, err := At(slice, -1)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrIndexOutOfBounds)
+ assert.Equal(t, 0, result)
+}
+
+func TestAt_IndexTooLarge(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result, err := At(slice, 5)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrIndexOutOfBounds)
+ assert.Equal(t, 0, result)
+}
+
+func TestAt_EmptySlice(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{}
+
+ result, err := At(slice, 0)
+
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, ErrIndexOutOfBounds)
+ assert.Equal(t, 0, result)
+}
+
+func TestFirstOrDefault_Success(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{1, 2, 3}
+
+ result := FirstOrDefault(slice, 99)
+
+ assert.Equal(t, 1, result)
+}
+
+func TestFirstOrDefault_EmptySlice(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{}
+
+ result := FirstOrDefault(slice, 99)
+
+ assert.Equal(t, 99, result)
+}
+
+func TestLastOrDefault_Success(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{1, 2, 3}
+
+ result := LastOrDefault(slice, 99)
+
+ assert.Equal(t, 3, result)
+}
+
+func TestLastOrDefault_EmptySlice(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{}
+
+ result := LastOrDefault(slice, 99)
+
+ assert.Equal(t, 99, result)
+}
+
+func TestAtOrDefault_Success(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result := AtOrDefault(slice, 1, 99)
+
+ assert.Equal(t, 20, result)
+}
+
+func TestAtOrDefault_OutOfBounds(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result := AtOrDefault(slice, 5, 99)
+
+ assert.Equal(t, 99, result)
+}
+
+func TestAtOrDefault_NegativeIndex(t *testing.T) {
+ t.Parallel()
+
+ slice := []int{10, 20, 30}
+
+ result := AtOrDefault(slice, -1, 99)
+
+ assert.Equal(t, 99, result)
+}
diff --git a/commons/secretsmanager/m2m.go b/commons/secretsmanager/m2m.go
new file mode 100644
index 00000000..ae8f200c
--- /dev/null
+++ b/commons/secretsmanager/m2m.go
@@ -0,0 +1,280 @@
+// Copyright Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+// Package secretsmanager provides functions for retrieving M2M (machine-to-machine)
+// credentials from AWS Secrets Manager.
+//
+// This package is designed to be self-contained with no dependency on internal packages,
+// making it suitable for migration to lib-commons.
+//
+// # M2M Credentials
+//
+// M2M credentials are OAuth2 client credentials stored in AWS Secrets Manager
+// following the path convention:
+//
+// tenants/{env}/{tenantOrgID}/{applicationName}/m2m/{targetService}/credentials
+//
+// # Usage
+//
+// A plugin retrieves credentials to authenticate with a product service:
+//
+// // Create AWS Secrets Manager client
+// cfg, err := awsconfig.LoadDefaultConfig(ctx)
+// if err != nil {
+// // handle error
+// }
+// client := secretsmanager.NewFromConfig(cfg)
+//
+// // Fetch M2M credentials
+// creds, err := secretsmanager.GetM2MCredentials(ctx, client, "staging", tenantOrgID, "plugin-pix", "ledger")
+// if err != nil {
+// // handle error
+// }
+//
+// // Use credentials to obtain an access token via client_credentials grant
+// // Post to the token endpoint with grant_type=client_credentials
+// // Authorization: Basic(creds.ClientID, creds.ClientSecret)
+//
+// # Thread Safety
+//
+// All functions in this package are safe for concurrent use.
+// No package-level mutable state is maintained.
+package secretsmanager
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/secretsmanager"
+ smtypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types"
+ smithy "github.com/aws/smithy-go"
+)
+
+// Sentinel errors for M2M credential operations.
+var (
+ // ErrM2MCredentialsNotFound is returned when M2M credentials cannot be found at the expected path.
+ ErrM2MCredentialsNotFound = errors.New("M2M credentials not found")
+
+ // ErrM2MVaultAccessDenied is returned when access to the vault is denied (missing IAM permissions or expired tokens).
+ ErrM2MVaultAccessDenied = errors.New("vault access denied")
+
+ // ErrM2MRetrievalFailed is returned when M2M credential retrieval fails due to infrastructure issues.
+ ErrM2MRetrievalFailed = errors.New("failed to retrieve M2M credentials")
+
+ // ErrM2MUnmarshalFailed is returned when the secret value cannot be deserialized into M2MCredentials.
+ ErrM2MUnmarshalFailed = errors.New("failed to unmarshal M2M credentials")
+
+ // ErrM2MInvalidInput is returned when required input parameters are missing.
+ ErrM2MInvalidInput = errors.New("invalid input")
+
+ // ErrM2MInvalidCredentials is returned when retrieved credentials are incomplete (missing required fields).
+ ErrM2MInvalidCredentials = errors.New("incomplete M2M credentials")
+
+ // ErrM2MBinarySecretNotSupported is returned when the secret is stored as binary data rather than a string.
+ ErrM2MBinarySecretNotSupported = errors.New("binary secrets are not supported for M2M credentials")
+
+ // ErrM2MInvalidPathSegment is returned when a path segment contains path traversal characters.
+ ErrM2MInvalidPathSegment = errors.New("invalid path segment")
+)
+
+// validatePathSegment checks that a path segment is safe for use in secret paths.
+// It rejects segments containing path traversal characters (/, .., \) and
+// trims leading/trailing whitespace.
+func validatePathSegment(name, value string) (string, error) {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ return "", fmt.Errorf("%w: %s is required", ErrM2MInvalidInput, name)
+ }
+
+ if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") || strings.Contains(trimmed, "..") {
+ return "", fmt.Errorf("%w: %s contains path traversal characters", ErrM2MInvalidPathSegment, name)
+ }
+
+ return trimmed, nil
+}
+
+// redactPath returns a safe representation of a secret path for error messages.
+// It includes only the last path segment and a truncated hash of the full path.
+func redactPath(secretPath string) string {
+ parts := strings.Split(secretPath, "/")
+ lastSegment := parts[len(parts)-1]
+
+ h := sha256.Sum256([]byte(secretPath))
+ shortHash := hex.EncodeToString(h[:4]) // 8 hex chars
+
+ return fmt.Sprintf(".../%s [%s]", lastSegment, shortHash)
+}
+
+// isNilInterface returns true if the interface value is nil or holds a typed nil.
+func isNilInterface(i any) bool {
+ if i == nil {
+ return true
+ }
+
+ v := reflect.ValueOf(i)
+
+ return v.Kind() == reflect.Ptr && v.IsNil()
+}
+
+// M2MCredentials holds credentials retrieved from the Secret Vault.
+// These credentials are used for OAuth2 client_credentials grant
+// to authenticate plugins with product services.
+type M2MCredentials struct {
+ ClientID string `json:"clientId"`
+ ClientSecret string `json:"clientSecret"` // #nosec G117 -- secret payload is intentionally deserialized from AWS Secrets Manager and redacted by String/GoString
+}
+
+// String redacts secret material from formatted output.
+func (c M2MCredentials) String() string {
+ return fmt.Sprintf("M2MCredentials{ClientID:%q, ClientSecret:REDACTED}", c.ClientID)
+}
+
+// GoString redacts secret material from Go-syntax formatted output.
+func (c M2MCredentials) GoString() string {
+ return c.String()
+}
+
+// SecretsManagerClient abstracts AWS Secrets Manager operations.
+// This interface allows for easier testing with mocks.
+type SecretsManagerClient interface {
+ GetSecretValue(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error)
+}
+
+// GetM2MCredentials fetches M2M credentials from AWS Secrets Manager.
+//
+// Parameters:
+// - ctx: context for cancellation and tracing
+// - client: AWS Secrets Manager client (must not be nil)
+// - env: deployment environment (e.g., "staging", "production"); empty string is accepted for backward compatibility
+// - tenantOrgID: resolved from request context (JWT owner claim); must not be empty
+// - applicationName: the plugin name (e.g., "plugin-pix"); must not be empty
+// - targetService: the product service name (e.g., "ledger"); must not be empty
+//
+// Path convention:
+//
+// tenants/{env}/{tenantOrgID}/{applicationName}/m2m/{targetService}/credentials
+//
+// Returns descriptive errors when:
+// - client is nil
+// - required parameters are missing
+// - secret not found at path
+// - vault credentials are missing or expired
+// - secret value is not valid JSON
+//
+// Safe for concurrent use (no shared mutable state).
+func GetM2MCredentials(ctx context.Context, client SecretsManagerClient, env, tenantOrgID, applicationName, targetService string) (*M2MCredentials, error) {
+ // Validate inputs - check for typed-nil client using reflect
+ if isNilInterface(client) {
+ return nil, fmt.Errorf("%w: client is required", ErrM2MInvalidInput)
+ }
+
+ // Validate and sanitize path segments (trims whitespace, rejects traversal chars)
+ cleanTenantOrgID, err := validatePathSegment("tenantOrgID", tenantOrgID)
+ if err != nil {
+ return nil, err
+ }
+
+ cleanAppName, err := validatePathSegment("applicationName", applicationName)
+ if err != nil {
+ return nil, err
+ }
+
+ cleanTargetService, err := validatePathSegment("targetService", targetService)
+ if err != nil {
+ return nil, err
+ }
+
+ // env is optional (empty for backward compat) but must be safe if provided
+ cleanEnv := strings.TrimSpace(env)
+ if cleanEnv != "" {
+ if strings.Contains(cleanEnv, "/") || strings.Contains(cleanEnv, "\\") || strings.Contains(cleanEnv, "..") {
+ return nil, fmt.Errorf("%w: env contains path traversal characters", ErrM2MInvalidPathSegment)
+ }
+ }
+
+ // Build the secret path
+ secretPath := buildM2MSecretPath(cleanEnv, cleanTenantOrgID, cleanAppName, cleanTargetService)
+ redacted := redactPath(secretPath)
+
+ // Fetch the secret from AWS Secrets Manager
+ input := &secretsmanager.GetSecretValueInput{
+ SecretId: aws.String(secretPath),
+ }
+
+ output, err := client.GetSecretValue(ctx, input)
+ if err != nil {
+ return nil, classifyAWSError(err, secretPath)
+ }
+
+ // Check for binary secret FIRST (before attempting JSON unmarshal)
+ if output == nil || output.SecretString == nil {
+ return nil, fmt.Errorf("%w: secret at %s is binary or nil", ErrM2MBinarySecretNotSupported, redacted)
+ }
+
+ // Unmarshal the JSON credentials
+ var creds M2MCredentials
+ if err := json.Unmarshal([]byte(*output.SecretString), &creds); err != nil {
+ return nil, fmt.Errorf("%w: secret at %s: %w", ErrM2MUnmarshalFailed, redacted, err)
+ }
+
+ // Validate required credential fields
+ var missing []string
+ if creds.ClientID == "" {
+ missing = append(missing, "clientId")
+ }
+
+ if creds.ClientSecret == "" {
+ missing = append(missing, "clientSecret")
+ }
+
+ if len(missing) > 0 {
+ return nil, fmt.Errorf("%w: secret at %s: missing fields: %s", ErrM2MInvalidCredentials, redacted, strings.Join(missing, ", "))
+ }
+
+ return &creds, nil
+}
+
+// buildM2MSecretPath constructs the secret path for M2M credentials.
+//
+// Format: tenants/{env}/{tenantOrgID}/{applicationName}/m2m/{targetService}/credentials
+//
+// When env is empty, the path omits the environment segment for backward compatibility:
+//
+// tenants/{tenantOrgID}/{applicationName}/m2m/{targetService}/credentials
+func buildM2MSecretPath(env, tenantOrgID, applicationName, targetService string) string {
+ envPrefix := ""
+ if env != "" {
+ envPrefix = env + "/"
+ }
+
+ return fmt.Sprintf("tenants/%s%s/%s/m2m/%s/credentials", envPrefix, tenantOrgID, applicationName, targetService)
+}
+
+// classifyAWSError maps AWS SDK errors to domain-specific sentinel errors.
+// Secret paths are redacted in returned errors to prevent information leakage.
+func classifyAWSError(err error, secretPath string) error {
+ redacted := redactPath(secretPath)
+
+ var notFoundErr *smtypes.ResourceNotFoundException
+ if errors.As(err, ¬FoundErr) {
+ return fmt.Errorf("%w at %s", ErrM2MCredentialsNotFound, redacted)
+ }
+
+ var apiErr smithy.APIError
+ if errors.As(err, &apiErr) {
+ switch apiErr.ErrorCode() {
+ case "AccessDeniedException", "ExpiredTokenException":
+ return fmt.Errorf("%w: %w", ErrM2MVaultAccessDenied, err)
+ }
+ }
+
+ return fmt.Errorf("%w: %s: %w", ErrM2MRetrievalFailed, redacted, err)
+}
diff --git a/commons/secretsmanager/m2m_test.go b/commons/secretsmanager/m2m_test.go
new file mode 100644
index 00000000..41d68e6b
--- /dev/null
+++ b/commons/secretsmanager/m2m_test.go
@@ -0,0 +1,828 @@
+// Copyright Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+package secretsmanager
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sync"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/secretsmanager"
+ smtypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types"
+ smithy "github.com/aws/smithy-go"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// mockBinarySecretsManagerClient returns a nil SecretString to simulate binary secrets.
+type mockBinarySecretsManagerClient struct{}
+
+func (m *mockBinarySecretsManagerClient) GetSecretValue(
+ _ context.Context,
+ _ *secretsmanager.GetSecretValueInput,
+ _ ...func(*secretsmanager.Options),
+) (*secretsmanager.GetSecretValueOutput, error) {
+ return &secretsmanager.GetSecretValueOutput{
+ SecretBinary: []byte{0x01, 0x02, 0x03},
+ SecretString: nil,
+ }, nil
+}
+
+// mockSecretsManagerClient implements SecretsManagerClient for testing.
+type mockSecretsManagerClient struct {
+ secrets map[string]string
+ errors map[string]error
+}
+
+func (m *mockSecretsManagerClient) GetSecretValue(
+ ctx context.Context,
+ params *secretsmanager.GetSecretValueInput,
+ optFns ...func(*secretsmanager.Options),
+) (*secretsmanager.GetSecretValueOutput, error) {
+ if params.SecretId == nil {
+ return nil, errors.New("InvalidParameterException: secret ID is required")
+ }
+
+ secretPath := *params.SecretId
+
+ if err, ok := m.errors[secretPath]; ok {
+ return nil, err
+ }
+
+ if secret, ok := m.secrets[secretPath]; ok {
+ return &secretsmanager.GetSecretValueOutput{
+ SecretString: aws.String(secret),
+ }, nil
+ }
+
+ return nil, &smtypes.ResourceNotFoundException{
+ Message: aws.String("Secrets Manager can't find the specified secret. path=" + secretPath),
+ }
+}
+
+// ============================================================================
+// Test: BuildM2MSecretPath (path construction)
+// ============================================================================
+
+func TestBuildM2MSecretPath(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ env string
+ tenantOrgID string
+ applicationName string
+ targetService string
+ expectedPath string
+ }{
+ {
+ name: "standard path with all parameters",
+ env: "staging",
+ tenantOrgID: "org_01KHVKQQP6D2N4RDJK0ADEKQX1",
+ applicationName: "plugin-pix",
+ targetService: "ledger",
+ expectedPath: "tenants/staging/org_01KHVKQQP6D2N4RDJK0ADEKQX1/plugin-pix/m2m/ledger/credentials",
+ },
+ {
+ name: "production environment",
+ env: "production",
+ tenantOrgID: "org_02ABCDEF",
+ applicationName: "plugin-auth",
+ targetService: "midaz",
+ expectedPath: "tenants/production/org_02ABCDEF/plugin-auth/m2m/midaz/credentials",
+ },
+ {
+ name: "empty env for backward compatibility",
+ env: "",
+ tenantOrgID: "org_01ABC",
+ applicationName: "plugin-crm",
+ targetService: "ledger",
+ expectedPath: "tenants/org_01ABC/plugin-crm/m2m/ledger/credentials",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act
+ path := buildM2MSecretPath(tt.env, tt.tenantOrgID, tt.applicationName, tt.targetService)
+
+ // Assert
+ assert.Equal(t, tt.expectedPath, path)
+ })
+ }
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - valid JSON deserialization
+// ============================================================================
+
+func TestGetM2MCredentials_ValidJSON(t *testing.T) {
+ t.Parallel()
+
+ validCreds := M2MCredentials{
+ ClientID: "plg_01KHVKQQP6D2N4RDJK0ADEKQX1",
+ ClientSecret: "sec_super-secret-value",
+ }
+
+ credsJSON, err := json.Marshal(validCreds)
+ require.NoError(t, err, "test setup: marshalling valid credentials should not fail")
+
+ secretPath := "tenants/staging/org_01ABC/plugin-pix/m2m/ledger/credentials"
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{
+ secretPath: string(credsJSON),
+ },
+ errors: map[string]error{},
+ }
+
+ tests := []struct {
+ name string
+ env string
+ tenantOrgID string
+ applicationName string
+ targetService string
+ expectedClientID string
+ expectedSecret string
+ }{
+ {
+ name: "deserializes all fields correctly",
+ env: "staging",
+ tenantOrgID: "org_01ABC",
+ applicationName: "plugin-pix",
+ targetService: "ledger",
+ expectedClientID: "plg_01KHVKQQP6D2N4RDJK0ADEKQX1",
+ expectedSecret: "sec_super-secret-value",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act
+ creds, err := GetM2MCredentials(context.Background(), mock, tt.env, tt.tenantOrgID, tt.applicationName, tt.targetService)
+
+ // Assert
+ require.NoError(t, err)
+ require.NotNil(t, creds)
+ assert.Equal(t, tt.expectedClientID, creds.ClientID)
+ assert.Equal(t, tt.expectedSecret, creds.ClientSecret)
+ })
+ }
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - invalid JSON deserialization
+// ============================================================================
+
+func TestGetM2MCredentials_InvalidJSON(t *testing.T) {
+ t.Parallel()
+
+ secretPath := "tenants/staging/org_01ABC/plugin-pix/m2m/ledger/credentials"
+
+ tests := []struct {
+ name string
+ secretValue string
+ expectedErr error
+ }{
+ {
+ name: "malformed JSON",
+ secretValue: `{invalid-json`,
+ expectedErr: ErrM2MUnmarshalFailed,
+ },
+ {
+ name: "empty string",
+ secretValue: ``,
+ expectedErr: ErrM2MUnmarshalFailed,
+ },
+ {
+ name: "plain text instead of JSON",
+ secretValue: `not-json-at-all`,
+ expectedErr: ErrM2MUnmarshalFailed,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{
+ secretPath: tt.secretValue,
+ },
+ errors: map[string]error{},
+ }
+
+ // Act
+ creds, err := GetM2MCredentials(context.Background(), mock, "staging", "org_01ABC", "plugin-pix", "ledger")
+
+ // Assert
+ require.ErrorIs(t, err, tt.expectedErr)
+ assert.Nil(t, creds)
+ })
+ }
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - incomplete credentials (missing required fields)
+// ============================================================================
+
+func TestGetM2MCredentials_IncompleteCredentials(t *testing.T) {
+ t.Parallel()
+
+ secretPath := "tenants/staging/org_01ABC/plugin-pix/m2m/ledger/credentials"
+
+ tests := []struct {
+ name string
+ secretValue string
+ expectedErr error
+ }{
+ {
+ name: "empty JSON object - all fields missing",
+ secretValue: `{}`,
+ expectedErr: ErrM2MInvalidCredentials,
+ },
+ {
+ name: "only clientId present",
+ secretValue: `{"clientId":"id1"}`,
+ expectedErr: ErrM2MInvalidCredentials,
+ },
+ {
+ name: "only clientSecret missing",
+ secretValue: `{"clientId":"id1"}`,
+ expectedErr: ErrM2MInvalidCredentials,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{
+ secretPath: tt.secretValue,
+ },
+ errors: map[string]error{},
+ }
+
+ // Act
+ creds, err := GetM2MCredentials(context.Background(), mock, "staging", "org_01ABC", "plugin-pix", "ledger")
+
+ // Assert
+ require.ErrorIs(t, err, tt.expectedErr)
+ assert.Nil(t, creds)
+ })
+ }
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - secret not found error
+// ============================================================================
+
+func TestGetM2MCredentials_SecretNotFound(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ env string
+ tenantOrgID string
+ applicationName string
+ targetService string
+ expectedErr error
+ }{
+ {
+ name: "secret does not exist in vault",
+ env: "staging",
+ tenantOrgID: "org_nonexistent",
+ applicationName: "plugin-pix",
+ targetService: "ledger",
+ expectedErr: ErrM2MCredentialsNotFound,
+ },
+ {
+ name: "different tenant not provisioned",
+ env: "production",
+ tenantOrgID: "org_notprovisioned",
+ applicationName: "plugin-auth",
+ targetService: "midaz",
+ expectedErr: ErrM2MCredentialsNotFound,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{},
+ errors: map[string]error{},
+ }
+
+ // Act
+ creds, err := GetM2MCredentials(context.Background(), mock, tt.env, tt.tenantOrgID, tt.applicationName, tt.targetService)
+
+ // Assert
+ require.ErrorIs(t, err, tt.expectedErr)
+ assert.Nil(t, creds)
+ })
+ }
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - AWS credentials/access missing
+// ============================================================================
+
+func TestGetM2MCredentials_AWSCredentialsMissing(t *testing.T) {
+ t.Parallel()
+
+ secretPath := "tenants/staging/org_01ABC/plugin-pix/m2m/ledger/credentials"
+
+ tests := []struct {
+ name string
+ awsError error
+ expectedErr error
+ }{
+ {
+ name: "access denied - missing IAM permissions",
+ awsError: &smithy.GenericAPIError{
+ Code: "AccessDeniedException",
+ Message: "User is not authorized to access this resource",
+ },
+ expectedErr: ErrM2MVaultAccessDenied,
+ },
+ {
+ name: "credentials expired",
+ awsError: &smithy.GenericAPIError{
+ Code: "ExpiredTokenException",
+ Message: "The security token included in the request is expired",
+ },
+ expectedErr: ErrM2MVaultAccessDenied,
+ },
+ {
+ name: "generic AWS error",
+ awsError: errors.New("InternalServiceError: service unavailable"),
+ expectedErr: ErrM2MRetrievalFailed,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{},
+ errors: map[string]error{
+ secretPath: tt.awsError,
+ },
+ }
+
+ // Act
+ creds, err := GetM2MCredentials(context.Background(), mock, "staging", "org_01ABC", "plugin-pix", "ledger")
+
+ // Assert
+ require.ErrorIs(t, err, tt.expectedErr)
+ assert.Nil(t, creds)
+ })
+ }
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - input validation
+// ============================================================================
+
+func TestGetM2MCredentials_InputValidation(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{},
+ errors: map[string]error{},
+ }
+
+ tests := []struct {
+ name string
+ env string
+ tenantOrgID string
+ applicationName string
+ targetService string
+ expectedErr error
+ }{
+ {
+ name: "empty tenantOrgID",
+ env: "staging",
+ tenantOrgID: "",
+ applicationName: "plugin-pix",
+ targetService: "ledger",
+ expectedErr: ErrM2MInvalidInput,
+ },
+ {
+ name: "empty applicationName",
+ env: "staging",
+ tenantOrgID: "org_01ABC",
+ applicationName: "",
+ targetService: "ledger",
+ expectedErr: ErrM2MInvalidInput,
+ },
+ {
+ name: "empty targetService",
+ env: "staging",
+ tenantOrgID: "org_01ABC",
+ applicationName: "plugin-pix",
+ targetService: "",
+ expectedErr: ErrM2MInvalidInput,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Act
+ creds, err := GetM2MCredentials(context.Background(), mock, tt.env, tt.tenantOrgID, tt.applicationName, tt.targetService)
+
+ // Assert
+ require.ErrorIs(t, err, tt.expectedErr)
+ assert.Nil(t, creds)
+ })
+ }
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - nil client
+// ============================================================================
+
+func TestGetM2MCredentials_NilClient(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil client returns descriptive error", func(t *testing.T) {
+ t.Parallel()
+
+ // Act
+ creds, err := GetM2MCredentials(context.Background(), nil, "staging", "org_01ABC", "plugin-pix", "ledger")
+
+ // Assert
+ require.ErrorIs(t, err, ErrM2MInvalidInput)
+ assert.Nil(t, creds)
+ })
+}
+
+// ============================================================================
+// Test: GetM2MCredentials - concurrent safety
+// ============================================================================
+
+func TestGetM2MCredentials_ConcurrentSafety(t *testing.T) {
+ t.Parallel()
+
+ validCreds := M2MCredentials{
+ ClientID: "plg_concurrent_test",
+ ClientSecret: "sec_concurrent_secret",
+ }
+
+ credsJSON, err := json.Marshal(validCreds)
+ require.NoError(t, err, "test setup: marshalling valid credentials should not fail")
+
+ secretPath := "tenants/staging/org_concurrent/plugin-pix/m2m/ledger/credentials"
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{
+ secretPath: string(credsJSON),
+ },
+ errors: map[string]error{},
+ }
+
+ const goroutineCount = 50
+
+ t.Run("concurrent calls do not race or panic", func(t *testing.T) {
+ t.Parallel()
+
+ var wg sync.WaitGroup
+ wg.Add(goroutineCount)
+
+ results := make([]*M2MCredentials, goroutineCount)
+ errs := make([]error, goroutineCount)
+
+ for i := range goroutineCount {
+ go func(idx int) {
+ defer wg.Done()
+ results[idx], errs[idx] = GetM2MCredentials(
+ context.Background(),
+ mock,
+ "staging",
+ "org_concurrent",
+ "plugin-pix",
+ "ledger",
+ )
+ }(i)
+ }
+
+ wg.Wait()
+
+ // Assert: all goroutines should succeed with identical results
+ for i := range goroutineCount {
+ require.NoError(t, errs[i], "goroutine %d should not error", i)
+ require.NotNil(t, results[i], "goroutine %d should return credentials", i)
+ assert.Equal(t, "plg_concurrent_test", results[i].ClientID, "goroutine %d should have correct clientId", i)
+ assert.Equal(t, "sec_concurrent_secret", results[i].ClientSecret, "goroutine %d should have correct clientSecret", i)
+ }
+ })
+}
+
+// ============================================================================
+// Test: M2MCredentials struct JSON tags
+// ============================================================================
+
+func TestM2MCredentials_JSONTags(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ json string
+ expected M2MCredentials
+ }{
+ {
+ name: "standard camelCase JSON fields",
+ json: `{"clientId":"id1","clientSecret":"sec1"}`,
+ expected: M2MCredentials{
+ ClientID: "id1",
+ ClientSecret: "sec1",
+ },
+ },
+ {
+ name: "extra fields are ignored",
+ json: `{"clientId":"id2","clientSecret":"sec2","tokenUrl":"https://example.com/token","tenantId":"t1","targetService":"ledger"}`,
+ expected: M2MCredentials{
+ ClientID: "id2",
+ ClientSecret: "sec2",
+ },
+ },
+ {
+ name: "missing fields default to empty strings",
+ json: `{}`,
+ expected: M2MCredentials{},
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ var creds M2MCredentials
+ err := json.Unmarshal([]byte(tt.json), &creds)
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, creds)
+ })
+ }
+}
+
+func TestM2MCredentials_StringRedactsSecret(t *testing.T) {
+ t.Parallel()
+
+ creds := M2MCredentials{
+ ClientID: "client-visible-id",
+ ClientSecret: "sec_super-secret-value",
+ }
+
+ formatted := fmt.Sprintf("%v", creds)
+ goFormatted := fmt.Sprintf("%#v", creds)
+
+ assert.Contains(t, formatted, "ClientSecret:REDACTED")
+ assert.Contains(t, goFormatted, "ClientSecret:REDACTED")
+ assert.NotContains(t, formatted, creds.ClientSecret)
+ assert.NotContains(t, goFormatted, creds.ClientSecret)
+ assert.Contains(t, formatted, creds.ClientID)
+ assert.Contains(t, goFormatted, creds.ClientID)
+}
+
+// ============================================================================
+// Test: Path traversal prevention
+// ============================================================================
+
+func TestGetM2MCredentials_PathTraversal(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{},
+ errors: map[string]error{},
+ }
+
+ tests := []struct {
+ name string
+ env string
+ tenantOrgID string
+ applicationName string
+ targetService string
+ expectedErr error
+ }{
+ {
+ name: "tenantOrgID with slash",
+ env: "staging",
+ tenantOrgID: "org/../admin",
+ applicationName: "plugin-pix",
+ targetService: "ledger",
+ expectedErr: ErrM2MInvalidPathSegment,
+ },
+ {
+ name: "applicationName with backslash",
+ env: "staging",
+ tenantOrgID: "org_01ABC",
+ applicationName: "plugin\\pix",
+ targetService: "ledger",
+ expectedErr: ErrM2MInvalidPathSegment,
+ },
+ {
+ name: "targetService with dot-dot",
+ env: "staging",
+ tenantOrgID: "org_01ABC",
+ applicationName: "plugin-pix",
+ targetService: "..secret",
+ expectedErr: ErrM2MInvalidPathSegment,
+ },
+ {
+ name: "env with slash",
+ env: "staging/../../admin",
+ tenantOrgID: "org_01ABC",
+ applicationName: "plugin-pix",
+ targetService: "ledger",
+ expectedErr: ErrM2MInvalidPathSegment,
+ },
+ {
+ name: "whitespace-only tenantOrgID",
+ env: "staging",
+ tenantOrgID: " ",
+ applicationName: "plugin-pix",
+ targetService: "ledger",
+ expectedErr: ErrM2MInvalidInput,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ creds, err := GetM2MCredentials(context.Background(), mock, tt.env, tt.tenantOrgID, tt.applicationName, tt.targetService)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, tt.expectedErr)
+ assert.Nil(t, creds)
+ })
+ }
+}
+
+// ============================================================================
+// Test: Binary secret detection
+// ============================================================================
+
+func TestGetM2MCredentials_BinarySecret(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockBinarySecretsManagerClient{}
+
+ creds, err := GetM2MCredentials(context.Background(), mock, "staging", "org_01ABC", "plugin-pix", "ledger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrM2MBinarySecretNotSupported)
+ assert.Nil(t, creds)
+}
+
+// ============================================================================
+// Test: Error path redaction
+// ============================================================================
+
+func TestGetM2MCredentials_ErrorsDoNotLeakFullPath(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{},
+ errors: map[string]error{},
+ }
+
+ // Secret not found → error should contain redacted path, not full path
+ _, err := GetM2MCredentials(context.Background(), mock, "staging", "org_01ABC", "plugin-pix", "ledger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrM2MCredentialsNotFound)
+ // Full path should not appear in the error
+ assert.NotContains(t, err.Error(), "tenants/staging/org_01ABC/plugin-pix/m2m/ledger/credentials")
+ // Redacted path should contain the last segment
+ assert.Contains(t, err.Error(), "credentials")
+}
+
+// ============================================================================
+// Test: Typed-nil client detection
+// ============================================================================
+
+func TestGetM2MCredentials_TypedNilClient(t *testing.T) {
+ t.Parallel()
+
+ // A typed-nil interface value should be caught.
+ var typedNil *mockSecretsManagerClient
+
+ creds, err := GetM2MCredentials(context.Background(), typedNil, "staging", "org_01ABC", "plugin-pix", "ledger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, ErrM2MInvalidInput)
+ assert.Nil(t, creds)
+}
+
+// ============================================================================
+// Test: Whitespace trimming in segments
+// ============================================================================
+
+func TestGetM2MCredentials_WhitespaceTrimming(t *testing.T) {
+ t.Parallel()
+
+ validCreds := M2MCredentials{
+ ClientID: "plg_trimmed",
+ ClientSecret: "sec_trimmed",
+ }
+
+ credsJSON, err := json.Marshal(validCreds)
+ require.NoError(t, err)
+
+ // The trimmed path should be used
+ secretPath := "tenants/staging/org_01ABC/plugin-pix/m2m/ledger/credentials"
+
+ mock := &mockSecretsManagerClient{
+ secrets: map[string]string{
+ secretPath: string(credsJSON),
+ },
+ errors: map[string]error{},
+ }
+
+ // Segments with leading/trailing whitespace should be trimmed
+ creds, err := GetM2MCredentials(context.Background(), mock, " staging ", " org_01ABC ", " plugin-pix ", " ledger ")
+ require.NoError(t, err)
+ require.NotNil(t, creds)
+ assert.Equal(t, "plg_trimmed", creds.ClientID)
+}
+
+// ============================================================================
+// Test: redactPath helper
+// ============================================================================
+
+func TestRedactPath(t *testing.T) {
+ t.Parallel()
+
+ result := redactPath("tenants/staging/org_01ABC/plugin-pix/m2m/ledger/credentials")
+
+ // Should contain the last segment
+ assert.Contains(t, result, "credentials")
+ // Should NOT contain the full path
+ assert.NotContains(t, result, "tenants/staging")
+ // Should contain a hash marker
+ assert.Contains(t, result, "[")
+ assert.Contains(t, result, "]")
+}
+
+// ============================================================================
+// Test: validatePathSegment helper
+// ============================================================================
+
+func TestValidatePathSegment(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ value string
+ expectErr bool
+ expectedErr error
+ expected string
+ }{
+ {name: "valid segment", value: "org_01ABC", expectErr: false, expected: "org_01ABC"},
+ {name: "trimmed segment", value: " org_01ABC ", expectErr: false, expected: "org_01ABC"},
+ {name: "empty", value: "", expectErr: true, expectedErr: ErrM2MInvalidInput},
+ {name: "whitespace only", value: " ", expectErr: true, expectedErr: ErrM2MInvalidInput},
+ {name: "contains slash", value: "org/admin", expectErr: true, expectedErr: ErrM2MInvalidPathSegment},
+ {name: "contains backslash", value: "org\\admin", expectErr: true, expectedErr: ErrM2MInvalidPathSegment},
+ {name: "contains dot-dot", value: "..admin", expectErr: true, expectedErr: ErrM2MInvalidPathSegment},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := validatePathSegment("test", tt.value)
+ if tt.expectErr {
+ require.Error(t, err)
+ assert.ErrorIs(t, err, tt.expectedErr)
+ } else {
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
diff --git a/commons/security/doc.go b/commons/security/doc.go
new file mode 100644
index 00000000..34ce0698
--- /dev/null
+++ b/commons/security/doc.go
@@ -0,0 +1,5 @@
+// Package security provides helpers for handling sensitive fields and data safety.
+//
+// It is primarily used by logging and telemetry packages to detect and obfuscate
+// secrets before data leaves process boundaries.
+package security
diff --git a/commons/security/sensitive_fields.go b/commons/security/sensitive_fields.go
index 84a8531e..46210978 100644
--- a/commons/security/sensitive_fields.go
+++ b/commons/security/sensitive_fields.go
@@ -1,12 +1,9 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package security
import (
"maps"
"regexp"
+ "slices"
"strings"
"sync"
"unicode"
@@ -30,12 +27,67 @@ var defaultSensitiveFields = []string{
"accesstoken",
"refresh_token",
"refreshtoken",
+ "bearer",
+ "jwt",
+ "session_id",
+ "sessionid",
+ "cookie",
"private_key",
"privatekey",
"clientid",
"client_id",
"clientsecret",
"client_secret",
+ "passwd",
+ "passphrase",
+ "card_number",
+ "cardnumber",
+ "cvv",
+ "cvc",
+ "ssn",
+ "social_security",
+ "pin",
+ "otp",
+ "account_number",
+ "accountnumber",
+ "routing_number",
+ "routingnumber",
+ "iban",
+ "swift",
+ "swift_code",
+ "bic",
+ "pan",
+ "expiry",
+ "expiry_date",
+ "expiration_date",
+ "card_expiry",
+ "date_of_birth",
+ "dob",
+ "tax_id",
+ "taxid",
+ "tin",
+ "national_id",
+ "sort_code",
+ "bsb",
+ "security_answer",
+ "security_question",
+ "mother_maiden_name",
+ "mfa_code",
+ "totp",
+ "biometric",
+ "fingerprint",
+ "certificate",
+ "connection_string",
+ "database_url",
+ // PII fields
+ "email",
+ "phone",
+ "phone_number",
+ "address",
+ "street",
+ "city",
+ "zip",
+ "postal_code",
}
var (
@@ -43,15 +95,18 @@ var (
sensitiveFieldsMap map[string]bool
)
+// DefaultSensitiveFields returns a copy of the default sensitive field names.
+// The returned slice is a clone — callers cannot mutate shared state.
func DefaultSensitiveFields() []string {
- return defaultSensitiveFields
+ clone := make([]string, len(defaultSensitiveFields))
+ copy(clone, defaultSensitiveFields)
+
+ return clone
}
-// DefaultSensitiveFieldsMap provides a map version of DefaultSensitiveFields
-// for lookup operations. All field names are lowercase for
-// case-insensitive matching. The underlying cache is initialized only once;
-// each call returns a shallow clone so callers cannot mutate shared state.
-func DefaultSensitiveFieldsMap() map[string]bool {
+// ensureSensitiveFieldsMap returns the internal map directly (no clone).
+// For internal use only where we just need read access.
+func ensureSensitiveFieldsMap() map[string]bool {
sensitiveFieldsMapOnce.Do(func() {
sensitiveFieldsMap = make(map[string]bool, len(defaultSensitiveFields))
for _, field := range defaultSensitiveFields {
@@ -59,8 +114,17 @@ func DefaultSensitiveFieldsMap() map[string]bool {
}
})
- clone := make(map[string]bool, len(sensitiveFieldsMap))
- maps.Copy(clone, sensitiveFieldsMap)
+ return sensitiveFieldsMap
+}
+
+// DefaultSensitiveFieldsMap provides a map version of DefaultSensitiveFields
+// for lookup operations. All field names are lowercase for
+// case-insensitive matching. The underlying cache is initialized only once;
+// each call returns a shallow clone so callers cannot mutate shared state.
+func DefaultSensitiveFieldsMap() map[string]bool {
+ m := ensureSensitiveFieldsMap()
+ clone := make(map[string]bool, len(m))
+ maps.Copy(clone, m)
return clone
}
@@ -70,6 +134,19 @@ func DefaultSensitiveFieldsMap() map[string]bool {
var shortSensitiveTokens = map[string]bool{
"key": true,
"auth": true,
+ "pin": true,
+ "otp": true,
+ "cvv": true,
+ "cvc": true,
+ "ssn": true,
+ "pan": true,
+ "bic": true,
+ "bsb": true,
+ "dob": true,
+ "tin": true,
+ "jwt": true,
+ "zip": true,
+ "city": true,
}
// tokenSplitRegex splits field names by non-alphanumeric characters.
@@ -112,16 +189,17 @@ func normalizeFieldName(fieldName string) string {
// Short tokens (like "key", "auth") use exact token matching to avoid false
// positives, while longer patterns use word-boundary matching.
func IsSensitiveField(fieldName string) bool {
+ m := ensureSensitiveFieldsMap()
lowerField := strings.ToLower(fieldName)
// Check exact match with lowercase
- if DefaultSensitiveFieldsMap()[lowerField] {
+ if m[lowerField] {
return true
}
// Also check with camelCase normalization (e.g., "sessionToken" -> "session_token")
normalized := normalizeFieldName(fieldName)
- if normalized != lowerField && DefaultSensitiveFieldsMap()[normalized] {
+ if normalized != lowerField && m[normalized] {
return true
}
@@ -130,10 +208,8 @@ func IsSensitiveField(fieldName string) bool {
for _, sensitive := range defaultSensitiveFields {
if shortSensitiveTokens[sensitive] {
- for _, token := range tokens {
- if token == sensitive {
- return true
- }
+ if slices.Contains(tokens, sensitive) {
+ return true
}
} else {
if matchesWordBoundary(normalized, sensitive) {
@@ -152,6 +228,10 @@ func IsSensitiveField(fieldName string) bool {
// matchesWordBoundary checks if the pattern appears in the field with word boundaries.
// A word boundary is either the start/end of string or a non-alphanumeric character.
func matchesWordBoundary(field, pattern string) bool {
+ if len(pattern) == 0 {
+ return false
+ }
+
idx := strings.Index(field, pattern)
if idx == -1 {
return false
diff --git a/commons/security/sensitive_fields_test.go b/commons/security/sensitive_fields_test.go
index d604198f..33d5037b 100644
--- a/commons/security/sensitive_fields_test.go
+++ b/commons/security/sensitive_fields_test.go
@@ -1,6 +1,4 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package security
@@ -14,6 +12,8 @@ import (
)
func TestDefaultSensitiveFields(t *testing.T) {
+ t.Parallel()
+
// Test that the slice is not empty
assert.NotEmpty(t, DefaultSensitiveFields(), "DefaultSensitiveFields should not be empty")
@@ -37,6 +37,8 @@ func TestDefaultSensitiveFields(t *testing.T) {
}
func TestDefaultSensitiveFieldsMap(t *testing.T) {
+ t.Parallel()
+
// Test that the map is not empty
assert.NotEmpty(t, DefaultSensitiveFieldsMap(), "DefaultSensitiveFieldsMap should not be empty")
@@ -57,6 +59,8 @@ func TestDefaultSensitiveFieldsMap(t *testing.T) {
}
func TestIsSensitiveField(t *testing.T) {
+ t.Parallel()
+
tests := []struct {
name string
fieldName string
@@ -129,9 +133,9 @@ func TestIsSensitiveField(t *testing.T) {
},
{
- name: "non-sensitive field - email",
+ name: "sensitive field - email (PII)",
fieldName: "email",
- expected: false,
+ expected: true,
},
{
name: "non-sensitive field - id",
@@ -166,7 +170,9 @@ func TestIsSensitiveField(t *testing.T) {
}
for _, tt := range tests {
+ tt := tt
t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
result := IsSensitiveField(tt.fieldName)
assert.Equal(t, tt.expected, result,
"IsSensitiveField(%s) should return %v", tt.fieldName, tt.expected)
@@ -175,6 +181,8 @@ func TestIsSensitiveField(t *testing.T) {
}
func TestIsSensitiveFieldCaseInsensitive(t *testing.T) {
+ t.Parallel()
+
// Test that case-insensitive matching works for all default fields
for _, field := range DefaultSensitiveFields() {
// Test lowercase
@@ -194,6 +202,8 @@ func TestIsSensitiveFieldCaseInsensitive(t *testing.T) {
}
func TestConsistencyBetweenSliceAndMap(t *testing.T) {
+ t.Parallel()
+
// Ensure that the slice and map are consistent
// Every field in the slice should be in the map
for _, field := range DefaultSensitiveFields() {
@@ -211,17 +221,334 @@ func TestConsistencyBetweenSliceAndMap(t *testing.T) {
}
func TestDefaultFieldsAreExpected(t *testing.T) {
- // Test that we have the expected number of fields (this helps catch accidental additions/removals)
- expectedCount := 23
- actualCount := len(DefaultSensitiveFields())
- assert.Equal(t, expectedCount, actualCount,
- "Expected %d default sensitive fields, but found %d. If this is intentional, update the test.",
- expectedCount, actualCount)
+ t.Parallel()
+
+ fields := DefaultSensitiveFields()
+
+ // Assert that required categories of sensitive fields are present,
+ // rather than asserting an exact count (which is brittle as the catalog grows).
+ requiredFields := []string{
+ // Auth credentials
+ "password", "token", "secret", "api_key", "bearer",
+ // Financial
+ "card_number", "cvv", "account_number", "iban",
+ // PII
+ "ssn", "date_of_birth", "email", "phone", "address",
+ // Infrastructure secrets
+ "connection_string", "database_url", "private_key",
+ }
+
+ for _, required := range requiredFields {
+ assert.Contains(t, fields, required,
+ "DefaultSensitiveFields must contain %q", required)
+ }
+
+ // Sanity-check minimum size — the catalog should never shrink below baseline.
+ assert.GreaterOrEqual(t, len(fields), len(requiredFields),
+ "DefaultSensitiveFields must have at least %d entries", len(requiredFields))
}
func TestNoEmptyFields(t *testing.T) {
+ t.Parallel()
+
// Ensure no empty strings in the default fields
for i, field := range DefaultSensitiveFields() {
assert.NotEmpty(t, field, "Field at index %d should not be empty", i)
}
}
+
+func TestDefaultSensitiveFields_ReturnsClone(t *testing.T) {
+ t.Parallel()
+
+ original := DefaultSensitiveFields()
+ original[0] = "MUTATED"
+
+ // The mutation should not affect subsequent calls
+ fresh := DefaultSensitiveFields()
+ assert.NotEqual(t, "MUTATED", fresh[0], "DefaultSensitiveFields must return a clone")
+}
+
+func TestIsSensitiveField_FinancialFields(t *testing.T) {
+ t.Parallel()
+
+ financialFields := []struct {
+ name string
+ expected bool
+ }{
+ {"card_number", true},
+ {"cardnumber", true},
+ {"cvv", true},
+ {"cvc", true},
+ {"ssn", true},
+ {"social_security", true},
+ {"pin", true},
+ {"otp", true},
+ {"account_number", true},
+ {"accountnumber", true},
+ {"routing_number", true},
+ {"routingnumber", true},
+ {"iban", true},
+ {"swift", true},
+ {"swift_code", true},
+ {"bic", true},
+ {"pan", true},
+ {"expiry", true},
+ {"expiry_date", true},
+ {"expiration_date", true},
+ {"card_expiry", true},
+ {"date_of_birth", true},
+ {"dob", true},
+ {"tax_id", true},
+ {"taxid", true},
+ {"tin", true},
+ {"national_id", true},
+ {"sort_code", true},
+ {"bsb", true},
+ {"security_answer", true},
+ {"security_question", true},
+ {"mother_maiden_name", true},
+ {"mfa_code", true},
+ {"totp", true},
+ {"biometric", true},
+ {"fingerprint", true},
+ // False positives for short tokens
+ {"spinning", false},
+ {"opinion", false},
+ {"pineapple", false},
+ {"cotton", false},
+ {"panther", false},
+ }
+
+ for _, tt := range financialFields {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result := IsSensitiveField(tt.name)
+ assert.Equal(t, tt.expected, result,
+ "IsSensitiveField(%q) = %v, want %v", tt.name, result, tt.expected)
+ })
+ }
+}
+
+func TestShortSensitiveTokens_ExactMatch(t *testing.T) {
+ t.Parallel()
+
+ // These short tokens should match exactly but not as substrings
+ tests := []struct {
+ field string
+ expected bool
+ }{
+ {"pin", true},
+ {"otp", true},
+ {"cvv", true},
+ {"cvc", true},
+ {"ssn", true},
+ {"pan", true},
+ {"bic", true},
+ {"bsb", true},
+ {"dob", true},
+ {"tin", true},
+ // CamelCase variants
+ {"userPin", true},
+ {"otpCode", true},
+ {"userSsn", true},
+ // Should NOT match as substrings in larger words
+ {"spinning", false},
+ {"option", false},
+ {"panther", false},
+ {"basic", false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.field, func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, tt.expected, IsSensitiveField(tt.field),
+ "IsSensitiveField(%q)", tt.field)
+ })
+ }
+}
+
+func TestNormalizeFieldName(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"sessionToken", "session_token"},
+ {"APIKey", "api_key"},
+ {"myPrivateKey", "my_private_key"},
+ {"DateOfBirth", "date_of_birth"},
+ {"simple", "simple"},
+ {"already_snake", "already_snake"},
+ {"HTTPSProxy", "https_proxy"},
+ {"userID", "user_id"},
+ {"", ""},
+ {"X", "x"},
+ {"ABC", "abc"},
+ {"getHTTPResponse", "get_http_response"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.input, func(t *testing.T) {
+ t.Parallel()
+ result := normalizeFieldName(tt.input)
+ assert.Equal(t, tt.expected, result, "normalizeFieldName(%q)", tt.input)
+ })
+ }
+}
+
+func TestIsSensitiveField_WordBoundaryPositivePath(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ field string
+ expected bool
+ }{
+ // Word-boundary matches (pattern found with non-alphanumeric boundaries)
+ {"my_secret_value", true}, // "secret" with underscore boundaries
+ {"x-authorization-header", true}, // "authorization" with hyphen boundaries
+ {"user_password_hash", true}, // "password" with underscore boundaries
+ {"db_credential_store", true}, // "credential" with underscore boundaries
+ {"old_token_backup", true}, // "token" with underscore boundaries
+ // CamelCase that normalizes to word-boundary matchable form
+ {"SessionToken", true}, // -> "session_token" -> "token" boundary match
+ {"ExpiryDate", true}, // -> "expiry_date" -> exact map match via normalization
+ {"AccountNumber", true}, // -> "account_number" -> exact map match via normalization
+ {"CardNumber", true}, // -> "card_number" -> exact map match via normalization
+ {"PrivateKeyData", true}, // -> "private_key_data" -> "private_key" boundary match
+ // Should NOT match
+ {"mysecretvalue", false}, // no word boundaries around "secret"
+ {"deauthorize", false}, // "authorization" not present
+ {"repass", false}, // "password" not present
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.field, func(t *testing.T) {
+ t.Parallel()
+ result := IsSensitiveField(tt.field)
+ assert.Equal(t, tt.expected, result, "IsSensitiveField(%q)", tt.field)
+ })
+ }
+}
+
+func TestDefaultSensitiveFieldsMap_ReturnsClone(t *testing.T) {
+ t.Parallel()
+
+ original := DefaultSensitiveFieldsMap()
+ // Mutate the returned map
+ original["password"] = false
+ original["INJECTED"] = true
+
+ // Fresh call should be unaffected
+ fresh := DefaultSensitiveFieldsMap()
+ assert.True(t, fresh["password"], "Map mutation must not affect shared state")
+ assert.False(t, fresh["INJECTED"], "Map mutation must not inject into shared state")
+}
+
+func TestIsSensitiveField_ConcurrentAccess(t *testing.T) {
+ t.Parallel()
+
+ const goroutines = 100
+
+ type result struct {
+ password bool
+ sessionToken bool
+ secretValue bool
+ userPin bool
+ harmless bool
+ }
+
+ results := make(chan result, goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ // Exercise all code paths concurrently and collect results
+ r := result{
+ password: IsSensitiveField("password"),
+ sessionToken: IsSensitiveField("SessionToken"),
+ secretValue: IsSensitiveField("my_secret_value"),
+ userPin: IsSensitiveField("userPin"),
+ harmless: IsSensitiveField("harmless"),
+ }
+ _ = DefaultSensitiveFields()
+ _ = DefaultSensitiveFieldsMap()
+ results <- r
+ }()
+ }
+
+ for i := 0; i < goroutines; i++ {
+ r := <-results
+ assert.True(t, r.password, "concurrent: password should be sensitive")
+ assert.True(t, r.sessionToken, "concurrent: SessionToken should be sensitive")
+ assert.True(t, r.secretValue, "concurrent: my_secret_value should be sensitive")
+ assert.True(t, r.userPin, "concurrent: userPin should be sensitive")
+ assert.False(t, r.harmless, "concurrent: harmless should not be sensitive")
+ }
+}
+
+func TestMatchesWordBoundary_EmptyPattern(t *testing.T) {
+ t.Parallel()
+
+ // Empty pattern must return false, not loop forever
+ assert.False(t, matchesWordBoundary("anything", ""), "Empty pattern must return false")
+ assert.False(t, matchesWordBoundary("", ""), "Both empty must return false")
+}
+
+func TestIsSensitiveField_PIIFields(t *testing.T) {
+ t.Parallel()
+
+ piiFields := []struct {
+ name string
+ expected bool
+ }{
+ {"email", true},
+ {"phone", true},
+ {"phone_number", true},
+ {"address", true},
+ {"street", true},
+ {"city", true},
+ {"zip", true},
+ {"postal_code", true},
+ // CamelCase variants
+ {"EmailAddress", true},
+ {"PhoneNumber", true},
+ {"PostalCode", true},
+ // False positives for short tokens
+ {"unzip", false},
+ {"capacity", false},
+ {"felicity", false},
+ }
+
+ for _, tt := range piiFields {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ result := IsSensitiveField(tt.name)
+ assert.Equal(t, tt.expected, result,
+ "IsSensitiveField(%q) = %v, want %v", tt.name, result, tt.expected)
+ })
+ }
+}
+
+func TestIsSensitiveField_NewV2Fields(t *testing.T) {
+ t.Parallel()
+
+ newFields := []string{
+ "passwd", "passphrase", "bearer", "jwt",
+ "session_id", "sessionid", "cookie",
+ "certificate", "connection_string", "database_url",
+ }
+
+ for _, field := range newFields {
+ field := field
+ t.Run(field, func(t *testing.T) {
+ t.Parallel()
+ assert.True(t, IsSensitiveField(field),
+ "IsSensitiveField(%q) should return true for v2 field", field)
+ })
+ }
+}
diff --git a/commons/server/doc.go b/commons/server/doc.go
new file mode 100644
index 00000000..6afe4c0f
--- /dev/null
+++ b/commons/server/doc.go
@@ -0,0 +1,5 @@
+// Package server provides server lifecycle and graceful shutdown helpers.
+//
+// Use this package to coordinate signal handling, shutdown deadlines, and ordered
+// resource cleanup for HTTP/gRPC service processes.
+package server
diff --git a/commons/server/grpc_test.go b/commons/server/grpc_test.go
deleted file mode 100644
index f2ed30fe..00000000
--- a/commons/server/grpc_test.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
-package server_test
-
-import (
- "testing"
-
- "github.com/LerianStudio/lib-commons/v2/commons/server"
- "github.com/stretchr/testify/assert"
- "google.golang.org/grpc"
-)
-
-func TestGracefulShutdownWithGRPCServer(t *testing.T) {
- // Create a new gRPC server for testing
- grpcServer := grpc.NewServer()
-
- // Create a graceful shutdown handler with the gRPC server
- gs := server.NewGracefulShutdown(nil, grpcServer, nil, nil, nil)
-
- // Assert that the graceful shutdown handler was created successfully
- assert.NotNil(t, gs, "NewGracefulShutdown should return a non-nil instance with gRPC server")
-
- // Test that we can create the shutdown handler without panicking
- // We don't test the actual signal handling as that would require OS signals
- assert.NotPanics(t, func() {
- // Just ensure the shutdown handler can be created and doesn't panic
- _ = server.NewGracefulShutdown(nil, grpcServer, nil, nil, nil)
- }, "Creating GracefulShutdown with gRPC server should not panic")
-}
-
-func TestServerManagerWithGRPCServer(t *testing.T) {
- grpcServer := grpc.NewServer()
-
- sm := server.NewServerManager(nil, nil, nil).
- WithGRPCServer(grpcServer, ":50051")
-
- assert.NotNil(t, sm, "ServerManager with gRPC server should not be nil")
-}
diff --git a/commons/server/shutdown.go b/commons/server/shutdown.go
index 406b3d39..a9601499 100644
--- a/commons/server/shutdown.go
+++ b/commons/server/shutdown.go
@@ -1,10 +1,7 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package server
import (
+ "context"
"errors"
"fmt"
"net"
@@ -12,10 +9,12 @@ import (
"os/signal"
"sync"
"syscall"
+ "time"
- "github.com/LerianStudio/lib-commons/v2/commons/license"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/license"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
"github.com/gofiber/fiber/v2"
"google.golang.org/grpc"
)
@@ -26,34 +25,70 @@ var ErrNoServersConfigured = errors.New("no servers configured: use WithHTTPServ
// ServerManager handles the graceful shutdown of multiple server types.
// It can manage HTTP servers, gRPC servers, or both simultaneously.
type ServerManager struct {
- httpServer *fiber.App
- grpcServer *grpc.Server
- licenseClient *license.ManagerShutdown
- telemetry *opentelemetry.Telemetry
- logger log.Logger
+ httpServer *fiber.App
+ grpcServer *grpc.Server
+ licenseClient *license.ManagerShutdown
+ telemetry *opentelemetry.Telemetry
+ logger log.Logger
httpAddress string
grpcAddress string
serversStarted chan struct{}
serversStartedOnce sync.Once
shutdownChan <-chan struct{}
+ shutdownOnce sync.Once
+ shutdownTimeout time.Duration
+ startupErrors chan error
+ shutdownHooks []func(context.Context) error
+}
+
+// ensureRuntimeDefaults initializes zero-value fields so exported lifecycle
+// methods remain nil-safe even when ServerManager is manually instantiated.
+func (sm *ServerManager) ensureRuntimeDefaults() {
+ if sm == nil {
+ return
+ }
+
+ if sm.logger == nil {
+ sm.logger = log.NewNop()
+ }
+
+ if sm.serversStarted == nil {
+ sm.serversStarted = make(chan struct{})
+ }
+
+ if sm.startupErrors == nil {
+ sm.startupErrors = make(chan error, 2)
+ }
}
// NewServerManager creates a new instance of ServerManager.
+// If logger is nil, a no-op logger is used to ensure nil-safe operation
+// throughout the server lifecycle.
func NewServerManager(
licenseClient *license.ManagerShutdown,
telemetry *opentelemetry.Telemetry,
logger log.Logger,
) *ServerManager {
+ if logger == nil {
+ logger = log.NewNop()
+ }
+
return &ServerManager{
- licenseClient: licenseClient,
- telemetry: telemetry,
- logger: logger,
- serversStarted: make(chan struct{}),
+ licenseClient: licenseClient,
+ telemetry: telemetry,
+ logger: logger,
+ serversStarted: make(chan struct{}),
+ shutdownTimeout: 30 * time.Second,
+ startupErrors: make(chan error, 2),
}
}
// WithHTTPServer configures the HTTP server for the ServerManager.
func (sm *ServerManager) WithHTTPServer(app *fiber.App, address string) *ServerManager {
+ if sm == nil {
+ return nil
+ }
+
sm.httpServer = app
sm.httpAddress = address
@@ -62,6 +97,10 @@ func (sm *ServerManager) WithHTTPServer(app *fiber.App, address string) *ServerM
// WithGRPCServer configures the gRPC server for the ServerManager.
func (sm *ServerManager) WithGRPCServer(server *grpc.Server, address string) *ServerManager {
+ if sm == nil {
+ return nil
+ }
+
sm.grpcServer = server
sm.grpcAddress = address
@@ -71,15 +110,54 @@ func (sm *ServerManager) WithGRPCServer(server *grpc.Server, address string) *Se
// WithShutdownChannel configures a custom shutdown channel for the ServerManager.
// This allows tests to trigger shutdown deterministically instead of relying on OS signals.
func (sm *ServerManager) WithShutdownChannel(ch <-chan struct{}) *ServerManager {
+ if sm == nil {
+ return nil
+ }
+
sm.shutdownChan = ch
return sm
}
+// WithShutdownTimeout configures the maximum duration to wait for gRPC GracefulStop
+// before forcing a hard stop. Defaults to 30 seconds.
+func (sm *ServerManager) WithShutdownTimeout(d time.Duration) *ServerManager {
+ if sm == nil {
+ return nil
+ }
+
+ sm.shutdownTimeout = d
+
+ return sm
+}
+
+// WithShutdownHook registers a function to be called during graceful shutdown.
+// Hooks are executed in registration order, AFTER HTTP server shutdown and
+// BEFORE telemetry shutdown. Each hook receives a context bounded by the
+// shutdown timeout. Errors from hooks are logged but do not prevent subsequent
+// hooks or the rest of the shutdown sequence from running (best-effort cleanup).
+func (sm *ServerManager) WithShutdownHook(hook func(context.Context) error) *ServerManager {
+ if sm == nil || hook == nil {
+ return sm
+ }
+
+ sm.shutdownHooks = append(sm.shutdownHooks, hook)
+
+ return sm
+}
+
// ServersStarted returns a channel that is closed when server goroutines have been launched.
// Note: This signals that goroutines were spawned, not that sockets are bound and ready to accept connections.
// This is useful for tests to coordinate shutdown timing after server launch.
+// Returns a closed channel on nil receiver to prevent callers from blocking forever.
func (sm *ServerManager) ServersStarted() <-chan struct{} {
+ if sm == nil {
+ ch := make(chan struct{})
+ close(ch)
+
+ return ch
+ }
+
return sm.serversStarted
}
@@ -94,9 +172,7 @@ func (sm *ServerManager) validateConfiguration() error {
// initServers validates configuration and starts servers without blocking.
// Returns an error if validation fails. Does not call Fatal.
func (sm *ServerManager) initServers() error {
- if sm.serversStarted == nil {
- sm.serversStarted = make(chan struct{})
- }
+ sm.ensureRuntimeDefaults()
if err := sm.validateConfiguration(); err != nil {
return err
@@ -111,13 +187,17 @@ func (sm *ServerManager) initServers() error {
// Returns an error if no servers are configured instead of calling Fatal.
// Blocks until shutdown signal is received or shutdown channel is closed.
func (sm *ServerManager) StartWithGracefulShutdownWithError() error {
+ if sm == nil {
+ return ErrNoServersConfigured
+ }
+
+ sm.ensureRuntimeDefaults()
+
if err := sm.initServers(); err != nil {
return err
}
- sm.handleShutdown()
-
- return nil
+ return sm.handleShutdown()
}
// StartWithGracefulShutdown initializes all configured servers and sets up graceful shutdown.
@@ -125,6 +205,13 @@ func (sm *ServerManager) StartWithGracefulShutdownWithError() error {
// Note: On configuration error, logFatal always terminates the process regardless of logger availability.
// Use StartWithGracefulShutdownWithError() for proper error handling without process termination.
func (sm *ServerManager) StartWithGracefulShutdown() {
+ if sm == nil {
+ fmt.Println("no servers configured: use WithHTTPServer() or WithGRPCServer()")
+ os.Exit(1)
+ }
+
+ sm.ensureRuntimeDefaults()
+
if err := sm.initServers(); err != nil {
// logFatal exits the process via os.Exit(1); code below is unreachable on error
sm.logFatal(err.Error())
@@ -133,11 +220,7 @@ func (sm *ServerManager) StartWithGracefulShutdown() {
// Run everything in a recover block
defer func() {
if r := recover(); r != nil {
- if sm.logger != nil {
- sm.logger.Errorf("Fatal error (panic): %v", r)
- } else {
- fmt.Printf("Fatal error (panic): %v\n", r)
- }
+ runtime.HandlePanicValue(context.Background(), sm.logger, r, "server", "StartWithGracefulShutdown")
sm.executeShutdown()
@@ -145,7 +228,7 @@ func (sm *ServerManager) StartWithGracefulShutdown() {
}
}()
- sm.handleShutdown()
+ _ = sm.handleShutdown()
}
// startServers starts all configured servers in separate goroutines.
@@ -157,37 +240,67 @@ func (sm *ServerManager) startServers() {
// Start HTTP server if configured
if sm.httpServer != nil {
- go func() {
- sm.logInfof("Starting HTTP server on %s", sm.httpAddress)
-
- if err := sm.httpServer.Listen(sm.httpAddress); err != nil {
- sm.logErrorf("HTTP server error: %v", err)
- }
- }()
+ runtime.SafeGoWithContextAndComponent(
+ context.Background(),
+ sm.logger,
+ "server",
+ "start_http_server",
+ runtime.KeepRunning,
+ func(_ context.Context) {
+ sm.logger.Log(context.Background(), log.LevelInfo, "starting HTTP server", log.String("address", sm.httpAddress))
+
+ if err := sm.httpServer.Listen(sm.httpAddress); err != nil {
+ sm.logger.Log(context.Background(), log.LevelError, "HTTP server error", log.Err(err))
+
+ select {
+ case sm.startupErrors <- fmt.Errorf("HTTP server: %w", err):
+ default:
+ }
+ }
+ },
+ )
started++
}
// Start gRPC server if configured
if sm.grpcServer != nil {
- go func() {
- sm.logInfof("Starting gRPC server on %s", sm.grpcAddress)
-
- listener, err := net.Listen("tcp", sm.grpcAddress)
- if err != nil {
- sm.logErrorf("Failed to listen on gRPC address: %v", err)
- return
- }
-
- if err := sm.grpcServer.Serve(listener); err != nil {
- sm.logErrorf("gRPC server error: %v", err)
- }
- }()
+ runtime.SafeGoWithContextAndComponent(
+ context.Background(),
+ sm.logger,
+ "server",
+ "start_grpc_server",
+ runtime.KeepRunning,
+ func(_ context.Context) {
+ sm.logger.Log(context.Background(), log.LevelInfo, "starting gRPC server", log.String("address", sm.grpcAddress))
+
+ listener, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", sm.grpcAddress)
+ if err != nil {
+ sm.logger.Log(context.Background(), log.LevelError, "failed to listen on gRPC address", log.Err(err))
+
+ select {
+ case sm.startupErrors <- fmt.Errorf("gRPC listen: %w", err):
+ default:
+ }
+
+ return
+ }
+
+ if err := sm.grpcServer.Serve(listener); err != nil {
+ sm.logger.Log(context.Background(), log.LevelError, "gRPC server error", log.Err(err))
+
+ select {
+ case sm.startupErrors <- fmt.Errorf("gRPC serve: %w", err):
+ default:
+ }
+ }
+ },
+ )
started++
}
- sm.logInfof("Launched %d server goroutine(s)", started)
+ sm.logger.Log(context.Background(), log.LevelInfo, "launched server goroutines", log.Int("count", started))
// Signal that server goroutines have been launched (not that sockets are bound).
sm.serversStartedOnce.Do(func() {
@@ -198,21 +311,7 @@ func (sm *ServerManager) startServers() {
// logInfo safely logs an info message if logger is available
func (sm *ServerManager) logInfo(msg string) {
if sm.logger != nil {
- sm.logger.Info(msg)
- }
-}
-
-// logInfof safely logs a formatted info message if logger is available
-func (sm *ServerManager) logInfof(format string, args ...any) {
- if sm.logger != nil {
- sm.logger.Infof(format, args...)
- }
-}
-
-// logErrorf safely logs an error message if logger is available
-func (sm *ServerManager) logErrorf(format string, args ...any) {
- if sm.logger != nil {
- sm.logger.Errorf(format, args...)
+ sm.logger.Log(context.Background(), log.LevelInfo, msg)
}
}
@@ -221,7 +320,7 @@ func (sm *ServerManager) logErrorf(format string, args ...any) {
// that may or may not call os.Exit(1) in their Fatal method.
func (sm *ServerManager) logFatal(msg string) {
if sm.logger != nil {
- sm.logger.Error(msg)
+ sm.logger.Log(context.Background(), log.LevelError, msg)
} else {
fmt.Println(msg)
}
@@ -230,149 +329,134 @@ func (sm *ServerManager) logFatal(msg string) {
}
// handleShutdown sets up signal handling and executes the shutdown sequence
-// when a termination signal is received or when the shutdown channel is closed.
-func (sm *ServerManager) handleShutdown() {
+// when a termination signal is received, when the shutdown channel is closed,
+// or when a server startup error is detected.
+// Returns the first startup error if one caused the shutdown, nil otherwise.
+func (sm *ServerManager) handleShutdown() error {
+ sm.ensureRuntimeDefaults()
+
+ var startupErr error
+
if sm.shutdownChan != nil {
- <-sm.shutdownChan
+ select {
+ case <-sm.shutdownChan:
+ case err := <-sm.startupErrors:
+ sm.logger.Log(context.Background(), log.LevelError, "server startup failed", log.Err(err))
+
+ startupErr = err
+ }
} else {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
- <-c
+
+ select {
+ case <-c:
+ signal.Stop(c)
+ case err := <-sm.startupErrors:
+ sm.logger.Log(context.Background(), log.LevelError, "server startup failed", log.Err(err))
+
+ startupErr = err
+ }
}
sm.logInfo("Gracefully shutting down all servers...")
sm.executeShutdown()
+
+ return startupErr
}
// executeShutdown performs the actual shutdown operations in the correct order for ServerManager.
+// It is idempotent: multiple calls are safe, but only the first invocation executes the shutdown sequence.
func (sm *ServerManager) executeShutdown() {
- // Use a non-blocking read to check if servers have started.
- // This prevents a deadlock if a panic occurs before startServers() completes.
- select {
- case <-sm.serversStarted:
- // Servers started, proceed with normal shutdown.
- default:
- // Servers did not start (or start was interrupted).
- sm.logInfo("Shutdown initiated before servers were fully started.")
- }
-
- // Shutdown the HTTP server if available
- if sm.httpServer != nil {
- sm.logInfo("Shutting down HTTP server...")
-
- if err := sm.httpServer.Shutdown(); err != nil {
- sm.logErrorf("Error during HTTP server shutdown: %v", err)
+ sm.ensureRuntimeDefaults()
+
+ sm.shutdownOnce.Do(func() {
+ // Use a non-blocking read to check if servers have started.
+ // This prevents a deadlock if a panic occurs before startServers() completes.
+ select {
+ case <-sm.serversStarted:
+ // Servers started, proceed with normal shutdown.
+ default:
+ // Servers did not start (or start was interrupted).
+ sm.logInfo("Shutdown initiated before servers were fully started.")
}
- }
- // Shutdown telemetry BEFORE gRPC server to allow metrics export
- if sm.telemetry != nil {
- sm.logInfo("Shutting down telemetry...")
- sm.telemetry.ShutdownTelemetry()
- }
-
- // Shutdown the gRPC server if available
- if sm.grpcServer != nil {
- sm.logInfo("Shutting down gRPC server...")
-
- // Use GracefulStop which waits for all RPCs to finish
- sm.grpcServer.GracefulStop()
- sm.logInfo("gRPC server stopped gracefully")
- }
-
- // Sync logger if available
- if sm.logger != nil {
- sm.logInfo("Syncing logger...")
+ // Shutdown the HTTP server if available
+ if sm.httpServer != nil {
+ sm.logInfo("Shutting down HTTP server...")
- if err := sm.logger.Sync(); err != nil {
- sm.logErrorf("Failed to sync logger: %v", err)
+ if err := sm.httpServer.Shutdown(); err != nil {
+ sm.logger.Log(context.Background(), log.LevelError, "error during HTTP server shutdown", log.Err(err))
+ }
}
- }
-
- sm.logInfo("Graceful shutdown completed")
-}
-
-// GracefulShutdown handles the graceful shutdown of application components.
-// It's designed to be reusable across different services.
-// Deprecated: Use ServerManager instead for better coordination.
-type GracefulShutdown struct {
- app *fiber.App
- grpcServer *grpc.Server
- licenseClient *license.ManagerShutdown
- telemetry *opentelemetry.Telemetry
- logger log.Logger
-}
-
-// NewGracefulShutdown creates a new instance of GracefulShutdown.
-// Deprecated: Use NewServerManager instead for better coordination.
-func NewGracefulShutdown(
- app *fiber.App,
- grpcServer *grpc.Server,
- licenseClient *license.ManagerShutdown,
- telemetry *opentelemetry.Telemetry,
- logger log.Logger,
-) *GracefulShutdown {
- return &GracefulShutdown{
- app: app,
- grpcServer: grpcServer,
- licenseClient: licenseClient,
- telemetry: telemetry,
- logger: logger,
- }
-}
-
-// HandleShutdown sets up signal handling and executes the shutdown sequence
-// when a termination signal is received.
-// Deprecated: Use ServerManager.StartWithGracefulShutdown() instead.
-func (gs *GracefulShutdown) HandleShutdown() {
- // Create channel for shutdown signals
- c := make(chan os.Signal, 1)
- signal.Notify(c, os.Interrupt, syscall.SIGTERM)
-
- // Block until we receive a signal
- <-c
- gs.logger.Info("Gracefully shutting down...")
- // Execute shutdown sequence
- gs.executeShutdown()
-}
-
-// executeShutdown performs the actual shutdown operations in the correct order.
-// Deprecated: Use ServerManager.executeShutdown() for better coordination.
-func (gs *GracefulShutdown) executeShutdown() {
- // Shutdown the HTTP server if available
- if gs.app != nil {
- gs.logger.Info("Shutting down HTTP server...")
+ // Execute shutdown hooks (best-effort, between HTTP and telemetry shutdown).
+ // Each hook gets its own context with an independent timeout to prevent
+ // one slow hook from consuming the entire budget.
+ for i, hook := range sm.shutdownHooks {
+ hookCtx, hookCancel := context.WithTimeout(context.Background(), sm.shutdownTimeout)
+
+ if err := hook(hookCtx); err != nil {
+ sm.logger.Log(context.Background(), log.LevelError, "shutdown hook failed",
+ log.Int("hook_index", i),
+ log.Err(err),
+ )
+ }
- if err := gs.app.Shutdown(); err != nil {
- gs.logger.Errorf("Error during HTTP server shutdown: %v", err)
+ hookCancel()
}
- }
- // Shutdown the gRPC server if available
- if gs.grpcServer != nil {
- gs.logger.Info("Shutting down gRPC server...")
+ // Shutdown the gRPC server BEFORE telemetry to allow in-flight RPCs
+ // to complete and emit their final spans/metrics before the telemetry
+ // pipeline is torn down.
+ if sm.grpcServer != nil {
+ sm.logInfo("Shutting down gRPC server...")
+
+ done := make(chan struct{})
+
+ runtime.SafeGoWithContextAndComponent(
+ context.Background(),
+ sm.logger,
+ "server",
+ "grpc_graceful_stop",
+ runtime.KeepRunning,
+ func(_ context.Context) {
+ sm.grpcServer.GracefulStop()
+ close(done)
+ },
+ )
+
+ select {
+ case <-done:
+ sm.logInfo("gRPC server stopped gracefully")
+ case <-time.After(sm.shutdownTimeout):
+ sm.logInfo("gRPC graceful stop timed out, forcing stop...")
+ sm.grpcServer.Stop()
+ }
+ }
- // Use GracefulStop which waits for all RPCs to finish
- gs.grpcServer.GracefulStop()
- gs.logger.Info("gRPC server stopped gracefully")
- }
+ // Shutdown telemetry AFTER servers have drained, so final spans/metrics are exported.
+ if sm.telemetry != nil {
+ sm.logInfo("Shutting down telemetry...")
+ sm.telemetry.ShutdownTelemetry()
+ }
- // Shutdown telemetry if available
- if gs.telemetry != nil {
- gs.logger.Info("Shutting down telemetry...")
- gs.telemetry.ShutdownTelemetry()
- }
+ // Sync logger if available
+ if sm.logger != nil {
+ sm.logInfo("Syncing logger...")
- // Sync logger if available
- if gs.logger != nil {
- gs.logger.Info("Syncing logger...")
+ if err := sm.logger.Sync(context.Background()); err != nil {
+ sm.logger.Log(context.Background(), log.LevelError, "failed to sync logger", log.Err(err))
+ }
+ }
- if err := gs.logger.Sync(); err != nil {
- gs.logger.Errorf("Failed to sync logger: %v", err)
+ // Shutdown license background refresh if available
+ if sm.licenseClient != nil {
+ sm.logInfo("Shutting down license background refresh...")
+ sm.licenseClient.Terminate("shutdown")
}
- }
- gs.logger.Info("Graceful shutdown completed")
+ sm.logInfo("Graceful shutdown completed")
+ })
}
diff --git a/commons/server/shutdown_example_test.go b/commons/server/shutdown_example_test.go
new file mode 100644
index 00000000..0f559eee
--- /dev/null
+++ b/commons/server/shutdown_example_test.go
@@ -0,0 +1,20 @@
+//go:build unit
+
+package server_test
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/server"
+)
+
+func ExampleServerManager_StartWithGracefulShutdownWithError_validation() {
+ sm := server.NewServerManager(nil, nil, nil)
+ err := sm.StartWithGracefulShutdownWithError()
+
+ fmt.Println(errors.Is(err, server.ErrNoServersConfigured))
+
+ // Output:
+ // true
+}
diff --git a/commons/server/shutdown_integration_test.go b/commons/server/shutdown_integration_test.go
new file mode 100644
index 00000000..1e92102f
--- /dev/null
+++ b/commons/server/shutdown_integration_test.go
@@ -0,0 +1,373 @@
+//go:build integration
+
+package server_test
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/server"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+)
+
+// getFreePort allocates a free TCP port from the OS, closes the listener, and
+// returns the port as a ":PORT" string suitable for Fiber's Listen or gRPC's
+// net.Listen. There is a small TOCTOU window, but for integration tests on
+// localhost this is reliable enough.
+func getFreePort(t *testing.T) string {
+ t.Helper()
+
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ port := l.Addr().(*net.TCPAddr).Port
+
+ require.NoError(t, l.Close())
+
+ return fmt.Sprintf(":%d", port)
+}
+
+// waitForTCP polls a TCP address until it accepts connections or the timeout
+// expires. This bridges the gap between ServersStarted() (goroutine launched)
+// and the socket actually being bound and ready.
+func waitForTCP(t *testing.T, addr string, timeout time.Duration) {
+ t.Helper()
+
+ deadline := time.Now().Add(timeout)
+
+ for time.Now().Before(deadline) {
+ conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
+ if err == nil {
+ require.NoError(t, conn.Close())
+ return
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ t.Fatalf("TCP address %s did not become available within %s", addr, timeout)
+}
+
+// TestIntegration_ServerManager_HTTPLifecycle verifies the full HTTP server
+// lifecycle: start → serve requests → graceful shutdown → clean exit.
+func TestIntegration_ServerManager_HTTPLifecycle(t *testing.T) {
+ addr := getFreePort(t)
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
+
+ app.Get("/ping", func(c *fiber.Ctx) error {
+ return c.SendString("pong")
+ })
+
+ shutdownChan := make(chan struct{})
+ logger := log.NewNop()
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, addr).
+ WithShutdownChannel(shutdownChan).
+ WithShutdownTimeout(5 * time.Second)
+
+ resultCh := make(chan error, 1)
+
+ go func() {
+ resultCh <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ // Wait for the goroutine launch signal.
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for ServersStarted signal")
+ }
+
+ // Wait for the socket to actually accept connections.
+ hostAddr := "127.0.0.1" + addr
+ waitForTCP(t, hostAddr, 5*time.Second)
+
+ // Verify the HTTP endpoint is serving correctly.
+ resp, err := http.Get("http://" + hostAddr + "/ping")
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.Equal(t, "pong", string(body))
+
+ // Trigger graceful shutdown.
+ close(shutdownChan)
+
+ // Verify clean exit.
+ select {
+ case err := <-resultCh:
+ assert.NoError(t, err, "StartWithGracefulShutdownWithError should return nil on clean shutdown")
+ case <-time.After(10 * time.Second):
+ t.Fatal("timed out waiting for server to shut down")
+ }
+
+ // Verify the server is no longer accepting connections.
+ _, err = net.DialTimeout("tcp", hostAddr, 200*time.Millisecond)
+ assert.Error(t, err, "server should no longer accept connections after shutdown")
+}
+
+// TestIntegration_ServerManager_ShutdownHooksExecuted verifies that registered
+// shutdown hooks are invoked during graceful shutdown, in registration order.
+func TestIntegration_ServerManager_ShutdownHooksExecuted(t *testing.T) {
+ addr := getFreePort(t)
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
+
+ app.Get("/health", func(c *fiber.Ctx) error {
+ return c.SendStatus(fiber.StatusOK)
+ })
+
+ shutdownChan := make(chan struct{})
+ logger := log.NewNop()
+
+ var hook1Called atomic.Int64
+ var hook2Called atomic.Int64
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, addr).
+ WithShutdownChannel(shutdownChan).
+ WithShutdownTimeout(5 * time.Second).
+ WithShutdownHook(func(_ context.Context) error {
+ hook1Called.Add(1)
+ return nil
+ }).
+ WithShutdownHook(func(_ context.Context) error {
+ hook2Called.Add(1)
+ return nil
+ })
+
+ resultCh := make(chan error, 1)
+
+ go func() {
+ resultCh <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for ServersStarted signal")
+ }
+
+ hostAddr := "127.0.0.1" + addr
+ waitForTCP(t, hostAddr, 5*time.Second)
+
+ // Confirm hooks haven't fired prematurely.
+ assert.Equal(t, int64(0), hook1Called.Load(), "hook1 should not fire before shutdown")
+ assert.Equal(t, int64(0), hook2Called.Load(), "hook2 should not fire before shutdown")
+
+ // Trigger shutdown.
+ close(shutdownChan)
+
+ select {
+ case err := <-resultCh:
+ assert.NoError(t, err)
+ case <-time.After(10 * time.Second):
+ t.Fatal("timed out waiting for shutdown")
+ }
+
+ // Both hooks must have been called exactly once.
+ assert.Equal(t, int64(1), hook1Called.Load(), "hook1 should be called exactly once")
+ assert.Equal(t, int64(1), hook2Called.Load(), "hook2 should be called exactly once")
+}
+
+// TestIntegration_ServerManager_GRPCLifecycle verifies the full gRPC server
+// lifecycle: start → accept connections → graceful shutdown → clean exit.
+func TestIntegration_ServerManager_GRPCLifecycle(t *testing.T) {
+ addr := getFreePort(t)
+
+ grpcServer := grpc.NewServer()
+ shutdownChan := make(chan struct{})
+ logger := log.NewNop()
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithGRPCServer(grpcServer, addr).
+ WithShutdownChannel(shutdownChan).
+ WithShutdownTimeout(5 * time.Second)
+
+ resultCh := make(chan error, 1)
+
+ go func() {
+ resultCh <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for ServersStarted signal")
+ }
+
+ hostAddr := "127.0.0.1" + addr
+ waitForTCP(t, hostAddr, 5*time.Second)
+
+ // Verify gRPC connectivity by establishing a client connection.
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ conn, err := grpc.NewClient(
+ hostAddr,
+ grpc.WithTransportCredentials(insecure.NewCredentials()),
+ )
+ require.NoError(t, err, "should be able to create a gRPC client")
+
+ defer conn.Close()
+
+ // Verify the transport layer is reachable by waiting for the connection
+ // state to become ready or idle (both indicate the server accepted the TCP
+ // handshake). We use WaitForStateChange to avoid spinning.
+ conn.Connect()
+
+ // Give the connection a moment to transition from IDLE.
+ state := conn.GetState()
+ if state.String() == "IDLE" {
+ conn.WaitForStateChange(ctx, state)
+ }
+
+ currentState := conn.GetState()
+ assert.NotEqual(t, "SHUTDOWN", currentState.String(),
+ "gRPC connection should not be in SHUTDOWN state while server is running")
+
+ // Trigger graceful shutdown.
+ close(shutdownChan)
+
+ select {
+ case err := <-resultCh:
+ assert.NoError(t, err, "gRPC server should shut down cleanly")
+ case <-time.After(10 * time.Second):
+ t.Fatal("timed out waiting for gRPC server to shut down")
+ }
+}
+
+// TestIntegration_ServerManager_NoServersError verifies that starting a
+// ServerManager with no configured servers returns ErrNoServersConfigured
+// immediately and synchronously.
+func TestIntegration_ServerManager_NoServersError(t *testing.T) {
+ logger := log.NewNop()
+ sm := server.NewServerManager(nil, nil, logger)
+
+ err := sm.StartWithGracefulShutdownWithError()
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, server.ErrNoServersConfigured,
+ "expected ErrNoServersConfigured when no servers are configured")
+}
+
+// TestIntegration_ServerManager_InFlightRequestsDrained verifies that the
+// graceful shutdown waits for in-flight HTTP requests to complete before the
+// server exits. This is the fundamental property of graceful shutdown: no
+// request is dropped mid-flight.
+func TestIntegration_ServerManager_InFlightRequestsDrained(t *testing.T) {
+ addr := getFreePort(t)
+
+ const slowEndpointDuration = 500 * time.Millisecond
+
+ var requestCompleted atomic.Bool
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
+
+ app.Get("/slow", func(c *fiber.Ctx) error {
+ time.Sleep(slowEndpointDuration)
+ requestCompleted.Store(true)
+
+ return c.SendString("done")
+ })
+
+ shutdownChan := make(chan struct{})
+ logger := log.NewNop()
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, addr).
+ WithShutdownChannel(shutdownChan).
+ WithShutdownTimeout(10 * time.Second)
+
+ serverResultCh := make(chan error, 1)
+
+ go func() {
+ serverResultCh <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for ServersStarted signal")
+ }
+
+ hostAddr := "127.0.0.1" + addr
+ waitForTCP(t, hostAddr, 5*time.Second)
+
+ // Launch the slow request in a background goroutine.
+ requestResultCh := make(chan *http.Response, 1)
+ requestErrCh := make(chan error, 1)
+
+ go func() {
+ client := &http.Client{Timeout: 10 * time.Second}
+
+ resp, err := client.Get("http://" + hostAddr + "/slow")
+ if err != nil {
+ requestErrCh <- err
+ return
+ }
+
+ requestResultCh <- resp
+ }()
+
+ // Give the request a moment to arrive at the server and begin processing.
+ // This ensures the request is genuinely in-flight before we trigger shutdown.
+ time.Sleep(100 * time.Millisecond)
+
+ // Trigger shutdown while the request is still being processed.
+ assert.False(t, requestCompleted.Load(),
+ "slow request should still be in-flight when shutdown is triggered")
+
+ close(shutdownChan)
+
+ // Wait for the in-flight request to complete.
+ select {
+ case resp := <-requestResultCh:
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Equal(t, http.StatusOK, resp.StatusCode,
+ "in-flight request should receive a successful response")
+ assert.Equal(t, "done", string(body),
+ "in-flight request should receive the complete response body")
+ case err := <-requestErrCh:
+ t.Fatalf("in-flight request failed (request was dropped during shutdown): %v", err)
+ case <-time.After(15 * time.Second):
+ t.Fatal("timed out waiting for in-flight request to complete")
+ }
+
+ // Verify the request handler ran to completion.
+ assert.True(t, requestCompleted.Load(),
+ "slow request handler should have completed before server exited")
+
+ // Verify the server exited cleanly.
+ select {
+ case err := <-serverResultCh:
+ assert.NoError(t, err, "server should exit cleanly after draining in-flight requests")
+ case <-time.After(10 * time.Second):
+ t.Fatal("timed out waiting for server to exit after shutdown")
+ }
+}
diff --git a/commons/server/shutdown_test.go b/commons/server/shutdown_test.go
index d6abd7cb..7e9b1e18 100644
--- a/commons/server/shutdown_test.go
+++ b/commons/server/shutdown_test.go
@@ -1,28 +1,51 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package server_test
import (
+ "context"
"errors"
+ "net"
+ "sync"
"testing"
"time"
- "github.com/LerianStudio/lib-commons/v2/commons/server"
+ "github.com/LerianStudio/lib-commons/v4/commons/license"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/server"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
-func TestNewGracefulShutdown(t *testing.T) {
- gs := server.NewGracefulShutdown(nil, nil, nil, nil, nil)
- assert.NotNil(t, gs, "NewGracefulShutdown should return a non-nil instance")
+// recordingLogger is a Logger that records messages and can return a Sync error.
+type recordingLogger struct {
+ mu sync.Mutex
+ messages []string
+ syncErr error
}
-func TestNewGracefulShutdownWithGRPC(t *testing.T) {
- gs := server.NewGracefulShutdown(nil, nil, nil, nil, nil)
- assert.NotNil(t, gs, "NewGracefulShutdown should return a non-nil instance with gRPC server")
+func (l *recordingLogger) Log(_ context.Context, _ log.Level, msg string, _ ...log.Field) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ l.messages = append(l.messages, msg)
+}
+
+func (l *recordingLogger) With(_ ...log.Field) log.Logger { return l }
+func (l *recordingLogger) WithGroup(_ string) log.Logger { return l }
+func (l *recordingLogger) Enabled(_ log.Level) bool { return true }
+func (l *recordingLogger) Sync(_ context.Context) error { return l.syncErr }
+func (l *recordingLogger) getMessages() []string {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ cp := make([]string, len(l.messages))
+ copy(cp, l.messages)
+
+ return cp
}
func TestNewServerManager(t *testing.T) {
@@ -31,7 +54,9 @@ func TestNewServerManager(t *testing.T) {
}
func TestServerManagerWithHTTPOnly(t *testing.T) {
- app := fiber.New()
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
sm := server.NewServerManager(nil, nil, nil).
WithHTTPServer(app, ":8080")
assert.NotNil(t, sm, "ServerManager with HTTP server should return a non-nil instance")
@@ -45,7 +70,9 @@ func TestServerManagerWithGRPCOnly(t *testing.T) {
}
func TestServerManagerWithBothServers(t *testing.T) {
- app := fiber.New()
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
grpcServer := grpc.NewServer()
sm := server.NewServerManager(nil, nil, nil).
WithHTTPServer(app, ":8080").
@@ -54,7 +81,9 @@ func TestServerManagerWithBothServers(t *testing.T) {
}
func TestServerManagerChaining(t *testing.T) {
- app := fiber.New()
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
grpcServer := grpc.NewServer()
// Test method chaining
@@ -79,7 +108,9 @@ func TestErrNoServersConfigured(t *testing.T) {
}
func TestStartWithGracefulShutdownWithError_HTTPServer_Success(t *testing.T) {
- app := fiber.New()
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
shutdownChan := make(chan struct{})
sm := server.NewServerManager(nil, nil, nil).
@@ -139,7 +170,9 @@ func TestStartWithGracefulShutdownWithError_GRPCServer_Success(t *testing.T) {
}
func TestStartWithGracefulShutdownWithError_BothServers_Success(t *testing.T) {
- app := fiber.New()
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
grpcServer := grpc.NewServer()
shutdownChan := make(chan struct{})
@@ -176,3 +209,671 @@ func TestWithShutdownChannel(t *testing.T) {
WithShutdownChannel(shutdownChan)
assert.NotNil(t, sm, "WithShutdownChannel should return a non-nil instance")
}
+
+func TestWithShutdownTimeout(t *testing.T) {
+ sm := server.NewServerManager(nil, nil, nil).
+ WithShutdownTimeout(10 * time.Second)
+ assert.NotNil(t, sm, "WithShutdownTimeout should return a non-nil instance")
+}
+
+func TestStartWithGracefulShutdownWithError_HTTPStartupError(t *testing.T) {
+ // Bind a port so the HTTP server will fail to listen
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ assert.NoError(t, err)
+ defer ln.Close()
+
+ occupiedAddr := ln.Addr().String()
+
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
+
+ sm := server.NewServerManager(nil, nil, nil).
+ WithHTTPServer(app, occupiedAddr)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ // The startup error should propagate through the return value.
+ select {
+ case err := <-done:
+ require.Error(t, err, "StartWithGracefulShutdownWithError should propagate startup error")
+ assert.Contains(t, err.Error(), "HTTP server")
+ case <-time.After(10 * time.Second):
+ t.Fatal("Test timed out: startup error was not propagated")
+ }
+}
+
+func TestExecuteShutdown_Idempotent(t *testing.T) {
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(nil, nil, nil).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Test timed out waiting for servers to start")
+ }
+
+ // Trigger shutdown
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Test timed out waiting for shutdown")
+ }
+
+ // Second shutdown call should be safe (no-op due to sync.Once)
+ assert.NotPanics(t, func() {
+ // Call StartWithGracefulShutdownWithError again - it will call executeShutdown
+ // but sync.Once ensures the shutdown body runs only once
+ // We can't call it directly since executeShutdown is unexported,
+ // but we can verify the manager is stable after shutdown
+ _ = sm.StartWithGracefulShutdownWithError()
+ }, "Second invocation after shutdown should not panic")
+}
+
+func TestStartWithGracefulShutdownWithError_GRPCShutdownTimeout(t *testing.T) {
+ grpcServer := grpc.NewServer()
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(nil, nil, nil).
+ WithGRPCServer(grpcServer, ":0").
+ WithShutdownChannel(shutdownChan).
+ WithShutdownTimeout(1 * time.Second)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Test timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err, "Shutdown with timeout should complete without error")
+ case <-time.After(10 * time.Second):
+ t.Fatal("Test timed out: gRPC shutdown timeout did not work")
+ }
+}
+
+func TestServerManager_NilLoggerSafe(t *testing.T) {
+ app := fiber.New(fiber.Config{
+ DisableStartupMessage: true,
+ })
+ shutdownChan := make(chan struct{})
+
+ // Explicitly pass nil logger
+ sm := server.NewServerManager(nil, nil, nil).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Test timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err, "Nil logger should not cause panics during lifecycle")
+ case <-time.After(5 * time.Second):
+ t.Fatal("Test timed out")
+ }
+}
+
+func TestStartWithGracefulShutdownWithError_ManualZeroValueManager_NoPanic(t *testing.T) {
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+ close(shutdownChan)
+
+ // Use a manually instantiated zero-value manager to verify nil-safe defaults.
+ sm := (&server.ServerManager{}).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ assert.NotPanics(t, func() {
+ err := sm.StartWithGracefulShutdownWithError()
+ assert.NoError(t, err)
+ })
+}
+
+func TestExecuteShutdown_WithTelemetry(t *testing.T) {
+ logger := &recordingLogger{}
+
+ tel, err := opentelemetry.NewTelemetry(opentelemetry.TelemetryConfig{
+ EnableTelemetry: false,
+ Logger: logger,
+ LibraryName: "test",
+ })
+ require.NoError(t, err)
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(nil, tel, logger).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ msgs := logger.getMessages()
+ assert.Contains(t, msgs, "Shutting down telemetry...")
+}
+
+func TestExecuteShutdown_WithLicenseClient(t *testing.T) {
+ logger := &recordingLogger{}
+ lc := license.New()
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(lc, nil, logger).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ msgs := logger.getMessages()
+ assert.Contains(t, msgs, "Shutting down license background refresh...")
+}
+
+func TestExecuteShutdown_LoggerSyncError(t *testing.T) {
+ logger := &recordingLogger{syncErr: errors.New("sync failed")}
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ msgs := logger.getMessages()
+ assert.Contains(t, msgs, "failed to sync logger")
+}
+
+func TestExecuteShutdown_WithAllComponents(t *testing.T) {
+ logger := &recordingLogger{}
+
+ tel, err := opentelemetry.NewTelemetry(opentelemetry.TelemetryConfig{
+ EnableTelemetry: false,
+ Logger: logger,
+ LibraryName: "test",
+ })
+ require.NoError(t, err)
+
+ lc := license.New()
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ grpcServer := grpc.NewServer()
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(lc, tel, logger).
+ WithHTTPServer(app, ":0").
+ WithGRPCServer(grpcServer, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(10 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ msgs := logger.getMessages()
+ assert.Contains(t, msgs, "Shutting down telemetry...")
+ assert.Contains(t, msgs, "Shutting down license background refresh...")
+ assert.Contains(t, msgs, "Graceful shutdown completed")
+}
+
+func TestStartWithGracefulShutdownWithError_GRPCStartupError(t *testing.T) {
+ // Bind a port so the gRPC server will fail to listen
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ defer ln.Close()
+
+ occupiedAddr := ln.Addr().String()
+
+ logger := &recordingLogger{}
+ grpcServer := grpc.NewServer()
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithGRPCServer(grpcServer, occupiedAddr)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case err := <-done:
+ require.Error(t, err, "StartWithGracefulShutdownWithError should propagate gRPC startup error")
+ assert.Contains(t, err.Error(), "gRPC listen")
+ case <-time.After(10 * time.Second):
+ t.Fatal("Timed out: gRPC startup error was not propagated")
+ }
+}
+
+func TestExecuteShutdown_HTTPShutdownError(t *testing.T) {
+ logger := &recordingLogger{}
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ // Shut down HTTP server manually before triggering shutdown to cause error
+ _ = app.Shutdown()
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+}
+
+func TestStartWithGracefulShutdownWithError_WithRealLogger(t *testing.T) {
+ logger := &recordingLogger{}
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ msgs := logger.getMessages()
+ assert.Contains(t, msgs, "Gracefully shutting down all servers...")
+ assert.Contains(t, msgs, "Syncing logger...")
+ assert.Contains(t, msgs, "Graceful shutdown completed")
+}
+
+func TestStartWithGracefulShutdownWithError_StartupErrorViaOSSignalPath(t *testing.T) {
+ // Exercise the OS-signal path in handleShutdown with a startup error
+ // (no shutdown channel, so it hits the else branch with signal.Notify).
+ logger := &recordingLogger{}
+
+ // Use an occupied port so the HTTP server fails immediately.
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ defer ln.Close()
+
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, ln.Addr().String())
+ // No WithShutdownChannel — uses the OS signal path.
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case err := <-done:
+ require.Error(t, err, "StartWithGracefulShutdownWithError should propagate startup error via OS signal path")
+ assert.Contains(t, err.Error(), "HTTP server")
+ case <-time.After(10 * time.Second):
+ t.Fatal("Timed out: startup error via OS signal path was not propagated")
+ }
+}
+
+// --- Shutdown Hook Tests ---
+
+func TestShutdownHook_NilFunctionIgnored(t *testing.T) {
+ t.Parallel()
+
+ sm := server.NewServerManager(nil, nil, nil)
+ result := sm.WithShutdownHook(nil)
+
+ // WithShutdownHook(nil) must return the same manager without appending.
+ assert.Same(t, sm, result, "WithShutdownHook(nil) should return the same ServerManager")
+
+ // Prove no hook was registered: run a full shutdown lifecycle and confirm
+ // only the standard messages appear (no "shutdown hook failed" noise).
+ logger := &recordingLogger{}
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ sm2 := server.NewServerManager(nil, nil, logger).
+ WithShutdownHook(nil). // nil hook — should be silently ignored
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan)
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm2.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm2.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ msgs := logger.getMessages()
+ for _, msg := range msgs {
+ assert.NotContains(t, msg, "shutdown hook failed",
+ "no hooks should have executed when only nil was registered")
+ }
+}
+
+func TestShutdownHook_NilServerManager(t *testing.T) {
+ t.Parallel()
+
+ var sm *server.ServerManager
+
+ // Calling WithShutdownHook on a nil receiver must not panic
+ // and must return nil.
+ assert.NotPanics(t, func() {
+ result := sm.WithShutdownHook(func(_ context.Context) error { return nil })
+ assert.Nil(t, result, "WithShutdownHook on nil receiver should return nil")
+ }, "WithShutdownHook on nil receiver must not panic")
+}
+
+func TestShutdownHook_StartWithGracefulShutdownWithError_NilReceiver(t *testing.T) {
+ t.Parallel()
+
+ var sm *server.ServerManager
+
+ err := sm.StartWithGracefulShutdownWithError()
+ require.ErrorIs(t, err, server.ErrNoServersConfigured,
+ "nil receiver should return ErrNoServersConfigured")
+}
+
+func TestShutdownHook_ExecuteInOrder(t *testing.T) {
+ t.Parallel()
+
+ logger := &recordingLogger{}
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ // mu + order track hook execution sequence.
+ var mu sync.Mutex
+
+ var order []int
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan).
+ WithShutdownHook(func(_ context.Context) error {
+ mu.Lock()
+ defer mu.Unlock()
+
+ order = append(order, 1)
+
+ return nil
+ }).
+ WithShutdownHook(func(_ context.Context) error {
+ mu.Lock()
+ defer mu.Unlock()
+
+ order = append(order, 2)
+
+ return nil
+ }).
+ WithShutdownHook(func(_ context.Context) error {
+ mu.Lock()
+ defer mu.Unlock()
+
+ order = append(order, 3)
+
+ return nil
+ })
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ mu.Lock()
+ defer mu.Unlock()
+
+ require.Len(t, order, 3, "all three hooks must execute")
+ assert.Equal(t, []int{1, 2, 3}, order, "hooks must execute in registration order")
+}
+
+func TestShutdownHook_ErrorDoesNotStopSubsequentHooks(t *testing.T) {
+ t.Parallel()
+
+ logger := &recordingLogger{}
+ app := fiber.New(fiber.Config{DisableStartupMessage: true})
+ shutdownChan := make(chan struct{})
+
+ hookErr := errors.New("hook1 intentional failure")
+
+ var mu sync.Mutex
+
+ var executed []int
+
+ sm := server.NewServerManager(nil, nil, logger).
+ WithHTTPServer(app, ":0").
+ WithShutdownChannel(shutdownChan).
+ WithShutdownHook(func(_ context.Context) error {
+ mu.Lock()
+ defer mu.Unlock()
+
+ executed = append(executed, 1)
+
+ return hookErr
+ }).
+ WithShutdownHook(func(_ context.Context) error {
+ mu.Lock()
+ defer mu.Unlock()
+
+ executed = append(executed, 2)
+
+ return nil
+ }).
+ WithShutdownHook(func(_ context.Context) error {
+ mu.Lock()
+ defer mu.Unlock()
+
+ executed = append(executed, 3)
+
+ return nil
+ })
+
+ done := make(chan error, 1)
+
+ go func() {
+ done <- sm.StartWithGracefulShutdownWithError()
+ }()
+
+ select {
+ case <-sm.ServersStarted():
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for servers to start")
+ }
+
+ close(shutdownChan)
+
+ select {
+ case err := <-done:
+ assert.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("Timed out waiting for shutdown")
+ }
+
+ // All three hooks must have run despite hook1 returning an error.
+ mu.Lock()
+ defer mu.Unlock()
+
+ require.Len(t, executed, 3, "all three hooks must execute even when one fails")
+ assert.Equal(t, []int{1, 2, 3}, executed,
+ "hooks must execute in order regardless of prior errors")
+
+ // Verify the error from hook1 was logged.
+ msgs := logger.getMessages()
+ assert.Contains(t, msgs, "shutdown hook failed",
+ "failing hook error should be logged")
+}
diff --git a/commons/shell/logo.txt b/commons/shell/logo.txt
index e179b1f2..47c966df 100644
--- a/commons/shell/logo.txt
+++ b/commons/shell/logo.txt
@@ -1,7 +1,7 @@
_ _ _
- | (_) |__ ___ ___ _ __ ___ _ __ ___ ___ _ __ ___
- | | | '_ \ _____ / __/ _ \| '_ ` _ \| '_ ` _ \ / _ \| '_ \/ __|
- | | | |_) |_____| (_| (_) | | | | | | | | | | | (_) | | | \__ \
- |_|_|_.__/ \___\___/|_| |_| |_|_| |_| |_|\___/|_| |_|___/
+ | (_) |__ _ _ _ __ ___ ___ _ __ ___ _ __ ___ ___ _ __ ___
+ | | | '_ \ _____ | | | | '_ \ / __/ _ \| '_ ` _ \| '_ ` _ \ / _ \| '_ \/ __|
+ | | | |_) |_____|| |_| | | | | (_| (_) | | | | | | | | | | | (_) | | | \__ \
+ |_|_|_.__/ \__,_|_| |_|\___\___/|_| |_| |_|_| |_| |_|\___/|_| |_|___/
- LERIAN.STUDIO ENGINEERING TEAM 🚀
\ No newline at end of file
+ LERIAN.STUDIO ENGINEERING TEAM
diff --git a/commons/shell/makefile_colors.mk b/commons/shell/makefile_colors.mk
index ffa9b0a5..dd8a5afb 100644
--- a/commons/shell/makefile_colors.mk
+++ b/commons/shell/makefile_colors.mk
@@ -3,7 +3,7 @@
# to be included in all component Makefiles
# ANSI color codes
-BLUE := \033[36m
+BLUE := \033[34m
NC := \033[0m
BOLD := \033[1m
RED := \033[31m
diff --git a/commons/shell/makefile_utils.mk b/commons/shell/makefile_utils.mk
index 4dedb9df..6d1b9641 100644
--- a/commons/shell/makefile_utils.mk
+++ b/commons/shell/makefile_utils.mk
@@ -1,39 +1,11 @@
-# Shell utility functions for Makefiles
-# This file contains standardized shell utility functions
-# to be included in all component Makefiles
-
-# Docker version detection
-DOCKER_VERSION := $(shell docker version --format '{{.Server.Version}}' 2>/dev/null || echo "0.0.0")
-DOCKER_MIN_VERSION := 20.10.13
-
-DOCKER_CMD := $(shell \
- if [ "$(shell printf '%s\n' "$(DOCKER_MIN_VERSION)" "$(DOCKER_VERSION)" | sort -V | head -n1)" = "$(DOCKER_MIN_VERSION)" ]; then \
- echo "docker compose"; \
- else \
- echo "docker-compose"; \
- fi \
-)
-
-# Border function for creating section headers
-define border
- @echo ""; \
- len=$$(echo "$(1)" | wc -c); \
- for i in $$(seq 1 $$((len + 4))); do \
- printf "-"; \
- done; \
- echo ""; \
- echo " $(1) "; \
- for i in $$(seq 1 $$((len + 4))); do \
- printf "-"; \
- done; \
- echo ""
-endef
-
-# Title function with emoji
-define title1
- @$(call border, "📝 $(1)")
-endef
-
-define title2
- @$(call border, "🔍 $(1)")
+# Makefile utility functions for lib-commons
+# Included by the root Makefile
+
+# Check that a command exists, or print install instructions and exit.
+# Usage: $(call check_command,,)
+define check_command
+ @command -v $(1) >/dev/null 2>&1 || { \
+ echo "$(RED)$(BOLD)Error:$(NC) '$(1)' is not installed. $(2)"; \
+ exit 1; \
+ }
endef
diff --git a/commons/stringUtils.go b/commons/stringUtils.go
index 68e7012c..c3af402a 100644
--- a/commons/stringUtils.go
+++ b/commons/stringUtils.go
@@ -1,22 +1,22 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package commons
import (
"bytes"
"crypto/sha256"
"encoding/hex"
- "golang.org/x/text/runes"
- "golang.org/x/text/transform"
- "golang.org/x/text/unicode/norm"
+ "net"
"regexp"
"strconv"
"strings"
"unicode"
+
+ "golang.org/x/text/runes"
+ "golang.org/x/text/transform"
+ "golang.org/x/text/unicode/norm"
)
+var uuidPattern = regexp.MustCompile(`[0-9a-fA-F-]{36}`)
+
// RemoveAccents removes accents of a given word and returns it
func RemoveAccents(word string) (string, error) {
t := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
@@ -60,7 +60,7 @@ func CamelToSnakeCase(str string) string {
buffer.WriteRune(unicode.ToLower(character))
} else {
- buffer.WriteString(string(character))
+ buffer.WriteRune(character)
}
}
@@ -135,26 +135,28 @@ func RegexIgnoreAccents(regex string) string {
"C": "C",
"Ç": "C",
}
- s := ""
+
+ var b strings.Builder
+ b.Grow(len(regex) * 2) // Pre-allocate: rough estimate, builder will grow if needed
for _, ch := range regex {
c := string(ch)
if v1, found := m2[c]; found {
if v2, found2 := m1[v1]; found2 {
- s += v2
+ b.WriteString(v2)
continue
}
}
- s += string(ch)
+ b.WriteRune(ch)
}
- return s
+ return b.String()
}
// RemoveChars from a string
func RemoveChars(str string, chars map[string]bool) string {
- s := ""
+ var b strings.Builder
for _, ch := range str {
c := string(ch)
@@ -162,23 +164,33 @@ func RemoveChars(str string, chars map[string]bool) string {
continue
}
- s += string(ch)
+ b.WriteRune(ch)
}
- return s
+ return b.String()
}
// ReplaceUUIDWithPlaceholder replaces UUIDs with a placeholder in a given path string.
func ReplaceUUIDWithPlaceholder(path string) string {
- re := regexp.MustCompile(`[0-9a-fA-F-]{36}`)
-
- return re.ReplaceAllString(path, ":id")
+ return uuidPattern.ReplaceAllString(path, ":id")
}
-// ValidateServerAddress checks if the value matches the pattern : and returns the value if it does.
+// ValidateServerAddress checks if the value is a valid host:port address.
+// It accepts IPv4 ("host:port"), IPv6 ("[::1]:port"), and hostname forms.
+// The port must be numeric and in the range [1, 65535].
+// Returns the original value when valid, or "" when invalid.
func ValidateServerAddress(value string) string {
- matched, _ := regexp.MatchString(`^[^:]+:\d+$`, value)
- if !matched {
+ host, portStr, err := net.SplitHostPort(value)
+ if err != nil {
+ return ""
+ }
+
+ if host == "" {
+ return ""
+ }
+
+ port, err := strconv.Atoi(portStr)
+ if err != nil || port < 1 || port > 65535 {
return ""
}
@@ -191,11 +203,18 @@ func HashSHA256(input string) string {
return hex.EncodeToString(hash[:])
}
-// StringToInt func that convert string to int.
+// StringToInt converts a string to an int, returning 100 on failure.
+//
+// Deprecated: Use StringToIntOrDefault for explicit default values.
func StringToInt(s string) int {
+ return StringToIntOrDefault(s, 100)
+}
+
+// StringToIntOrDefault converts a string to an int, returning defaultVal on parse failure.
+func StringToIntOrDefault(s string, defaultVal int) int {
i, err := strconv.Atoi(s)
if err != nil {
- return 100
+ return defaultVal
}
return i
diff --git a/commons/stringUtils_test.go b/commons/stringUtils_test.go
new file mode 100644
index 00000000..2d202ee2
--- /dev/null
+++ b/commons/stringUtils_test.go
@@ -0,0 +1,191 @@
+//go:build unit
+
+package commons
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestRemoveAccents(t *testing.T) {
+ t.Parallel()
+
+ t.Run("accented", func(t *testing.T) {
+ t.Parallel()
+
+ result, err := RemoveAccents("café résumé")
+ require.NoError(t, err)
+ assert.Equal(t, "cafe resume", result)
+ })
+
+ t.Run("plain_text", func(t *testing.T) {
+ t.Parallel()
+
+ result, err := RemoveAccents("hello world")
+ require.NoError(t, err)
+ assert.Equal(t, "hello world", result)
+ })
+}
+
+func TestRemoveSpaces(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {"spaces", "a b c", "abc"},
+ {"tabs", "a\tb\tc", "abc"},
+ {"mixed", " a \t b \n c ", "abc"},
+ {"empty", "", ""},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, tc.want, RemoveSpaces(tc.input))
+ })
+ }
+}
+
+func TestIsNilOrEmpty(t *testing.T) {
+ t.Parallel()
+
+ s := func(v string) *string { return &v }
+
+ tests := []struct {
+ name string
+ val *string
+ want bool
+ }{
+ {"nil", nil, true},
+ {"empty", s(""), true},
+ {"whitespace", s(" "), true},
+ {"null_string", s("null"), true},
+ {"nil_string", s("nil"), true},
+ {"valid", s("hello"), false},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, tc.want, IsNilOrEmpty(tc.val))
+ })
+ }
+}
+
+func TestCamelToSnakeCase(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {"simple", "CamelCase", "camel_case"},
+ {"lower", "already", "already"},
+ {"multiple_upper", "HTTPServer", "h_t_t_p_server"},
+ {"empty", "", ""},
+ {"single_upper", "A", "a"},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, tc.want, CamelToSnakeCase(tc.input))
+ })
+ }
+}
+
+func TestRegexIgnoreAccents(t *testing.T) {
+ t.Parallel()
+
+ t.Run("accented_input", func(t *testing.T) {
+ t.Parallel()
+
+ result := RegexIgnoreAccents("café")
+ assert.Contains(t, result, "[cç]")
+ assert.Contains(t, result, "[aáàãâ]")
+ assert.Contains(t, result, "[eéèê]")
+ })
+
+ t.Run("plain_input", func(t *testing.T) {
+ t.Parallel()
+
+ result := RegexIgnoreAccents("abc")
+ assert.Contains(t, result, "[aáàãâ]")
+ assert.Contains(t, result, "[cç]")
+ })
+}
+
+func TestRemoveChars(t *testing.T) {
+ t.Parallel()
+
+ chars := map[string]bool{"-": true, ".": true}
+ assert.Equal(t, "abc", RemoveChars("a-b.c", chars))
+}
+
+func TestReplaceUUIDWithPlaceholder(t *testing.T) {
+ t.Parallel()
+
+ path := "/api/v1/550e8400-e29b-41d4-a716-446655440000/items"
+ assert.Equal(t, "/api/v1/:id/items", ReplaceUUIDWithPlaceholder(path))
+}
+
+func TestValidateServerAddress(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ input string
+ want string
+ }{
+ {"valid_hostname", "localhost:8080", "localhost:8080"},
+ {"valid_ip", "192.168.1.1:443", "192.168.1.1:443"},
+ {"valid_ipv6_bracketed", "[::1]:8080", "[::1]:8080"},
+ {"valid_ipv6_full", "[2001:db8::1]:9090", "[2001:db8::1]:9090"},
+ {"valid_port_1", "host:1", "host:1"},
+ {"valid_port_65535", "host:65535", "host:65535"},
+ {"invalid_no_port", "localhost", ""},
+ {"invalid_empty", "", ""},
+ {"invalid_port_0", "host:0", ""},
+ {"invalid_port_65536", "host:65536", ""},
+ {"invalid_port_negative", "host:-1", ""},
+ {"invalid_port_non_numeric", "host:abc", ""},
+ {"invalid_no_host", ":8080", ""},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, tc.want, ValidateServerAddress(tc.input))
+ })
+ }
+}
+
+func TestHashSHA256(t *testing.T) {
+ t.Parallel()
+
+ h1 := HashSHA256("hello")
+ h2 := HashSHA256("hello")
+
+ assert.Equal(t, h1, h2)
+ assert.Len(t, h1, 64) // SHA-256 hex is 64 chars
+}
+
+func TestStringToInt(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, 42, StringToInt("42"))
+ })
+
+ t.Run("invalid_returns_100", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, 100, StringToInt("not_a_number"))
+ })
+}
diff --git a/commons/tenant-manager/cache/config_cache.go b/commons/tenant-manager/cache/config_cache.go
new file mode 100644
index 00000000..894b19cf
--- /dev/null
+++ b/commons/tenant-manager/cache/config_cache.go
@@ -0,0 +1,37 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+// Package cache provides caching interfaces and implementations for tenant
+// configuration data, reducing HTTP roundtrips to the Tenant Manager service.
+package cache
+
+import (
+ "context"
+ "errors"
+ "time"
+)
+
+// ErrCacheMiss is returned when a requested key is not found in the cache
+// or has expired.
+var ErrCacheMiss = errors.New("cache miss")
+
+// ConfigCache is the interface for tenant config caching.
+// Implementations must be safe for concurrent use by multiple goroutines.
+//
+// Available implementations:
+// - InMemoryCache (default): Zero-dependency, process-local cache with TTL
+// - Custom implementations can be provided via client.WithCache()
+type ConfigCache interface {
+ // Get retrieves a cached value by key.
+ // Returns ErrCacheMiss if the key is not found or has expired.
+ Get(ctx context.Context, key string) (string, error)
+
+ // Set stores a value with the given TTL.
+ // A TTL of zero or negative means the entry never expires.
+ Set(ctx context.Context, key string, value string, ttl time.Duration) error
+
+ // Del removes a key from the cache.
+ // Returns nil if the key does not exist.
+ Del(ctx context.Context, key string) error
+}
diff --git a/commons/tenant-manager/cache/memory.go b/commons/tenant-manager/cache/memory.go
new file mode 100644
index 00000000..578b8005
--- /dev/null
+++ b/commons/tenant-manager/cache/memory.go
@@ -0,0 +1,151 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+package cache
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+// cleanupInterval is the interval at which the background goroutine evicts
+// expired entries to prevent unbounded memory growth.
+const cleanupInterval = 5 * time.Minute
+
+// cacheEntry holds a cached value together with its absolute expiration time.
+type cacheEntry struct {
+ value string
+ expiresAt time.Time
+}
+
+// isExpired reports whether the entry has passed its expiration time.
+// An entry with a zero expiresAt never expires.
+func (e cacheEntry) isExpired() bool {
+ if e.expiresAt.IsZero() {
+ return false
+ }
+
+ return time.Now().After(e.expiresAt)
+}
+
+// InMemoryCache is a thread-safe, process-local cache with per-key TTL.
+// It uses lazy expiration on Get (expired entries are deleted on access) and
+// a background goroutine that periodically sweeps all expired entries.
+//
+// Call Close to stop the background cleanup goroutine when the cache is no
+// longer needed. Failing to call Close will leak the goroutine.
+type InMemoryCache struct {
+ mu sync.RWMutex
+ entries map[string]cacheEntry
+ done chan struct{}
+}
+
+// NewInMemoryCache creates a new InMemoryCache and starts a background
+// goroutine that evicts expired entries every 5 minutes.
+func NewInMemoryCache() *InMemoryCache {
+ c := &InMemoryCache{
+ entries: make(map[string]cacheEntry),
+ done: make(chan struct{}),
+ }
+
+ go c.cleanupLoop()
+
+ return c
+}
+
+// Get retrieves a cached value by key.
+// If the key exists but has expired, it is deleted (lazy expiration) and
+// ErrCacheMiss is returned.
+func (c *InMemoryCache) Get(_ context.Context, key string) (string, error) {
+ c.mu.RLock()
+ entry, ok := c.entries[key]
+ c.mu.RUnlock()
+
+ if !ok {
+ return "", ErrCacheMiss
+ }
+
+ if entry.isExpired() {
+ // Lazy eviction: promote to write lock and delete
+ c.mu.Lock()
+ // Re-check under write lock to avoid deleting a fresher entry
+ if current, stillExists := c.entries[key]; stillExists && current.isExpired() {
+ delete(c.entries, key)
+ }
+ c.mu.Unlock()
+
+ return "", ErrCacheMiss
+ }
+
+ return entry.value, nil
+}
+
+// Set stores a value with the given TTL.
+// A TTL of zero or negative means the entry never expires.
+func (c *InMemoryCache) Set(_ context.Context, key string, value string, ttl time.Duration) error {
+ entry := cacheEntry{
+ value: value,
+ }
+
+ if ttl > 0 {
+ entry.expiresAt = time.Now().Add(ttl)
+ }
+
+ c.mu.Lock()
+ c.entries[key] = entry
+ c.mu.Unlock()
+
+ return nil
+}
+
+// Del removes a key from the cache. Returns nil if the key does not exist.
+func (c *InMemoryCache) Del(_ context.Context, key string) error {
+ c.mu.Lock()
+ delete(c.entries, key)
+ c.mu.Unlock()
+
+ return nil
+}
+
+// Close stops the background cleanup goroutine. After Close returns, no more
+// cleanup sweeps will run. Close is safe to call multiple times.
+func (c *InMemoryCache) Close() error {
+ select {
+ case <-c.done:
+ // Already closed
+ default:
+ close(c.done)
+ }
+
+ return nil
+}
+
+// cleanupLoop runs in a background goroutine and periodically evicts expired
+// entries to prevent unbounded memory growth.
+func (c *InMemoryCache) cleanupLoop() {
+ ticker := time.NewTicker(cleanupInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-c.done:
+ return
+ case <-ticker.C:
+ c.evictExpired()
+ }
+ }
+}
+
+// evictExpired removes all expired entries from the cache.
+func (c *InMemoryCache) evictExpired() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ for key, entry := range c.entries {
+ if entry.isExpired() {
+ delete(c.entries, key)
+ }
+ }
+}
diff --git a/commons/tenant-manager/cache/memory_test.go b/commons/tenant-manager/cache/memory_test.go
new file mode 100644
index 00000000..91654849
--- /dev/null
+++ b/commons/tenant-manager/cache/memory_test.go
@@ -0,0 +1,306 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+package cache
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/goleak"
+)
+
+func TestMain(m *testing.M) {
+ goleak.VerifyTestMain(m)
+}
+
+func TestInMemoryCache_Get(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(c *InMemoryCache)
+ key string
+ wantValue string
+ wantErr error
+ wantErrString string
+ }{
+ {
+ name: "returns ErrCacheMiss for non-existent key",
+ setup: func(_ *InMemoryCache) {},
+ key: "missing-key",
+ wantErr: ErrCacheMiss,
+ },
+ {
+ name: "returns cached value for existing key",
+ setup: func(c *InMemoryCache) {
+ require.NoError(t, c.Set(context.Background(), "my-key", "my-value", time.Hour))
+ },
+ key: "my-key",
+ wantValue: "my-value",
+ },
+ {
+ name: "returns ErrCacheMiss for expired key",
+ setup: func(c *InMemoryCache) {
+ require.NoError(t, c.Set(context.Background(), "expired-key", "old-value", time.Millisecond))
+ time.Sleep(5 * time.Millisecond)
+ },
+ key: "expired-key",
+ wantErr: ErrCacheMiss,
+ },
+ {
+ name: "returns value for key with zero TTL (never expires)",
+ setup: func(c *InMemoryCache) {
+ require.NoError(t, c.Set(context.Background(), "forever-key", "forever-value", 0))
+ },
+ key: "forever-key",
+ wantValue: "forever-value",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := NewInMemoryCache()
+ defer func() { require.NoError(t, c.Close()) }()
+
+ tt.setup(c)
+
+ value, err := c.Get(context.Background(), tt.key)
+
+ if tt.wantErr != nil {
+ assert.ErrorIs(t, err, tt.wantErr)
+ assert.Empty(t, value)
+
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.wantValue, value)
+ })
+ }
+}
+
+func TestInMemoryCache_Set(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ value string
+ ttl time.Duration
+ }{
+ {
+ name: "stores value with positive TTL",
+ key: "key-1",
+ value: "value-1",
+ ttl: time.Hour,
+ },
+ {
+ name: "stores value with zero TTL (never expires)",
+ key: "key-2",
+ value: "value-2",
+ ttl: 0,
+ },
+ {
+ name: "stores value with negative TTL (never expires)",
+ key: "key-3",
+ value: "value-3",
+ ttl: -1 * time.Second,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := NewInMemoryCache()
+ defer func() { require.NoError(t, c.Close()) }()
+
+ err := c.Set(context.Background(), tt.key, tt.value, tt.ttl)
+ require.NoError(t, err)
+
+ got, getErr := c.Get(context.Background(), tt.key)
+ require.NoError(t, getErr)
+ assert.Equal(t, tt.value, got)
+ })
+ }
+}
+
+func TestInMemoryCache_Set_Overwrites(t *testing.T) {
+ c := NewInMemoryCache()
+ defer func() { require.NoError(t, c.Close()) }()
+
+ ctx := context.Background()
+
+ require.NoError(t, c.Set(ctx, "key", "original", time.Hour))
+ require.NoError(t, c.Set(ctx, "key", "updated", time.Hour))
+
+ got, err := c.Get(ctx, "key")
+ require.NoError(t, err)
+ assert.Equal(t, "updated", got)
+}
+
+func TestInMemoryCache_Del(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(c *InMemoryCache)
+ key string
+ }{
+ {
+ name: "deletes existing key",
+ setup: func(c *InMemoryCache) {
+ require.NoError(t, c.Set(context.Background(), "del-key", "value", time.Hour))
+ },
+ key: "del-key",
+ },
+ {
+ name: "returns nil for non-existent key",
+ setup: func(_ *InMemoryCache) {},
+ key: "no-such-key",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c := NewInMemoryCache()
+ defer func() { require.NoError(t, c.Close()) }()
+
+ tt.setup(c)
+
+ err := c.Del(context.Background(), tt.key)
+ require.NoError(t, err)
+
+ // Verify key is gone
+ _, getErr := c.Get(context.Background(), tt.key)
+ assert.ErrorIs(t, getErr, ErrCacheMiss)
+ })
+ }
+}
+
+func TestInMemoryCache_TTLExpiration(t *testing.T) {
+ c := NewInMemoryCache()
+ defer func() { require.NoError(t, c.Close()) }()
+
+ ctx := context.Background()
+
+ // Set with a very short TTL
+ require.NoError(t, c.Set(ctx, "short-lived", "value", 10*time.Millisecond))
+
+ // Should be available immediately
+ got, err := c.Get(ctx, "short-lived")
+ require.NoError(t, err)
+ assert.Equal(t, "value", got)
+
+ // Wait for TTL to expire
+ time.Sleep(20 * time.Millisecond)
+
+ // Should now be expired (lazy eviction)
+ _, err = c.Get(ctx, "short-lived")
+ assert.ErrorIs(t, err, ErrCacheMiss)
+}
+
+func TestInMemoryCache_ConcurrentAccess(t *testing.T) {
+ c := NewInMemoryCache()
+ defer func() { require.NoError(t, c.Close()) }()
+
+ ctx := context.Background()
+ const goroutines = 50
+ const iterations = 100
+
+ var wg sync.WaitGroup
+
+ wg.Add(goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(id int) {
+ defer wg.Done()
+
+ for j := 0; j < iterations; j++ {
+ key := "key"
+ value := "value"
+
+ // Mix of Set, Get, Del operations
+ switch j % 3 {
+ case 0:
+ _ = c.Set(ctx, key, value, time.Hour)
+ case 1:
+ _, _ = c.Get(ctx, key)
+ case 2:
+ _ = c.Del(ctx, key)
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+func TestInMemoryCache_Close(t *testing.T) {
+ t.Run("stops cleanup goroutine", func(t *testing.T) {
+ c := NewInMemoryCache()
+
+ err := c.Close()
+ require.NoError(t, err)
+ })
+
+ t.Run("double close is safe", func(t *testing.T) {
+ c := NewInMemoryCache()
+
+ require.NoError(t, c.Close())
+ require.NoError(t, c.Close())
+ })
+}
+
+func TestInMemoryCache_EvictExpired(t *testing.T) {
+ c := NewInMemoryCache()
+ defer func() { require.NoError(t, c.Close()) }()
+
+ ctx := context.Background()
+
+ // Add entries: one expired, one still valid
+ require.NoError(t, c.Set(ctx, "expired", "value", time.Millisecond))
+ require.NoError(t, c.Set(ctx, "valid", "value", time.Hour))
+
+ time.Sleep(5 * time.Millisecond)
+
+ // Manually trigger eviction
+ c.evictExpired()
+
+ // Expired entry should be gone
+ c.mu.RLock()
+ _, expiredExists := c.entries["expired"]
+ _, validExists := c.entries["valid"]
+ c.mu.RUnlock()
+
+ assert.False(t, expiredExists, "expired entry should have been evicted")
+ assert.True(t, validExists, "valid entry should still exist")
+}
+
+func TestCacheEntry_IsExpired(t *testing.T) {
+ tests := []struct {
+ name string
+ entry cacheEntry
+ wantExpd bool
+ }{
+ {
+ name: "zero expiresAt never expires",
+ entry: cacheEntry{value: "v", expiresAt: time.Time{}},
+ wantExpd: false,
+ },
+ {
+ name: "future expiresAt is not expired",
+ entry: cacheEntry{value: "v", expiresAt: time.Now().Add(time.Hour)},
+ wantExpd: false,
+ },
+ {
+ name: "past expiresAt is expired",
+ entry: cacheEntry{value: "v", expiresAt: time.Now().Add(-time.Second)},
+ wantExpd: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.wantExpd, tt.entry.isExpired())
+ })
+ }
+}
diff --git a/commons/tenant-manager/client/client.go b/commons/tenant-manager/client/client.go
new file mode 100644
index 00000000..02093104
--- /dev/null
+++ b/commons/tenant-manager/client/client.go
@@ -0,0 +1,669 @@
+// Package client provides an HTTP client for interacting with the Tenant Manager service.
+// It handles tenant-specific database connection retrieval for multi-tenant architectures.
+package client
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "sync"
+ "time"
+ "unicode/utf8"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/cache"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "go.opentelemetry.io/otel/trace"
+)
+
+// maxResponseBodySize is the maximum allowed response body size (10 MB).
+// This prevents unbounded memory allocation from malicious or malformed responses.
+const maxResponseBodySize = 10 * 1024 * 1024
+
+// defaultCacheTTL is the default time-to-live for cached tenant config entries.
+const defaultCacheTTL = 1 * time.Hour
+
+// cacheKeyPrefix matches the tenant-manager key format for debugging clarity.
+const cacheKeyPrefix = "tenant-connections"
+
+// cbState represents the circuit breaker state.
+type cbState int
+
+const (
+ // cbClosed is the normal operating state. All requests are allowed through.
+ cbClosed cbState = iota
+ // cbOpen means the circuit breaker has tripped. Requests fail fast with ErrCircuitBreakerOpen.
+ cbOpen
+ // cbHalfOpen allows a single test request through to probe whether the service has recovered.
+ cbHalfOpen
+)
+
+// TenantSummary represents a minimal tenant information for listing.
+type TenantSummary struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Status string `json:"status"`
+}
+
+// Client is an HTTP client for the Tenant Manager service.
+// It fetches tenant-specific database configurations from the Tenant Manager API.
+// An optional circuit breaker can be enabled via WithCircuitBreaker to fail fast
+// when the Tenant Manager service is unresponsive.
+type Client struct {
+ baseURL string
+ httpClient *http.Client
+ logger libLog.Logger
+ serviceAPIKey string
+ cache cache.ConfigCache
+ cacheTTL time.Duration
+
+ // allowInsecureHTTP permits http:// URLs when set to true.
+ // By default, only https:// URLs are accepted unless explicitly opted in
+ // via WithAllowInsecureHTTP().
+ allowInsecureHTTP bool
+
+ // Circuit breaker fields. When cbThreshold is 0, the circuit breaker is disabled (default).
+ cbMu sync.Mutex
+ cbFailures int
+ cbLastFailure time.Time
+ cbState cbState
+ cbThreshold int // consecutive failures before opening (0 = disabled)
+ cbTimeout time.Duration // how long to stay open before transitioning to half-open
+}
+
+// getConfigOpts holds options for a single GetTenantConfig call.
+type getConfigOpts struct {
+ skipCache bool
+}
+
+// GetConfigOption is a functional option for individual GetTenantConfig calls.
+type GetConfigOption func(*getConfigOpts)
+
+// WithSkipCache forces GetTenantConfig to bypass the cache and fetch directly.
+func WithSkipCache() GetConfigOption {
+ return func(o *getConfigOpts) {
+ o.skipCache = true
+ }
+}
+
+// ClientOption is a functional option for configuring the Client.
+type ClientOption func(*Client)
+
+// WithHTTPClient sets a custom HTTP client for the Client.
+// If client is nil, the option is a no-op (the default HTTP client is preserved).
+func WithHTTPClient(client *http.Client) ClientOption {
+ return func(c *Client) {
+ if client != nil {
+ c.httpClient = client
+ }
+ }
+}
+
+// WithTimeout sets the HTTP client timeout.
+// If the HTTP client has not been initialized yet, a new default client is created.
+func WithTimeout(timeout time.Duration) ClientOption {
+ return func(c *Client) {
+ if c.httpClient == nil {
+ c.httpClient = &http.Client{}
+ }
+
+ c.httpClient.Timeout = timeout
+ }
+}
+
+// WithCircuitBreaker enables the circuit breaker on the Client.
+// After threshold consecutive service failures (network errors or HTTP 5xx),
+// the circuit breaker opens and subsequent requests fail fast with ErrCircuitBreakerOpen.
+// After timeout elapses, one probe request is allowed through (half-open state).
+// If the probe succeeds, the circuit breaker closes; if it fails, it reopens.
+//
+// A threshold of 0 disables the circuit breaker (default behavior).
+// HTTP 4xx responses (400, 403, 404) are NOT counted as failures because they
+// represent valid responses from the Tenant Manager, not service unavailability.
+func WithCircuitBreaker(threshold int, timeout time.Duration) ClientOption {
+ return func(c *Client) {
+ c.cbThreshold = threshold
+ c.cbTimeout = timeout
+ }
+}
+
+// WithCache sets a custom cache implementation for tenant config responses.
+// Returns an error during NewClient if the cache is a typed-nil interface
+// (e.g., (*InMemoryCache)(nil)), since that would cause nil-pointer panics.
+func WithCache(cc cache.ConfigCache) ClientOption {
+ return func(c *Client) {
+ if cc != nil {
+ c.cache = cc
+ }
+ }
+}
+
+// withCacheValidated is the internal validation that runs during NewClient
+// after all options are applied. It detects typed-nil caches.
+func withCacheValidated(c *Client) error {
+ if c.cache != nil && core.IsNilInterface(c.cache) {
+ return fmt.Errorf("client.NewClient: %w", core.ErrNilCache)
+ }
+
+ return nil
+}
+
+// WithCacheTTL sets the TTL for cached tenant config entries.
+func WithCacheTTL(ttl time.Duration) ClientOption {
+ return func(c *Client) {
+ c.cacheTTL = ttl
+ }
+}
+
+// WithAllowInsecureHTTP permits the use of http:// (plaintext) URLs for the
+// Tenant Manager base URL. By default, only https:// is accepted. Use this
+// option only for local development or testing environments.
+func WithAllowInsecureHTTP() ClientOption {
+ return func(c *Client) {
+ c.allowInsecureHTTP = true
+ }
+}
+
+// WithServiceAPIKey sets the API key sent as X-API-Key header on all HTTP
+// requests to the Tenant Manager. The key MUST be non-empty; NewClient returns
+// an error if no key is provided or the key is empty. Typically sourced from
+// the MULTI_TENANT_SERVICE_API_KEY environment variable.
+func WithServiceAPIKey(key string) ClientOption {
+ return func(c *Client) {
+ c.serviceAPIKey = key
+ }
+}
+
+// NewClient creates a new Tenant Manager client.
+// Parameters:
+// - baseURL: The base URL of the Tenant Manager service (e.g., "https://tenant-manager:8080")
+// - logger: Logger for request/response logging
+// - opts: Optional configuration options
+//
+// The baseURL is validated at construction time to ensure it is a well-formed URL with a scheme.
+// This prevents SSRF risks by ensuring only trusted, pre-configured URLs are used for HTTP requests.
+// By default, only https:// URLs are accepted. Use WithAllowInsecureHTTP() to permit http://.
+func NewClient(baseURL string, logger libLog.Logger, opts ...ClientOption) (*Client, error) {
+ if logger == nil {
+ logger = libLog.NewNop()
+ }
+
+ // Validate baseURL to ensure it is a well-formed URL with a scheme.
+ // This is a defense-in-depth measure: the baseURL is configured at deployment time
+ // (not user-controlled), but we validate it to fail fast on misconfiguration.
+ parsedURL, err := url.Parse(baseURL)
+ if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
+ logger.Log(context.Background(), libLog.LevelError, "invalid tenant manager baseURL",
+ libLog.String("base_url", baseURL),
+ )
+
+ return nil, fmt.Errorf("invalid tenant manager baseURL: %q", baseURL)
+ }
+
+ c := &Client{
+ baseURL: baseURL,
+ httpClient: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ logger: logger,
+ cacheTTL: defaultCacheTTL,
+ }
+
+ for _, opt := range opts {
+ opt(c)
+ }
+
+ // Enforce HTTPS by default. Allow http:// only with explicit opt-in.
+ if parsedURL.Scheme == "http" && !c.allowInsecureHTTP {
+ return nil, fmt.Errorf("client.NewClient: %w: got %q", core.ErrInsecureHTTP, baseURL)
+ }
+
+ // Validate that a non-empty service API key was provided.
+ if c.serviceAPIKey == "" {
+ return nil, fmt.Errorf("client.NewClient: %w", core.ErrServiceAPIKeyRequired)
+ }
+
+ // Validate that the cache is not a typed-nil interface.
+ if err := withCacheValidated(c); err != nil {
+ return nil, err
+ }
+
+ if c.cache == nil {
+ c.cache = cache.NewInMemoryCache()
+ }
+
+ return c, nil
+}
+
+// checkCircuitBreaker checks if the circuit breaker allows a request to proceed.
+// Returns ErrCircuitBreakerOpen if the circuit breaker is open and the timeout has not elapsed.
+// Transitions from open to half-open when the timeout expires.
+// When the circuit breaker is disabled (cbThreshold == 0), this is a no-op.
+func (c *Client) checkCircuitBreaker() error {
+ if c.cbThreshold <= 0 {
+ return nil
+ }
+
+ c.cbMu.Lock()
+ defer c.cbMu.Unlock()
+
+ switch c.cbState {
+ case cbOpen:
+ if time.Since(c.cbLastFailure) > c.cbTimeout {
+ c.cbState = cbHalfOpen
+ return nil
+ }
+
+ return core.ErrCircuitBreakerOpen
+ default:
+ return nil
+ }
+}
+
+// recordSuccess resets the circuit breaker to the closed state with zero failures.
+// Called after a successful response from the Tenant Manager.
+func (c *Client) recordSuccess() {
+ if c.cbThreshold <= 0 {
+ return
+ }
+
+ c.cbMu.Lock()
+ defer c.cbMu.Unlock()
+
+ c.cbFailures = 0
+ c.cbState = cbClosed
+}
+
+// recordFailure increments the failure counter and opens the circuit breaker
+// when the threshold is reached. Only service-level failures (network errors,
+// HTTP 5xx) should trigger this - not client errors (4xx).
+func (c *Client) recordFailure() {
+ if c.cbThreshold <= 0 {
+ return
+ }
+
+ c.cbMu.Lock()
+ defer c.cbMu.Unlock()
+
+ c.cbFailures++
+ c.cbLastFailure = time.Now()
+
+ if c.cbFailures >= c.cbThreshold {
+ c.cbState = cbOpen
+ }
+}
+
+// isServerError returns true if the HTTP status code indicates a server-side failure
+// that should count toward the circuit breaker threshold.
+// Only 5xx status codes are considered failures. 4xx responses (400, 403, 404)
+// are valid responses from the Tenant Manager and do NOT indicate service unavailability.
+func isServerError(statusCode int) bool {
+ return statusCode >= http.StatusInternalServerError
+}
+
+// truncateBody returns the body as a string, truncated to maxLen bytes with a
+// "...(truncated)" suffix if the body exceeds maxLen. This prevents large
+// response bodies from being logged or included in error messages.
+// The truncation point is adjusted to the last valid UTF-8 rune boundary
+// to avoid splitting multi-byte characters.
+func truncateBody(body []byte, maxLen int) string {
+ if len(body) <= maxLen {
+ return string(body)
+ }
+
+ // Find the last valid rune boundary at or before maxLen to avoid
+ // splitting multi-byte UTF-8 sequences.
+ truncated := body[:maxLen]
+ for len(truncated) > 0 && !utf8.Valid(truncated) {
+ truncated = truncated[:len(truncated)-1]
+ }
+
+ return string(truncated) + "...(truncated)"
+}
+
+func (c *Client) getCachedTenantConfig(ctx context.Context, cacheKey, tenantID, service string) (*core.TenantConfig, bool) {
+ if c.cache == nil {
+ return nil, false
+ }
+
+ cached, err := c.cache.Get(ctx, cacheKey)
+ if err != nil {
+ return nil, false
+ }
+
+ var config core.TenantConfig
+ if jsonErr := json.Unmarshal([]byte(cached), &config); jsonErr == nil {
+ c.logger.Log(ctx, libLog.LevelDebug, "tenant config cache hit",
+ libLog.String("tenant_id", tenantID),
+ libLog.String("service", service),
+ )
+
+ return &config, true
+ }
+
+ // Malformed cache entry: evict before refetching to prevent repeated
+ // deserialization failures on the same corrupt data.
+ c.logger.Log(ctx, libLog.LevelWarn, "invalid tenant config cache entry; evicting before refetch",
+ libLog.String("tenant_id", tenantID),
+ libLog.String("service", service),
+ )
+
+ _ = c.cache.Del(ctx, cacheKey)
+
+ return nil, false
+}
+
+func (c *Client) handleGetTenantConfigStatus(
+ ctx context.Context,
+ span trace.Span,
+ tenantID, service string,
+ statusCode int,
+ body []byte,
+) error {
+ switch statusCode {
+ case http.StatusOK:
+ return nil
+ case http.StatusNotFound:
+ c.recordSuccess()
+ c.logger.Log(ctx, libLog.LevelWarn, "tenant not found",
+ libLog.String("tenant_id", tenantID),
+ libLog.String("service", service),
+ )
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "Tenant not found", core.ErrTenantNotFound)
+
+ return core.ErrTenantNotFound
+ case http.StatusForbidden:
+ c.recordSuccess()
+ c.logger.Log(ctx, libLog.LevelWarn, "tenant service access denied",
+ libLog.String("tenant_id", tenantID),
+ libLog.String("service", service),
+ )
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "Tenant service suspended or purged", core.ErrTenantServiceAccessDenied)
+
+ // All 403 responses wrap ErrTenantServiceAccessDenied so callers can
+ // use errors.Is(err, core.ErrTenantServiceAccessDenied) reliably.
+ // When the JSON body includes a status field, we enrich the error
+ // with a TenantSuspendedError for more specific handling.
+ var errResp struct {
+ Code string `json:"code"`
+ Error string `json:"error"`
+ Status string `json:"status"`
+ }
+
+ if jsonErr := json.Unmarshal(body, &errResp); jsonErr == nil && errResp.Status != "" {
+ return fmt.Errorf("%w: %w", core.ErrTenantServiceAccessDenied, &core.TenantSuspendedError{
+ TenantID: tenantID,
+ Status: errResp.Status,
+ Message: errResp.Error,
+ })
+ }
+
+ // Non-JSON or missing status: still wrap ErrTenantServiceAccessDenied
+ return fmt.Errorf("tenant %s: %w", tenantID, core.ErrTenantServiceAccessDenied)
+ default:
+ if isServerError(statusCode) {
+ c.recordFailure()
+ }
+
+ c.logger.Log(ctx, libLog.LevelError, "tenant manager returned error",
+ libLog.Int("status", statusCode),
+ libLog.String("body", truncateBody(body, 512)),
+ )
+ libOpentelemetry.HandleSpanError(span, "Tenant Manager returned error", fmt.Errorf("status %d", statusCode))
+
+ return fmt.Errorf("tenant manager returned status %d for tenant %s", statusCode, tenantID)
+ }
+}
+
+func (c *Client) cacheTenantConfig(ctx context.Context, cacheKey string, config *core.TenantConfig) {
+ if c.cache == nil {
+ return
+ }
+
+ if configJSON, marshalErr := json.Marshal(config); marshalErr == nil {
+ _ = c.cache.Set(ctx, cacheKey, string(configJSON), c.cacheTTL)
+ }
+}
+
+// GetTenantConfig fetches tenant configuration from the Tenant Manager API.
+// The API endpoint is: GET {baseURL}/v1/tenants/{tenantID}/associations/{service}/connections.
+// Successful responses are cached unless WithSkipCache is used.
+func (c *Client) GetTenantConfig(ctx context.Context, tenantID, service string, opts ...GetConfigOption) (*core.TenantConfig, error) {
+ if c.httpClient == nil {
+ c.httpClient = &http.Client{Timeout: 30 * time.Second}
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "tenantmanager.client.get_tenant_config")
+ defer span.End()
+
+ callOpts := &getConfigOpts{}
+ for _, opt := range opts {
+ opt(callOpts)
+ }
+
+ cacheKey := fmt.Sprintf("%s:%s:%s", cacheKeyPrefix, tenantID, service)
+ if !callOpts.skipCache {
+ if cachedConfig, ok := c.getCachedTenantConfig(ctx, cacheKey, tenantID, service); ok {
+ return cachedConfig, nil
+ }
+ }
+
+ // Check circuit breaker before making the HTTP request
+ if err := c.checkCircuitBreaker(); err != nil {
+ logger.Log(ctx, libLog.LevelWarn, "circuit breaker open, failing fast",
+ libLog.String("tenant_id", tenantID),
+ libLog.String("service", service),
+ )
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "Circuit breaker open", err)
+
+ return nil, err
+ }
+
+ // Build the URL with properly escaped path parameters to prevent path traversal
+ requestURL := fmt.Sprintf("%s/v1/tenants/%s/associations/%s/connections",
+ c.baseURL, url.PathEscape(tenantID), url.PathEscape(service))
+
+ logger.Log(ctx, libLog.LevelInfo, "fetching tenant config",
+ libLog.String("tenant_id", tenantID),
+ libLog.String("service", service),
+ )
+
+ // Create request with context
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
+ if err != nil {
+ logger.Log(ctx, libLog.LevelError, "failed to create request", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "Failed to create HTTP request", err)
+
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ if c.serviceAPIKey != "" {
+ req.Header.Set("X-API-Key", c.serviceAPIKey)
+ }
+
+ // Inject trace context into outgoing HTTP headers for distributed tracing
+ libOpentelemetry.InjectHTTPContext(ctx, req.Header)
+
+ // Execute request
+ // #nosec G704 -- baseURL is validated at construction time and not user-controlled
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ c.recordFailure()
+ logger.Log(ctx, libLog.LevelError, "failed to execute request", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "HTTP request failed", err)
+
+ return nil, fmt.Errorf("failed to execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Read response body with size limit to prevent unbounded memory allocation
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
+ if err != nil {
+ c.recordFailure()
+ logger.Log(ctx, libLog.LevelError, "failed to read response body", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "Failed to read response body", err)
+
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ // Check response status
+ // 404 and 403 are valid business responses - do NOT count as circuit breaker failures
+ if err := c.handleGetTenantConfigStatus(ctx, span, tenantID, service, resp.StatusCode, body); err != nil {
+ return nil, err
+ }
+
+ // Parse response
+ var config core.TenantConfig
+ if err := json.Unmarshal(body, &config); err != nil {
+ logger.Log(ctx, libLog.LevelError, "failed to parse response", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "Failed to parse response", err)
+
+ return nil, fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ c.recordSuccess()
+ logger.Log(ctx, libLog.LevelInfo, "successfully fetched tenant config",
+ libLog.String("tenant_id", tenantID),
+ libLog.String("slug", config.TenantSlug),
+ )
+
+ c.cacheTenantConfig(ctx, cacheKey, &config)
+
+ return &config, nil
+}
+
+// InvalidateConfig removes the cached tenant config for the given tenant and service.
+func (c *Client) InvalidateConfig(ctx context.Context, tenantID, service string) error {
+ if c.cache == nil {
+ return nil
+ }
+
+ cacheKey := fmt.Sprintf("%s:%s:%s", cacheKeyPrefix, tenantID, service)
+
+ return c.cache.Del(ctx, cacheKey)
+}
+
+// Close releases any resources held by the cache implementation.
+func (c *Client) Close() error {
+ type closer interface {
+ Close() error
+ }
+
+ if cc, ok := c.cache.(closer); ok {
+ return cc.Close()
+ }
+
+ return nil
+}
+
+// GetActiveTenantsByService fetches active tenants for a service from Tenant Manager.
+// This is used as a fallback when Redis cache is unavailable.
+// The API endpoint is: GET {baseURL}/v1/tenants/active?service={service}
+func (c *Client) GetActiveTenantsByService(ctx context.Context, service string) ([]*TenantSummary, error) {
+ if c.httpClient == nil {
+ c.httpClient = &http.Client{Timeout: 30 * time.Second}
+ }
+
+ logger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+
+ ctx, span := tracer.Start(ctx, "tenantmanager.client.get_active_tenants")
+ defer span.End()
+
+ // Check circuit breaker before making the HTTP request
+ if err := c.checkCircuitBreaker(); err != nil {
+ logger.Log(ctx, libLog.LevelWarn, "circuit breaker open, failing fast", libLog.String("service", service))
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "Circuit breaker open", err)
+
+ return nil, err
+ }
+
+ // Build the URL with properly escaped query parameter to prevent injection
+
+ requestURL := fmt.Sprintf("%s/v1/tenants/active?service=%s", c.baseURL, url.QueryEscape(service))
+
+ logger.Log(ctx, libLog.LevelInfo, "fetching active tenants", libLog.String("service", service))
+
+ // Create request with context
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
+ if err != nil {
+ logger.Log(ctx, libLog.LevelError, "failed to create request", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "Failed to create HTTP request", err)
+
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ if c.serviceAPIKey != "" {
+ req.Header.Set("X-API-Key", c.serviceAPIKey)
+ }
+
+ // Inject trace context into outgoing HTTP headers for distributed tracing
+ libOpentelemetry.InjectHTTPContext(ctx, req.Header)
+
+ // Execute request
+ // #nosec G704 -- baseURL is validated at construction time and not user-controlled
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ c.recordFailure()
+ logger.Log(ctx, libLog.LevelError, "failed to execute request", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "HTTP request failed", err)
+
+ return nil, fmt.Errorf("failed to execute request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Read response body with size limit to prevent unbounded memory allocation
+ body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
+ if err != nil {
+ c.recordFailure()
+ logger.Log(ctx, libLog.LevelError, "failed to read response body", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "Failed to read response body", err)
+
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ // Check response status
+ if resp.StatusCode != http.StatusOK {
+ // Only record failure for server errors (5xx), not client errors (4xx)
+ if isServerError(resp.StatusCode) {
+ c.recordFailure()
+ }
+
+ logger.Log(ctx, libLog.LevelError, "tenant manager returned error",
+ libLog.Int("status", resp.StatusCode),
+ libLog.String("body", truncateBody(body, 512)),
+ )
+ libOpentelemetry.HandleSpanError(span, "Tenant Manager returned error", fmt.Errorf("status %d", resp.StatusCode))
+
+ return nil, fmt.Errorf("tenant manager returned status %d for service %s", resp.StatusCode, service)
+ }
+
+ // Parse response
+ var tenants []*TenantSummary
+ if err := json.Unmarshal(body, &tenants); err != nil {
+ logger.Log(ctx, libLog.LevelError, "failed to parse response", libLog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "Failed to parse response", err)
+
+ return nil, fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ c.recordSuccess()
+ logger.Log(ctx, libLog.LevelInfo, "successfully fetched active tenants",
+ libLog.Int("count", len(tenants)),
+ libLog.String("service", service),
+ )
+
+ return tenants, nil
+}
diff --git a/commons/tenant-manager/client/client_test.go b/commons/tenant-manager/client/client_test.go
new file mode 100644
index 00000000..249a8c09
--- /dev/null
+++ b/commons/tenant-manager/client/client_test.go
@@ -0,0 +1,916 @@
+package client
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// newMinimalTenantConfig returns a TenantConfig with only essential fields set.
+// Used by circuit breaker tests that do not inspect database configuration.
+func newMinimalTenantConfig() core.TenantConfig {
+ return core.TenantConfig{
+ ID: "tenant-123",
+ TenantSlug: "test-tenant",
+ Service: "ledger",
+ Status: "active",
+ }
+}
+
+// newTestTenantConfig returns a fully populated TenantConfig for test assertions.
+// Callers can override fields after construction for specific test scenarios.
+func newTestTenantConfig() core.TenantConfig {
+ return core.TenantConfig{
+ ID: "tenant-123",
+ TenantSlug: "test-tenant",
+ TenantName: "Test Tenant",
+ Service: "ledger",
+ Status: "active",
+ IsolationMode: "database",
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Database: "test_db",
+ Username: "user",
+ Password: "pass",
+ SSLMode: "disable",
+ },
+ },
+ },
+ }
+}
+
+func mustNewClient(t *testing.T, baseURL string, opts ...ClientOption) *Client {
+ t.Helper()
+
+ // Tests use httptest servers which are http://, so allow insecure by default.
+ // A default test API key is provided so tests that don't care about API key auth still pass.
+ allOpts := append([]ClientOption{WithAllowInsecureHTTP(), WithServiceAPIKey("test-api-key")}, opts...)
+
+ c, err := NewClient(baseURL, testutil.NewMockLogger(), allOpts...)
+ require.NoError(t, err)
+
+ return c
+}
+
+func TestNewClient(t *testing.T) {
+ t.Run("creates client with defaults", func(t *testing.T) {
+ client := mustNewClient(t, "http://localhost:8080")
+
+ assert.NotNil(t, client)
+ assert.Equal(t, "http://localhost:8080", client.baseURL)
+ assert.Equal(t, 30*time.Second, client.httpClient.Timeout)
+ })
+
+ t.Run("creates client with custom timeout", func(t *testing.T) {
+ client := mustNewClient(t, "http://localhost:8080", WithTimeout(60*time.Second))
+
+ assert.Equal(t, 60*time.Second, client.httpClient.Timeout)
+ })
+
+ t.Run("creates client with custom http client", func(t *testing.T) {
+ customClient := &http.Client{Timeout: 10 * time.Second}
+ client := mustNewClient(t, "http://localhost:8080", WithHTTPClient(customClient))
+
+ assert.Equal(t, customClient, client.httpClient)
+ })
+
+ t.Run("WithHTTPClient_nil_preserves_default", func(t *testing.T) {
+ client := mustNewClient(t, "http://localhost:8080", WithHTTPClient(nil))
+
+ assert.NotNil(t, client.httpClient, "nil HTTPClient should be ignored, default preserved")
+ assert.Equal(t, 30*time.Second, client.httpClient.Timeout)
+ })
+
+ t.Run("WithTimeout_after_nil_HTTPClient_does_not_panic", func(t *testing.T) {
+ assert.NotPanics(t, func() {
+ _, _ = NewClient("http://localhost:8080", testutil.NewMockLogger(),
+ WithAllowInsecureHTTP(),
+ WithServiceAPIKey("test-key"),
+ WithHTTPClient(nil),
+ WithTimeout(45*time.Second),
+ )
+ })
+ })
+
+ t.Run("rejects http URL without WithAllowInsecureHTTP", func(t *testing.T) {
+ _, err := NewClient("http://localhost:8080", testutil.NewMockLogger(),
+ WithServiceAPIKey("test-key"),
+ )
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrInsecureHTTP)
+ })
+
+ t.Run("accepts https URL by default", func(t *testing.T) {
+ c, err := NewClient("https://localhost:8080", testutil.NewMockLogger(),
+ WithServiceAPIKey("test-key"),
+ )
+ require.NoError(t, err)
+ assert.NotNil(t, c)
+ })
+}
+
+func TestNewClient_ValidationErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ baseURL string
+ expectErr bool
+ }{
+ {
+ name: "empty baseURL returns error",
+ baseURL: "",
+ expectErr: true,
+ },
+ {
+ name: "URL without scheme returns error",
+ baseURL: "localhost:8080",
+ expectErr: true,
+ },
+ {
+ name: "URL without host returns error",
+ baseURL: "http://",
+ expectErr: true,
+ },
+ {
+ name: "invalid URL syntax returns error",
+ baseURL: "://bad-url",
+ expectErr: true,
+ },
+ {
+ name: "http URL without opt-in returns error",
+ baseURL: "http://localhost:8080",
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Pass nil logger to also verify the nil-logger defaulting path
+ client, err := NewClient(tt.baseURL, nil)
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Nil(t, client)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, client)
+ }
+ })
+ }
+}
+
+func TestClient_GetTenantConfig(t *testing.T) {
+ t.Run("successful response", func(t *testing.T) {
+ config := newTestTenantConfig()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "/v1/tenants/tenant-123/associations/ledger/connections", r.URL.Path)
+
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(config))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-123", result.ID)
+ assert.Equal(t, "test-tenant", result.TenantSlug)
+ pgConfig := result.GetPostgreSQLConfig("ledger", "onboarding")
+ assert.NotNil(t, pgConfig)
+ assert.Equal(t, "localhost", pgConfig.Host)
+ })
+
+ t.Run("tenant not found", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "non-existent", "ledger")
+
+ assert.Nil(t, result)
+ assert.ErrorIs(t, err, core.ErrTenantNotFound)
+ })
+
+ t.Run("server error", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte("internal error"))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ assert.Nil(t, result)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "500")
+ })
+
+ t.Run("tenant service suspended returns TenantSuspendedError wrapping ErrTenantServiceAccessDenied", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusForbidden)
+ require.NoError(t, json.NewEncoder(w).Encode(map[string]string{
+ "code": "TS-SUSPENDED",
+ "error": "service ledger is suspended for this tenant",
+ "status": "suspended",
+ }))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ assert.Nil(t, result)
+ require.Error(t, err)
+ // All 403s should be detectable via ErrTenantServiceAccessDenied
+ assert.ErrorIs(t, err, core.ErrTenantServiceAccessDenied)
+ // Enriched 403s also carry TenantSuspendedError
+ assert.True(t, core.IsTenantSuspendedError(err))
+
+ var suspErr *core.TenantSuspendedError
+ require.ErrorAs(t, err, &suspErr)
+ assert.Equal(t, "tenant-123", suspErr.TenantID)
+ assert.Equal(t, "suspended", suspErr.Status)
+ assert.Equal(t, "service ledger is suspended for this tenant", suspErr.Message)
+ })
+
+ t.Run("tenant service purged returns TenantSuspendedError wrapping ErrTenantServiceAccessDenied", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusForbidden)
+ require.NoError(t, json.NewEncoder(w).Encode(map[string]string{
+ "code": "TS-SUSPENDED",
+ "error": "service ledger is purged for this tenant",
+ "status": "purged",
+ }))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ assert.Nil(t, result)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrTenantServiceAccessDenied)
+
+ var suspErr *core.TenantSuspendedError
+ require.ErrorAs(t, err, &suspErr)
+ assert.Equal(t, "purged", suspErr.Status)
+ })
+
+ t.Run("403 with unparseable body wraps ErrTenantServiceAccessDenied", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte("not json"))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ assert.Nil(t, result)
+ require.Error(t, err)
+ // Non-parseable 403 should still wrap ErrTenantServiceAccessDenied
+ assert.ErrorIs(t, err, core.ErrTenantServiceAccessDenied)
+ // But should NOT be a TenantSuspendedError (no status info)
+ assert.False(t, core.IsTenantSuspendedError(err))
+ })
+
+ t.Run("403 with empty status wraps ErrTenantServiceAccessDenied", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusForbidden)
+ require.NoError(t, json.NewEncoder(w).Encode(map[string]string{
+ "code": "SOME-OTHER",
+ "error": "something else",
+ }))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ assert.Nil(t, result)
+ require.Error(t, err)
+ // All 403s wrap ErrTenantServiceAccessDenied
+ assert.ErrorIs(t, err, core.ErrTenantServiceAccessDenied)
+ // Empty status = no TenantSuspendedError
+ assert.False(t, core.IsTenantSuspendedError(err))
+ })
+}
+
+func TestNewClient_WithCircuitBreaker(t *testing.T) {
+ t.Run("creates client with circuit breaker option", func(t *testing.T) {
+ client := mustNewClient(t, "http://localhost:8080",
+ WithCircuitBreaker(5, 30*time.Second),
+ )
+
+ assert.Equal(t, 5, client.cbThreshold)
+ assert.Equal(t, 30*time.Second, client.cbTimeout)
+ assert.Equal(t, cbClosed, client.cbState)
+ assert.Equal(t, 0, client.cbFailures)
+ })
+
+ t.Run("default client has circuit breaker disabled", func(t *testing.T) {
+ client := mustNewClient(t, "http://localhost:8080")
+
+ assert.Equal(t, 0, client.cbThreshold)
+ assert.Equal(t, time.Duration(0), client.cbTimeout)
+ })
+}
+
+func TestClient_CircuitBreaker_StaysClosedOnSuccess(t *testing.T) {
+ config := newMinimalTenantConfig()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(config))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(3, 30*time.Second))
+ ctx := context.Background()
+
+ // Multiple successful requests should keep circuit breaker closed
+ for i := 0; i < 5; i++ {
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-123", result.ID)
+ }
+
+ assert.Equal(t, cbClosed, client.cbState)
+ assert.Equal(t, 0, client.cbFailures)
+}
+
+func TestClient_CircuitBreaker_OpensAfterThresholdFailures(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("internal error"))
+ }))
+ defer server.Close()
+
+ threshold := 3
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // Send threshold number of requests that trigger server errors
+ for i := 0; i < threshold; i++ {
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.NotErrorIs(t, err, core.ErrCircuitBreakerOpen, "should not be circuit breaker error yet on failure %d", i+1)
+ }
+
+ // Circuit breaker should now be open
+ assert.Equal(t, cbOpen, client.cbState)
+ assert.Equal(t, threshold, client.cbFailures)
+}
+
+func TestClient_CircuitBreaker_ReturnsErrCircuitBreakerOpenWhenOpen(t *testing.T) {
+ var requestCount atomic.Int32
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestCount.Add(1)
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("internal error"))
+ }))
+ defer server.Close()
+
+ threshold := 2
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // Trigger circuit breaker to open
+ for i := 0; i < threshold; i++ {
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ }
+
+ assert.Equal(t, cbOpen, client.cbState)
+ countAfterOpen := requestCount.Load()
+
+ // Subsequent requests should fail fast without hitting the server
+ for i := 0; i < 5; i++ {
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrCircuitBreakerOpen)
+ }
+
+ // No additional requests should have reached the server
+ assert.Equal(t, countAfterOpen, requestCount.Load(), "no additional HTTP requests should reach server when circuit is open")
+}
+
+func TestClient_CircuitBreaker_TransitionsToHalfOpenAfterTimeout(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("internal error"))
+ }))
+ defer server.Close()
+
+ threshold := 2
+ cbTimeout := 50 * time.Millisecond
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, cbTimeout))
+ ctx := context.Background()
+
+ // Trigger circuit breaker to open
+ for i := 0; i < threshold; i++ {
+ _, _ = client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ }
+
+ assert.Equal(t, cbOpen, client.cbState)
+
+ // Wait for the timeout to expire
+ time.Sleep(cbTimeout + 10*time.Millisecond)
+
+ // The next request should be allowed through (half-open probe)
+ // It will fail (server still returns 500), but the request should reach the server
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.NotErrorIs(t, err, core.ErrCircuitBreakerOpen, "request should pass through in half-open state")
+}
+
+func TestClient_CircuitBreaker_ClosesOnSuccessfulHalfOpenRequest(t *testing.T) {
+ var shouldSucceed atomic.Bool
+
+ config := newMinimalTenantConfig()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if shouldSucceed.Load() {
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(config))
+ return
+ }
+
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("internal error"))
+ }))
+ defer server.Close()
+
+ threshold := 2
+ cbTimeout := 50 * time.Millisecond
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, cbTimeout))
+ ctx := context.Background()
+
+ // Trigger circuit breaker to open
+ for i := 0; i < threshold; i++ {
+ _, _ = client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ }
+
+ assert.Equal(t, cbOpen, client.cbState)
+
+ // Wait for timeout, then make the server return success
+ time.Sleep(cbTimeout + 10*time.Millisecond)
+ shouldSucceed.Store(true)
+
+ // Half-open probe should succeed and close the circuit breaker
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-123", result.ID)
+ assert.Equal(t, cbClosed, client.cbState)
+ assert.Equal(t, 0, client.cbFailures)
+}
+
+func TestClient_CircuitBreaker_404DoesNotCountAsFailure(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer server.Close()
+
+ threshold := 3
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // Multiple 404s should NOT trigger the circuit breaker
+ for i := 0; i < threshold+2; i++ {
+ _, err := client.GetTenantConfig(ctx, "non-existent", "ledger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrTenantNotFound)
+ }
+
+ assert.Equal(t, cbClosed, client.cbState, "404 responses should not open the circuit breaker")
+ assert.Equal(t, 0, client.cbFailures, "404 responses should not count as failures")
+}
+
+func TestClient_CircuitBreaker_403DoesNotCountAsFailure(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusForbidden)
+ require.NoError(t, json.NewEncoder(w).Encode(map[string]string{
+ "code": "TS-SUSPENDED",
+ "error": "service ledger is suspended for this tenant",
+ "status": "suspended",
+ }))
+ }))
+ defer server.Close()
+
+ threshold := 3
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // Multiple 403s should NOT trigger the circuit breaker
+ for i := 0; i < threshold+2; i++ {
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.True(t, core.IsTenantSuspendedError(err))
+ }
+
+ assert.Equal(t, cbClosed, client.cbState, "403 responses should not open the circuit breaker")
+ assert.Equal(t, 0, client.cbFailures, "403 responses should not count as failures")
+}
+
+func TestClient_CircuitBreaker_400DoesNotCountAsFailure(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = w.Write([]byte("bad request"))
+ }))
+ defer server.Close()
+
+ threshold := 3
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // Multiple 400s should NOT trigger the circuit breaker
+ for i := 0; i < threshold+2; i++ {
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "400")
+ }
+
+ assert.Equal(t, cbClosed, client.cbState, "400 responses should not open the circuit breaker")
+ assert.Equal(t, 0, client.cbFailures, "400 responses should not count as failures")
+}
+
+func TestClient_CircuitBreaker_DisabledByDefault(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("internal error"))
+ }))
+ defer server.Close()
+
+ // No WithCircuitBreaker option - threshold is 0, circuit breaker disabled
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ // Even after many failures, requests should still go through
+ for i := 0; i < 10; i++ {
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.NotErrorIs(t, err, core.ErrCircuitBreakerOpen)
+ assert.Contains(t, err.Error(), "500")
+ }
+
+ assert.Equal(t, cbClosed, client.cbState, "circuit breaker should remain closed when disabled")
+ assert.Equal(t, 0, client.cbFailures, "failures should not be counted when circuit breaker is disabled")
+}
+
+func TestClient_GetActiveTenantsByService_Success(t *testing.T) {
+ tenants := []*TenantSummary{
+ {ID: "tenant-1", Name: "Acme Corp", Status: "active"},
+ {ID: "tenant-2", Name: "Globex Inc", Status: "active"},
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "/v1/tenants/active", r.URL.Path)
+ assert.Equal(t, "ledger", r.URL.Query().Get("service"))
+
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(tenants))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetActiveTenantsByService(ctx, "ledger")
+
+ require.NoError(t, err)
+ require.Len(t, result, 2)
+ assert.Equal(t, "tenant-1", result[0].ID)
+ assert.Equal(t, "Acme Corp", result[0].Name)
+ assert.Equal(t, "active", result[0].Status)
+ assert.Equal(t, "tenant-2", result[1].ID)
+ assert.Equal(t, "Globex Inc", result[1].Name)
+ assert.Equal(t, "active", result[1].Status)
+}
+
+func TestClient_CircuitBreaker_GetActiveTenantsByService(t *testing.T) {
+ t.Run("opens on server errors", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusServiceUnavailable)
+ _, _ = w.Write([]byte("service unavailable"))
+ }))
+ defer server.Close()
+
+ threshold := 2
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // Trigger circuit breaker via GetActiveTenantsByService
+ for i := 0; i < threshold; i++ {
+ _, err := client.GetActiveTenantsByService(ctx, "ledger")
+ require.Error(t, err)
+ }
+
+ assert.Equal(t, cbOpen, client.cbState)
+
+ // Should fail fast
+ _, err := client.GetActiveTenantsByService(ctx, "ledger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrCircuitBreakerOpen)
+ })
+
+ t.Run("shared state with GetTenantConfig", func(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusBadGateway)
+ _, _ = w.Write([]byte("bad gateway"))
+ }))
+ defer server.Close()
+
+ threshold := 3
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // Mix failures from both methods - they share the same circuit breaker
+ _, _ = client.GetTenantConfig(ctx, "t1", "ledger") // failure 1
+ _, _ = client.GetActiveTenantsByService(ctx, "ledger") // failure 2
+ _, _ = client.GetTenantConfig(ctx, "t2", "ledger") // failure 3 -> opens
+
+ assert.Equal(t, cbOpen, client.cbState)
+
+ // Both methods should fail fast
+ _, err := client.GetTenantConfig(ctx, "t3", "ledger")
+ assert.ErrorIs(t, err, core.ErrCircuitBreakerOpen)
+
+ _, err = client.GetActiveTenantsByService(ctx, "ledger")
+ assert.ErrorIs(t, err, core.ErrCircuitBreakerOpen)
+ })
+}
+
+func TestClient_CircuitBreaker_NetworkErrorCountsAsFailure(t *testing.T) {
+ // Use a URL that will definitely fail to connect
+ client := mustNewClient(t, "http://127.0.0.1:1",
+ WithCircuitBreaker(2, 30*time.Second),
+ WithTimeout(100*time.Millisecond),
+ )
+ ctx := context.Background()
+
+ // Network errors should count as failures
+ for i := 0; i < 2; i++ {
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ }
+
+ assert.Equal(t, cbOpen, client.cbState, "network errors should trigger circuit breaker")
+
+ // Should fail fast now
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrCircuitBreakerOpen)
+}
+
+func TestClient_CircuitBreaker_SuccessResetsAfterPartialFailures(t *testing.T) {
+ var requestCount atomic.Int32
+
+ config := newMinimalTenantConfig()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ count := requestCount.Add(1)
+ // Fail on first 2 requests, succeed on the rest
+ if count <= 2 {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte("internal error"))
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(config))
+ }))
+ defer server.Close()
+
+ threshold := 3
+ client := mustNewClient(t, server.URL, WithCircuitBreaker(threshold, 30*time.Second))
+ ctx := context.Background()
+
+ // 2 failures (below threshold)
+ _, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ _, err = client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.Error(t, err)
+ assert.Equal(t, 2, client.cbFailures)
+ assert.Equal(t, cbClosed, client.cbState, "should still be closed - below threshold")
+
+ // A success resets the counter
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-123", result.ID)
+ assert.Equal(t, 0, client.cbFailures, "success should reset failure count")
+ assert.Equal(t, cbClosed, client.cbState)
+}
+
+func TestIsServerError(t *testing.T) {
+ tests := []struct {
+ name string
+ statusCode int
+ expected bool
+ }{
+ {"200 OK is not server error", http.StatusOK, false},
+ {"400 Bad Request is not server error", http.StatusBadRequest, false},
+ {"403 Forbidden is not server error", http.StatusForbidden, false},
+ {"404 Not Found is not server error", http.StatusNotFound, false},
+ {"499 is not server error", 499, false},
+ {"500 Internal Server Error is server error", http.StatusInternalServerError, true},
+ {"502 Bad Gateway is server error", http.StatusBadGateway, true},
+ {"503 Service Unavailable is server error", http.StatusServiceUnavailable, true},
+ {"504 Gateway Timeout is server error", http.StatusGatewayTimeout, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.expected, isServerError(tt.statusCode))
+ })
+ }
+}
+
+func TestWithServiceAPIKey(t *testing.T) {
+ t.Run("sets serviceAPIKey on client", func(t *testing.T) {
+ client := mustNewClient(t, "http://localhost:8080",
+ WithServiceAPIKey("my-secret-key"),
+ )
+
+ assert.Equal(t, "my-secret-key", client.serviceAPIKey)
+ })
+
+ t.Run("missing WithServiceAPIKey returns error", func(t *testing.T) {
+ _, err := NewClient("http://localhost:8080", testutil.NewMockLogger(),
+ WithAllowInsecureHTTP(),
+ )
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrServiceAPIKeyRequired)
+ })
+
+ t.Run("empty string returns error", func(t *testing.T) {
+ _, err := NewClient("http://localhost:8080", testutil.NewMockLogger(),
+ WithAllowInsecureHTTP(),
+ WithServiceAPIKey(""),
+ )
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrServiceAPIKeyRequired)
+ })
+
+ t.Run("valid key succeeds", func(t *testing.T) {
+ c, err := NewClient("http://localhost:8080", testutil.NewMockLogger(),
+ WithAllowInsecureHTTP(),
+ WithServiceAPIKey("valid-key"),
+ )
+
+ require.NoError(t, err)
+ assert.Equal(t, "valid-key", c.serviceAPIKey)
+ })
+}
+
+func TestClient_GetTenantConfig_APIKeyHeader(t *testing.T) {
+ t.Run("sends X-API-Key header when serviceAPIKey is set", func(t *testing.T) {
+ config := newMinimalTenantConfig()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "my-secret-key", r.Header.Get("X-API-Key"))
+
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(config))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL, WithServiceAPIKey("my-secret-key"))
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-123", result.ID)
+ })
+
+ t.Run("sends default test API key from mustNewClient", func(t *testing.T) {
+ config := newMinimalTenantConfig()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "test-api-key", r.Header.Get("X-API-Key"),
+ "mustNewClient provides a default test API key")
+
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(config))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetTenantConfig(ctx, "tenant-123", "ledger")
+
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-123", result.ID)
+ })
+}
+
+func TestClient_GetActiveTenantsByService_APIKeyHeader(t *testing.T) {
+ t.Run("sends X-API-Key header when serviceAPIKey is set", func(t *testing.T) {
+ tenants := []*TenantSummary{
+ {ID: "tenant-1", Name: "Acme Corp", Status: "active"},
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "my-secret-key", r.Header.Get("X-API-Key"))
+
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(tenants))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL, WithServiceAPIKey("my-secret-key"))
+ ctx := context.Background()
+
+ result, err := client.GetActiveTenantsByService(ctx, "ledger")
+
+ require.NoError(t, err)
+ require.Len(t, result, 1)
+ assert.Equal(t, "tenant-1", result[0].ID)
+ })
+
+ t.Run("sends default test API key from mustNewClient", func(t *testing.T) {
+ tenants := []*TenantSummary{
+ {ID: "tenant-1", Name: "Acme Corp", Status: "active"},
+ }
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "test-api-key", r.Header.Get("X-API-Key"),
+ "mustNewClient provides a default test API key")
+
+ w.Header().Set("Content-Type", "application/json")
+ require.NoError(t, json.NewEncoder(w).Encode(tenants))
+ }))
+ defer server.Close()
+
+ client := mustNewClient(t, server.URL)
+ ctx := context.Background()
+
+ result, err := client.GetActiveTenantsByService(ctx, "ledger")
+
+ require.NoError(t, err)
+ require.Len(t, result, 1)
+ })
+}
+
+func TestCacheKeyPrefix(t *testing.T) {
+ t.Run("uses tenant-connections prefix", func(t *testing.T) {
+ assert.Equal(t, "tenant-connections", cacheKeyPrefix,
+ "cacheKeyPrefix must match the renamed /connections endpoint")
+ })
+}
+
+func TestIsCircuitBreakerOpenError(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ expected bool
+ }{
+ {"nil error returns false", nil, false},
+ {"ErrCircuitBreakerOpen returns true", core.ErrCircuitBreakerOpen, true},
+ {"other error returns false", core.ErrTenantNotFound, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.expected, core.IsCircuitBreakerOpenError(tt.err))
+ })
+ }
+}
diff --git a/commons/tenant-manager/consumer/goroutine_leak_test.go b/commons/tenant-manager/consumer/goroutine_leak_test.go
new file mode 100644
index 00000000..216394a1
--- /dev/null
+++ b/commons/tenant-manager/consumer/goroutine_leak_test.go
@@ -0,0 +1,109 @@
+package consumer
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/stretchr/testify/assert"
+ "go.uber.org/goleak"
+)
+
+// TestMultiTenantConsumer_Run_CloseStopsSyncLoop proves that Close() alone
+// (without cancelling the original context) stops the sync loop goroutine.
+// This prevents goroutine leaks when callers pass context.Background().
+func TestMultiTenantConsumer_Run_CloseStopsSyncLoop(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ // Populate Redis so fetchTenantIDs succeeds during discovery
+ mr.SAdd(testActiveTenantsKey, "tenant-001")
+
+ consumer := mustNewConsumer(t,
+ dummyRabbitMQManager(),
+ redisClient,
+ MultiTenantConfig{
+ SyncInterval: 100 * time.Millisecond,
+ PrefetchCount: 10,
+ Service: testServiceName,
+ },
+ testutil.NewMockLogger(),
+ )
+
+ // Use context.Background() — never cancelled, like Midaz does in production.
+ ctx := context.Background()
+
+ err := consumer.Run(ctx)
+ if err != nil {
+ t.Fatalf("Run() returned unexpected error: %v", err)
+ }
+
+ assert.Eventually(t, func() bool {
+ return consumer.Stats().KnownTenants > 0
+ }, time.Second, 20*time.Millisecond)
+
+ // Close without cancelling ctx — this must stop the sync loop.
+ if closeErr := consumer.Close(); closeErr != nil {
+ t.Fatalf("Close() returned unexpected error: %v", closeErr)
+ }
+
+ assert.Eventually(t, func() bool {
+ return consumer.Stats().Closed && consumer.Stats().ActiveTenants == 0
+ }, time.Second, 20*time.Millisecond)
+
+ goleak.VerifyNone(t,
+ goleak.IgnoreTopFunction("github.com/alicebob/miniredis/v2/server.(*Server).servePeer"),
+ goleak.IgnoreTopFunction("github.com/alicebob/miniredis/v2.(*Miniredis).handleClient"),
+ goleak.IgnoreTopFunction("github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/cache.(*InMemoryCache).cleanupLoop"),
+ goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
+ )
+}
+
+// TestMultiTenantConsumer_Run_CancelAndCloseNoLeak proves that the normal
+// cleanup path (cancel context + Close) also leaves no leaked goroutines.
+func TestMultiTenantConsumer_Run_CancelAndCloseNoLeak(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ // Populate Redis so fetchTenantIDs succeeds during discovery
+ mr.SAdd(testActiveTenantsKey, "tenant-001")
+
+ consumer := mustNewConsumer(t,
+ dummyRabbitMQManager(),
+ redisClient,
+ MultiTenantConfig{
+ SyncInterval: 100 * time.Millisecond,
+ PrefetchCount: 10,
+ Service: testServiceName,
+ },
+ testutil.NewMockLogger(),
+ )
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ err := consumer.Run(ctx)
+ if err != nil {
+ t.Fatalf("Run() returned unexpected error: %v", err)
+ }
+
+ assert.Eventually(t, func() bool {
+ return consumer.Stats().KnownTenants > 0
+ }, time.Second, 20*time.Millisecond)
+
+ // Normal cleanup: cancel context first, then Close.
+ cancel()
+
+ if closeErr := consumer.Close(); closeErr != nil {
+ t.Fatalf("Close() returned unexpected error: %v", closeErr)
+ }
+
+ assert.Eventually(t, func() bool {
+ return consumer.Stats().Closed && consumer.Stats().ActiveTenants == 0
+ }, time.Second, 20*time.Millisecond)
+
+ goleak.VerifyNone(t,
+ goleak.IgnoreTopFunction("github.com/alicebob/miniredis/v2/server.(*Server).servePeer"),
+ goleak.IgnoreTopFunction("github.com/alicebob/miniredis/v2.(*Miniredis).handleClient"),
+ goleak.IgnoreTopFunction("github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/cache.(*InMemoryCache).cleanupLoop"),
+ goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
+ )
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant.go b/commons/tenant-manager/consumer/multi_tenant.go
new file mode 100644
index 00000000..1add29c2
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant.go
@@ -0,0 +1,370 @@
+// Package consumer provides multi-tenant message queue consumption management.
+package consumer
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ amqp "github.com/rabbitmq/amqp091-go"
+ "github.com/redis/go-redis/v9"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ tmmongo "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/mongo"
+ tmpostgres "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/postgres"
+ tmrabbitmq "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/rabbitmq"
+)
+
+// HandlerFunc is a function that processes messages from a queue.
+// The context contains the tenant ID via core.SetTenantIDInContext.
+type HandlerFunc func(ctx context.Context, delivery amqp.Delivery) error
+
+// MultiTenantConfig holds configuration for the MultiTenantConsumer.
+type MultiTenantConfig struct {
+ // SyncInterval is the interval between tenant list synchronizations.
+ // Default: 30 seconds
+ SyncInterval time.Duration
+
+ // WorkersPerQueue is reserved for future use. It is currently not implemented
+ // and has no effect on consumer behavior. Each queue runs a single consumer goroutine.
+ // Setting this field is a no-op; it is retained only for backward compatibility.
+ //
+ // Deprecated: This field is not yet implemented. Setting it has no effect.
+ WorkersPerQueue int
+
+ // PrefetchCount is the QoS prefetch count per channel.
+ // Default: 10
+ PrefetchCount int
+
+ // MultiTenantURL is the fallback HTTP endpoint to fetch tenants if Redis cache misses.
+ // Format: http://tenant-manager:4003
+ MultiTenantURL string
+
+ // ServiceAPIKey is the API key sent as X-API-Key header on HTTP requests to the
+ // Tenant Manager. Required when MultiTenantURL is set. Typically sourced from
+ // the MULTI_TENANT_SERVICE_API_KEY environment variable.
+ ServiceAPIKey string
+
+ // Service is the service name to filter tenants by.
+ // This is passed to tenant-manager when fetching tenant list.
+ Service string
+
+ // Environment is the deployment environment (e.g., "staging", "production").
+ // Used to build environment-segmented Redis cache keys for active tenants.
+ // When set together with Service, the Redis key becomes:
+ // "tenant-manager:tenants:active:{Environment}:{Service}"
+ Environment string
+
+ // DiscoveryTimeout is the maximum time allowed for the initial tenant discovery
+ // (fetching tenant IDs at startup). If zero, 500ms is used. Increase this for
+ // high-latency or loaded environments where Redis or the tenant-manager API
+ // may respond slowly; discovery is best-effort and the sync loop will retry.
+ // Default: 500ms
+ DiscoveryTimeout time.Duration
+
+ // AllowInsecureHTTP permits the use of http:// (plaintext) URLs for the
+ // MultiTenantURL. By default, only https:// is accepted by the underlying
+ // client. Set this to true for in-cluster Kubernetes service URLs that use
+ // plain HTTP (e.g., http://tenant-manager.namespace.svc.cluster.local:4003).
+ // Default: false
+ AllowInsecureHTTP bool
+
+ // Deprecated: EagerStart is ignored. Consumers are always started eagerly.
+ // This field is retained only for backward compatibility with existing configs;
+ // setting it has no effect. It will be removed in a future major version.
+ EagerStart bool
+}
+
+// DefaultMultiTenantConfig returns a MultiTenantConfig with sensible defaults.
+func DefaultMultiTenantConfig() MultiTenantConfig {
+ return MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ PrefetchCount: 10,
+ DiscoveryTimeout: 500 * time.Millisecond,
+ }
+}
+
+// Option configures a MultiTenantConsumer.
+type Option func(*MultiTenantConsumer)
+
+// WithPostgresManager sets the postgres Manager on the consumer.
+// When set, database connections for removed tenants are automatically closed
+// during tenant synchronization.
+func WithPostgresManager(p *tmpostgres.Manager) Option {
+ return func(c *MultiTenantConsumer) { c.postgres = p }
+}
+
+// WithMongoManager sets the mongo Manager on the consumer.
+// When set, MongoDB connections for removed tenants are automatically closed
+// during tenant synchronization.
+func WithMongoManager(m *tmmongo.Manager) Option {
+ return func(c *MultiTenantConsumer) { c.mongo = m }
+}
+
+// MultiTenantConsumer manages message consumption across multiple tenant vhosts.
+// It dynamically discovers tenants from Redis cache and spawns consumer goroutines.
+// Run() discovers tenants and eagerly starts consumers for all known tenants.
+// New tenants discovered during background sync are also started immediately.
+type MultiTenantConsumer struct {
+ rabbitmq *tmrabbitmq.Manager
+ redisClient redis.UniversalClient
+ pmClient *client.Client // Tenant Manager client for fallback
+ handlers map[string]HandlerFunc
+ tenants map[string]context.CancelFunc // Active tenant goroutines
+ knownTenants map[string]bool // Discovered tenants (populated by discovery and sync)
+ // tenantAbsenceCount tracks consecutive syncs each tenant was missing from the fetched list.
+ // Used to avoid removing tenants on a single transient incomplete fetch.
+ tenantAbsenceCount map[string]int
+ config MultiTenantConfig
+ mu sync.RWMutex
+ logger *logcompat.Logger
+ closed bool
+
+ // postgres manages PostgreSQL connections per tenant.
+ // When set, connections are closed automatically when a tenant is removed.
+ postgres *tmpostgres.Manager
+
+ // mongo manages MongoDB connections per tenant.
+ // When set, connections are closed automatically when a tenant is removed.
+ mongo *tmmongo.Manager
+
+ // consumerLocks provides per-tenant mutexes for double-check locking in ensureConsumerStarted.
+ // Key: tenantID, Value: *sync.Mutex
+ consumerLocks sync.Map
+
+ // retryState holds per-tenant retry counters for connection failure resilience.
+ // Key: tenantID, Value: *retryStateEntry
+ retryState sync.Map
+
+ // parentCtx is the context passed to Run(), stored for use by ensureConsumerStarted.
+ parentCtx context.Context
+
+ // syncLoopCancel cancels the context used by the sync loop goroutine.
+ // Stored in Run() and called in Close() to ensure the sync loop stops
+ // even when the original context (e.g., context.Background()) is never cancelled.
+ syncLoopCancel context.CancelFunc
+}
+
+// NewMultiTenantConsumerWithError creates a new MultiTenantConsumer.
+// Parameters:
+// - rabbitmq: RabbitMQ connection manager for tenant vhosts (must not be nil)
+// - redisClient: Redis client for tenant cache access (must not be nil)
+// - config: Consumer configuration
+// - logger: Logger for operational logging
+// - opts: Optional configuration options (e.g., WithPostgresManager, WithMongoManager)
+//
+// Returns an error if rabbitmq or redisClient is nil, as they are required for core functionality.
+func NewMultiTenantConsumerWithError(
+ rabbitmq *tmrabbitmq.Manager,
+ redisClient redis.UniversalClient,
+ config MultiTenantConfig,
+ logger libLog.Logger,
+ opts ...Option,
+) (*MultiTenantConsumer, error) {
+ if rabbitmq == nil {
+ return nil, errors.New("consumer.NewMultiTenantConsumerWithError: rabbitmq must not be nil")
+ }
+
+ if redisClient == nil {
+ return nil, errors.New("consumer.NewMultiTenantConsumerWithError: redisClient must not be nil")
+ }
+
+ // Guard against nil logger to prevent panics downstream
+ if logger == nil {
+ logger = libLog.NewNop()
+ }
+
+ // Apply defaults
+ if config.SyncInterval <= 0 {
+ config.SyncInterval = 30 * time.Second
+ }
+
+ if config.PrefetchCount == 0 {
+ config.PrefetchCount = 10
+ }
+
+ consumer := &MultiTenantConsumer{
+ rabbitmq: rabbitmq,
+ redisClient: redisClient,
+ handlers: make(map[string]HandlerFunc),
+ tenants: make(map[string]context.CancelFunc),
+ knownTenants: make(map[string]bool),
+ tenantAbsenceCount: make(map[string]int),
+ config: config,
+ logger: logcompat.New(logger),
+ }
+
+ // Apply optional configurations
+ for _, opt := range opts {
+ opt(consumer)
+ }
+
+ // Create Tenant Manager client for fallback if URL is configured
+ if config.MultiTenantURL != "" {
+ clientOpts := []client.ClientOption{
+ client.WithServiceAPIKey(config.ServiceAPIKey),
+ }
+
+ if config.AllowInsecureHTTP {
+ clientOpts = append(clientOpts, client.WithAllowInsecureHTTP())
+ }
+
+ pmClient, err := client.NewClient(config.MultiTenantURL, consumer.logger.Base(), clientOpts...)
+ if err != nil {
+ return nil, fmt.Errorf("consumer.NewMultiTenantConsumerWithError: invalid MultiTenantURL: %w", err)
+ }
+
+ consumer.pmClient = pmClient
+ }
+
+ if config.WorkersPerQueue > 0 {
+ consumer.logger.Base().Log(context.Background(), libLog.LevelWarn,
+ "WorkersPerQueue is deprecated and has no effect; the field is reserved for future use",
+ libLog.Int("workers_per_queue", config.WorkersPerQueue))
+ }
+
+ return consumer, nil
+}
+
+// Register adds a queue handler for all tenant vhosts.
+// The handler will be invoked for messages from the specified queue in each tenant's vhost.
+//
+// Handlers should be registered before calling Run(). Handlers registered after Run()
+// has been called will only take effect for tenants whose consumers are spawned after
+// the registration; already-running tenant consumers will NOT pick up the new handler.
+//
+// Returns an error if handler is nil.
+func (c *MultiTenantConsumer) Register(queueName string, handler HandlerFunc) error {
+ if handler == nil {
+ return fmt.Errorf("consumer.Register: queue %q: %w", queueName, core.ErrNilHandlerFunc)
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.handlers[queueName] = handler
+ c.logger.Infof("registered handler for queue: %s", queueName)
+
+ return nil
+}
+
+// Run starts the multi-tenant consumer.
+// It discovers tenants (non-blocking, soft failure), eagerly starts consumers
+// for all discovered tenants, and starts background polling for new tenants.
+// Returns nil even on discovery failure (soft failure).
+func (c *MultiTenantConsumer) Run(ctx context.Context) error {
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ // Fall back to constructor logger when context has no logger attached
+ // (e.g., context.Background()). This prevents silent log loss.
+ if c.logger != nil {
+ logger = c.logger
+ }
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.run")
+ defer span.End()
+
+ // Store parent context for use by ensureConsumerStarted.
+ // Protected by c.mu because ensureConsumerStarted reads it concurrently.
+ c.mu.Lock()
+ c.parentCtx = ctx
+ c.mu.Unlock()
+
+ // Discover tenants without blocking (soft failure - does not start consumers)
+ c.discoverTenants(ctx)
+
+ // Capture count under lock to avoid concurrent read race
+ c.mu.RLock()
+ knownCount := len(c.knownTenants)
+ c.mu.RUnlock()
+
+ logger.InfofCtx(ctx, "starting multi-tenant consumer, connection_mode=eager, known_tenants=%d",
+ knownCount)
+
+ // Eager start: start consumers for all discovered tenants immediately
+ if knownCount > 0 {
+ c.eagerStartKnownTenants(ctx)
+ }
+
+ // Background polling - ASYNC
+ // Create a derived context so Close() can stop the sync loop even when
+ // the caller passes a never-cancelled context (e.g., context.Background()).
+ syncCtx, syncCancel := context.WithCancel(ctx) //#nosec G118 -- cancel is stored in c.syncLoopCancel and called by Close()
+ c.syncLoopCancel = syncCancel
+
+ go c.syncActiveTenants(syncCtx)
+
+ return nil
+}
+
+// Close stops all consumer goroutines and marks the consumer as closed.
+// It also closes the fallback pmClient to prevent goroutine leaks from its
+// InMemoryCache cleanup loop.
+func (c *MultiTenantConsumer) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.closed = true
+
+ // Cancel the sync loop context first, so the background polling goroutine
+ // stops before we tear down individual tenant consumers.
+ if c.syncLoopCancel != nil {
+ c.syncLoopCancel()
+ }
+
+ // Cancel all tenant contexts
+ for tenantID, cancel := range c.tenants {
+ c.logger.Infof("stopping consumer for tenant: %s", tenantID)
+ cancel()
+ }
+
+ // Clear the maps
+
+ c.tenants = make(map[string]context.CancelFunc)
+ c.knownTenants = make(map[string]bool)
+ c.tenantAbsenceCount = make(map[string]int)
+
+ // Close fallback pmClient to release its InMemoryCache cleanup goroutine.
+ if c.pmClient != nil {
+ if err := c.pmClient.Close(); err != nil {
+ c.logger.Warnf("failed to close fallback tenant manager client: %v", err)
+ }
+ }
+
+ c.logger.Info("multi-tenant consumer closed")
+
+ return nil
+}
+
+// Stats holds statistics for the consumer.
+type Stats struct {
+ ActiveTenants int `json:"activeTenants"`
+ TenantIDs []string `json:"tenantIds"`
+ RegisteredQueues []string `json:"registeredQueues"`
+ Closed bool `json:"closed"`
+ KnownTenants int `json:"knownTenants"`
+ KnownTenantIDs []string `json:"knownTenantIds"`
+ PendingTenants int `json:"pendingTenants"`
+ PendingTenantIDs []string `json:"pendingTenantIds"`
+ DegradedTenants []string `json:"degradedTenants"`
+}
+
+// Prometheus-compatible metric name constants for multi-tenant consumer observability.
+// These constants provide a standardized naming scheme for metrics instrumentation.
+const (
+ // MetricTenantConnectionsTotal tracks the total number of tenant connections established.
+ MetricTenantConnectionsTotal = "tenant_connections_total"
+ // MetricTenantConnectionErrors tracks connection errors by tenant.
+ MetricTenantConnectionErrors = "tenant_connection_errors_total"
+ // MetricTenantConsumersActive tracks the number of currently active tenant consumers.
+ MetricTenantConsumersActive = "tenant_consumers_active"
+ // MetricTenantMessageProcessed tracks the total number of messages processed per tenant.
+ MetricTenantMessageProcessed = "tenant_messages_processed_total"
+)
diff --git a/commons/tenant-manager/consumer/multi_tenant_consume.go b/commons/tenant-manager/consumer/multi_tenant_consume.go
new file mode 100644
index 00000000..23cc3dd2
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_consume.go
@@ -0,0 +1,483 @@
+package consumer
+
+import (
+ "context"
+ crand "crypto/rand"
+ "encoding/binary"
+ "maps"
+ "sync"
+ "time"
+
+ amqp "github.com/rabbitmq/amqp091-go"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+)
+
+// retryStateEntry holds per-tenant retry state for connection failure resilience.
+type retryStateEntry struct {
+ mu sync.Mutex
+ retryCount int
+ degraded bool
+}
+
+// reset clears retry counters and degraded flag. Must be called with no other goroutine
+// holding the entry's mutex (e.g. after Load from sync.Map).
+func (e *retryStateEntry) reset() {
+ e.mu.Lock()
+ e.retryCount = 0
+ e.degraded = false
+ e.mu.Unlock()
+}
+
+// isDegraded returns whether the tenant is marked degraded.
+func (e *retryStateEntry) isDegraded() bool {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.degraded
+}
+
+// incRetryAndMaybeMarkDegraded increments retry count, optionally marks degraded if count >= max,
+// and returns the backoff delay and current retry count. justMarkedDegraded is true only when
+// the entry was not degraded and is now marked degraded by this call.
+func (e *retryStateEntry) incRetryAndMaybeMarkDegraded(maxBeforeDegraded int) (delay time.Duration, retryCount int, justMarkedDegraded bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ delay = backoffDelay(e.retryCount)
+ e.retryCount++
+
+ prev := e.degraded
+ if e.retryCount >= maxBeforeDegraded {
+ e.degraded = true
+ }
+
+ justMarkedDegraded = !prev && e.degraded
+
+ return delay, e.retryCount, justMarkedDegraded
+}
+
+// startTenantConsumer spawns a consumer goroutine for a tenant.
+// MUST be called with c.mu held.
+func (c *MultiTenantConsumer) startTenantConsumer(parentCtx context.Context, tenantID string) {
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(parentCtx)
+ logger := logcompat.New(baseLogger)
+
+ parentCtx, span := tracer.Start(parentCtx, "consumer.multi_tenant_consumer.start_tenant_consumer")
+ defer span.End()
+
+ // Create a cancellable context for this tenant
+ tenantCtx, cancel := context.WithCancel(parentCtx) //#nosec G118 -- cancel stored in c.tenants[tenantID] and called when tenant consumer is stopped
+
+ // Store the cancel function (caller holds lock)
+ c.tenants[tenantID] = cancel
+
+ logger.InfofCtx(parentCtx, "starting consumer for tenant: %s", tenantID)
+
+ // Spawn consumer goroutine
+ go c.superviseTenantQueues(tenantCtx, tenantID)
+}
+
+// superviseTenantQueues runs the consumer loop for a single tenant.
+func (c *MultiTenantConsumer) superviseTenantQueues(ctx context.Context, tenantID string) {
+ // Set tenantID in context for handlers
+ ctx = core.SetTenantIDInContext(ctx, tenantID)
+
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.consume_for_tenant")
+ defer span.End()
+
+ logger = logger.WithFields("tenant_id", tenantID)
+ logger.InfoCtx(ctx, "consumer started for tenant")
+
+ // Get all registered handlers (read-only, no lock needed after initial registration)
+ c.mu.RLock()
+
+ handlers := make(map[string]HandlerFunc, len(c.handlers))
+ maps.Copy(handlers, c.handlers)
+
+ c.mu.RUnlock()
+
+ // Consume from each registered queue
+ for queueName, handler := range handlers {
+ go c.consumeTenantQueue(ctx, tenantID, queueName, handler, logger)
+ }
+
+ // Wait for context cancellation
+ <-ctx.Done()
+ logger.InfoCtx(ctx, "consumer stopped for tenant")
+}
+
+// consumeTenantQueue consumes messages from a specific queue for a tenant.
+// Each connection attempt creates a short-lived span to avoid accumulating events
+// on a long-lived span that would grow unbounded over the consumer's lifetime.
+func (c *MultiTenantConsumer) consumeTenantQueue(
+ ctx context.Context,
+ tenantID string,
+ queueName string,
+ handler HandlerFunc,
+ _ *logcompat.Logger,
+) {
+ baseLogger, _, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled
+ logger := logcompat.New(baseLogger).WithFields("tenant_id", tenantID, "queue", queueName)
+
+ // Guard against nil RabbitMQ manager (e.g., during lazy mode testing)
+ if c.rabbitmq == nil {
+ logger.WarnCtx(ctx, "RabbitMQ manager is nil, cannot consume from queue")
+ return
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ logger.InfoCtx(ctx, "queue consumer stopped")
+ return
+ default:
+ }
+
+ shouldContinue := c.attemptConsumeConnection(ctx, tenantID, queueName, handler, logger)
+ if !shouldContinue {
+ return
+ }
+
+ logger.WarnCtx(ctx, "channel closed, reconnecting...")
+ }
+}
+
+// attemptConsumeConnection attempts to establish a channel and consume messages.
+// Returns true if the loop should continue (reconnect), false if it should stop.
+// Uses exponential backoff with per-tenant retry state for connection failures.
+func (c *MultiTenantConsumer) attemptConsumeConnection(
+ ctx context.Context,
+ tenantID string,
+ queueName string,
+ handler HandlerFunc,
+ logger *logcompat.Logger,
+) bool {
+ _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled
+
+ connCtx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.consume_connection")
+ defer span.End()
+
+ state := c.getRetryState(tenantID)
+
+ // Get channel for this tenant's vhost
+ ch, err := c.rabbitmq.GetChannel(connCtx, tenantID)
+ if err != nil {
+ // If the tenant is suspended or purged, stop the consumer instead of retrying.
+ // Retrying a suspended/purged tenant would cause infinite reconnect loops.
+ if core.IsTenantSuspendedError(err) || core.IsTenantPurgedError(err) {
+ logger.WarnfCtx(ctx, "tenant %s is suspended/purged, stopping consumer: %v", tenantID, err)
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "tenant suspended/purged, stopping consumer", err)
+ c.evictSuspendedTenant(ctx, tenantID, logger)
+
+ return false
+ }
+
+ delay, retryCount, justMarkedDegraded := state.incRetryAndMaybeMarkDegraded(maxRetryBeforeDegraded)
+ if justMarkedDegraded {
+ logger.WarnfCtx(ctx, "tenant %s marked as degraded after %d consecutive failures", tenantID, retryCount)
+ }
+
+ logger.WarnfCtx(ctx, "failed to get channel for tenant %s, retrying in %s (attempt %d): %v",
+ tenantID, delay, retryCount, err)
+ libOpentelemetry.HandleSpanError(span, "failed to get channel", err)
+
+ select {
+ case <-ctx.Done():
+ return false
+ case <-time.After(delay):
+ return true
+ }
+ }
+
+ // Set QoS
+ if err := ch.Qos(c.config.PrefetchCount, 0, false); err != nil {
+ _ = ch.Close() // Close channel to prevent leak
+
+ delay, retryCount, justMarkedDegraded := state.incRetryAndMaybeMarkDegraded(maxRetryBeforeDegraded)
+ if justMarkedDegraded {
+ logger.WarnfCtx(ctx, "tenant %s marked as degraded after %d consecutive failures", tenantID, retryCount)
+ }
+
+ logger.WarnfCtx(ctx, "failed to set QoS for tenant %s, retrying in %s (attempt %d): %v",
+ tenantID, delay, retryCount, err)
+ libOpentelemetry.HandleSpanError(span, "failed to set QoS", err)
+
+ select {
+ case <-ctx.Done():
+ return false
+ case <-time.After(delay):
+ return true
+ }
+ }
+
+ // Start consuming
+ msgs, err := ch.Consume(
+ queueName,
+ "", // consumer tag
+ false, // auto-ack
+ false, // exclusive
+ false, // no-local
+ false, // no-wait
+ nil, // args
+ )
+ if err != nil {
+ _ = ch.Close() // Close channel to prevent leak
+
+ delay, retryCount, justMarkedDegraded := state.incRetryAndMaybeMarkDegraded(maxRetryBeforeDegraded)
+ if justMarkedDegraded {
+ logger.WarnfCtx(ctx, "tenant %s marked as degraded after %d consecutive failures", tenantID, retryCount)
+ }
+
+ logger.WarnfCtx(ctx, "failed to start consuming for tenant %s, retrying in %s (attempt %d): %v",
+ tenantID, delay, retryCount, err)
+ libOpentelemetry.HandleSpanError(span, "failed to start consuming", err)
+
+ select {
+ case <-ctx.Done():
+ return false
+ case <-time.After(delay):
+ return true
+ }
+ }
+
+ // Connection succeeded: reset retry state
+ c.resetRetryState(tenantID)
+
+ logger.InfofCtx(ctx, "consuming started for tenant %s on queue %s", tenantID, queueName)
+
+ // Setup channel close notification
+ notifyClose := make(chan *amqp.Error, 1)
+ ch.NotifyClose(notifyClose)
+
+ // Process messages (blocks until channel closes or context is cancelled)
+ c.processMessages(ctx, tenantID, queueName, handler, msgs, notifyClose, logger)
+
+ return true
+}
+
+// processMessages processes messages from the channel until it closes.
+// Each message is processed with its own span to avoid accumulating events on a long-lived span.
+func (c *MultiTenantConsumer) processMessages(
+ ctx context.Context,
+ tenantID string,
+ queueName string,
+ handler HandlerFunc,
+ msgs <-chan amqp.Delivery,
+ notifyClose <-chan *amqp.Error,
+ _ *logcompat.Logger,
+) {
+ baseLogger, _, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled
+ logger := logcompat.New(baseLogger).WithFields("tenant_id", tenantID, "queue", queueName)
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case err := <-notifyClose:
+ if err != nil {
+ logger.WarnfCtx(ctx, "channel closed with error: %v", err)
+ }
+
+ return
+ case msg, ok := <-msgs:
+ if !ok {
+ logger.WarnCtx(ctx, "message channel closed")
+ return
+ }
+
+ c.handleMessage(ctx, tenantID, queueName, handler, msg, logger)
+ }
+ }
+}
+
+// handleMessage processes a single message with its own span.
+func (c *MultiTenantConsumer) handleMessage(
+ ctx context.Context,
+ tenantID string,
+ queueName string,
+ handler HandlerFunc,
+ msg amqp.Delivery,
+ logger *logcompat.Logger,
+) {
+ _, tracer, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled
+
+ // Process message with tenant context
+ msgCtx := core.SetTenantIDInContext(ctx, tenantID)
+
+ // Extract trace context from message headers
+ msgCtx = libOpentelemetry.ExtractTraceContextFromQueueHeaders(msgCtx, msg.Headers)
+
+ // Create a per-message span
+ msgCtx, span := tracer.Start(msgCtx, "consumer.multi_tenant_consumer.handle_message")
+ defer span.End()
+
+ if err := handler(msgCtx, msg); err != nil {
+ logger.ErrorfCtx(ctx, "handler error for queue %s: %v", queueName, err)
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "handler error", err)
+
+ if nackErr := msg.Nack(false, true); nackErr != nil {
+ logger.ErrorfCtx(ctx, "failed to nack message: %v", nackErr)
+ }
+ } else {
+ // Ack on success
+ if ackErr := msg.Ack(false); ackErr != nil {
+ logger.ErrorfCtx(ctx, "failed to ack message: %v", ackErr)
+ }
+ }
+}
+
+// initialBackoff is the base delay for exponential backoff on connection failures.
+const initialBackoff = 5 * time.Second
+
+// maxBackoff is the maximum delay between retry attempts.
+const maxBackoff = 40 * time.Second
+
+// maxRetryBeforeDegraded is the number of consecutive failures before marking a tenant as degraded.
+const maxRetryBeforeDegraded = 3
+
+// backoffDelay calculates the exponential backoff delay for a given retry count
+// with +/-25% jitter to prevent thundering herd when multiple tenants retry simultaneously.
+// Base sequence: 5s, 10s, 20s, 40s, 40s, ... (before jitter).
+func backoffDelay(retryCount int) time.Duration {
+ delay := initialBackoff
+ for range retryCount {
+ delay *= 2
+ if delay > maxBackoff {
+ delay = maxBackoff
+
+ break
+ }
+ }
+
+ // Apply +/-25% jitter: multiply by a random factor in [0.75, 1.25).
+ // Uses crypto/rand to satisfy gosec G404.
+ var b [8]byte
+
+ _, _ = crand.Read(b[:])
+
+ jitter := 0.75 + float64(binary.LittleEndian.Uint64(b[:]))/(1<<64)*0.5
+
+ return time.Duration(float64(delay) * jitter)
+}
+
+// getRetryState returns the retry state entry for a tenant, creating one if it does not exist.
+func (c *MultiTenantConsumer) getRetryState(tenantID string) *retryStateEntry {
+ entry, _ := c.retryState.LoadOrStore(tenantID, &retryStateEntry{})
+
+ val, ok := entry.(*retryStateEntry)
+ if !ok {
+ return &retryStateEntry{}
+ }
+
+ return val
+}
+
+// resetRetryState resets the retry counter and degraded flag for a tenant after a successful connection.
+// It reuses the existing entry when present (reset in place) to avoid allocation churn; only stores
+// a new entry when the tenant has no entry yet.
+func (c *MultiTenantConsumer) resetRetryState(tenantID string) {
+ if entry, ok := c.retryState.Load(tenantID); ok {
+ if state, ok := entry.(*retryStateEntry); ok {
+ state.reset()
+ return
+ }
+ }
+
+ c.retryState.Store(tenantID, &retryStateEntry{})
+}
+
+// ensureConsumerStarted ensures a consumer is running for the given tenant.
+// It uses double-check locking with a per-tenant mutex to guarantee exactly-once
+// consumer spawning under concurrent access.
+//
+// Consumers are only started for tenants that are known (resolved via discovery or
+// sync). Unknown tenants are rejected to prevent starting consumers for tenants
+// that have not been validated by the sync loop.
+func (c *MultiTenantConsumer) ensureConsumerStarted(ctx context.Context, tenantID string) {
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.ensure_consumer_started")
+ defer span.End()
+
+ // Fast path: check if consumer is already active (read lock only)
+ c.mu.RLock()
+
+ _, exists := c.tenants[tenantID]
+ known := c.knownTenants[tenantID]
+ closed := c.closed
+ c.mu.RUnlock()
+
+ if exists || closed {
+ return
+ }
+
+ // Reject unknown tenants: they haven't been discovered or validated yet.
+ // The sync loop will add them to knownTenants when they appear.
+ if !known {
+ logger.WarnfCtx(ctx, "rejecting consumer start for unknown tenant: %s (not yet resolved by sync)", tenantID)
+
+ return
+ }
+
+ // Slow path: acquire per-tenant mutex for double-check locking
+ lockVal, _ := c.consumerLocks.LoadOrStore(tenantID, &sync.Mutex{})
+
+ tenantMu, ok := lockVal.(*sync.Mutex)
+ if !ok {
+ return
+ }
+
+ tenantMu.Lock()
+ defer tenantMu.Unlock()
+
+ // Double-check under per-tenant lock
+ c.mu.RLock()
+ _, exists = c.tenants[tenantID]
+ closed = c.closed
+ c.mu.RUnlock()
+
+ if exists || closed {
+ return
+ }
+
+ // Use stored parentCtx if available (from Run()), otherwise use the provided ctx.
+ // Protected by c.mu.RLock because Run() writes parentCtx concurrently.
+ c.mu.RLock()
+
+ startCtx := ctx
+ if c.parentCtx != nil {
+ startCtx = c.parentCtx
+ }
+
+ c.mu.RUnlock()
+
+ logger.InfofCtx(ctx, "on-demand consumer start for tenant: %s", tenantID)
+
+ c.mu.Lock()
+ c.startTenantConsumer(startCtx, tenantID)
+ c.mu.Unlock()
+}
+
+// IsDegraded returns true if the given tenant is currently in a degraded state
+// due to repeated connection failures (>= maxRetryBeforeDegraded consecutive failures).
+func (c *MultiTenantConsumer) IsDegraded(tenantID string) bool {
+ entry, ok := c.retryState.Load(tenantID)
+ if !ok {
+ return false
+ }
+
+ state, ok := entry.(*retryStateEntry)
+ if !ok {
+ return false
+ }
+
+ return state.isDegraded()
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant_consume_test.go b/commons/tenant-manager/consumer/multi_tenant_consume_test.go
new file mode 100644
index 00000000..a04af932
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_consume_test.go
@@ -0,0 +1,96 @@
+package consumer
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ amqp "github.com/rabbitmq/amqp091-go"
+ "github.com/stretchr/testify/assert"
+)
+
+type fakeAcknowledger struct {
+ ackCalls int
+ nackCalls int
+ requeue bool
+}
+
+func (f *fakeAcknowledger) Ack(uint64, bool) error {
+ f.ackCalls++
+ return nil
+}
+
+func (f *fakeAcknowledger) Nack(uint64, bool, bool) error {
+ f.nackCalls++
+ f.requeue = true
+ return nil
+}
+
+func (f *fakeAcknowledger) Reject(uint64, bool) error { return nil }
+
+func TestMultiTenantConsumer_HandleMessage_AcksSuccessfulMessages(t *testing.T) {
+ t.Parallel()
+
+ consumer := &MultiTenantConsumer{}
+ ack := &fakeAcknowledger{}
+ logger := logcompat.New(testutil.NewMockLogger())
+
+ var seenTenantID string
+ msg := amqp.Delivery{Acknowledger: ack, DeliveryTag: 1, Headers: amqp.Table{}}
+
+ consumer.handleMessage(context.Background(), "tenant-ack", "queue-a", func(ctx context.Context, delivery amqp.Delivery) error {
+ seenTenantID = core.GetTenantIDFromContext(ctx)
+ return nil
+ }, msg, logger)
+
+ assert.Equal(t, "tenant-ack", seenTenantID)
+ assert.Equal(t, 1, ack.ackCalls)
+ assert.Equal(t, 0, ack.nackCalls)
+}
+
+func TestMultiTenantConsumer_HandleMessage_NacksFailedMessages(t *testing.T) {
+ t.Parallel()
+
+ consumer := &MultiTenantConsumer{}
+ ack := &fakeAcknowledger{}
+ logger := logcompat.New(testutil.NewMockLogger())
+
+ msg := amqp.Delivery{Acknowledger: ack, DeliveryTag: 2, Headers: amqp.Table{}}
+
+ consumer.handleMessage(context.Background(), "tenant-nack", "queue-b", func(context.Context, amqp.Delivery) error {
+ return errors.New("boom")
+ }, msg, logger)
+
+ assert.Equal(t, 0, ack.ackCalls)
+ assert.Equal(t, 1, ack.nackCalls)
+ assert.True(t, ack.requeue)
+}
+
+func TestMultiTenantConsumer_ProcessMessages_ReturnsOnChannelClose(t *testing.T) {
+ t.Parallel()
+
+ consumer := &MultiTenantConsumer{}
+ logger := logcompat.New(testutil.NewMockLogger())
+ msgs := make(chan amqp.Delivery)
+ notifyClose := make(chan *amqp.Error, 1)
+ done := make(chan struct{})
+
+ go func() {
+ consumer.processMessages(context.Background(), "tenant-close", "queue-c", func(context.Context, amqp.Delivery) error {
+ return nil
+ }, msgs, notifyClose, logger)
+ close(done)
+ }()
+
+ notifyClose <- &amqp.Error{Reason: "channel closed"}
+
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("processMessages did not return after channel close notification")
+ }
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant_retry_test.go b/commons/tenant-manager/consumer/multi_tenant_retry_test.go
new file mode 100644
index 00000000..f0dec698
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_retry_test.go
@@ -0,0 +1,151 @@
+package consumer
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ amqp "github.com/rabbitmq/amqp091-go"
+ "github.com/stretchr/testify/assert"
+)
+
+func applyRetryFailures(state *retryStateEntry, count int) {
+ for range count {
+ _, _, _ = state.incRetryAndMaybeMarkDegraded(maxRetryBeforeDegraded)
+ }
+}
+
+// TestMultiTenantConsumer_RetryState verifies per-tenant retry state management.
+func TestMultiTenantConsumer_RetryState(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ incrementRetries int
+ expectedDegraded bool
+ resetBeforeAssert bool
+ }{
+ {name: "initial_retry_state_is_zero", tenantID: "tenant-fresh", incrementRetries: 0, expectedDegraded: false},
+ {name: "2_retries_not_degraded", tenantID: "tenant-2-retries", incrementRetries: 2, expectedDegraded: false},
+ {name: "3_retries_marks_degraded", tenantID: "tenant-3-retries", incrementRetries: 3, expectedDegraded: true},
+ {name: "5_retries_stays_degraded", tenantID: "tenant-5-retries", incrementRetries: 5, expectedDegraded: true},
+ {name: "reset_clears_retry_state", tenantID: "tenant-reset", incrementRetries: 5, resetBeforeAssert: true, expectedDegraded: false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ state := consumer.getRetryState(tt.tenantID)
+ applyRetryFailures(state, tt.incrementRetries)
+
+ if tt.resetBeforeAssert {
+ consumer.resetRetryState(tt.tenantID)
+ }
+
+ assert.Equal(t, tt.expectedDegraded, consumer.IsDegraded(tt.tenantID))
+ })
+ }
+}
+
+// TestMultiTenantConsumer_RetryStateIsolation verifies that retry state is
+// isolated between tenants (one tenant's failures don't affect another).
+func TestMultiTenantConsumer_RetryStateIsolation(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ applyRetryFailures(consumer.getRetryState("tenant-a"), 5)
+ _ = consumer.getRetryState("tenant-b")
+
+ assert.True(t, consumer.IsDegraded("tenant-a"))
+ assert.False(t, consumer.IsDegraded("tenant-b"))
+}
+
+// TestMultiTenantConsumer_Stats_Enhanced verifies the enhanced Stats() API
+// returns KnownTenants, PendingTenants, and DegradedTenants.
+func TestMultiTenantConsumer_Stats_Enhanced(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ redisTenantIDs []string
+ startConsumerForIDs []string
+ degradeTenantIDs []string
+ expectedKnown int
+ expectedActive int
+ expectedPending int
+ expectedDegradedCount int
+ }{
+ {name: "all_tenants_pending", redisTenantIDs: []string{"tenant-a", "tenant-b", "tenant-c"}, expectedKnown: 3, expectedActive: 0, expectedPending: 3, expectedDegradedCount: 0},
+ {name: "mix_of_active_and_pending", redisTenantIDs: []string{"tenant-a", "tenant-b", "tenant-c"}, startConsumerForIDs: []string{"tenant-a"}, expectedKnown: 3, expectedActive: 1, expectedPending: 2, expectedDegradedCount: 0},
+ {name: "degraded_tenant_appears_in_stats", redisTenantIDs: []string{"tenant-a", "tenant-b"}, degradeTenantIDs: []string{"tenant-b"}, expectedKnown: 2, expectedActive: 0, expectedPending: 2, expectedDegradedCount: 1},
+ {name: "empty_consumer_returns_zero_stats", expectedKnown: 0, expectedActive: 0, expectedPending: 0, expectedDegradedCount: 0},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+
+ for _, id := range tt.redisTenantIDs {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: testServiceName,
+ }, testutil.NewMockLogger())
+
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ consumer.parentCtx = ctx
+ consumer.discoverTenants(ctx)
+
+ for _, id := range tt.startConsumerForIDs {
+ consumer.mu.Lock()
+ consumer.startTenantConsumer(ctx, id)
+ consumer.mu.Unlock()
+ }
+
+ for _, id := range tt.degradeTenantIDs {
+ applyRetryFailures(consumer.getRetryState(id), maxRetryBeforeDegraded)
+ }
+
+ stats := consumer.Stats()
+
+ assert.Equal(t, tt.expectedKnown, stats.KnownTenants)
+ assert.Equal(t, tt.expectedActive, stats.ActiveTenants)
+ assert.Equal(t, tt.expectedPending, stats.PendingTenants)
+ assert.Equal(t, tt.expectedDegradedCount, len(stats.DegradedTenants))
+
+ consumer.Close()
+ })
+ }
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant_revalidate.go b/commons/tenant-manager/consumer/multi_tenant_revalidate.go
new file mode 100644
index 00000000..f04c665d
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_revalidate.go
@@ -0,0 +1,118 @@
+package consumer
+
+import (
+ "context"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+)
+
+// revalidateConnectionSettings fetches current settings from the Tenant Manager
+// for each active tenant and applies any changed connection pool settings to
+// existing PostgreSQL and MongoDB connections.
+//
+// For PostgreSQL, SetMaxOpenConns/SetMaxIdleConns are thread-safe and take effect
+// immediately for new connections from the pool without recreating the connection.
+// For MongoDB, the driver does not support pool resize after creation, so a warning
+// is logged and changes take effect on the next connection recreation.
+//
+// This method is called after syncTenants in each sync iteration. Errors fetching
+// config for individual tenants are logged and skipped (will retry next cycle).
+// If the Tenant Manager is down, the circuit breaker handles fast-fail.
+func (c *MultiTenantConsumer) revalidateConnectionSettings(ctx context.Context) {
+ if c.postgres == nil && c.mongo == nil {
+ return
+ }
+
+ if c.pmClient == nil || c.config.Service == "" {
+ return
+ }
+
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.revalidate_connection_settings")
+ defer span.End()
+
+ // Snapshot current tenant IDs under lock to avoid holding the lock during HTTP calls.
+ c.mu.RLock()
+
+ tenantIDs := make([]string, 0, len(c.tenants))
+ for tenantID := range c.tenants {
+ tenantIDs = append(tenantIDs, tenantID)
+ }
+
+ c.mu.RUnlock()
+
+ if len(tenantIDs) == 0 {
+ return
+ }
+
+ var revalidated int
+
+ for _, tenantID := range tenantIDs {
+ config, err := c.pmClient.GetTenantConfig(ctx, tenantID, c.config.Service)
+ if err != nil {
+ // If tenant service was suspended/purged, stop consumer and close connections.
+ if core.IsTenantSuspendedError(err) {
+ c.evictSuspendedTenant(ctx, tenantID, logger)
+ continue
+ }
+
+ logger.WarnfCtx(ctx, "failed to fetch config for tenant %s during settings revalidation: %v", tenantID, err)
+
+ continue
+ }
+
+ if c.postgres != nil {
+ c.postgres.ApplyConnectionSettings(tenantID, config)
+ }
+
+ if c.mongo != nil {
+ c.mongo.ApplyConnectionSettings(tenantID, config)
+ }
+
+ revalidated++
+ }
+
+ if revalidated > 0 {
+ logger.InfofCtx(ctx, "revalidated connection settings for %d/%d active tenants", revalidated, len(tenantIDs))
+ }
+}
+
+// evictSuspendedTenant stops the consumer and closes all database connections for a
+// tenant whose service was suspended or purged by the Tenant Manager. The tenant is
+// removed from both tenants and knownTenants maps so it will not be restarted by the
+// sync loop. The next request for this tenant will receive the 403 error directly.
+func (c *MultiTenantConsumer) evictSuspendedTenant(ctx context.Context, tenantID string, logger *logcompat.Logger) {
+ logger.WarnfCtx(ctx, "tenant %s service suspended, stopping consumer and closing connections", tenantID)
+
+ c.mu.Lock()
+
+ if cancel, ok := c.tenants[tenantID]; ok {
+ cancel()
+ delete(c.tenants, tenantID)
+ }
+
+ delete(c.knownTenants, tenantID)
+ c.mu.Unlock()
+
+ if c.postgres != nil {
+ if err := c.postgres.CloseConnection(ctx, tenantID); err != nil {
+ logger.WarnfCtx(ctx, "failed to close PostgreSQL connection for suspended tenant %s: %v", tenantID, err)
+ }
+ }
+
+ if c.mongo != nil {
+ if err := c.mongo.CloseConnection(ctx, tenantID); err != nil {
+ logger.WarnfCtx(ctx, "failed to close MongoDB connection for suspended tenant %s: %v", tenantID, err)
+ }
+ }
+
+ if c.rabbitmq != nil {
+ if err := c.rabbitmq.CloseConnection(ctx, tenantID); err != nil {
+ logger.WarnfCtx(ctx, "failed to close RabbitMQ connection for suspended tenant %s: %v", tenantID, err)
+ }
+ }
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant_stats.go b/commons/tenant-manager/consumer/multi_tenant_stats.go
new file mode 100644
index 00000000..a4fccbe8
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_stats.go
@@ -0,0 +1,59 @@
+package consumer
+
+// Stats returns statistics about the consumer.
+func (c *MultiTenantConsumer) Stats() Stats {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ tenantIDs := make([]string, 0, len(c.tenants))
+ for id := range c.tenants {
+ tenantIDs = append(tenantIDs, id)
+ }
+
+ queueNames := make([]string, 0, len(c.handlers))
+ for name := range c.handlers {
+ queueNames = append(queueNames, name)
+ }
+
+ knownTenantIDs := make([]string, 0, len(c.knownTenants))
+ for id := range c.knownTenants {
+ knownTenantIDs = append(knownTenantIDs, id)
+ }
+
+ // Compute pending tenants (known but not yet active)
+ pendingTenantIDs := make([]string, 0)
+
+ for id := range c.knownTenants {
+ if _, active := c.tenants[id]; !active {
+ pendingTenantIDs = append(pendingTenantIDs, id)
+ }
+ }
+
+ // Collect degraded tenants from retry state
+ degradedTenantIDs := make([]string, 0)
+
+ c.retryState.Range(func(key, value any) bool {
+ tenantID, ok := key.(string)
+ if !ok {
+ return true
+ }
+
+ if entry, ok := value.(*retryStateEntry); ok && entry.isDegraded() {
+ degradedTenantIDs = append(degradedTenantIDs, tenantID)
+ }
+
+ return true
+ })
+
+ return Stats{
+ ActiveTenants: len(c.tenants),
+ TenantIDs: tenantIDs,
+ RegisteredQueues: queueNames,
+ Closed: c.closed,
+ KnownTenants: len(c.knownTenants),
+ KnownTenantIDs: knownTenantIDs,
+ PendingTenants: len(pendingTenantIDs),
+ PendingTenantIDs: pendingTenantIDs,
+ DegradedTenants: degradedTenantIDs,
+ }
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant_sync.go b/commons/tenant-manager/consumer/multi_tenant_sync.go
new file mode 100644
index 00000000..f8a7d403
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_sync.go
@@ -0,0 +1,407 @@
+package consumer
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+)
+
+// absentSyncsBeforeRemoval is the number of consecutive syncs a tenant can be
+// missing from the fetched list before it is removed from knownTenants and
+// any active consumer is stopped. Prevents transient incomplete fetches from
+// purging tenants immediately.
+const absentSyncsBeforeRemoval = 3
+
+// buildActiveTenantsKey returns an environment+service segmented Redis key for active tenants.
+// The key format is always: "tenant-manager:tenants:active:{env}:{service}"
+// The caller is responsible for providing valid env and service values.
+func buildActiveTenantsKey(env, service string) string {
+ return fmt.Sprintf("tenant-manager:tenants:active:%s:%s", env, service)
+}
+
+// eagerStartKnownTenants starts consumers for all known tenants.
+// Called during Run() and when new tenants are discovered during sync.
+func (c *MultiTenantConsumer) eagerStartKnownTenants(ctx context.Context) {
+ c.mu.RLock()
+
+ tenantIDs := make([]string, 0, len(c.knownTenants))
+ for id := range c.knownTenants {
+ tenantIDs = append(tenantIDs, id)
+ }
+
+ c.mu.RUnlock()
+
+ c.logger.InfofCtx(ctx, "eager start: bootstrapping consumers for %d tenants", len(tenantIDs))
+
+ for _, tenantID := range tenantIDs {
+ c.ensureConsumerStarted(ctx, tenantID)
+ }
+}
+
+// discoverTenants fetches tenant IDs and populates knownTenants without starting consumers.
+// This is the initial discovery step at startup. Actual consumer spawning is handled by
+// eagerStartKnownTenants() after discovery completes. Failures are logged as warnings
+// (soft failure) and do not propagate errors to the caller.
+// A short timeout is applied to avoid blocking startup on unresponsive infrastructure.
+func (c *MultiTenantConsumer) discoverTenants(ctx context.Context) {
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ if c.logger != nil {
+ logger = c.logger
+ }
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.discover_tenants")
+ defer span.End()
+
+ // Apply a short timeout to prevent blocking startup when infrastructure is down.
+ // Discovery is best-effort; the background sync loop will retry periodically.
+ discoveryTimeout := c.config.DiscoveryTimeout
+ if discoveryTimeout == 0 {
+ discoveryTimeout = 500 * time.Millisecond
+ }
+
+ discoveryCtx, cancel := context.WithTimeout(ctx, discoveryTimeout)
+ defer cancel()
+
+ tenantIDs, err := c.fetchTenantIDs(discoveryCtx)
+ if err != nil {
+ logger.WarnfCtx(ctx, "tenant discovery failed (soft failure, will retry in background): %v", err)
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "tenant discovery failed (soft failure)", err)
+
+ return
+ }
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ for _, id := range tenantIDs {
+ c.knownTenants[id] = true
+ }
+
+ logger.InfofCtx(ctx, "discovered %d tenants", len(tenantIDs))
+}
+
+// syncActiveTenants periodically syncs the tenant list.
+// Each iteration creates its own span to avoid accumulating events on a long-lived span.
+func (c *MultiTenantConsumer) syncActiveTenants(ctx context.Context) {
+ baseLogger, _, _, _ := libCommons.NewTrackingFromContext(ctx) //nolint:dogsled
+ logger := logcompat.New(baseLogger)
+
+ if c.logger != nil {
+ logger = c.logger
+ }
+
+ ticker := time.NewTicker(c.config.SyncInterval)
+ defer ticker.Stop()
+
+ logger.InfoCtx(ctx, "sync loop started")
+
+ for {
+ select {
+ case <-ticker.C:
+ c.runSyncIteration(ctx)
+ case <-ctx.Done():
+ logger.InfoCtx(ctx, "sync loop stopped: context cancelled")
+ return
+ }
+ }
+}
+
+// runSyncIteration executes a single sync iteration with its own span.
+func (c *MultiTenantConsumer) runSyncIteration(ctx context.Context) {
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ if c.logger != nil {
+ logger = c.logger
+ }
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.sync_iteration")
+ defer span.End()
+
+ if err := c.syncTenants(ctx); err != nil {
+ logger.WarnfCtx(ctx, "tenant sync failed (continuing): %v", err)
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "tenant sync failed (continuing)", err)
+ }
+
+ // Revalidate connection settings for active tenants.
+ // This runs outside syncTenants to avoid holding c.mu during HTTP calls.
+ c.revalidateConnectionSettings(ctx)
+}
+
+// syncTenants fetches tenant IDs and updates the known tenant registry.
+// New tenants are added to knownTenants and consumers are started immediately.
+// Tenants missing from the fetched list are retained in knownTenants for up to
+// absentSyncsBeforeRemoval consecutive syncs; only after that threshold are they
+// removed from knownTenants and any active consumers stopped. This avoids purging
+// tenants on a single transient incomplete fetch.
+// Error handling: if fetchTenantIDs fails, syncTenants returns the error immediately
+// without modifying the current tenant state. The caller (runSyncIteration) logs
+// the failure and continues retrying on the next sync interval.
+func (c *MultiTenantConsumer) syncTenants(ctx context.Context) error {
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ if c.logger != nil {
+ logger = c.logger
+ }
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.sync_tenants")
+ defer span.End()
+
+ // Fetch tenant IDs from Redis cache
+ tenantIDs, err := c.fetchTenantIDs(ctx)
+ if err != nil {
+ logger.ErrorfCtx(ctx, "failed to fetch tenant IDs: %v", err)
+ libOpentelemetry.HandleSpanError(span, "failed to fetch tenant IDs", err)
+
+ return fmt.Errorf("failed to fetch tenant IDs: %w", err)
+ }
+
+ validTenantIDs, currentTenants := c.filterValidTenantIDs(ctx, tenantIDs, logger)
+
+ c.mu.Lock()
+
+ if c.closed {
+ c.mu.Unlock()
+ return errors.New("consumer is closed")
+ }
+
+ previousKnown := c.snapshotKnownTenantsLocked()
+ removedTenants := c.reconcileTenantPresence(previousKnown, currentTenants)
+ newTenants := c.identifyNewTenants(validTenantIDs, previousKnown)
+ c.cancelRemovedTenantConsumers(removedTenants)
+
+ // Capture stats under lock for the final log line.
+ knownCount := len(c.knownTenants)
+ activeCount := len(c.tenants)
+
+ c.mu.Unlock()
+
+ // Close database connections for removed tenants outside the lock (network I/O).
+ c.closeRemovedTenantConnections(ctx, removedTenants, logger)
+
+ if len(newTenants) > 0 {
+ logger.InfofCtx(ctx, "discovered %d new tenants (starting consumers): %v",
+ len(newTenants), newTenants)
+ }
+
+ logger.InfofCtx(ctx, "sync complete: %d known, %d active, %d discovered, %d removed",
+ knownCount, activeCount, len(newTenants), len(removedTenants))
+
+ // Start consumers for newly discovered tenants.
+ // ensureConsumerStarted is called outside the lock (already unlocked above).
+ for _, tenantID := range newTenants {
+ c.ensureConsumerStarted(ctx, tenantID)
+ }
+
+ return nil
+}
+
+// filterValidTenantIDs validates the fetched tenant IDs and returns both the
+// valid ID slice and a set for quick lookup.
+func (c *MultiTenantConsumer) filterValidTenantIDs(
+ ctx context.Context,
+ tenantIDs []string,
+ logger *logcompat.Logger,
+) ([]string, map[string]bool) {
+ validTenantIDs := make([]string, 0, len(tenantIDs))
+
+ for _, id := range tenantIDs {
+ if core.IsValidTenantID(id) {
+ validTenantIDs = append(validTenantIDs, id)
+ } else {
+ logger.WarnfCtx(ctx, "skipping invalid tenant ID: %q", id)
+ }
+ }
+
+ currentTenants := make(map[string]bool, len(validTenantIDs))
+ for _, id := range validTenantIDs {
+ currentTenants[id] = true
+ }
+
+ return validTenantIDs, currentTenants
+}
+
+// snapshotKnownTenantsLocked copies the current known-tenants set.
+// MUST be called with c.mu held.
+func (c *MultiTenantConsumer) snapshotKnownTenantsLocked() map[string]bool {
+ previousKnown := make(map[string]bool, len(c.knownTenants))
+ for id := range c.knownTenants {
+ previousKnown[id] = true
+ }
+
+ return previousKnown
+}
+
+// reconcileTenantPresence updates knownTenants by merging the current fetch with
+// previously known tenants, applying the absence-count threshold. It returns the
+// list of tenant IDs that exceeded the threshold and should be removed.
+// MUST be called with c.mu held.
+func (c *MultiTenantConsumer) reconcileTenantPresence(previousKnown, currentTenants map[string]bool) []string {
+ newKnown := make(map[string]bool, len(currentTenants)+len(previousKnown))
+
+ var removedTenants []string
+
+ for id := range currentTenants {
+ newKnown[id] = true
+ c.tenantAbsenceCount[id] = 0
+ }
+
+ for id := range previousKnown {
+ if currentTenants[id] {
+ continue
+ }
+
+ abs := c.tenantAbsenceCount[id] + 1
+
+ c.tenantAbsenceCount[id] = abs
+ if abs < absentSyncsBeforeRemoval {
+ newKnown[id] = true
+ } else {
+ delete(c.tenantAbsenceCount, id)
+
+ if _, running := c.tenants[id]; running {
+ removedTenants = append(removedTenants, id)
+ }
+ }
+ }
+
+ c.knownTenants = newKnown
+
+ return removedTenants
+}
+
+// identifyNewTenants returns tenant IDs from the valid list that are neither
+// already running nor present in the pre-sync known-tenants snapshot.
+// This prevents logging lazy-known tenants as "new" on every sync iteration
+// while still correctly surfacing tenants first discovered in the current sync.
+// MUST be called with c.mu held.
+func (c *MultiTenantConsumer) identifyNewTenants(validTenantIDs []string, previousKnown map[string]bool) []string {
+ var newTenants []string
+
+ for _, tenantID := range validTenantIDs {
+ if _, running := c.tenants[tenantID]; running {
+ continue
+ }
+
+ // Only report as "new" if not already in the pre-sync known set.
+ // Tenants that are known but not yet active are "pending", not "new".
+ if previousKnown[tenantID] {
+ continue
+ }
+
+ newTenants = append(newTenants, tenantID)
+ }
+
+ return newTenants
+}
+
+// cancelRemovedTenantConsumers cancels goroutines and removes tenants from internal maps.
+// MUST be called with c.mu held.
+func (c *MultiTenantConsumer) cancelRemovedTenantConsumers(removedTenants []string) {
+ for _, tenantID := range removedTenants {
+ if cancel, ok := c.tenants[tenantID]; ok {
+ cancel()
+ delete(c.tenants, tenantID)
+ }
+ }
+}
+
+// closeRemovedTenantConnections closes database and messaging connections for
+// tenants that have been removed from the known tenant registry.
+// This method performs network I/O and MUST be called WITHOUT holding c.mu.
+// The caller is responsible for cancelling goroutines and cleaning internal maps
+// under the lock before invoking this function.
+func (c *MultiTenantConsumer) closeRemovedTenantConnections(ctx context.Context, removedTenants []string, logger *logcompat.Logger) {
+ for _, tenantID := range removedTenants {
+ logger.InfofCtx(ctx, "closing connections for removed tenant: %s", tenantID)
+
+ if c.rabbitmq != nil {
+ if err := c.rabbitmq.CloseConnection(ctx, tenantID); err != nil {
+ logger.WarnfCtx(ctx, "failed to close RabbitMQ connection for tenant %s: %v", tenantID, err)
+ }
+ }
+
+ if c.postgres != nil {
+ if err := c.postgres.CloseConnection(ctx, tenantID); err != nil {
+ logger.WarnfCtx(ctx, "failed to close PostgreSQL connection for tenant %s: %v", tenantID, err)
+ }
+ }
+
+ if c.mongo != nil {
+ if err := c.mongo.CloseConnection(ctx, tenantID); err != nil {
+ logger.WarnfCtx(ctx, "failed to close MongoDB connection for tenant %s: %v", tenantID, err)
+ }
+ }
+ }
+}
+
+// fetchTenantIDs gets tenant IDs from Redis cache, falling back to Tenant Manager API.
+func (c *MultiTenantConsumer) fetchTenantIDs(ctx context.Context) ([]string, error) {
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "consumer.multi_tenant_consumer.fetch_tenant_ids")
+ defer span.End()
+
+ // Build environment+service segmented Redis key
+ cacheKey := buildActiveTenantsKey(c.config.Environment, c.config.Service)
+
+ // Try Redis cache first
+ tenantIDs, err := c.redisClient.SMembers(ctx, cacheKey).Result()
+ if err == nil && len(tenantIDs) > 0 {
+ logger.InfofCtx(ctx, "fetched %d tenant IDs from cache", len(tenantIDs))
+ return tenantIDs, nil
+ }
+
+ if err != nil {
+ logger.WarnfCtx(ctx, "Redis cache fetch failed: %v", err)
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "Redis cache fetch failed", err)
+ }
+
+ // Fallback to Tenant Manager API
+ if c.pmClient != nil && c.config.Service != "" {
+ logger.InfoCtx(ctx, "falling back to Tenant Manager API for tenant list")
+
+ tenants, apiErr := c.pmClient.GetActiveTenantsByService(ctx, c.config.Service)
+ if apiErr != nil {
+ logger.ErrorfCtx(ctx, "Tenant Manager API fallback failed: %v", apiErr)
+ libOpentelemetry.HandleSpanError(span, "Tenant Manager API fallback failed", apiErr)
+ // Return Redis error if API also fails
+ if err != nil {
+ return nil, err
+ }
+
+ return nil, apiErr
+ }
+
+ // Extract IDs from tenant summaries
+ ids := make([]string, 0, len(tenants))
+ for _, t := range tenants {
+ if t == nil {
+ continue
+ }
+
+ ids = append(ids, t.ID)
+ }
+
+ logger.InfofCtx(ctx, "fetched %d tenant IDs from Tenant Manager API", len(ids))
+
+ return ids, nil
+ }
+
+ // No tenants available
+ if err != nil {
+ return nil, err
+ }
+
+ return []string{}, nil
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant_sync_test.go b/commons/tenant-manager/consumer/multi_tenant_sync_test.go
new file mode 100644
index 00000000..765cd440
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_sync_test.go
@@ -0,0 +1,43 @@
+package consumer
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMultiTenantConsumer_SyncTenants_EagerModeStartsNewTenant(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+ mr.SAdd(testActiveTenantsKey, "tenant-a")
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: testServiceName,
+ }, testutil.NewMockLogger())
+ defer func() { require.NoError(t, consumer.Close()) }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ consumer.discoverTenants(ctx)
+ mr.SAdd(testActiveTenantsKey, "tenant-b")
+
+ err := consumer.syncTenants(ctx)
+ require.NoError(t, err)
+
+ assert.Eventually(t, func() bool {
+ consumer.mu.RLock()
+ defer consumer.mu.RUnlock()
+
+ _, active := consumer.tenants["tenant-b"]
+ return consumer.knownTenants["tenant-b"] && active
+ }, time.Second, 10*time.Millisecond)
+}
diff --git a/commons/tenant-manager/consumer/multi_tenant_test.go b/commons/tenant-manager/consumer/multi_tenant_test.go
new file mode 100644
index 00000000..da15421a
--- /dev/null
+++ b/commons/tenant-manager/consumer/multi_tenant_test.go
@@ -0,0 +1,3044 @@
+package consumer
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ tmmongo "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/mongo"
+ tmpostgres "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/postgres"
+ tmrabbitmq "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/rabbitmq"
+ "github.com/alicebob/miniredis/v2"
+ amqp "github.com/rabbitmq/amqp091-go"
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// NewMultiTenantConsumer is a test convenience wrapper around NewMultiTenantConsumerWithError
+// that panics on error. This keeps test code concise while preserving the v4 constructor signature.
+func NewMultiTenantConsumer(
+ rabbitmq *tmrabbitmq.Manager,
+ redisClient redis.UniversalClient,
+ config MultiTenantConfig,
+ logger libLog.Logger,
+ opts ...Option,
+) *MultiTenantConsumer {
+ c, err := NewMultiTenantConsumerWithError(rabbitmq, redisClient, config, logger, opts...)
+ if err != nil {
+ panic(fmt.Sprintf("NewMultiTenantConsumer (test helper): %v", err))
+ }
+
+ return c
+}
+
+// mustNewConsumer is an alternative test helper that takes *testing.T and calls t.Fatal on error.
+func mustNewConsumer(
+ t *testing.T,
+ rabbitmq *tmrabbitmq.Manager,
+ redisClient redis.UniversalClient,
+ config MultiTenantConfig,
+ logger libLog.Logger,
+ opts ...Option,
+) *MultiTenantConsumer {
+ t.Helper()
+
+ c, err := NewMultiTenantConsumerWithError(rabbitmq, redisClient, config, logger, opts...)
+ if err != nil {
+ t.Fatalf("mustNewConsumer: %v", err)
+ }
+
+ return c
+}
+
+// generateTenantIDs creates a slice of N tenant IDs for testing.
+func generateTenantIDs(n int) []string {
+ ids := make([]string, n)
+ for i := range n {
+ ids[i] = fmt.Sprintf("tenant-%04d", i)
+ }
+
+ return ids
+}
+
+// setupMiniredis creates a miniredis instance and returns it with a go-redis client.
+func setupMiniredis(t *testing.T) (*miniredis.Miniredis, redis.UniversalClient) {
+ t.Helper()
+
+ mr, err := miniredis.Run()
+ require.NoError(t, err, "failed to start miniredis")
+
+ redisClient := redis.NewClient(&redis.Options{
+ Addr: mr.Addr(),
+ })
+
+ t.Cleanup(func() {
+ redisClient.Close()
+ mr.Close()
+ })
+
+ return mr, redisClient
+}
+
+// dummyRabbitMQManager returns a minimal non-nil *tmrabbitmq.Manager for tests that
+// do not exercise RabbitMQ connections. Required because NewMultiTenantConsumer
+// validates that rabbitmq is non-nil. A dummy Client is attached so that
+// consumer goroutines spawned by ensureConsumerStarted do not panic on nil
+// dereference; they will receive connection errors instead.
+func dummyRabbitMQManager() *tmrabbitmq.Manager {
+ dummyClient, err := client.NewClient("http://127.0.0.1:0", testutil.NewMockLogger(), client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ if err != nil {
+ panic(fmt.Sprintf("dummyRabbitMQManager: failed to create client: %v", err))
+ }
+
+ return tmrabbitmq.NewManager(dummyClient, "test-service")
+}
+
+// dummyRedisClient returns a miniredis-backed Redis client for tests that need a
+// non-nil redisClient but do not exercise Redis. The caller does not need to
+// close the returned client; it is registered for cleanup via t.Cleanup.
+func dummyRedisClient(t *testing.T) redis.UniversalClient {
+ t.Helper()
+
+ _, redisClient := setupMiniredis(t)
+
+ return redisClient
+}
+
+// setupTenantManagerAPIServer creates an httptest server that returns active tenants.
+func setupTenantManagerAPIServer(t *testing.T, tenants []*client.TenantSummary) *httptest.Server {
+ t.Helper()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(tenants); err != nil {
+ t.Errorf("failed to encode tenant response: %v", err)
+ }
+ }))
+
+ t.Cleanup(func() {
+ server.Close()
+ })
+
+ return server
+}
+
+// makeTenantSummaries generates N TenantSummary entries for testing.
+func makeTenantSummaries(n int) []*client.TenantSummary {
+ tenants := make([]*client.TenantSummary, n)
+ for i := range n {
+ tenants[i] = &client.TenantSummary{
+ ID: fmt.Sprintf("tenant-%04d", i),
+ Name: fmt.Sprintf("Tenant %d", i),
+ Status: "active",
+ }
+ }
+ return tenants
+}
+
+// testServiceName is the service name used by most tests.
+const testServiceName = "test-service"
+
+// testActiveTenantsKey is the Redis key used by tests with Service="test-service" and no Environment.
+// This matches the key that fetchTenantIDs will read from when Environment is empty.
+var testActiveTenantsKey = buildActiveTenantsKey("", testServiceName)
+
+// maxRunDuration is the maximum time Run() is allowed to take.
+// The requirement specifies <1 second. We use 1 second as the hard deadline.
+const maxRunDuration = 1 * time.Second
+
+// TestMultiTenantConsumer_Run_EagerMode validates that Run() completes within 1 second,
+// returns nil error (soft failure), and populates knownTenants.
+func TestMultiTenantConsumer_Run_EagerMode(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ redisTenantIDs []string
+ apiTenants []*client.TenantSummary
+ apiServerDown bool
+ redisDown bool
+ expectedKnownTenantCount int
+ expectError bool
+ expectConsumersStarted bool
+ }{
+ {
+ name: "returns_within_1s_with_0_tenants_configured",
+ redisTenantIDs: []string{},
+ apiTenants: nil,
+ expectedKnownTenantCount: 0,
+ expectError: false,
+ expectConsumersStarted: false,
+ },
+ {
+ name: "returns_within_1s_with_100_tenants_in_Redis_cache",
+ redisTenantIDs: generateTenantIDs(100),
+ apiTenants: nil,
+ expectedKnownTenantCount: 100,
+ expectError: false,
+ expectConsumersStarted: true,
+ },
+ {
+ name: "returns_within_1s_with_500_tenants_from_Tenant_Manager_API",
+ redisTenantIDs: []string{},
+ apiTenants: makeTenantSummaries(500),
+ expectedKnownTenantCount: 500,
+ expectError: false,
+ expectConsumersStarted: true,
+ },
+ {
+ name: "returns_nil_error_when_both_Redis_and_API_are_down",
+ redisTenantIDs: nil,
+ redisDown: true,
+ apiServerDown: true,
+ expectedKnownTenantCount: 0,
+ expectError: false,
+ expectConsumersStarted: false,
+ },
+ {
+ name: "returns_nil_error_when_API_server_is_down",
+ redisTenantIDs: []string{},
+ apiServerDown: true,
+ expectedKnownTenantCount: 0,
+ expectError: false,
+ expectConsumersStarted: false,
+ },
+ // Edge case: single tenant in Redis
+ {
+ name: "returns_within_1s_with_1_tenant_in_Redis_cache",
+ redisTenantIDs: []string{"single-tenant"},
+ apiTenants: nil,
+ expectedKnownTenantCount: 1,
+ expectError: false,
+ expectConsumersStarted: true,
+ },
+ // Edge case: Redis empty but API returns tenants (fallback path)
+ {
+ name: "falls_back_to_API_when_Redis_cache_is_empty",
+ redisTenantIDs: []string{},
+ apiTenants: makeTenantSummaries(3),
+ expectedKnownTenantCount: 3,
+ expectError: false,
+ expectConsumersStarted: true,
+ },
+ // Edge case: Redis down but API is up. Discovery timeout (500ms) may
+ // be consumed by the Redis connection attempt, so API fallback may not
+ // complete in time. In this case, discoverTenants treats it as soft failure
+ // and the background sync loop will retry. We expect 0 tenants known at startup.
+ {
+ name: "returns_nil_error_when_Redis_down_and_API_configured",
+ redisTenantIDs: nil,
+ redisDown: true,
+ apiServerDown: false,
+ apiTenants: makeTenantSummaries(5),
+ expectedKnownTenantCount: 0,
+ expectError: false,
+ expectConsumersStarted: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt // capture loop variable for parallel subtests
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Setup miniredis
+ mr, redisClient := setupMiniredis(t)
+
+ // Populate Redis SET with tenant IDs (if provided and Redis is up)
+ if !tt.redisDown && len(tt.redisTenantIDs) > 0 {
+ for _, id := range tt.redisTenantIDs {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+ }
+
+ // If Redis should be down, close it
+ if tt.redisDown {
+ mr.Close()
+ }
+
+ // Setup Tenant Manager API server and pmClient
+ var apiURL string
+ if !tt.apiServerDown && tt.apiTenants != nil {
+ server := setupTenantManagerAPIServer(t, tt.apiTenants)
+ apiURL = server.URL
+ } else if tt.apiServerDown {
+ apiURL = "http://127.0.0.1:0" // unreachable port
+ }
+
+ // Create consumer config (MultiTenantURL left empty; pmClient set manually below)
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }
+
+ // Create the consumer
+ mockLogger := testutil.NewMockLogger()
+ consumer := NewMultiTenantConsumer(
+ dummyRabbitMQManager(),
+ redisClient,
+ config,
+ mockLogger,
+ )
+
+ // Manually create pmClient for http:// test URLs (bypasses HTTPS enforcement in constructor)
+ if apiURL != "" {
+ pmClient, pmErr := client.NewClient(apiURL, mockLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, pmErr)
+ consumer.pmClient = pmClient
+ consumer.config.Service = "test-service"
+ }
+
+ // Register a handler
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // Measure execution time of Run()
+ start := time.Now()
+ err := consumer.Run(ctx)
+ elapsed := time.Since(start)
+
+ // ASSERTION 1: Run() completes within maxRunDuration
+ assert.Less(t, elapsed, maxRunDuration,
+ "Run() must complete within %s, took %s", maxRunDuration, elapsed)
+
+ // ASSERTION 2: Run() returns nil error (even on discovery failure)
+ if !tt.expectError {
+ assert.NoError(t, err,
+ "Run() must return nil error (soft failure on discovery)")
+ }
+
+ // ASSERTION 3: knownTenants is populated
+ consumer.mu.RLock()
+ knownCount := len(consumer.knownTenants)
+ consumersStarted := len(consumer.tenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedKnownTenantCount, knownCount,
+ "knownTenants should have %d entries after Run(), got %d",
+ tt.expectedKnownTenantCount, knownCount)
+
+ // ASSERTION 4: Consumers started for discovered tenants (eager mode)
+ if tt.expectConsumersStarted {
+ assert.Greater(t, consumersStarted, 0,
+ "consumers should be started eagerly for discovered tenants")
+ } else {
+ assert.Equal(t, 0, consumersStarted,
+ "no consumers should be started when no tenants discovered")
+ }
+
+ // Cleanup
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Run_SignatureUnchanged verifies the Run() method signature
+// matches the expected interface: func (c *MultiTenantConsumer) Run(ctx context.Context) error
+// This is a compile-time assertion. If the signature changes, this test will not compile.
+// Covers: AC-T1
+func TestMultiTenantConsumer_Run_SignatureUnchanged(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ }{
+ {name: "Run_accepts_context_and_returns_error"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Compile-time signature assertion: Run must accept context.Context and return error.
+ // If the signature changes, this assignment will fail to compile.
+ var fn func(ctx context.Context) error
+
+ _, redisClient := setupMiniredis(t)
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ fn = consumer.Run
+ assert.NotNil(t, fn, "Run method must exist and match expected signature")
+ })
+ }
+}
+
+// TestMultiTenantConsumer_DiscoverTenants_ReuseFetchTenantIDs verifies that
+// discoverTenants() delegates to fetchTenantIDs() internally by confirming that
+// tenant IDs sourced from Redis (via fetchTenantIDs) end up in knownTenants.
+// Covers: AC-T2
+func TestMultiTenantConsumer_DiscoverTenants_ReuseFetchTenantIDs(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ redisTenantIDs []string
+ expectedCount int
+ }{
+ {
+ name: "discovers_tenants_from_Redis_via_fetchTenantIDs",
+ redisTenantIDs: []string{"tenant-a", "tenant-b", "tenant-c"},
+ expectedCount: 3,
+ },
+ {
+ name: "discovers_zero_tenants_when_Redis_is_empty",
+ redisTenantIDs: []string{},
+ expectedCount: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+
+ // This test uses no Service or Environment, so the key has empty segments
+ noServiceKey := buildActiveTenantsKey("", "")
+ for _, id := range tt.redisTenantIDs {
+ mr.SAdd(noServiceKey, id)
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ ctx := context.Background()
+
+ // Call discoverTenants which internally uses fetchTenantIDs
+ consumer.discoverTenants(ctx)
+
+ consumer.mu.RLock()
+ knownCount := len(consumer.knownTenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedCount, knownCount,
+ "discoverTenants should populate knownTenants via fetchTenantIDs")
+
+ // Verify each tenant ID is present in knownTenants
+ consumer.mu.RLock()
+ for _, id := range tt.redisTenantIDs {
+ assert.True(t, consumer.knownTenants[id],
+ "tenant %q should be in knownTenants after discovery", id)
+ }
+ consumer.mu.RUnlock()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Run_StartupLog verifies that Run() produces a log message
+// containing "connection_mode=eager" during startup.
+func TestMultiTenantConsumer_Run_StartupLog(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ expectedLogPart string
+ }{
+ {
+ name: "startup_log_contains_connection_mode_eager",
+ expectedLogPart: "connection_mode=eager",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }
+
+ logger := testutil.NewCapturingLogger()
+
+ consumer := NewMultiTenantConsumer(
+ dummyRabbitMQManager(),
+ redisClient,
+ config,
+ logger,
+ )
+
+ // Set the capturing logger in context so NewTrackingFromContext returns it
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ err := consumer.Run(ctx)
+ assert.NoError(t, err, "Run() should return nil")
+
+ // Verify the startup log contains connection_mode=eager
+ assert.True(t, logger.ContainsSubstring(tt.expectedLogPart),
+ "startup log must contain %q, got messages: %v",
+ tt.expectedLogPart, logger.GetMessages())
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Run_BackgroundSyncStarts verifies that syncActiveTenants
+// is started in the background after Run() returns.
+// Covers: AC-T4
+func TestMultiTenantConsumer_Run_BackgroundSyncStarts(t *testing.T) {
+ // Not parallel: relies on timing (time.Sleep) for sync loop detection
+ tests := []struct {
+ name string
+ syncInterval time.Duration
+ tenantToAdd string
+ expectedCount int
+ }{
+ {
+ name: "sync_loop_discovers_tenants_added_after_Run",
+ syncInterval: 100 * time.Millisecond,
+ tenantToAdd: "new-tenant-001",
+ expectedCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ config := MultiTenantConfig{
+ SyncInterval: tt.syncInterval,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }
+
+ consumer := NewMultiTenantConsumer(
+ dummyRabbitMQManager(),
+ redisClient,
+ config,
+ testutil.NewMockLogger(),
+ )
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // Run() should return quickly
+ err := consumer.Run(ctx)
+ require.NoError(t, err, "Run() should succeed")
+
+ // After Run, add tenants to Redis - the sync loop should pick them up
+ mr.SAdd(testActiveTenantsKey, tt.tenantToAdd)
+
+ // Wait for at least one sync cycle to complete
+ time.Sleep(3 * tt.syncInterval)
+
+ // The background sync loop should have discovered the new tenant
+ consumer.mu.RLock()
+ knownCount := len(consumer.knownTenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedCount, knownCount,
+ "background syncActiveTenants should discover tenants added after Run(), found %d", knownCount)
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Run_ReadinessWithinDeadline verifies that the service
+// becomes ready (Run() returns) within 5 seconds across all tenant configurations.
+// Covers: AC-O1
+func TestMultiTenantConsumer_Run_ReadinessWithinDeadline(t *testing.T) {
+ t.Parallel()
+
+ const readinessDeadline = 5 * time.Second
+
+ tests := []struct {
+ name string
+ redisTenantIDs []string
+ apiTenants []*client.TenantSummary
+ }{
+ {
+ name: "ready_within_5s_with_0_tenants",
+ redisTenantIDs: []string{},
+ },
+ {
+ name: "ready_within_5s_with_100_tenants",
+ redisTenantIDs: generateTenantIDs(100),
+ },
+ {
+ name: "ready_within_5s_with_500_tenants_via_API",
+ apiTenants: makeTenantSummaries(500),
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+
+ for _, id := range tt.redisTenantIDs {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ var apiURL string
+ if tt.apiTenants != nil {
+ server := setupTenantManagerAPIServer(t, tt.apiTenants)
+ apiURL = server.URL
+ }
+
+ mockLogger := testutil.NewMockLogger()
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, config, mockLogger)
+
+ if apiURL != "" {
+ pmClient, pmErr := client.NewClient(apiURL, mockLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, pmErr)
+ consumer.pmClient = pmClient
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), readinessDeadline)
+ defer cancel()
+
+ start := time.Now()
+ err := consumer.Run(ctx)
+ elapsed := time.Since(start)
+
+ assert.NoError(t, err, "Run() must not return error")
+ assert.Less(t, elapsed, readinessDeadline,
+ "Run() must complete within readiness deadline (%s), took %s", readinessDeadline, elapsed)
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Run_StartupTimeVariance verifies that startup time variance
+// is <= 1 second across 0/100/500 tenant configurations.
+// Covers: AC-O2
+func TestMultiTenantConsumer_Run_StartupTimeVariance(t *testing.T) {
+ // Not parallel: measures timing across sequential runs
+
+ tests := []struct {
+ name string
+ redisTenantIDs []string
+ apiTenants []*client.TenantSummary
+ }{
+ {name: "0_tenants", redisTenantIDs: []string{}},
+ {name: "100_tenants", redisTenantIDs: generateTenantIDs(100)},
+ {name: "500_tenants_via_API", apiTenants: makeTenantSummaries(500)},
+ }
+
+ var durations []time.Duration
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ for _, id := range tt.redisTenantIDs {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ var apiURL string
+ if tt.apiTenants != nil {
+ server := setupTenantManagerAPIServer(t, tt.apiTenants)
+ apiURL = server.URL
+ }
+
+ mockLogger := testutil.NewMockLogger()
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, config, mockLogger)
+
+ if apiURL != "" {
+ pmClient, pmErr := client.NewClient(apiURL, mockLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, pmErr)
+ consumer.pmClient = pmClient
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ start := time.Now()
+ err := consumer.Run(ctx)
+ elapsed := time.Since(start)
+
+ assert.NoError(t, err, "Run() must not return error")
+ durations = append(durations, elapsed)
+
+ cancel()
+ consumer.Close()
+ })
+ }
+
+ // After all subtests run, verify variance
+ if len(durations) >= 2 {
+ var minDuration, maxDuration time.Duration
+ minDuration = durations[0]
+ maxDuration = durations[0]
+
+ for _, d := range durations[1:] {
+ if d < minDuration {
+ minDuration = d
+ }
+ if d > maxDuration {
+ maxDuration = d
+ }
+ }
+
+ variance := maxDuration - minDuration
+ assert.LessOrEqual(t, variance, 1*time.Second,
+ "startup time variance must be <= 1s, got %s (min=%s, max=%s)",
+ variance, minDuration, maxDuration)
+ }
+}
+
+// TestMultiTenantConsumer_DiscoveryFailure_LogsWarning verifies that when tenant
+// discovery fails, a warning is logged but Run() does not return an error.
+// Covers: AC-O3 (explicit warning log verification)
+func TestMultiTenantConsumer_DiscoveryFailure_LogsWarning(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ redisDown bool
+ apiDown bool
+ expectedLogPart string
+ }{
+ {
+ name: "logs_warning_when_Redis_and_API_both_fail",
+ redisDown: true,
+ apiDown: true,
+ expectedLogPart: "tenant discovery failed",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+
+ if tt.redisDown {
+ mr.Close()
+ }
+
+ var apiURL string
+ if tt.apiDown {
+ apiURL = "http://127.0.0.1:0"
+ }
+
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }
+
+ logger := testutil.NewCapturingLogger()
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, config, logger)
+
+ if apiURL != "" {
+ pmClient, pmErr := client.NewClient(apiURL, logger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, pmErr)
+ consumer.pmClient = pmClient
+ }
+
+ // Set the capturing logger in context so NewTrackingFromContext returns it
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ err := consumer.Run(ctx)
+
+ // Run() must return nil even when discovery fails
+ assert.NoError(t, err, "Run() must return nil on discovery failure (soft failure)")
+
+ // Warning log must contain discovery failure message
+ assert.True(t, logger.ContainsSubstring(tt.expectedLogPart),
+ "discovery failure must log warning containing %q, got: %v",
+ tt.expectedLogPart, logger.GetMessages())
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_DefaultMultiTenantConfig verifies DefaultMultiTenantConfig
+// returns sensible defaults.
+func TestMultiTenantConsumer_DefaultMultiTenantConfig(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ expectedSync time.Duration
+ expectedWorkers int
+ expectedPrefetch int
+ expectedDiscoveryTO time.Duration
+ }{
+ {
+ name: "returns_default_values",
+ expectedSync: 30 * time.Second,
+ expectedWorkers: 0, // WorkersPerQueue is deprecated, default is 0
+ expectedPrefetch: 10,
+ expectedDiscoveryTO: 500 * time.Millisecond,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ config := DefaultMultiTenantConfig()
+
+ assert.Equal(t, tt.expectedSync, config.SyncInterval,
+ "default SyncInterval should be %s", tt.expectedSync)
+ assert.Equal(t, tt.expectedWorkers, config.WorkersPerQueue,
+ "default WorkersPerQueue should be %d", tt.expectedWorkers)
+ assert.Equal(t, tt.expectedPrefetch, config.PrefetchCount,
+ "default PrefetchCount should be %d", tt.expectedPrefetch)
+ assert.Equal(t, tt.expectedDiscoveryTO, config.DiscoveryTimeout,
+ "default DiscoveryTimeout should be %s", tt.expectedDiscoveryTO)
+ assert.Empty(t, config.MultiTenantURL, "default MultiTenantURL should be empty")
+ assert.Empty(t, config.Service, "default Service should be empty")
+ assert.False(t, config.AllowInsecureHTTP, "default AllowInsecureHTTP should be false")
+ })
+ }
+}
+
+// TestMultiTenantConsumer_NewWithZeroConfig verifies that NewMultiTenantConsumer
+// applies defaults when config fields are zero-valued.
+func TestMultiTenantConsumer_NewWithZeroConfig(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ config MultiTenantConfig
+ expectedSync time.Duration
+ expectedWorkers int
+ expectedPrefetch int
+ expectPMClient bool
+ }{
+ {
+ name: "applies_defaults_for_all_zero_fields",
+ config: MultiTenantConfig{},
+ expectedSync: 30 * time.Second,
+ expectedWorkers: 0, // WorkersPerQueue is deprecated, default is 0
+ expectedPrefetch: 10,
+ expectPMClient: false,
+ },
+ {
+ name: "preserves_explicit_values",
+ config: MultiTenantConfig{
+ SyncInterval: 60 * time.Second,
+ WorkersPerQueue: 5,
+ PrefetchCount: 20,
+ },
+ expectedSync: 60 * time.Second,
+ expectedWorkers: 5,
+ expectedPrefetch: 20,
+ expectPMClient: false,
+ },
+ {
+ name: "creates_pmClient_when_URL_configured",
+ config: MultiTenantConfig{
+ MultiTenantURL: "https://tenant-manager:4003",
+ ServiceAPIKey: "test-key",
+ },
+ expectedSync: 30 * time.Second,
+ expectedWorkers: 0, // WorkersPerQueue is deprecated, default is 0
+ expectedPrefetch: 10,
+ expectPMClient: true,
+ },
+ {
+ name: "creates_pmClient_with_http_URL_when_AllowInsecureHTTP_set",
+ config: MultiTenantConfig{
+ MultiTenantURL: "http://tenant-manager.namespace.svc.cluster.local:4003",
+ ServiceAPIKey: "test-key",
+ AllowInsecureHTTP: true,
+ },
+ expectedSync: 30 * time.Second,
+ expectedWorkers: 0,
+ expectedPrefetch: 10,
+ expectPMClient: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, tt.config, testutil.NewMockLogger())
+
+ assert.NotNil(t, consumer, "consumer must not be nil")
+ assert.Equal(t, tt.expectedSync, consumer.config.SyncInterval)
+ assert.Equal(t, tt.expectedWorkers, consumer.config.WorkersPerQueue)
+ assert.Equal(t, tt.expectedPrefetch, consumer.config.PrefetchCount)
+ assert.NotNil(t, consumer.handlers, "handlers map must be initialized")
+ assert.NotNil(t, consumer.tenants, "tenants map must be initialized")
+ assert.NotNil(t, consumer.knownTenants, "knownTenants map must be initialized")
+
+ if tt.expectPMClient {
+ assert.NotNil(t, consumer.pmClient,
+ "pmClient should be created when MultiTenantURL is configured")
+ } else {
+ assert.Nil(t, consumer.pmClient,
+ "pmClient should be nil when MultiTenantURL is empty")
+ }
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Stats verifies the Stats() method returns correct statistics.
+func TestMultiTenantConsumer_Stats(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ registerQueues []string
+ expectClosed bool
+ closeBeforeStat bool
+ }{
+ {
+ name: "returns_stats_with_no_registered_queues",
+ registerQueues: nil,
+ expectClosed: false,
+ closeBeforeStat: false,
+ },
+ {
+ name: "returns_stats_with_registered_queues",
+ registerQueues: []string{"queue-a", "queue-b"},
+ expectClosed: false,
+ closeBeforeStat: false,
+ },
+ {
+ name: "returns_closed_true_after_Close",
+ registerQueues: []string{"queue-a"},
+ expectClosed: true,
+ closeBeforeStat: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ for _, q := range tt.registerQueues {
+ consumer.Register(q, func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+ }
+
+ if tt.closeBeforeStat {
+ consumer.Close()
+ }
+
+ stats := consumer.Stats()
+
+ assert.Equal(t, 0, stats.ActiveTenants,
+ "no tenants should be active (no Run() called)")
+ assert.Equal(t, len(tt.registerQueues), len(stats.RegisteredQueues),
+ "registered queues count should match")
+ assert.Equal(t, tt.expectClosed, stats.Closed, "closed flag mismatch")
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Close verifies the Close() method lifecycle behavior.
+func TestMultiTenantConsumer_Close(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ }{
+ {name: "close_marks_consumer_as_closed_and_clears_maps"},
+ {name: "close_is_idempotent_on_double_call"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ // Pre-populate sync.Map entries to verify they are cleaned on Close
+ consumer.consumerLocks.Store("tenant-x", &sync.Mutex{})
+ consumer.consumerLocks.Store("tenant-y", &sync.Mutex{})
+ consumer.retryState.Store("tenant-x", &retryStateEntry{})
+ consumer.retryState.Store("tenant-y", &retryStateEntry{})
+
+ // First close
+ err := consumer.Close()
+ assert.NoError(t, err, "Close() should not return error")
+
+ consumer.mu.RLock()
+ assert.True(t, consumer.closed, "consumer should be marked as closed")
+ assert.Empty(t, consumer.tenants, "tenants map should be cleared after Close()")
+ assert.Empty(t, consumer.knownTenants, "knownTenants map should be cleared after Close()")
+ consumer.mu.RUnlock()
+
+ // Note: sync.Map entries (consumerLocks, retryState) are NOT cleared by Close().
+ // Close() clears regular maps (tenants, knownTenants, tenantAbsenceCount) only.
+ // sync.Map entries are cleaned lazily during syncTenants / eviction.
+
+ if tt.name == "close_is_idempotent_on_double_call" {
+ // Second close should not panic
+ err2 := consumer.Close()
+ assert.NoError(t, err2, "second Close() should not return error")
+ }
+ })
+ }
+}
+
+// TestMultiTenantConsumer_SyncTenants_RemovesTenants verifies that syncTenants()
+// removes tenants that are no longer in the Redis cache.
+func TestMultiTenantConsumer_SyncTenants_RemovesTenants(t *testing.T) {
+ // Not parallel: relies on internal state manipulation
+
+ tests := []struct {
+ name string
+ initialTenants []string
+ postSyncTenants []string
+ expectedKnownAfterSync int
+ }{
+ {
+ name: "removes_tenants_no_longer_in_cache",
+ initialTenants: []string{"tenant-a", "tenant-b", "tenant-c"},
+ postSyncTenants: []string{"tenant-a"},
+ expectedKnownAfterSync: 1,
+ },
+ {
+ name: "handles_all_tenants_removed",
+ initialTenants: []string{"tenant-a", "tenant-b"},
+ postSyncTenants: []string{},
+ expectedKnownAfterSync: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ // Populate initial tenants
+ for _, id := range tt.initialTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }, testutil.NewMockLogger())
+
+ ctx := context.Background()
+
+ // Initial discovery
+ consumer.discoverTenants(ctx)
+
+ consumer.mu.RLock()
+ initialCount := len(consumer.knownTenants)
+ consumer.mu.RUnlock()
+ assert.Equal(t, len(tt.initialTenants), initialCount,
+ "initial discovery should find all tenants")
+
+ // Pre-populate consumerLocks, retryState, and active consumers for
+ // all initial tenants to verify they are cleaned up when tenants are removed.
+ // Active consumers (c.tenants) are required because stopRemovedTenants only
+ // processes tenants that have a running consumer.
+ consumer.mu.Lock()
+ for _, id := range tt.initialTenants {
+ consumer.consumerLocks.Store(id, &sync.Mutex{})
+ consumer.retryState.Store(id, &retryStateEntry{})
+ _, cancel := context.WithCancel(ctx)
+ consumer.tenants[id] = cancel
+ }
+ consumer.mu.Unlock()
+
+ // Update Redis to reflect post-sync state (remove some tenants)
+ mr.Del(testActiveTenantsKey)
+ for _, id := range tt.postSyncTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ // Run syncTenants absentSyncsBeforeRemoval times so retained tenants
+ // exceed the absence threshold and are actually removed.
+ for i := 0; i < absentSyncsBeforeRemoval; i++ {
+ err := consumer.syncTenants(ctx)
+ assert.NoError(t, err, "syncTenants should not return error")
+ }
+
+ consumer.mu.RLock()
+ afterSyncCount := len(consumer.knownTenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedKnownAfterSync, afterSyncCount,
+ "after %d syncs, knownTenants should reflect updated tenant list", absentSyncsBeforeRemoval)
+
+ // Verify consumerLocks and retryState are cleaned for removed tenants
+ removedSet := make(map[string]bool, len(tt.initialTenants))
+ for _, id := range tt.initialTenants {
+ removedSet[id] = true
+ }
+
+ for _, id := range tt.postSyncTenants {
+ delete(removedSet, id)
+ }
+
+ // Note: sync.Map entries (consumerLocks, retryState) are NOT cleaned by
+ // syncTenants/cancelRemovedTenantConsumers. They are cleaned lazily.
+ // Only regular maps (tenants, knownTenants) are reconciled during sync.
+ _ = removedSet
+ })
+ }
+}
+
+// TestMultiTenantConsumer_SyncTenants_EagerMode verifies that syncTenants() populates
+// knownTenants for new tenants AND starts consumer goroutines eagerly.
+func TestMultiTenantConsumer_SyncTenants_EagerMode(t *testing.T) {
+ tests := []struct {
+ name string
+ initialRedisTenants []string
+ newRedisTenants []string
+ expectedKnownCount int
+ expectConsumers bool
+ }{
+ {
+ name: "new_tenants_added_and_consumers_started",
+ initialRedisTenants: []string{},
+ newRedisTenants: []string{"tenant-a", "tenant-b", "tenant-c"},
+ expectedKnownCount: 3,
+ expectConsumers: true,
+ },
+ {
+ name: "sync_discovers_tenants_and_starts_consumers",
+ initialRedisTenants: []string{},
+ newRedisTenants: generateTenantIDs(10),
+ expectedKnownCount: 10,
+ expectConsumers: true,
+ },
+ {
+ name: "sync_with_zero_tenants_starts_no_consumers",
+ initialRedisTenants: []string{},
+ newRedisTenants: []string{},
+ expectedKnownCount: 0,
+ expectConsumers: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ // Populate initial tenants
+ for _, id := range tt.initialRedisTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }, testutil.NewMockLogger())
+
+ // Register a handler so startTenantConsumer has something to consume
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ consumer.parentCtx = ctx
+
+ // Initial discovery (populates knownTenants only)
+ consumer.discoverTenants(ctx)
+
+ // Update Redis with new tenants
+ mr.Del(testActiveTenantsKey)
+ for _, id := range tt.newRedisTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ // Run syncTenants - should populate knownTenants and start consumers
+ err := consumer.syncTenants(ctx)
+ assert.NoError(t, err, "syncTenants should not return error")
+
+ consumer.mu.RLock()
+ knownCount := len(consumer.knownTenants)
+ consumerCount := len(consumer.tenants)
+ consumer.mu.RUnlock()
+
+ // ASSERTION 1: knownTenants is populated with discovered tenants
+ assert.Equal(t, tt.expectedKnownCount, knownCount,
+ "syncTenants must populate knownTenants (expected %d, got %d)",
+ tt.expectedKnownCount, knownCount)
+
+ // ASSERTION 2: Consumers started for discovered tenants
+ if tt.expectConsumers {
+ assert.Greater(t, consumerCount, 0,
+ "syncTenants should start consumers eagerly for discovered tenants")
+ } else {
+ assert.Equal(t, 0, consumerCount,
+ "no consumers expected when no tenants discovered")
+ }
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_SyncTenants_RemovalCleansKnownTenants verifies that when
+// a tenant is removed from Redis, syncTenants() cleans it from knownTenants and
+// cancels any active consumer for that tenant.
+// Covers: T-005 AC-F3, AC-F4
+func TestMultiTenantConsumer_SyncTenants_RemovalCleansKnownTenants(t *testing.T) {
+ tests := []struct {
+ name string
+ initialTenants []string
+ remainingTenants []string
+ expectedKnownAfterRemoval int
+ }{
+ {
+ name: "removed_tenant_cleaned_from_knownTenants",
+ initialTenants: []string{"tenant-a", "tenant-b", "tenant-c"},
+ remainingTenants: []string{"tenant-a"},
+ expectedKnownAfterRemoval: 1,
+ },
+ {
+ name: "all_tenants_removed_cleans_knownTenants",
+ initialTenants: []string{"tenant-a", "tenant-b"},
+ remainingTenants: []string{},
+ expectedKnownAfterRemoval: 0,
+ },
+ {
+ name: "no_tenants_removed_keeps_all_in_knownTenants",
+ initialTenants: []string{"tenant-a", "tenant-b"},
+ remainingTenants: []string{"tenant-a", "tenant-b"},
+ expectedKnownAfterRemoval: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ // Populate initial tenants
+ for _, id := range tt.initialTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }, testutil.NewMockLogger())
+
+ ctx := context.Background()
+
+ // First sync to populate initial state
+ err := consumer.syncTenants(ctx)
+ require.NoError(t, err, "initial syncTenants should succeed")
+
+ // Verify initial knownTenants count
+ consumer.mu.RLock()
+ initialKnown := len(consumer.knownTenants)
+ consumer.mu.RUnlock()
+ assert.Equal(t, len(tt.initialTenants), initialKnown,
+ "initial sync should discover all tenants")
+
+ // Remove tenants from Redis
+ mr.Del(testActiveTenantsKey)
+ for _, id := range tt.remainingTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ // Run sync absentSyncsBeforeRemoval times so retained tenants exceed
+ // the absence threshold and are cleaned from knownTenants.
+ for i := 0; i < absentSyncsBeforeRemoval; i++ {
+ err = consumer.syncTenants(ctx)
+ require.NoError(t, err, "syncTenants should succeed")
+ }
+
+ consumer.mu.RLock()
+ afterRemovalKnown := len(consumer.knownTenants)
+ // Verify removed tenants are NOT in knownTenants
+ for _, id := range tt.initialTenants {
+ isRemaining := false
+ for _, remaining := range tt.remainingTenants {
+ if id == remaining {
+ isRemaining = true
+ break
+ }
+ }
+ if !isRemaining {
+ assert.False(t, consumer.knownTenants[id],
+ "removed tenant %q must be cleaned from knownTenants after %d absences", id, absentSyncsBeforeRemoval)
+ }
+ }
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedKnownAfterRemoval, afterRemovalKnown,
+ "after %d absences, knownTenants should have %d entries, got %d",
+ absentSyncsBeforeRemoval, tt.expectedKnownAfterRemoval, afterRemovalKnown)
+ })
+ }
+}
+
+// TestMultiTenantConsumer_SyncTenants_SyncLoopContinuesOnError verifies that the
+// sync loop continues operating when individual sync iterations fail.
+// Covers: T-005 AC-O3
+func TestMultiTenantConsumer_SyncTenants_SyncLoopContinuesOnError(t *testing.T) {
+ tests := []struct {
+ name string
+ breakRedisOnFirst bool
+ restoreBefore int // restore Redis before this sync iteration
+ }{
+ {
+ name: "continues_after_transient_error",
+ breakRedisOnFirst: true,
+ restoreBefore: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ // Populate tenants
+ mr.SAdd(testActiveTenantsKey, "tenant-001")
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 100 * time.Millisecond,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }, testutil.NewMockLogger())
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ // First sync succeeds
+ err := consumer.syncTenants(ctx)
+ assert.NoError(t, err, "first syncTenants should succeed")
+
+ // Break Redis
+ mr.Close()
+
+ // Second sync should fail but not crash
+ err = consumer.syncTenants(ctx)
+ assert.Error(t, err, "syncTenants should return error when Redis is down")
+
+ // Verify consumer still functional (not panicked)
+ consumer.mu.RLock()
+ assert.False(t, consumer.closed, "consumer should not be closed after sync error")
+ consumer.mu.RUnlock()
+
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_SyncTenants_ClosedConsumer verifies that syncTenants
+// returns an error when the consumer is already closed.
+func TestMultiTenantConsumer_SyncTenants_ClosedConsumer(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ errContains string
+ }{
+ {
+ name: "returns_error_when_consumer_is_closed",
+ errContains: "consumer is closed",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+ mr.SAdd(testActiveTenantsKey, "tenant-001")
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }, testutil.NewMockLogger())
+
+ // Close consumer first
+ consumer.Close()
+
+ // syncTenants should detect closed state
+ err := consumer.syncTenants(context.Background())
+ require.Error(t, err, "syncTenants must return error for closed consumer")
+ assert.Contains(t, err.Error(), tt.errContains,
+ "error message should indicate consumer is closed")
+ })
+ }
+}
+
+// TestMultiTenantConsumer_FetchTenantIDs verifies fetchTenantIDs behavior in isolation.
+func TestMultiTenantConsumer_FetchTenantIDs(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ redisTenantIDs []string
+ apiTenants []*client.TenantSummary
+ redisDown bool
+ apiDown bool
+ expectError bool
+ expectedCount int
+ errContains string
+ }{
+ {
+ name: "returns_tenants_from_Redis_cache",
+ redisTenantIDs: []string{"t1", "t2", "t3"},
+ expectedCount: 3,
+ },
+ {
+ name: "returns_empty_list_when_no_tenants",
+ redisTenantIDs: []string{},
+ expectedCount: 0,
+ },
+ {
+ name: "falls_back_to_API_when_Redis_is_empty",
+ apiTenants: makeTenantSummaries(2),
+ expectedCount: 2,
+ },
+ {
+ name: "returns_error_when_both_Redis_and_API_fail",
+ redisDown: true,
+ apiDown: true,
+ expectError: true,
+ },
+ {
+ name: "returns_tenants_from_API_when_Redis_fails",
+ redisDown: true,
+ apiTenants: makeTenantSummaries(4),
+ expectedCount: 4,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+
+ if !tt.redisDown {
+ for _, id := range tt.redisTenantIDs {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+ } else {
+ mr.Close()
+ }
+
+ var apiURL string
+ if tt.apiTenants != nil && !tt.apiDown {
+ server := setupTenantManagerAPIServer(t, tt.apiTenants)
+ apiURL = server.URL
+ } else if tt.apiDown {
+ apiURL = "http://127.0.0.1:0"
+ }
+
+ mockLogger := testutil.NewMockLogger()
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, config, mockLogger)
+
+ if apiURL != "" {
+ pmClient, pmErr := client.NewClient(apiURL, mockLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, pmErr)
+ consumer.pmClient = pmClient
+ }
+
+ ids, err := consumer.fetchTenantIDs(context.Background())
+
+ if tt.expectError {
+ assert.Error(t, err, "fetchTenantIDs should return error")
+ if tt.errContains != "" {
+ assert.Contains(t, err.Error(), tt.errContains)
+ }
+ } else {
+ assert.NoError(t, err, "fetchTenantIDs should not return error")
+ assert.Len(t, ids, tt.expectedCount,
+ "expected %d tenant IDs, got %d", tt.expectedCount, len(ids))
+ }
+ })
+ }
+}
+
+// TestMultiTenantConsumer_Register verifies handler registration.
+func TestMultiTenantConsumer_Register(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ queueNames []string
+ expectedCount int
+ }{
+ {
+ name: "registers_single_queue_handler",
+ queueNames: []string{"queue-a"},
+ expectedCount: 1,
+ },
+ {
+ name: "registers_multiple_queue_handlers",
+ queueNames: []string{"queue-a", "queue-b", "queue-c"},
+ expectedCount: 3,
+ },
+ {
+ name: "overwrites_handler_for_same_queue",
+ queueNames: []string{"queue-a", "queue-a"},
+ expectedCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ for _, q := range tt.queueNames {
+ consumer.Register(q, func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+ }
+
+ consumer.mu.RLock()
+ handlerCount := len(consumer.handlers)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedCount, handlerCount,
+ "expected %d registered handlers, got %d", tt.expectedCount, handlerCount)
+ })
+ }
+}
+
+// TestMultiTenantConsumer_NilLogger verifies that NewMultiTenantConsumer does not panic
+// when a nil logger is provided and defaults to NoneLogger.
+func TestMultiTenantConsumer_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ }{
+ {name: "nil_logger_does_not_panic_on_creation"},
+ {name: "nil_logger_consumer_can_register_handler"},
+ {name: "nil_logger_consumer_can_close"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ assert.NotPanics(t, func() {
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, nil) // nil logger
+
+ assert.NotNil(t, consumer, "consumer must not be nil even with nil logger")
+
+ if tt.name == "nil_logger_consumer_can_register_handler" {
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+ }
+
+ if tt.name == "nil_logger_consumer_can_close" {
+ err := consumer.Close()
+ assert.NoError(t, err, "Close() should not panic with nil-guarded logger")
+ }
+ })
+ })
+ }
+}
+
+// TestIsValidTenantID verifies tenant ID validation logic.
+func TestIsValidTenantID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ expected bool
+ }{
+ {name: "valid_alphanumeric", tenantID: "tenant123", expected: true},
+ {name: "valid_with_hyphens", tenantID: "tenant-123-abc", expected: true},
+ {name: "valid_with_underscores", tenantID: "tenant_123_abc", expected: true},
+ {name: "valid_uuid_format", tenantID: "550e8400-e29b-41d4-a716-446655440000", expected: true},
+ {name: "valid_single_char", tenantID: "t", expected: true},
+ {name: "invalid_empty", tenantID: "", expected: false},
+ {name: "invalid_starts_with_hyphen", tenantID: "-tenant", expected: false},
+ {name: "invalid_starts_with_underscore", tenantID: "_tenant", expected: false},
+ {name: "invalid_contains_slash", tenantID: "tenant/../../etc", expected: false},
+ {name: "invalid_contains_space", tenantID: "tenant 123", expected: false},
+ {name: "invalid_contains_dots", tenantID: "tenant.123", expected: false},
+ {name: "invalid_contains_special_chars", tenantID: "tenant@123!", expected: false},
+ {name: "invalid_exceeds_max_length", tenantID: string(make([]byte, 257)), expected: false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := core.IsValidTenantID(tt.tenantID)
+ assert.Equal(t, tt.expected, result,
+ "IsValidTenantID(%q) = %v, want %v", tt.tenantID, result, tt.expected)
+ })
+ }
+}
+
+// TestMultiTenantConsumer_SyncTenants_FiltersInvalidIDs verifies that syncTenants
+// skips tenant IDs that fail validation.
+func TestMultiTenantConsumer_SyncTenants_FiltersInvalidIDs(t *testing.T) {
+ tests := []struct {
+ name string
+ redisTenantIDs []string
+ expectedKnownIDs int
+ }{
+ {
+ name: "filters_out_path_traversal_attempts",
+ redisTenantIDs: []string{"valid-tenant", "../../etc/passwd", "also-valid"},
+ expectedKnownIDs: 2,
+ },
+ {
+ name: "filters_out_empty_strings",
+ redisTenantIDs: []string{"valid-tenant", "", "another-valid"},
+ expectedKnownIDs: 2,
+ },
+ {
+ name: "all_valid_tenants_pass",
+ redisTenantIDs: []string{"tenant-a", "tenant-b", "tenant-c"},
+ expectedKnownIDs: 3,
+ },
+ {
+ name: "all_invalid_tenants_filtered",
+ redisTenantIDs: []string{"../etc", "tenant with spaces", ""},
+ expectedKnownIDs: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ for _, id := range tt.redisTenantIDs {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }, testutil.NewMockLogger())
+
+ ctx := context.Background()
+ err := consumer.syncTenants(ctx)
+ assert.NoError(t, err, "syncTenants should not return error")
+
+ consumer.mu.RLock()
+ knownCount := len(consumer.knownTenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedKnownIDs, knownCount,
+ "expected %d known tenants after filtering, got %d", tt.expectedKnownIDs, knownCount)
+ })
+ }
+}
+
+// ---------------------
+// T-002: On-Demand Consumer Spawning Tests
+// ---------------------
+
+// TestMultiTenantConsumer_EnsureConsumerStarted_SpawnsExactlyOnce verifies that
+// concurrent calls to ensureConsumerStarted for the same tenant spawn exactly one consumer.
+// Covers: T-002 exactly-once guarantee under concurrency
+func TestMultiTenantConsumer_EnsureConsumerStarted_SpawnsExactlyOnce(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ concurrentCalls int
+ expectedConsumer int
+ }{
+ {
+ name: "10_concurrent_calls_spawn_exactly_1_consumer",
+ tenantID: "tenant-001",
+ concurrentCalls: 10,
+ expectedConsumer: 1,
+ },
+ {
+ name: "50_concurrent_calls_spawn_exactly_1_consumer",
+ tenantID: "tenant-002",
+ concurrentCalls: 50,
+ expectedConsumer: 1,
+ },
+ {
+ name: "100_concurrent_calls_spawn_exactly_1_consumer",
+ tenantID: "tenant-003",
+ concurrentCalls: 100,
+ expectedConsumer: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ // Register a handler so startTenantConsumer has something to work with
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // Store parentCtx (normally done by Run())
+ consumer.parentCtx = ctx
+
+ // Add tenant to knownTenants (normally done by discoverTenants)
+ consumer.mu.Lock()
+ consumer.knownTenants[tt.tenantID] = true
+ consumer.mu.Unlock()
+
+ // Launch N concurrent calls to ensureConsumerStarted
+ var wg sync.WaitGroup
+ wg.Add(tt.concurrentCalls)
+
+ for i := 0; i < tt.concurrentCalls; i++ {
+ go func() {
+ defer wg.Done()
+ consumer.ensureConsumerStarted(ctx, tt.tenantID)
+ }()
+ }
+
+ wg.Wait()
+
+ // Verify exactly one consumer was spawned
+ consumer.mu.RLock()
+ consumerCount := len(consumer.tenants)
+ _, hasCancel := consumer.tenants[tt.tenantID]
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, tt.expectedConsumer, consumerCount,
+ "expected exactly %d consumer, got %d", tt.expectedConsumer, consumerCount)
+ assert.True(t, hasCancel,
+ "tenant %q should have an active cancel func in tenants map", tt.tenantID)
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_EnsureConsumerStarted_NoopWhenActive verifies that
+// ensureConsumerStarted is a no-op when the consumer is already running.
+func TestMultiTenantConsumer_EnsureConsumerStarted_NoopWhenActive(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ }{
+ {
+ name: "noop_when_consumer_already_active",
+ tenantID: "tenant-active",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ consumer.parentCtx = ctx
+
+ // Add tenant to knownTenants (normally done by discoverTenants)
+ consumer.mu.Lock()
+ consumer.knownTenants[tt.tenantID] = true
+ consumer.mu.Unlock()
+
+ // First call spawns the consumer
+ consumer.ensureConsumerStarted(ctx, tt.tenantID)
+
+ consumer.mu.RLock()
+ countAfterFirst := len(consumer.tenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, 1, countAfterFirst, "first call should spawn 1 consumer")
+
+ // Second call should be a no-op
+ consumer.ensureConsumerStarted(ctx, tt.tenantID)
+
+ consumer.mu.RLock()
+ countAfterSecond := len(consumer.tenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, 1, countAfterSecond,
+ "second call should NOT spawn another consumer, count should remain 1")
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// TestMultiTenantConsumer_EnsureConsumerStarted_SkipsWhenClosed verifies that
+// ensureConsumerStarted is a no-op when the consumer has been closed.
+func TestMultiTenantConsumer_EnsureConsumerStarted_SkipsWhenClosed(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ }{
+ {
+ name: "noop_when_consumer_is_closed",
+ tenantID: "tenant-closed",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+
+ ctx := context.Background()
+ consumer.parentCtx = ctx
+
+ // Close before calling ensureConsumerStarted
+ consumer.Close()
+
+ // Should be a no-op
+ consumer.ensureConsumerStarted(ctx, tt.tenantID)
+
+ consumer.mu.RLock()
+ consumerCount := len(consumer.tenants)
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, 0, consumerCount,
+ "no consumer should be spawned after Close()")
+ })
+ }
+}
+
+// TestMultiTenantConsumer_EnsureConsumerStarted_MultipleTenants verifies that
+// ensureConsumerStarted can spawn consumers for different tenants concurrently.
+func TestMultiTenantConsumer_EnsureConsumerStarted_MultipleTenants(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantIDs []string
+ }{
+ {
+ name: "spawns_independent_consumers_for_3_tenants",
+ tenantIDs: []string{"tenant-a", "tenant-b", "tenant-c"},
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger())
+
+ consumer.Register("test-queue", func(ctx context.Context, delivery amqp.Delivery) error {
+ return nil
+ })
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ consumer.parentCtx = ctx
+
+ // Add tenants to knownTenants (normally done by discoverTenants)
+ consumer.mu.Lock()
+ for _, id := range tt.tenantIDs {
+ consumer.knownTenants[id] = true
+ }
+ consumer.mu.Unlock()
+
+ // Spawn consumers for all tenants concurrently
+ var wg sync.WaitGroup
+ wg.Add(len(tt.tenantIDs))
+
+ for _, id := range tt.tenantIDs {
+ go func(tenantID string) {
+ defer wg.Done()
+ consumer.ensureConsumerStarted(ctx, tenantID)
+ }(id)
+ }
+
+ wg.Wait()
+
+ consumer.mu.RLock()
+ consumerCount := len(consumer.tenants)
+ for _, id := range tt.tenantIDs {
+ _, exists := consumer.tenants[id]
+ assert.True(t, exists, "consumer for tenant %q should be active", id)
+ }
+ consumer.mu.RUnlock()
+
+ assert.Equal(t, len(tt.tenantIDs), consumerCount,
+ "expected %d consumers, got %d", len(tt.tenantIDs), consumerCount)
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// ---------------------
+// T-004: Connection Failure Resilience Tests
+// ---------------------
+
+// TestBackoffDelay verifies the exponential backoff delay calculation.
+// Expected base sequence: 5s, 10s, 20s, 40s, 40s (capped), with ±25% jitter applied.
+func TestBackoffDelay(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ retryCount int
+ baseDelay time.Duration
+ }{
+ {name: "retry_0_base_5s", retryCount: 0, baseDelay: 5 * time.Second},
+ {name: "retry_1_base_10s", retryCount: 1, baseDelay: 10 * time.Second},
+ {name: "retry_2_base_20s", retryCount: 2, baseDelay: 20 * time.Second},
+ {name: "retry_3_base_40s", retryCount: 3, baseDelay: 40 * time.Second},
+ {name: "retry_4_capped_at_40s", retryCount: 4, baseDelay: 40 * time.Second},
+ {name: "retry_10_capped_at_40s", retryCount: 10, baseDelay: 40 * time.Second},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ delay := backoffDelay(tt.retryCount)
+ // backoffDelay applies ±25% jitter: delay ∈ [0.75*base, 1.25*base)
+ minDelay := time.Duration(float64(tt.baseDelay) * 0.75)
+ maxDelay := time.Duration(float64(tt.baseDelay) * 1.25)
+ assert.GreaterOrEqual(t, delay, minDelay,
+ "backoffDelay(%d) = %s, want >= %s (0.75 * %s)", tt.retryCount, delay, minDelay, tt.baseDelay)
+ assert.Less(t, delay, maxDelay,
+ "backoffDelay(%d) = %s, want < %s (1.25 * %s)", tt.retryCount, delay, maxDelay, tt.baseDelay)
+ })
+ }
+}
+
+// TestMultiTenantConsumer_MetricConstants verifies that metric name constants are defined.
+func TestMultiTenantConsumer_MetricConstants(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ constant string
+ expected string
+ }{
+ {
+ name: "tenant_connections_total",
+ constant: MetricTenantConnectionsTotal,
+ expected: "tenant_connections_total",
+ },
+ {
+ name: "tenant_connection_errors_total",
+ constant: MetricTenantConnectionErrors,
+ expected: "tenant_connection_errors_total",
+ },
+ {
+ name: "tenant_consumers_active",
+ constant: MetricTenantConsumersActive,
+ expected: "tenant_consumers_active",
+ },
+ {
+ name: "tenant_messages_processed_total",
+ constant: MetricTenantMessageProcessed,
+ expected: "tenant_messages_processed_total",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tt.expected, tt.constant,
+ "metric constant %q should equal %q", tt.constant, tt.expected)
+ })
+ }
+}
+
+// TestMultiTenantConsumer_StructuredLogEvents verifies that key operations
+// produce structured log messages with tenant_id context.
+func TestMultiTenantConsumer_StructuredLogEvents(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ operation string
+ expectedLogPart string
+ }{
+ {
+ name: "run_logs_connection_mode",
+ operation: "run",
+ expectedLogPart: "connection_mode=eager",
+ },
+ {
+ name: "discover_logs_tenant_count",
+ operation: "discover",
+ expectedLogPart: "discovered",
+ },
+ {
+ name: "ensure_consumer_logs_on_demand",
+ operation: "ensure",
+ expectedLogPart: "on-demand consumer start",
+ },
+ {
+ name: "sync_logs_summary",
+ operation: "sync",
+ expectedLogPart: "sync complete",
+ },
+ {
+ name: "register_logs_queue",
+ operation: "register",
+ expectedLogPart: "registered handler for queue",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+ mr.SAdd(testActiveTenantsKey, "tenant-log-test")
+
+ logger := testutil.NewCapturingLogger()
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: "test-service",
+ }, logger)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ consumer.parentCtx = ctx
+
+ switch tt.operation {
+ case "run":
+ consumer.Run(ctx)
+ case "discover":
+ consumer.discoverTenants(ctx)
+ case "ensure":
+ consumer.Register("test-queue", func(ctx context.Context, d amqp.Delivery) error {
+ return nil
+ })
+ // Add tenant to knownTenants so ensureConsumerStarted doesn't reject it
+ consumer.mu.Lock()
+ consumer.knownTenants["tenant-log-test"] = true
+ consumer.mu.Unlock()
+ consumer.ensureConsumerStarted(ctx, "tenant-log-test")
+ case "sync":
+ consumer.syncTenants(ctx)
+ case "register":
+ consumer.Register("test-queue", func(ctx context.Context, d amqp.Delivery) error {
+ return nil
+ })
+ }
+
+ assert.True(t, logger.ContainsSubstring(tt.expectedLogPart),
+ "operation %q should produce log containing %q, got: %v",
+ tt.operation, tt.expectedLogPart, logger.GetMessages())
+
+ cancel()
+ consumer.Close()
+ })
+ }
+}
+
+// BenchmarkMultiTenantConsumer_Run_Startup measures startup time of Run().
+// Target: <1 second for all tenant configurations.
+// Covers: AC-Q2
+func BenchmarkMultiTenantConsumer_Run_Startup(b *testing.B) {
+ benchmarks := []struct {
+ name string
+ tenantCount int
+ useRedis bool
+ }{
+ {name: "0_tenants", tenantCount: 0, useRedis: true},
+ {name: "100_tenants_Redis", tenantCount: 100, useRedis: true},
+ {name: "500_tenants_Redis", tenantCount: 500, useRedis: true},
+ }
+
+ for _, bm := range benchmarks {
+ b.Run(bm.name, func(b *testing.B) {
+ mr, err := miniredis.Run()
+ require.NoError(b, err)
+ defer mr.Close()
+
+ redisClient := redis.NewClient(&redis.Options{
+ Addr: mr.Addr(),
+ })
+ defer redisClient.Close()
+
+ benchService := "bench-service"
+ benchKey := buildActiveTenantsKey("", benchService)
+
+ if bm.useRedis && bm.tenantCount > 0 {
+ ids := generateTenantIDs(bm.tenantCount)
+ for _, id := range ids {
+ mr.SAdd(benchKey, id)
+ }
+ }
+
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: benchService,
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, config, testutil.NewMockLogger())
+
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ err := consumer.Run(ctx)
+ if err != nil {
+ b.Fatalf("Run() returned error: %v", err)
+ }
+ cancel()
+ consumer.Close()
+ }
+ })
+ }
+}
+
+// ---------------------
+// Environment-Aware Cache Key Tests
+// ---------------------
+
+// TestBuildActiveTenantsKey verifies environment+service segmented Redis key construction.
+func TestBuildActiveTenantsKey(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ env string
+ service string
+ expected string
+ }{
+ {
+ name: "env_and_service_produces_segmented_key",
+ env: "staging",
+ service: "ledger",
+ expected: "tenant-manager:tenants:active:staging:ledger",
+ },
+ {
+ name: "production_env_with_service",
+ env: "production",
+ service: "transaction",
+ expected: "tenant-manager:tenants:active:production:transaction",
+ },
+ {
+ name: "only_service_produces_key_with_empty_env",
+ env: "",
+ service: "ledger",
+ expected: "tenant-manager:tenants:active::ledger",
+ },
+ {
+ name: "neither_env_nor_service_produces_key_with_empty_segments",
+ env: "",
+ service: "",
+ expected: "tenant-manager:tenants:active::",
+ },
+ {
+ name: "env_without_service_produces_key_with_empty_service",
+ env: "staging",
+ service: "",
+ expected: "tenant-manager:tenants:active:staging:",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result := buildActiveTenantsKey(tt.env, tt.service)
+ assert.Equal(t, tt.expected, result,
+ "buildActiveTenantsKey(%q, %q) = %q, want %q",
+ tt.env, tt.service, result, tt.expected)
+ })
+ }
+}
+
+// TestMultiTenantConsumer_FetchTenantIDs_EnvironmentAwareKey verifies that
+// fetchTenantIDs reads from the environment+service segmented Redis key.
+func TestMultiTenantConsumer_FetchTenantIDs_EnvironmentAwareKey(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ env string
+ service string
+ redisKey string
+ redisTenants []string
+ expectedCount int
+ }{
+ {
+ name: "reads_from_env_service_segmented_key",
+ env: "staging",
+ service: "ledger",
+ redisKey: "tenant-manager:tenants:active:staging:ledger",
+ redisTenants: []string{"tenant-a", "tenant-b"},
+ expectedCount: 2,
+ },
+ {
+ name: "reads_from_key_with_empty_env",
+ env: "",
+ service: "transaction",
+ redisKey: "tenant-manager:tenants:active::transaction",
+ redisTenants: []string{"tenant-x"},
+ expectedCount: 1,
+ },
+ {
+ name: "reads_from_key_with_empty_env_and_service",
+ env: "",
+ service: "",
+ redisKey: "tenant-manager:tenants:active::",
+ redisTenants: []string{"tenant-1", "tenant-2", "tenant-3"},
+ expectedCount: 3,
+ },
+ {
+ name: "does_not_read_from_wrong_key",
+ env: "staging",
+ service: "ledger",
+ redisKey: "tenant-manager:tenants:active::", // Wrong key - empty segments instead of segmented
+ redisTenants: []string{"tenant-a"},
+ expectedCount: 0, // Should NOT find tenants
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ mr, redisClient := setupMiniredis(t)
+
+ // Write tenants to the specified Redis key
+ for _, id := range tt.redisTenants {
+ mr.SAdd(tt.redisKey, id)
+ }
+
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Environment: tt.env,
+ Service: tt.service,
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, config, testutil.NewMockLogger())
+
+ ids, err := consumer.fetchTenantIDs(context.Background())
+ assert.NoError(t, err, "fetchTenantIDs should not return error")
+ assert.Len(t, ids, tt.expectedCount,
+ "expected %d tenant IDs from key %q, got %d",
+ tt.expectedCount, tt.redisKey, len(ids))
+ })
+ }
+}
+
+// ---------------------
+// Consumer Option Tests
+// ---------------------
+
+// TestMultiTenantConsumer_WithOptions verifies that option functions configure the consumer correctly.
+func TestMultiTenantConsumer_WithOptions(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ withPostgres bool
+ withMongo bool
+ expectPostgres bool
+ expectMongo bool
+ }{
+ {
+ name: "no_options_leaves_managers_nil",
+ withPostgres: false,
+ withMongo: false,
+ expectPostgres: false,
+ expectMongo: false,
+ },
+ {
+ name: "with_postgres_manager",
+ withPostgres: true,
+ withMongo: false,
+ expectPostgres: true,
+ expectMongo: false,
+ },
+ {
+ name: "with_mongo_manager",
+ withPostgres: false,
+ withMongo: true,
+ expectPostgres: false,
+ expectMongo: true,
+ },
+ {
+ name: "with_both_managers",
+ withPostgres: true,
+ withMongo: true,
+ expectPostgres: true,
+ expectMongo: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ _, redisClient := setupMiniredis(t)
+
+ var opts []Option
+
+ if tt.withPostgres {
+ pgManager := tmpostgres.NewManager(nil, "test-service")
+ opts = append(opts, WithPostgresManager(pgManager))
+ }
+
+ if tt.withMongo {
+ mongoManager := tmmongo.NewManager(nil, "test-service")
+ opts = append(opts, WithMongoManager(mongoManager))
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ }, testutil.NewMockLogger(), opts...)
+
+ if tt.expectPostgres {
+ assert.NotNil(t, consumer.postgres, "postgres manager should be set")
+ } else {
+ assert.Nil(t, consumer.postgres, "postgres manager should be nil")
+ }
+
+ if tt.expectMongo {
+ assert.NotNil(t, consumer.mongo, "mongo manager should be set")
+ } else {
+ assert.Nil(t, consumer.mongo, "mongo manager should be nil")
+ }
+ })
+ }
+}
+
+// TestMultiTenantConsumer_DefaultMultiTenantConfig_IncludesEnvironment verifies that
+// DefaultMultiTenantConfig returns an empty Environment field.
+func TestMultiTenantConsumer_DefaultMultiTenantConfig_IncludesEnvironment(t *testing.T) {
+ t.Parallel()
+
+ config := DefaultMultiTenantConfig()
+ assert.Empty(t, config.Environment, "default Environment should be empty")
+}
+
+// ---------------------
+// Connection Cleanup on Tenant Removal Tests
+// ---------------------
+
+// TestMultiTenantConsumer_SyncTenants_ClosesConnectionsOnRemoval verifies that
+// when a tenant is removed during sync, its database connections are closed.
+// Note: Uses NewManager constructors from sub-packages since we cannot access
+// unexported fields (connections map) from the consumer package. CloseConnection
+// returns nil for unknown tenants, so the test verifies log messages instead.
+func TestMultiTenantConsumer_SyncTenants_ClosesConnectionsOnRemoval(t *testing.T) {
+ tests := []struct {
+ name string
+ initialTenants []string
+ remainingTenants []string
+ removedTenants []string
+ }{
+ {
+ name: "closes_connections_for_single_removed_tenant",
+ initialTenants: []string{"tenant-a", "tenant-b"},
+ remainingTenants: []string{"tenant-a"},
+ removedTenants: []string{"tenant-b"},
+ },
+ {
+ name: "closes_connections_for_all_removed_tenants",
+ initialTenants: []string{"tenant-a", "tenant-b", "tenant-c"},
+ remainingTenants: []string{},
+ removedTenants: []string{"tenant-a", "tenant-b", "tenant-c"},
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ mr, redisClient := setupMiniredis(t)
+
+ // Use a capturing logger to verify close log messages
+ logger := testutil.NewCapturingLogger()
+
+ config := MultiTenantConfig{
+ SyncInterval: 30 * time.Second,
+ WorkersPerQueue: 1,
+ PrefetchCount: 10,
+ Service: testServiceName,
+ }
+
+ // Create managers using sub-package constructors.
+ // CloseConnection returns nil for tenants not in the connections map,
+ // so we verify behavior through log messages.
+ pgManager := tmpostgres.NewManager(nil, "test-service")
+ mongoManager := tmmongo.NewManager(nil, "test-service")
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), redisClient, config, logger,
+ WithPostgresManager(pgManager),
+ WithMongoManager(mongoManager),
+ )
+
+ // Populate initial tenants in Redis
+ for _, id := range tt.initialTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ ctx := context.Background()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ // Initial sync to populate state
+ err := consumer.syncTenants(ctx)
+ require.NoError(t, err, "initial syncTenants should succeed")
+
+ // Simulate active consumers for all tenants (so removal code path is triggered)
+ consumer.mu.Lock()
+ for _, id := range tt.initialTenants {
+ _, cancel := context.WithCancel(ctx)
+ consumer.tenants[id] = cancel
+ }
+ consumer.mu.Unlock()
+
+ // Update Redis to remaining tenants only
+ mr.Del(testActiveTenantsKey)
+ for _, id := range tt.remainingTenants {
+ mr.SAdd(testActiveTenantsKey, id)
+ }
+
+ // Run sync absentSyncsBeforeRemoval times so removals are confirmed and connections closed
+ for i := 0; i < absentSyncsBeforeRemoval; i++ {
+ err = consumer.syncTenants(ctx)
+ require.NoError(t, err, "syncTenants should succeed")
+ }
+
+ // Verify removed tenants are gone from tenants map
+ consumer.mu.RLock()
+ for _, id := range tt.removedTenants {
+ _, exists := consumer.tenants[id]
+ assert.False(t, exists,
+ "removed tenant %q should not be in tenants map", id)
+ }
+ consumer.mu.RUnlock()
+
+ // Verify log messages contain removal information for each removed tenant
+ for _, id := range tt.removedTenants {
+ assert.True(t, logger.ContainsSubstring("closing connections for removed tenant: "+id),
+ "should log closing connections for removed tenant %q", id)
+ }
+ })
+ }
+}
+
+// TestMultiTenantConsumer_RevalidateConnectionSettings tests revalidation behavior.
+// Note: Tests that require injecting connections into the postgres/mongo manager's
+// internal connections map (applies_settings_to_active_tenants, continues_on_individual_tenant_error)
+// are tested in the postgres sub-package's own test file since they need access to
+// unexported fields. Here we test the consumer-level skip conditions.
+func TestMultiTenantConsumer_RevalidateConnectionSettings(t *testing.T) {
+ t.Parallel()
+
+ t.Run("skips_when_no_managers_configured", func(t *testing.T) {
+ t.Parallel()
+
+ logger := testutil.NewCapturingLogger()
+ config := MultiTenantConfig{
+ Service: "ledger",
+ SyncInterval: 30 * time.Second,
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), dummyRedisClient(t), config, logger)
+
+ ctx := context.Background()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ // Should return immediately without logging
+ consumer.revalidateConnectionSettings(ctx)
+
+ assert.False(t, logger.ContainsSubstring("revalidated connection settings"),
+ "should not log revalidation when no managers are configured")
+ })
+
+ t.Run("skips_when_no_pmClient_configured", func(t *testing.T) {
+ t.Parallel()
+
+ logger := testutil.NewCapturingLogger()
+ pgManager := tmpostgres.NewManager(nil, "ledger")
+
+ config := MultiTenantConfig{
+ Service: "ledger",
+ SyncInterval: 30 * time.Second,
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), dummyRedisClient(t), config, logger,
+ WithPostgresManager(pgManager),
+ )
+ // Explicitly ensure no pmClient
+ consumer.pmClient = nil
+
+ ctx := context.Background()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ consumer.revalidateConnectionSettings(ctx)
+
+ assert.False(t, logger.ContainsSubstring("revalidated connection settings"),
+ "should not log revalidation when pmClient is nil")
+ })
+
+ t.Run("skips_when_no_active_tenants", func(t *testing.T) {
+ t.Parallel()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ t.Error("should not call Tenant Manager when no active tenants")
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ logger := testutil.NewCapturingLogger()
+ tmClient, tmErr := client.NewClient(server.URL, logger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, tmErr)
+
+ pgManager := tmpostgres.NewManager(tmClient, "ledger")
+
+ config := MultiTenantConfig{
+ Service: "ledger",
+ SyncInterval: 30 * time.Second,
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), dummyRedisClient(t), config, logger,
+ WithPostgresManager(pgManager),
+ )
+ consumer.pmClient = tmClient
+
+ ctx := context.Background()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ consumer.revalidateConnectionSettings(ctx)
+
+ assert.False(t, logger.ContainsSubstring("revalidated connection settings"),
+ "should not log revalidation when no active tenants")
+ })
+
+ t.Run("applies_settings_to_active_tenants", func(t *testing.T) {
+ t.Parallel()
+
+ // Set up a mock Tenant Manager that returns config with connection settings
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ resp := `{
+ "id": "tenant-abc",
+ "tenantSlug": "abc",
+ "databases": {
+ "onboarding": {
+ "connectionSettings": {
+ "maxOpenConns": 50,
+ "maxIdleConns": 15
+ }
+ }
+ }
+ }`
+ w.Write([]byte(resp))
+ }))
+ defer server.Close()
+
+ logger := testutil.NewCapturingLogger()
+ tmClient, tmErr := client.NewClient(server.URL, logger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, tmErr)
+
+ pgManager := tmpostgres.NewManager(tmClient, "ledger",
+ tmpostgres.WithModule("onboarding"),
+ tmpostgres.WithLogger(logger),
+ )
+
+ config := MultiTenantConfig{
+ Service: "ledger",
+ SyncInterval: 30 * time.Second,
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), dummyRedisClient(t), config, logger,
+ WithPostgresManager(pgManager),
+ )
+ consumer.pmClient = tmClient
+
+ // Simulate active tenant
+ consumer.mu.Lock()
+ _, cancel := context.WithCancel(context.Background())
+ consumer.tenants["tenant-abc"] = cancel
+ consumer.mu.Unlock()
+
+ ctx := context.Background()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ consumer.revalidateConnectionSettings(ctx)
+
+ // ApplyConnectionSettings was called but since there is no actual connection
+ // in the pgManager's internal map, it is effectively a no-op for the settings.
+ // We verify that revalidation was attempted by checking the log message.
+ assert.True(t, logger.ContainsSubstring("revalidated connection settings"),
+ "should log revalidation summary")
+ })
+
+ t.Run("continues_on_individual_tenant_error", func(t *testing.T) {
+ t.Parallel()
+
+ callCount := 0
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ callCount++
+ if strings.Contains(r.URL.Path, "tenant-fail") {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ resp := `{
+ "id": "tenant-ok",
+ "tenantSlug": "ok",
+ "databases": {
+ "onboarding": {
+ "connectionSettings": {
+ "maxOpenConns": 25,
+ "maxIdleConns": 5
+ }
+ }
+ }
+ }`
+ w.Write([]byte(resp))
+ }))
+ defer server.Close()
+
+ logger := testutil.NewCapturingLogger()
+ tmClient, tmErr := client.NewClient(server.URL, logger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, tmErr)
+
+ pgManager := tmpostgres.NewManager(tmClient, "ledger",
+ tmpostgres.WithModule("onboarding"),
+ tmpostgres.WithLogger(logger),
+ )
+
+ config := MultiTenantConfig{
+ Service: "ledger",
+ SyncInterval: 30 * time.Second,
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), dummyRedisClient(t), config, logger,
+ WithPostgresManager(pgManager),
+ )
+ consumer.pmClient = tmClient
+
+ // Simulate active tenants
+ consumer.mu.Lock()
+ ctx := context.Background()
+ _, cancelOK := context.WithCancel(ctx)
+ _, cancelFail := context.WithCancel(ctx)
+ consumer.tenants["tenant-ok"] = cancelOK
+ consumer.tenants["tenant-fail"] = cancelFail
+ consumer.mu.Unlock()
+
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ consumer.revalidateConnectionSettings(ctx)
+
+ // Should log warning about failed tenant
+ assert.True(t, logger.ContainsSubstring("failed to fetch config for tenant tenant-fail"),
+ "should log warning about fetch failure")
+ })
+}
+
+// TestMultiTenantConsumer_RevalidateSettings_StopsSuspendedTenant verifies that
+// revalidateConnectionSettings stops the consumer and removes the tenant from
+// knownTenants and tenants maps when the Tenant Manager returns 403 (suspended/purged).
+func TestMultiTenantConsumer_RevalidateSettings_StopsSuspendedTenant(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ responseBody string
+ suspendedTenantID string
+ healthyTenantID string
+ expectLogSubstring string
+ }{
+ {
+ name: "stops_suspended_tenant_and_keeps_healthy_tenant",
+ responseBody: `{"code":"TS-SUSPENDED","error":"service suspended","status":"suspended"}`,
+ suspendedTenantID: "tenant-suspended",
+ healthyTenantID: "tenant-healthy",
+ expectLogSubstring: "tenant tenant-suspended service suspended, stopping consumer and closing connections",
+ },
+ {
+ name: "stops_purged_tenant_and_keeps_healthy_tenant",
+ responseBody: `{"code":"TS-SUSPENDED","error":"service purged","status":"purged"}`,
+ suspendedTenantID: "tenant-purged",
+ healthyTenantID: "tenant-healthy",
+ expectLogSubstring: "service suspended, stopping consumer and closing connections",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Set up a mock Tenant Manager that returns 403 for the suspended tenant
+ // and 200 with valid config for the healthy tenant
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ if strings.Contains(r.URL.Path, tt.suspendedTenantID) {
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte(tt.responseBody))
+
+ return
+ }
+
+ // Return valid config for healthy tenant
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "` + tt.healthyTenantID + `",
+ "tenantSlug": "healthy",
+ "databases": {
+ "onboarding": {
+ "connectionSettings": {
+ "maxOpenConns": 25,
+ "maxIdleConns": 5
+ }
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ logger := testutil.NewCapturingLogger()
+ tmClient, tmErr := client.NewClient(server.URL, logger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, tmErr)
+
+ pgManager := tmpostgres.NewManager(tmClient, "ledger",
+ tmpostgres.WithModule("onboarding"),
+ tmpostgres.WithLogger(logger),
+ )
+
+ config := MultiTenantConfig{
+ Service: "ledger",
+ SyncInterval: 30 * time.Second,
+ }
+
+ consumer := NewMultiTenantConsumer(dummyRabbitMQManager(), dummyRedisClient(t), config, logger,
+ WithPostgresManager(pgManager),
+ )
+ consumer.pmClient = tmClient
+
+ // Pre-populate per-tenant sync.Map entries for the suspended tenant
+ // to verify they are cleaned up during eviction.
+ consumer.consumerLocks.Store(tt.suspendedTenantID, &sync.Mutex{})
+ consumer.retryState.Store(tt.suspendedTenantID, &retryStateEntry{})
+ consumer.consumerLocks.Store(tt.healthyTenantID, &sync.Mutex{})
+ consumer.retryState.Store(tt.healthyTenantID, &retryStateEntry{})
+
+ // Simulate active tenants with cancel functions
+ consumer.mu.Lock()
+ suspendedCanceled := false
+ _, cancelSuspended := context.WithCancel(context.Background())
+ wrappedCancel := func() {
+ suspendedCanceled = true
+ cancelSuspended()
+ }
+ _, cancelHealthy := context.WithCancel(context.Background())
+ consumer.tenants[tt.suspendedTenantID] = wrappedCancel
+ consumer.tenants[tt.healthyTenantID] = cancelHealthy
+ consumer.knownTenants[tt.suspendedTenantID] = true
+ consumer.knownTenants[tt.healthyTenantID] = true
+ // Pre-populate tenantAbsenceCount for the suspended tenant
+ consumer.tenantAbsenceCount[tt.suspendedTenantID] = 1
+ consumer.mu.Unlock()
+
+ ctx := context.Background()
+ ctx = libCommons.ContextWithLogger(ctx, logger)
+
+ // Trigger revalidation
+ consumer.revalidateConnectionSettings(ctx)
+
+ // Verify the suspended tenant was removed from tenants map
+ consumer.mu.RLock()
+ _, suspendedInTenants := consumer.tenants[tt.suspendedTenantID]
+ _, suspendedInKnown := consumer.knownTenants[tt.suspendedTenantID]
+ _, healthyInTenants := consumer.tenants[tt.healthyTenantID]
+ _, healthyInKnown := consumer.knownTenants[tt.healthyTenantID]
+ consumer.mu.RUnlock()
+
+ assert.False(t, suspendedInTenants,
+ "suspended tenant should be removed from tenants map")
+ assert.False(t, suspendedInKnown,
+ "suspended tenant should be removed from knownTenants map")
+ assert.True(t, suspendedCanceled,
+ "suspended tenant's context cancel should have been called")
+
+ // Verify the healthy tenant is still active
+ assert.True(t, healthyInTenants,
+ "healthy tenant should still be in tenants map")
+ assert.True(t, healthyInKnown,
+ "healthy tenant should still be in knownTenants map")
+
+ // Verify the appropriate log message was produced
+ assert.True(t, logger.ContainsSubstring(tt.expectLogSubstring),
+ "expected log message containing %q, got: %v",
+ tt.expectLogSubstring, logger.GetMessages())
+
+ // Verify that the healthy tenant was still revalidated
+ assert.True(t, logger.ContainsSubstring("revalidated connection settings for 1/"),
+ "should log revalidation summary for the healthy tenant")
+
+ // Note: evictSuspendedTenant does NOT clean sync.Map entries (consumerLocks, retryState)
+ // or tenantAbsenceCount. It only cleans regular maps (tenants, knownTenants).
+ // sync.Map entries persist until overwritten or garbage collected.
+
+ // Verify healthy tenant's sync.Map entries are NOT affected
+ _, healthyLockExists := consumer.consumerLocks.Load(tt.healthyTenantID)
+ assert.True(t, healthyLockExists,
+ "consumerLocks should still exist for healthy tenant %q", tt.healthyTenantID)
+
+ _, healthyRetryExists := consumer.retryState.Load(tt.healthyTenantID)
+ assert.True(t, healthyRetryExists,
+ "retryState should still exist for healthy tenant %q", tt.healthyTenantID)
+ })
+ }
+}
+
+// TestMultiTenantConsumer_AllowInsecureHTTP verifies that the AllowInsecureHTTP
+// config field controls whether http:// MultiTenantURLs are accepted by the constructor.
+func TestMultiTenantConsumer_AllowInsecureHTTP(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ config MultiTenantConfig
+ expectError bool
+ errContains string
+ }{
+ {
+ name: "rejects_http_URL_when_AllowInsecureHTTP_is_false",
+ config: MultiTenantConfig{
+ MultiTenantURL: "http://tenant-manager.namespace.svc.cluster.local:4003",
+ ServiceAPIKey: "test-key",
+ },
+ expectError: true,
+ errContains: "insecure HTTP",
+ },
+ {
+ name: "accepts_http_URL_when_AllowInsecureHTTP_is_true",
+ config: MultiTenantConfig{
+ MultiTenantURL: "http://tenant-manager.namespace.svc.cluster.local:4003",
+ ServiceAPIKey: "test-key",
+ AllowInsecureHTTP: true,
+ },
+ expectError: false,
+ },
+ {
+ name: "accepts_https_URL_regardless_of_AllowInsecureHTTP",
+ config: MultiTenantConfig{
+ MultiTenantURL: "https://tenant-manager.dev.example.com",
+ ServiceAPIKey: "test-key",
+ },
+ expectError: false,
+ },
+ {
+ name: "no_error_when_MultiTenantURL_is_empty",
+ config: MultiTenantConfig{
+ MultiTenantURL: "",
+ },
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ redisClient := dummyRedisClient(t)
+ consumer, err := NewMultiTenantConsumerWithError(
+ dummyRabbitMQManager(), redisClient, tt.config, testutil.NewMockLogger(),
+ )
+
+ if tt.expectError {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errContains)
+ assert.Nil(t, consumer)
+ } else {
+ require.NoError(t, err)
+ assert.NotNil(t, consumer)
+ if consumer != nil {
+ _ = consumer.Close()
+ }
+ }
+ })
+ }
+}
diff --git a/commons/tenant-manager/core/context.go b/commons/tenant-manager/core/context.go
new file mode 100644
index 00000000..ecc1fb9f
--- /dev/null
+++ b/commons/tenant-manager/core/context.go
@@ -0,0 +1,140 @@
+package core
+
+import (
+ "context"
+
+ "github.com/bxcodec/dbresolver/v2"
+ "go.mongodb.org/mongo-driver/mongo"
+)
+
+// nonNilContext returns ctx if non-nil, otherwise context.Background().
+// This guards every exported setter/getter against nil-context panics.
+func nonNilContext(ctx context.Context) context.Context {
+ if ctx == nil {
+ return context.Background()
+ }
+
+ return ctx
+}
+
+// Context key types for storing tenant information.
+// Use unexported struct keys to avoid collisions across packages.
+type contextKey struct {
+ name string
+}
+
+var (
+ // tenantIDKey is the context key for storing the tenant ID.
+ tenantIDKey = contextKey{name: "tenantID"}
+ // tenantPGConnectionKey is the context key for storing the resolved dbresolver.DB connection.
+ tenantPGConnectionKey = contextKey{name: "tenantPGConnection"}
+ // tenantMongoKey is the context key for storing the tenant MongoDB database.
+ tenantMongoKey = contextKey{name: "tenantMongo"}
+)
+
+// SetTenantIDInContext stores the tenant ID in the context.
+func SetTenantIDInContext(ctx context.Context, tenantID string) context.Context {
+ return context.WithValue(nonNilContext(ctx), tenantIDKey, tenantID)
+}
+
+// GetTenantIDFromContext retrieves the tenant ID from the context.
+// Returns empty string if not found.
+func GetTenantIDFromContext(ctx context.Context) string {
+ if id, ok := nonNilContext(ctx).Value(tenantIDKey).(string); ok {
+ return id
+ }
+
+ return ""
+}
+
+// GetTenantID is an alias for GetTenantIDFromContext.
+// Returns the tenant ID from context, or empty string if not found.
+func GetTenantID(ctx context.Context) string {
+ return GetTenantIDFromContext(ctx)
+}
+
+// ContextWithTenantID stores the tenant ID in the context.
+// Alias for SetTenantIDInContext for compatibility with middleware.
+func ContextWithTenantID(ctx context.Context, tenantID string) context.Context {
+ return SetTenantIDInContext(ctx, tenantID)
+}
+
+// ContextWithTenantPGConnection stores the resolved dbresolver.DB connection in the context.
+// This is used by the middleware to store the tenant-specific database connection.
+func ContextWithTenantPGConnection(ctx context.Context, db dbresolver.DB) context.Context {
+ return context.WithValue(nonNilContext(ctx), tenantPGConnectionKey, db)
+}
+
+// GetTenantPGConnectionFromContext retrieves the resolved dbresolver.DB from the context.
+// Returns nil if not found.
+func GetTenantPGConnectionFromContext(ctx context.Context) dbresolver.DB {
+ if db, ok := nonNilContext(ctx).Value(tenantPGConnectionKey).(dbresolver.DB); ok {
+ return db
+ }
+
+ return nil
+}
+
+// GetPostgresForTenant returns the PostgreSQL database connection for the current tenant from context.
+// If no tenant connection is found in context, returns ErrTenantContextRequired.
+// This function ALWAYS requires tenant context - there is no fallback to default connections.
+func GetPostgresForTenant(ctx context.Context) (dbresolver.DB, error) {
+ if tenantDB := GetTenantPGConnectionFromContext(ctx); tenantDB != nil {
+ return tenantDB, nil
+ }
+
+ return nil, ErrTenantContextRequired
+}
+
+// moduleContextKey generates a dynamic context key for a given module name.
+// This allows any module to store its own PostgreSQL connection in context
+// without requiring changes to lib-commons.
+func moduleContextKey(moduleName string) contextKey {
+ return contextKey{name: "tenantPGConnection:" + moduleName}
+}
+
+// ContextWithModulePGConnection stores a module-specific PostgreSQL connection in context.
+// moduleName identifies the module (e.g., "onboarding", "transaction").
+// This is used in multi-module processes where each module needs its own database connection
+// in context to avoid cross-module conflicts.
+func ContextWithModulePGConnection(ctx context.Context, moduleName string, db dbresolver.DB) context.Context {
+ return context.WithValue(nonNilContext(ctx), moduleContextKey(moduleName), db)
+}
+
+// GetModulePostgresForTenant returns the module-specific PostgreSQL connection from context.
+// moduleName identifies the module (e.g., "onboarding", "transaction").
+// Returns ErrTenantContextRequired if no connection is found for the given module.
+// This function does NOT fallback to the generic tenantPGConnectionKey.
+func GetModulePostgresForTenant(ctx context.Context, moduleName string) (dbresolver.DB, error) {
+ if db, ok := nonNilContext(ctx).Value(moduleContextKey(moduleName)).(dbresolver.DB); ok && db != nil {
+ return db, nil
+ }
+
+ return nil, ErrTenantContextRequired
+}
+
+// ContextWithTenantMongo stores the MongoDB database in the context.
+func ContextWithTenantMongo(ctx context.Context, db *mongo.Database) context.Context {
+ return context.WithValue(nonNilContext(ctx), tenantMongoKey, db)
+}
+
+// GetMongoFromContext retrieves the MongoDB database from the context.
+// Returns nil if not found.
+func GetMongoFromContext(ctx context.Context) *mongo.Database {
+ if db, ok := nonNilContext(ctx).Value(tenantMongoKey).(*mongo.Database); ok {
+ return db
+ }
+
+ return nil
+}
+
+// GetMongoForTenant returns the MongoDB database for the current tenant from context.
+// If no tenant connection is found in context, returns ErrTenantContextRequired.
+// This function ALWAYS requires tenant context - there is no fallback to default connections.
+func GetMongoForTenant(ctx context.Context) (*mongo.Database, error) {
+ if db := GetMongoFromContext(ctx); db != nil {
+ return db, nil
+ }
+
+ return nil, ErrTenantContextRequired
+}
diff --git a/commons/tenant-manager/core/context_test.go b/commons/tenant-manager/core/context_test.go
new file mode 100644
index 00000000..1e2c0940
--- /dev/null
+++ b/commons/tenant-manager/core/context_test.go
@@ -0,0 +1,359 @@
+package core
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "testing"
+ "time"
+
+ "github.com/bxcodec/dbresolver/v2"
+ "github.com/stretchr/testify/assert"
+ "go.mongodb.org/mongo-driver/mongo"
+)
+
+func TestSetTenantIDInContext(t *testing.T) {
+ ctx := context.Background()
+
+ ctx = SetTenantIDInContext(ctx, "tenant-123")
+
+ assert.Equal(t, "tenant-123", GetTenantIDFromContext(ctx))
+}
+
+func TestGetTenantIDFromContext_NotSet(t *testing.T) {
+ ctx := context.Background()
+
+ id := GetTenantIDFromContext(ctx)
+
+ assert.Equal(t, "", id)
+}
+
+func TestContextWithTenantID(t *testing.T) {
+ ctx := context.Background()
+
+ ctx = ContextWithTenantID(ctx, "tenant-456")
+
+ assert.Equal(t, "tenant-456", GetTenantIDFromContext(ctx))
+}
+
+func TestGetPostgresForTenant(t *testing.T) {
+ t.Run("returns error when no connection in context", func(t *testing.T) {
+ ctx := context.Background()
+
+ db, err := GetPostgresForTenant(ctx)
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+}
+
+// mockDB implements dbresolver.DB interface for testing purposes.
+type mockDB struct {
+ name string
+}
+
+// Ensure mockDB implements dbresolver.DB interface.
+var _ dbresolver.DB = (*mockDB)(nil)
+
+func (m *mockDB) Begin() (dbresolver.Tx, error) { return nil, nil }
+func (m *mockDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (dbresolver.Tx, error) {
+ return nil, nil
+}
+func (m *mockDB) Close() error { return nil }
+func (m *mockDB) Conn(ctx context.Context) (dbresolver.Conn, error) { return nil, nil }
+func (m *mockDB) Driver() driver.Driver { return nil }
+func (m *mockDB) Exec(query string, args ...interface{}) (sql.Result, error) { return nil, nil }
+func (m *mockDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+ return nil, nil
+}
+func (m *mockDB) Ping() error { return nil }
+func (m *mockDB) PingContext(ctx context.Context) error { return nil }
+func (m *mockDB) Prepare(query string) (dbresolver.Stmt, error) { return nil, nil }
+func (m *mockDB) PrepareContext(ctx context.Context, query string) (dbresolver.Stmt, error) {
+ return nil, nil
+}
+func (m *mockDB) Query(query string, args ...interface{}) (*sql.Rows, error) { return nil, nil }
+func (m *mockDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
+ return nil, nil
+}
+func (m *mockDB) QueryRow(query string, args ...interface{}) *sql.Row { return nil }
+func (m *mockDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
+ return nil
+}
+func (m *mockDB) SetConnMaxIdleTime(d time.Duration) {}
+func (m *mockDB) SetConnMaxLifetime(d time.Duration) {}
+func (m *mockDB) SetMaxIdleConns(n int) {}
+func (m *mockDB) SetMaxOpenConns(n int) {}
+func (m *mockDB) PrimaryDBs() []*sql.DB { return nil }
+func (m *mockDB) ReplicaDBs() []*sql.DB { return nil }
+func (m *mockDB) Stats() sql.DBStats { return sql.DBStats{} }
+
+func TestGetTenantPGConnectionFromContext(t *testing.T) {
+ t.Run("returns nil when no PG connection in context", func(t *testing.T) {
+ ctx := context.Background()
+
+ db := GetTenantPGConnectionFromContext(ctx)
+
+ assert.Nil(t, db)
+ })
+
+ t.Run("returns connection when set via ContextWithTenantPGConnection", func(t *testing.T) {
+ ctx := context.Background()
+ mockConn := &mockDB{name: "tenant-db"}
+
+ ctx = ContextWithTenantPGConnection(ctx, mockConn)
+ db := GetTenantPGConnectionFromContext(ctx)
+
+ assert.Equal(t, mockConn, db)
+ })
+}
+
+func TestContextWithModulePGConnection(t *testing.T) {
+ t.Run("stores and retrieves module connection", func(t *testing.T) {
+ ctx := context.Background()
+ mockConn := &mockDB{name: "module-db"}
+
+ ctx = ContextWithModulePGConnection(ctx, "onboarding", mockConn)
+ db, err := GetModulePostgresForTenant(ctx, "onboarding")
+
+ assert.NoError(t, err)
+ assert.Equal(t, mockConn, db)
+ })
+}
+
+func TestGetModulePostgresForTenant(t *testing.T) {
+ t.Run("returns error when no connection in context", func(t *testing.T) {
+ ctx := context.Background()
+
+ db, err := GetModulePostgresForTenant(ctx, "onboarding")
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+
+ t.Run("does not fallback to generic connection", func(t *testing.T) {
+ ctx := context.Background()
+ genericConn := &mockDB{name: "generic-db"}
+
+ ctx = ContextWithTenantPGConnection(ctx, genericConn)
+
+ db, err := GetModulePostgresForTenant(ctx, "onboarding")
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+
+ t.Run("does not fallback to other module connection", func(t *testing.T) {
+ ctx := context.Background()
+ txnConn := &mockDB{name: "transaction-db"}
+
+ ctx = ContextWithModulePGConnection(ctx, "transaction", txnConn)
+
+ db, err := GetModulePostgresForTenant(ctx, "onboarding")
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+
+ t.Run("works with arbitrary module names", func(t *testing.T) {
+ ctx := context.Background()
+ reportingConn := &mockDB{name: "reporting-db"}
+
+ ctx = ContextWithModulePGConnection(ctx, "reporting", reportingConn)
+ db, err := GetModulePostgresForTenant(ctx, "reporting")
+
+ assert.NoError(t, err)
+ assert.Equal(t, reportingConn, db)
+ })
+}
+
+func TestModuleConnectionIsolationGeneric(t *testing.T) {
+ t.Run("multiple modules are isolated from each other", func(t *testing.T) {
+ ctx := context.Background()
+ onbConn := &mockDB{name: "onboarding-db"}
+ txnConn := &mockDB{name: "transaction-db"}
+ rptConn := &mockDB{name: "reporting-db"}
+
+ ctx = ContextWithModulePGConnection(ctx, "onboarding", onbConn)
+ ctx = ContextWithModulePGConnection(ctx, "transaction", txnConn)
+ ctx = ContextWithModulePGConnection(ctx, "reporting", rptConn)
+
+ onbDB, onbErr := GetModulePostgresForTenant(ctx, "onboarding")
+ txnDB, txnErr := GetModulePostgresForTenant(ctx, "transaction")
+ rptDB, rptErr := GetModulePostgresForTenant(ctx, "reporting")
+
+ assert.NoError(t, onbErr)
+ assert.NoError(t, txnErr)
+ assert.NoError(t, rptErr)
+ assert.Equal(t, onbConn, onbDB)
+ assert.Equal(t, txnConn, txnDB)
+ assert.Equal(t, rptConn, rptDB)
+ })
+
+ t.Run("module connections are independent of generic connection", func(t *testing.T) {
+ ctx := context.Background()
+ genericConn := &mockDB{name: "generic-db"}
+ moduleConn := &mockDB{name: "module-db"}
+
+ ctx = ContextWithTenantPGConnection(ctx, genericConn)
+ ctx = ContextWithModulePGConnection(ctx, "mymodule", moduleConn)
+
+ genDB, genErr := GetPostgresForTenant(ctx)
+ modDB, modErr := GetModulePostgresForTenant(ctx, "mymodule")
+
+ assert.NoError(t, genErr)
+ assert.NoError(t, modErr)
+ assert.Equal(t, genericConn, genDB)
+ assert.Equal(t, moduleConn, modDB)
+ assert.NotEqual(t, genDB, modDB)
+ })
+}
+
+func TestGetMongoFromContext(t *testing.T) {
+ t.Run("returns nil when no mongo in context", func(t *testing.T) {
+ ctx := context.Background()
+
+ db := GetMongoFromContext(ctx)
+
+ assert.Nil(t, db)
+ })
+
+ t.Run("returns nil for nil mongo database stored in context", func(t *testing.T) {
+ ctx := context.Background()
+
+ var nilDB *mongo.Database
+ ctx = ContextWithTenantMongo(ctx, nilDB)
+
+ db := GetMongoFromContext(ctx)
+
+ assert.Nil(t, db)
+ })
+}
+
+func TestNilContext(t *testing.T) {
+ t.Run("SetTenantIDInContext with nil context does not panic and stores value", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ ctx := SetTenantIDInContext(nil, "t1")
+
+ assert.Equal(t, "t1", GetTenantIDFromContext(ctx))
+ })
+
+ t.Run("GetTenantIDFromContext with nil context returns empty string", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ id := GetTenantIDFromContext(nil)
+
+ assert.Equal(t, "", id)
+ })
+
+ t.Run("ContextWithTenantPGConnection with nil context does not panic", func(t *testing.T) {
+ mockConn := &mockDB{name: "test-db"}
+
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ ctx := ContextWithTenantPGConnection(nil, mockConn)
+
+ assert.Equal(t, mockConn, GetTenantPGConnectionFromContext(ctx))
+ })
+
+ t.Run("GetTenantPGConnectionFromContext with nil context returns nil", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ db := GetTenantPGConnectionFromContext(nil)
+
+ assert.Nil(t, db)
+ })
+
+ t.Run("ContextWithTenantMongo with nil context does not panic", func(t *testing.T) {
+ // We cannot create a real *mongo.Database without a live client,
+ // but we can verify nil context does not panic with a nil DB value.
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ ctx := ContextWithTenantMongo(nil, nil)
+
+ assert.NotNil(t, ctx)
+ })
+
+ t.Run("GetMongoFromContext with nil context returns nil", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ db := GetMongoFromContext(nil)
+
+ assert.Nil(t, db)
+ })
+
+ t.Run("GetTenantID alias with nil context returns empty string", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ id := GetTenantID(nil)
+
+ assert.Equal(t, "", id)
+ })
+
+ t.Run("ContextWithTenantID alias with nil context does not panic", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ ctx := ContextWithTenantID(nil, "t2")
+
+ assert.Equal(t, "t2", GetTenantIDFromContext(ctx))
+ })
+
+ t.Run("GetPostgresForTenant with nil context returns error", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ db, err := GetPostgresForTenant(nil)
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+
+ t.Run("GetMongoForTenant with nil context returns error", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ db, err := GetMongoForTenant(nil)
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+
+ t.Run("ContextWithModulePGConnection with nil context does not panic", func(t *testing.T) {
+ mockConn := &mockDB{name: "module-db"}
+
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ ctx := ContextWithModulePGConnection(nil, "mymod", mockConn)
+
+ db, err := GetModulePostgresForTenant(ctx, "mymod")
+
+ assert.NoError(t, err)
+ assert.Equal(t, mockConn, db)
+ })
+
+ t.Run("GetModulePostgresForTenant with nil context returns error", func(t *testing.T) {
+ //nolint:staticcheck // SA1012: intentionally passing nil context to test nil-safety guard
+ db, err := GetModulePostgresForTenant(nil, "mymod")
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+}
+
+func TestGetMongoForTenant(t *testing.T) {
+ t.Run("returns error when no connection in context", func(t *testing.T) {
+ ctx := context.Background()
+
+ db, err := GetMongoForTenant(ctx)
+
+ assert.Nil(t, db)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+
+ t.Run("returns ErrTenantContextRequired for nil db in context", func(t *testing.T) {
+ ctx := context.Background()
+
+ // Use ContextWithTenantMongo with a nil *mongo.Database to test the path
+ // (We cannot create a real *mongo.Database without a live client,
+ // but we can test the nil path and the type assertion path.)
+ var nilDB *mongo.Database
+ ctx = ContextWithTenantMongo(ctx, nilDB)
+
+ // nil *mongo.Database stored in context: type assertion succeeds but value is nil
+ db := GetMongoFromContext(ctx)
+ assert.Nil(t, db)
+
+ // GetMongoForTenant should return error for nil db
+ result, err := GetMongoForTenant(ctx)
+ assert.Nil(t, result)
+ assert.ErrorIs(t, err, ErrTenantContextRequired)
+ })
+}
diff --git a/commons/tenant-manager/core/errors.go b/commons/tenant-manager/core/errors.go
new file mode 100644
index 00000000..015d6728
--- /dev/null
+++ b/commons/tenant-manager/core/errors.go
@@ -0,0 +1,158 @@
+package core
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// ErrNilHandlerFunc is returned when a nil HandlerFunc is registered.
+var ErrNilHandlerFunc = errors.New("handler function must not be nil")
+
+// ErrNilCache is returned when a typed-nil cache implementation is provided.
+var ErrNilCache = errors.New("cache implementation must not be nil (received typed-nil interface)")
+
+// ErrNilConfig is returned when a required configuration pointer is nil.
+var ErrNilConfig = errors.New("configuration must not be nil")
+
+// ErrInsecureHTTP is returned when an HTTP URL is used without explicit opt-in.
+var ErrInsecureHTTP = errors.New("insecure HTTP is not allowed; use HTTPS or enable WithAllowInsecureHTTP()")
+
+// ErrServiceAPIKeyRequired is returned when NewClient is called without a non-empty service API key.
+var ErrServiceAPIKeyRequired = errors.New("service API key is required: use WithServiceAPIKey() with a non-empty key")
+
+// IsNilInterface reports whether v is a nil interface value or an interface
+// wrapping a nil pointer (typed-nil). This is necessary because Go interfaces
+// with a nil concrete value are not == nil.
+func IsNilInterface(v any) bool {
+ if v == nil {
+ return true
+ }
+
+ rv := reflect.ValueOf(v)
+ switch rv.Kind() {
+ case reflect.Ptr, reflect.Map, reflect.Slice, reflect.Chan, reflect.Func, reflect.Interface:
+ return rv.IsNil()
+ default:
+ return false
+ }
+}
+
+// ErrTenantNotFound is returned when the tenant is not found in Tenant Manager.
+var ErrTenantNotFound = errors.New("tenant not found")
+
+// ErrServiceNotConfigured is returned when the service is not configured for the tenant.
+var ErrServiceNotConfigured = errors.New("service not configured for tenant")
+
+// ErrTenantServiceAccessDenied is returned when the tenant-service association exists
+// but is not active (e.g., suspended or purged), resulting in an HTTP 403 from the Tenant Manager.
+var ErrTenantServiceAccessDenied = errors.New("tenant service access denied")
+
+// ErrManagerClosed is returned when attempting to use a closed connection manager.
+var ErrManagerClosed = errors.New("tenant connection manager is closed")
+
+// ErrTenantContextRequired is returned when no tenant context is found for a database operation.
+// This error indicates that a request attempted to access the database without proper tenant identification.
+// The tenant connection must be set in context via middleware before database operations.
+var ErrTenantContextRequired = errors.New("tenant context required: no tenant database connection found in context")
+
+// ErrTenantNotProvisioned is returned when the tenant database schema has not been initialized.
+// This typically happens when migrations have not been run on the tenant's database.
+// PostgreSQL error code 42P01 (undefined_table) indicates this condition.
+var ErrTenantNotProvisioned = errors.New("tenant database not provisioned: schema has not been initialized")
+
+// ErrCircuitBreakerOpen is returned when the circuit breaker is in the open state,
+// indicating the Tenant Manager service is temporarily unavailable.
+// Callers should retry after the circuit breaker timeout elapses.
+var ErrCircuitBreakerOpen = errors.New("tenant manager circuit breaker is open: service temporarily unavailable")
+
+// ErrAuthorizationTokenRequired is returned when the Authorization header is missing.
+var ErrAuthorizationTokenRequired = errors.New("authorization token is required")
+
+// ErrInvalidAuthorizationToken is returned when the JWT token cannot be parsed.
+var ErrInvalidAuthorizationToken = errors.New("invalid authorization token")
+
+// ErrInvalidTenantClaims is returned when JWT claims are malformed.
+var ErrInvalidTenantClaims = errors.New("invalid tenant claims")
+
+// ErrMissingTenantIDClaim is returned when JWT does not include tenantId.
+var ErrMissingTenantIDClaim = errors.New("tenantId claim is required")
+
+// ErrConnectionFailed is returned when tenant DB connection resolution fails.
+var ErrConnectionFailed = errors.New("tenant connection failed")
+
+// IsCircuitBreakerOpenError checks whether err (or any error in its chain) is ErrCircuitBreakerOpen.
+func IsCircuitBreakerOpenError(err error) bool {
+ return errors.Is(err, ErrCircuitBreakerOpen)
+}
+
+// TenantSuspendedError is returned when the tenant-service association exists but is not active
+// (e.g., suspended or purged). This allows callers to distinguish between "not found" and
+// "access denied due to status" scenarios.
+type TenantSuspendedError struct {
+ TenantID string // The tenant identifier that was requested
+ Status string // The current status (e.g., "suspended", "purged")
+ Message string // Human-readable error message from the server
+}
+
+// Error implements the error interface.
+func (e *TenantSuspendedError) Error() string {
+ if e == nil {
+ return "tenant service is unavailable"
+ }
+
+ if e.Message != "" {
+ return e.Message
+ }
+
+ return fmt.Sprintf("tenant service is %s for tenant %s", e.Status, e.TenantID)
+}
+
+// IsTenantSuspendedError checks whether err (or any error in its chain) is a *TenantSuspendedError.
+func IsTenantSuspendedError(err error) bool {
+ var target *TenantSuspendedError
+ return errors.As(err, &target)
+}
+
+// IsTenantPurgedError checks whether err (or any error in its chain) is a
+// *TenantSuspendedError whose Status is "purged". This allows callers to
+// distinguish purged tenants from suspended ones for eviction decisions.
+func IsTenantPurgedError(err error) bool {
+ var target *TenantSuspendedError
+ if errors.As(err, &target) {
+ return target.Status == "purged"
+ }
+
+ return false
+}
+
+// IsTenantNotProvisionedError checks if the error indicates an unprovisioned tenant database.
+// It first checks the error chain using errors.Is for the sentinel ErrTenantNotProvisioned,
+// then falls back to string matching for PostgreSQL SQLSTATE 42P01 (undefined_table).
+// This typically occurs when migrations have not been run on the tenant database.
+func IsTenantNotProvisionedError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ // Prefer errors.Is for wrapped sentinel errors
+ if errors.Is(err, ErrTenantNotProvisioned) {
+ return true
+ }
+
+ errStr := err.Error()
+
+ // Check for PostgreSQL error code 42P01 (undefined_table)
+ // This is the standard SQLSTATE for "relation does not exist"
+ if strings.Contains(errStr, "42P01") {
+ return true
+ }
+
+ // Also check for the common error message pattern
+ if strings.Contains(errStr, "relation") && strings.Contains(errStr, "does not exist") {
+ return true
+ }
+
+ return false
+}
diff --git a/commons/tenant-manager/core/errors_test.go b/commons/tenant-manager/core/errors_test.go
new file mode 100644
index 00000000..c18affc5
--- /dev/null
+++ b/commons/tenant-manager/core/errors_test.go
@@ -0,0 +1,132 @@
+package core
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestTenantSuspendedError(t *testing.T) {
+ t.Run("Error returns message when set", func(t *testing.T) {
+ err := &TenantSuspendedError{
+ TenantID: "tenant-123",
+ Status: "suspended",
+ Message: "service ledger is suspended for this tenant",
+ }
+
+ assert.Equal(t, "service ledger is suspended for this tenant", err.Error())
+ })
+
+ t.Run("Error returns default message when message is empty", func(t *testing.T) {
+ err := &TenantSuspendedError{
+ TenantID: "tenant-123",
+ Status: "purged",
+ }
+
+ assert.Equal(t, "tenant service is purged for tenant tenant-123", err.Error())
+ })
+
+ t.Run("implements error interface", func(t *testing.T) {
+ var err error = &TenantSuspendedError{
+ TenantID: "tenant-123",
+ Status: "suspended",
+ Message: "test",
+ }
+
+ assert.Error(t, err)
+ })
+}
+
+func TestTenantSuspendedError_NilReceiver(t *testing.T) {
+ var err *TenantSuspendedError
+
+ assert.Equal(t, "tenant service is unavailable", err.Error())
+}
+
+func TestErrTenantServiceAccessDenied(t *testing.T) {
+ assert.Error(t, ErrTenantServiceAccessDenied)
+ assert.Equal(t, "tenant service access denied", ErrTenantServiceAccessDenied.Error())
+
+ // Verify errors.Is works with wrapped errors
+ wrapped := fmt.Errorf("wrap: %w", ErrTenantServiceAccessDenied)
+ assert.ErrorIs(t, wrapped, ErrTenantServiceAccessDenied)
+}
+
+func TestIsTenantSuspendedError(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ expected bool
+ }{
+ {
+ name: "nil error returns false",
+ err: nil,
+ expected: false,
+ },
+ {
+ name: "TenantSuspendedError returns true",
+ err: &TenantSuspendedError{TenantID: "t1", Status: "suspended"},
+ expected: true,
+ },
+ {
+ name: "wrapped TenantSuspendedError returns true",
+ err: fmt.Errorf("outer: %w", &TenantSuspendedError{TenantID: "t1", Status: "suspended"}),
+ expected: true,
+ },
+ {
+ name: "generic error returns false",
+ err: errors.New("some error"),
+ expected: false,
+ },
+ {
+ name: "ErrTenantNotFound returns false",
+ err: ErrTenantNotFound,
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := IsTenantSuspendedError(tt.err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestIsTenantNotProvisionedError(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ expected bool
+ }{
+ {
+ name: "nil error returns false",
+ err: nil,
+ expected: false,
+ },
+ {
+ name: "42P01 error returns true",
+ err: errors.New("ERROR: relation \"table\" does not exist (SQLSTATE 42P01)"),
+ expected: true,
+ },
+ {
+ name: "relation does not exist returns true",
+ err: errors.New("pq: relation \"account\" does not exist"),
+ expected: true,
+ },
+ {
+ name: "generic error returns false",
+ err: errors.New("connection refused"),
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := IsTenantNotProvisionedError(tt.err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
diff --git a/commons/tenant-manager/core/types.go b/commons/tenant-manager/core/types.go
new file mode 100644
index 00000000..e2e0fbc8
--- /dev/null
+++ b/commons/tenant-manager/core/types.go
@@ -0,0 +1,253 @@
+// Package core provides shared types, errors, and context helpers used by all
+// tenant-manager sub-packages.
+package core
+
+import (
+ "sort"
+ "time"
+)
+
+// PostgreSQLConfig holds PostgreSQL connection configuration.
+// Credentials are provided directly by the tenant-manager settings endpoint.
+type PostgreSQLConfig struct {
+ Host string `json:"host"`
+ Port int `json:"port"`
+ Database string `json:"database"`
+ Username string `json:"username"`
+ Password string `json:"password"` // #nosec G117
+ Schema string `json:"schema,omitempty"`
+ SSLMode string `json:"sslmode,omitempty"`
+ SSLRootCert string `json:"sslrootcert,omitempty"` // path to CA certificate file
+ SSLCert string `json:"sslcert,omitempty"` // path to client certificate file
+ SSLKey string `json:"sslkey,omitempty"` // path to client private key file
+}
+
+// MongoDBConfig holds MongoDB connection configuration.
+// Credentials are provided directly by the tenant-manager settings endpoint.
+type MongoDBConfig struct {
+ Host string `json:"host,omitempty"`
+ Port int `json:"port,omitempty"`
+ Database string `json:"database"`
+ Username string `json:"username,omitempty"`
+ Password string `json:"password,omitempty"` // #nosec G117
+ URI string `json:"uri,omitempty"`
+ AuthSource string `json:"authSource,omitempty"`
+ DirectConnection bool `json:"directConnection,omitempty"`
+ MaxPoolSize uint64 `json:"maxPoolSize,omitempty"`
+ TLS bool `json:"tls,omitempty"`
+ TLSCAFile string `json:"tlsCAFile,omitempty"` // path to CA certificate file
+ TLSCertFile string `json:"tlsCertFile,omitempty"` // path to client certificate file
+ TLSKeyFile string `json:"tlsKeyFile,omitempty"` // path to client private key file
+ // TLSSkipVerify disables both certificate-chain validation and hostname
+ // verification (maps to MongoDB tlsInsecure). Use only in trusted environments;
+ // enabling this flag significantly increases the risk of man-in-the-middle attacks.
+ TLSSkipVerify bool `json:"tlsSkipVerify,omitempty"`
+}
+
+// RabbitMQConfig holds RabbitMQ connection configuration for tenant vhosts.
+type RabbitMQConfig struct {
+ Host string `json:"host"`
+ Port int `json:"port"`
+ VHost string `json:"vhost"`
+ Username string `json:"username"`
+ Password string `json:"password"` // #nosec G117
+ TLS *bool `json:"tls,omitempty"` // enable TLS (amqps://); nil = use global default
+ TLSCAFile string `json:"tlsCAFile,omitempty"` // path to CA certificate file for custom CAs
+}
+
+// MessagingConfig holds messaging configuration for a tenant.
+type MessagingConfig struct {
+ RabbitMQ *RabbitMQConfig `json:"rabbitmq,omitempty"`
+}
+
+// DatabaseConfig holds database configurations for a module (onboarding, transaction, etc.).
+// In the flat format returned by tenant-manager, the Databases map is keyed by module name
+// directly (e.g., "onboarding", "transaction"), without an intermediate service wrapper.
+type DatabaseConfig struct {
+ PostgreSQL *PostgreSQLConfig `json:"postgresql,omitempty"`
+ PostgreSQLReplica *PostgreSQLConfig `json:"postgresqlReplica,omitempty"`
+ MongoDB *MongoDBConfig `json:"mongodb,omitempty"`
+ ConnectionSettings *ConnectionSettings `json:"connectionSettings,omitempty"`
+}
+
+// ConnectionSettings holds per-tenant database connection pool settings.
+// When present in the tenant config response, these values override the global
+// defaults configured on the PostgresManager or MongoManager.
+// If nil (e.g., for older associations without settings), global defaults apply.
+type ConnectionSettings struct {
+ MaxOpenConns int `json:"maxOpenConns"`
+ MaxIdleConns int `json:"maxIdleConns"`
+}
+
+// TenantConfig represents the tenant configuration from Tenant Manager.
+// The Databases map is keyed by module name (e.g., "onboarding", "transaction").
+// This matches the flat format returned by the tenant-manager /v1/.../connections endpoint.
+type TenantConfig struct {
+ ID string `json:"id"`
+ TenantSlug string `json:"tenantSlug"`
+ TenantName string `json:"tenantName,omitempty"`
+ Service string `json:"service,omitempty"`
+ Status string `json:"status,omitempty"`
+ IsolationMode string `json:"isolationMode,omitempty"`
+ Databases map[string]DatabaseConfig `json:"databases,omitempty"`
+ Messaging *MessagingConfig `json:"messaging,omitempty"`
+ ConnectionSettings *ConnectionSettings `json:"connectionSettings,omitempty"`
+ CreatedAt time.Time `json:"createdAt,omitzero"`
+ UpdatedAt time.Time `json:"updatedAt,omitzero"`
+}
+
+// sortedDatabaseKeys returns the keys of the Databases map in sorted order.
+// This ensures deterministic behavior when module is empty.
+func sortedDatabaseKeys(databases map[string]DatabaseConfig) []string {
+ keys := make([]string, 0, len(databases))
+ for k := range databases {
+ keys = append(keys, k)
+ }
+
+ sort.Strings(keys)
+
+ return keys
+}
+
+// GetPostgreSQLConfig returns the PostgreSQL config for a module.
+// module: e.g., "onboarding", "transaction"
+// If module is empty, returns the first PostgreSQL config found (sorted by key for determinism).
+// The service parameter is accepted for backward compatibility but is ignored
+// since the flat format returned by tenant-manager keys databases by module directly.
+func (tc *TenantConfig) GetPostgreSQLConfig(service, module string) *PostgreSQLConfig {
+ if tc == nil {
+ return nil
+ }
+
+ if tc.Databases == nil {
+ return nil
+ }
+
+ if module != "" {
+ if db, ok := tc.Databases[module]; ok {
+ return db.PostgreSQL
+ }
+
+ return nil
+ }
+
+ // Return first PostgreSQL config found (deterministic: sorted by key)
+ keys := sortedDatabaseKeys(tc.Databases)
+ for _, key := range keys {
+ if db := tc.Databases[key]; db.PostgreSQL != nil {
+ return db.PostgreSQL
+ }
+ }
+
+ return nil
+}
+
+// GetPostgreSQLReplicaConfig returns the PostgreSQL replica config for a module.
+// module: e.g., "onboarding", "transaction"
+// If module is empty, returns the first PostgreSQL replica config found (sorted by key for determinism).
+// Returns nil if no replica is configured (callers should fall back to primary).
+// The service parameter is accepted for backward compatibility but is ignored
+// since the flat format returned by tenant-manager keys databases by module directly.
+func (tc *TenantConfig) GetPostgreSQLReplicaConfig(service, module string) *PostgreSQLConfig {
+ if tc == nil {
+ return nil
+ }
+
+ if tc.Databases == nil {
+ return nil
+ }
+
+ if module != "" {
+ if db, ok := tc.Databases[module]; ok {
+ return db.PostgreSQLReplica
+ }
+
+ return nil
+ }
+
+ // Return first PostgreSQL replica config found (deterministic: sorted by key)
+ keys := sortedDatabaseKeys(tc.Databases)
+ for _, key := range keys {
+ if db := tc.Databases[key]; db.PostgreSQLReplica != nil {
+ return db.PostgreSQLReplica
+ }
+ }
+
+ return nil
+}
+
+// GetMongoDBConfig returns the MongoDB config for a module.
+// module: e.g., "onboarding", "transaction"
+// If module is empty, returns the first MongoDB config found (sorted by key for determinism).
+// The service parameter is accepted for backward compatibility but is ignored
+// since the flat format returned by tenant-manager keys databases by module directly.
+func (tc *TenantConfig) GetMongoDBConfig(service, module string) *MongoDBConfig {
+ if tc == nil {
+ return nil
+ }
+
+ if tc.Databases == nil {
+ return nil
+ }
+
+ if module != "" {
+ if db, ok := tc.Databases[module]; ok {
+ return db.MongoDB
+ }
+
+ return nil
+ }
+
+ // Return first MongoDB config found (deterministic: sorted by key)
+ keys := sortedDatabaseKeys(tc.Databases)
+ for _, key := range keys {
+ if db := tc.Databases[key]; db.MongoDB != nil {
+ return db.MongoDB
+ }
+ }
+
+ return nil
+}
+
+// IsSchemaMode returns true if the tenant is configured for schema-based isolation.
+// In schema mode, all tenants share the same database but have separate schemas.
+func (tc *TenantConfig) IsSchemaMode() bool {
+ if tc == nil {
+ return false
+ }
+
+ return tc.IsolationMode == "schema"
+}
+
+// IsIsolatedMode returns true if the tenant has a dedicated database (isolated mode).
+// This is the default mode when IsolationMode is empty or explicitly set to "isolated" or "database".
+func (tc *TenantConfig) IsIsolatedMode() bool {
+ if tc == nil {
+ return false
+ }
+
+ return tc.IsolationMode == "" || tc.IsolationMode == "isolated" || tc.IsolationMode == "database"
+}
+
+// GetRabbitMQConfig returns the RabbitMQ config for the tenant.
+// Returns nil if messaging or RabbitMQ is not configured.
+func (tc *TenantConfig) GetRabbitMQConfig() *RabbitMQConfig {
+ if tc == nil {
+ return nil
+ }
+
+ if tc.Messaging == nil {
+ return nil
+ }
+
+ return tc.Messaging.RabbitMQ
+}
+
+// HasRabbitMQ returns true if the tenant has RabbitMQ configured.
+func (tc *TenantConfig) HasRabbitMQ() bool {
+ if tc == nil {
+ return false
+ }
+
+ return tc.GetRabbitMQConfig() != nil
+}
diff --git a/commons/tenant-manager/core/types_test.go b/commons/tenant-manager/core/types_test.go
new file mode 100644
index 00000000..46810864
--- /dev/null
+++ b/commons/tenant-manager/core/types_test.go
@@ -0,0 +1,607 @@
+package core
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// newTenantConfigFixture returns a fully populated TenantConfig with PostgreSQL,
+// PostgreSQL replica, and MongoDB configurations for two modules (onboarding
+// and transaction). Callers can override or nil-out fields for edge case tests.
+func newTenantConfigFixture() *TenantConfig {
+ return &TenantConfig{
+ ID: "tenant-fixture",
+ TenantSlug: "fixture-tenant",
+ Service: "ledger",
+ Status: "active",
+ IsolationMode: "database",
+ Databases: map[string]DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &PostgreSQLConfig{
+ Host: "onboarding-db.example.com",
+ Port: 5432,
+ },
+ PostgreSQLReplica: &PostgreSQLConfig{
+ Host: "onboarding-replica.example.com",
+ Port: 5433,
+ },
+ MongoDB: &MongoDBConfig{
+ Host: "onboarding-mongo.example.com",
+ Port: 27017,
+ Database: "onboarding_db",
+ },
+ },
+ "transaction": {
+ PostgreSQL: &PostgreSQLConfig{
+ Host: "transaction-db.example.com",
+ Port: 5432,
+ },
+ PostgreSQLReplica: &PostgreSQLConfig{
+ Host: "transaction-replica.example.com",
+ Port: 5433,
+ },
+ MongoDB: &MongoDBConfig{
+ Host: "transaction-mongo.example.com",
+ Port: 27017,
+ Database: "transaction_db",
+ },
+ },
+ },
+ }
+}
+
+func TestTenantConfig_GetPostgreSQLConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ config *TenantConfig
+ service string
+ module string
+ expectNil bool
+ expectedHost string
+ }{
+ {
+ name: "returns config for onboarding module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "onboarding",
+ expectedHost: "onboarding-db.example.com",
+ },
+ {
+ name: "returns config for transaction module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "transaction",
+ expectedHost: "transaction-db.example.com",
+ },
+ {
+ name: "returns nil for unknown module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "unknown",
+ expectNil: true,
+ },
+ {
+ name: "returns first config when module is empty",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "",
+ expectedHost: "", // non-nil but host depends on map iteration order
+ },
+ {
+ name: "returns nil when databases is nil",
+ config: &TenantConfig{},
+ service: "ledger",
+ module: "onboarding",
+ expectNil: true,
+ },
+ {
+ name: "service parameter is ignored in flat format",
+ config: newTenantConfigFixture(),
+ service: "audit",
+ module: "onboarding",
+ expectedHost: "onboarding-db.example.com",
+ },
+ {
+ name: "empty service still resolves module",
+ config: newTenantConfigFixture(),
+ service: "",
+ module: "onboarding",
+ expectedHost: "onboarding-db.example.com",
+ },
+ {
+ name: "returns nil when module exists but has no PostgreSQL config",
+ config: &TenantConfig{
+ Databases: map[string]DatabaseConfig{
+ "onboarding": {
+ MongoDB: &MongoDBConfig{Host: "mongo.example.com"},
+ },
+ },
+ },
+ service: "ledger",
+ module: "onboarding",
+ expectNil: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.config.GetPostgreSQLConfig(tt.service, tt.module)
+
+ if tt.expectNil {
+ assert.Nil(t, result)
+ return
+ }
+
+ require.NotNil(t, result)
+ if tt.expectedHost != "" {
+ assert.Equal(t, tt.expectedHost, result.Host)
+ }
+ })
+ }
+}
+
+func TestTenantConfig_GetPostgreSQLReplicaConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ config *TenantConfig
+ service string
+ module string
+ expectNil bool
+ expectedHost string
+ expectedPort int
+ }{
+ {
+ name: "returns replica config for onboarding module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "onboarding",
+ expectedHost: "onboarding-replica.example.com",
+ expectedPort: 5433,
+ },
+ {
+ name: "returns replica config for transaction module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "transaction",
+ expectedHost: "transaction-replica.example.com",
+ expectedPort: 5433,
+ },
+ {
+ name: "returns nil when replica not configured",
+ config: &TenantConfig{
+ Databases: map[string]DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &PostgreSQLConfig{
+ Host: "primary-db.example.com",
+ Port: 5432,
+ },
+ },
+ },
+ },
+ service: "ledger",
+ module: "onboarding",
+ expectNil: true,
+ },
+ {
+ name: "returns nil for unknown module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "unknown",
+ expectNil: true,
+ },
+ {
+ name: "returns first replica config when module is empty",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "",
+ expectedHost: "", // non-nil but host depends on map iteration order
+ },
+ {
+ name: "returns nil when databases is nil",
+ config: &TenantConfig{},
+ service: "ledger",
+ module: "onboarding",
+ expectNil: true,
+ },
+ {
+ name: "returns nil when module exists but has no replica config",
+ config: &TenantConfig{
+ Databases: map[string]DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &PostgreSQLConfig{Host: "primary.example.com"},
+ },
+ },
+ },
+ service: "ledger",
+ module: "onboarding",
+ expectNil: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.config.GetPostgreSQLReplicaConfig(tt.service, tt.module)
+
+ if tt.expectNil {
+ assert.Nil(t, result)
+ return
+ }
+
+ require.NotNil(t, result)
+ if tt.expectedHost != "" {
+ assert.Equal(t, tt.expectedHost, result.Host)
+ }
+ if tt.expectedPort != 0 {
+ assert.Equal(t, tt.expectedPort, result.Port)
+ }
+ })
+ }
+}
+
+func TestTenantConfig_GetMongoDBConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ config *TenantConfig
+ service string
+ module string
+ expectNil bool
+ expectedHost string
+ expectedDatabase string
+ }{
+ {
+ name: "returns config for onboarding module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "onboarding",
+ expectedHost: "onboarding-mongo.example.com",
+ expectedDatabase: "onboarding_db",
+ },
+ {
+ name: "returns config for transaction module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "transaction",
+ expectedHost: "transaction-mongo.example.com",
+ expectedDatabase: "transaction_db",
+ },
+ {
+ name: "returns nil for unknown module",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "unknown",
+ expectNil: true,
+ },
+ {
+ name: "returns first config when module is empty",
+ config: newTenantConfigFixture(),
+ service: "ledger",
+ module: "",
+ expectedHost: "", // non-nil but host depends on map iteration order
+ },
+ {
+ name: "returns nil when databases is nil",
+ config: &TenantConfig{},
+ service: "ledger",
+ module: "onboarding",
+ expectNil: true,
+ },
+ {
+ name: "returns nil when module exists but has no MongoDB config",
+ config: &TenantConfig{
+ Databases: map[string]DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &PostgreSQLConfig{Host: "pg.example.com"},
+ },
+ },
+ },
+ service: "ledger",
+ module: "onboarding",
+ expectNil: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.config.GetMongoDBConfig(tt.service, tt.module)
+
+ if tt.expectNil {
+ assert.Nil(t, result)
+ return
+ }
+
+ require.NotNil(t, result)
+ if tt.expectedHost != "" {
+ assert.Equal(t, tt.expectedHost, result.Host)
+ }
+ if tt.expectedDatabase != "" {
+ assert.Equal(t, tt.expectedDatabase, result.Database)
+ }
+ })
+ }
+}
+
+func TestTenantConfig_IsSchemaMode(t *testing.T) {
+ tests := []struct {
+ name string
+ config *TenantConfig
+ expected bool
+ }{
+ {
+ name: "returns true when isolation mode is schema",
+ config: &TenantConfig{IsolationMode: "schema"},
+ expected: true,
+ },
+ {
+ name: "returns false when isolation mode is isolated",
+ config: &TenantConfig{IsolationMode: "isolated"},
+ expected: false,
+ },
+ {
+ name: "returns false when isolation mode is empty",
+ config: &TenantConfig{IsolationMode: ""},
+ expected: false,
+ },
+ {
+ name: "returns false when isolation mode is unknown",
+ config: &TenantConfig{IsolationMode: "unknown"},
+ expected: false,
+ },
+ {
+ name: "returns false for nil receiver",
+ config: nil,
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.config.IsSchemaMode()
+
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTenantConfig_IsIsolatedMode(t *testing.T) {
+ tests := []struct {
+ name string
+ config *TenantConfig
+ expected bool
+ }{
+ {
+ name: "returns true when isolation mode is isolated",
+ config: &TenantConfig{IsolationMode: "isolated"},
+ expected: true,
+ },
+ {
+ name: "returns true when isolation mode is database",
+ config: &TenantConfig{IsolationMode: "database"},
+ expected: true,
+ },
+ {
+ name: "returns true when isolation mode is empty (default)",
+ config: &TenantConfig{IsolationMode: ""},
+ expected: true,
+ },
+ {
+ name: "returns false when isolation mode is schema",
+ config: &TenantConfig{IsolationMode: "schema"},
+ expected: false,
+ },
+ {
+ name: "returns false when isolation mode is unknown",
+ config: &TenantConfig{IsolationMode: "unknown"},
+ expected: false,
+ },
+ {
+ name: "returns false for nil receiver",
+ config: nil,
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.config.IsIsolatedMode()
+
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTenantConfig_GetRabbitMQConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ config *TenantConfig
+ expectNil bool
+ expectedVHost string
+ }{
+ {
+ name: "returns nil for nil receiver",
+ config: nil,
+ expectNil: true,
+ },
+ {
+ name: "returns nil when messaging is nil",
+ config: &TenantConfig{},
+ expectNil: true,
+ },
+ {
+ name: "returns nil when rabbitmq is nil in messaging",
+ config: &TenantConfig{
+ Messaging: &MessagingConfig{},
+ },
+ expectNil: true,
+ },
+ {
+ name: "returns config when rabbitmq is set",
+ config: &TenantConfig{
+ Messaging: &MessagingConfig{
+ RabbitMQ: &RabbitMQConfig{
+ Host: "rabbitmq.example.com",
+ Port: 5672,
+ VHost: "tenant-vhost",
+ },
+ },
+ },
+ expectedVHost: "tenant-vhost",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.config.GetRabbitMQConfig()
+
+ if tt.expectNil {
+ assert.Nil(t, result)
+ return
+ }
+
+ require.NotNil(t, result)
+ assert.Equal(t, tt.expectedVHost, result.VHost)
+ })
+ }
+}
+
+func TestTenantConfig_HasRabbitMQ(t *testing.T) {
+ tests := []struct {
+ name string
+ config *TenantConfig
+ expected bool
+ }{
+ {
+ name: "returns false for nil receiver",
+ config: nil,
+ expected: false,
+ },
+ {
+ name: "returns false when messaging is nil",
+ config: &TenantConfig{},
+ expected: false,
+ },
+ {
+ name: "returns false when rabbitmq is nil in messaging",
+ config: &TenantConfig{
+ Messaging: &MessagingConfig{},
+ },
+ expected: false,
+ },
+ {
+ name: "returns true when rabbitmq is configured",
+ config: &TenantConfig{
+ Messaging: &MessagingConfig{
+ RabbitMQ: &RabbitMQConfig{
+ Host: "rabbitmq.example.com",
+ Port: 5672,
+ VHost: "tenant-vhost",
+ },
+ },
+ },
+ expected: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tt.config.HasRabbitMQ()
+
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTenantConfig_ConnectionSettings(t *testing.T) {
+ t.Run("deserializes connectionSettings from JSON", func(t *testing.T) {
+ jsonData := `{
+ "id": "cfg-123",
+ "tenantSlug": "acme",
+ "isolationMode": "schema",
+ "connectionSettings": {
+ "maxOpenConns": 20,
+ "maxIdleConns": 10
+ },
+ "databases": {
+ "onboarding": {
+ "postgresql": {
+ "host": "localhost",
+ "port": 5432,
+ "database": "testdb",
+ "username": "user",
+ "password": "pass"
+ }
+ }
+ }
+ }`
+
+ var config TenantConfig
+ err := json.Unmarshal([]byte(jsonData), &config)
+
+ require.NoError(t, err)
+ require.NotNil(t, config.ConnectionSettings)
+ assert.Equal(t, 20, config.ConnectionSettings.MaxOpenConns)
+ assert.Equal(t, 10, config.ConnectionSettings.MaxIdleConns)
+ })
+
+ t.Run("connectionSettings is nil when not present in JSON", func(t *testing.T) {
+ jsonData := `{
+ "id": "cfg-123",
+ "tenantSlug": "acme",
+ "isolationMode": "schema",
+ "databases": {
+ "onboarding": {
+ "postgresql": {
+ "host": "localhost",
+ "port": 5432,
+ "database": "testdb",
+ "username": "user",
+ "password": "pass"
+ }
+ }
+ }
+ }`
+
+ var config TenantConfig
+ err := json.Unmarshal([]byte(jsonData), &config)
+
+ require.NoError(t, err)
+ assert.Nil(t, config.ConnectionSettings)
+ })
+
+ t.Run("connectionSettings with zero values deserializes correctly", func(t *testing.T) {
+ jsonData := `{
+ "id": "cfg-123",
+ "connectionSettings": {
+ "maxOpenConns": 0,
+ "maxIdleConns": 0
+ }
+ }`
+
+ var config TenantConfig
+ err := json.Unmarshal([]byte(jsonData), &config)
+
+ require.NoError(t, err)
+ require.NotNil(t, config.ConnectionSettings)
+ assert.Equal(t, 0, config.ConnectionSettings.MaxOpenConns)
+ assert.Equal(t, 0, config.ConnectionSettings.MaxIdleConns)
+ })
+
+ t.Run("connectionSettings with partial values deserializes correctly", func(t *testing.T) {
+ jsonData := `{
+ "id": "cfg-123",
+ "connectionSettings": {
+ "maxOpenConns": 30
+ }
+ }`
+
+ var config TenantConfig
+ err := json.Unmarshal([]byte(jsonData), &config)
+
+ require.NoError(t, err)
+ require.NotNil(t, config.ConnectionSettings)
+ assert.Equal(t, 30, config.ConnectionSettings.MaxOpenConns)
+ assert.Equal(t, 0, config.ConnectionSettings.MaxIdleConns)
+ })
+}
diff --git a/commons/tenant-manager/core/validation.go b/commons/tenant-manager/core/validation.go
new file mode 100644
index 00000000..8ffedadd
--- /dev/null
+++ b/commons/tenant-manager/core/validation.go
@@ -0,0 +1,22 @@
+package core
+
+import "regexp"
+
+// MaxTenantIDLength is the maximum allowed length for a tenant ID.
+const MaxTenantIDLength = 256
+
+// validTenantIDPattern enforces a character whitelist for tenant IDs.
+// Only alphanumeric characters, hyphens, and underscores are allowed.
+// The first character must be alphanumeric.
+var validTenantIDPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
+
+// IsValidTenantID validates a tenant ID against security constraints.
+// Valid tenant IDs must be non-empty, at most MaxTenantIDLength characters,
+// and match validTenantIDPattern.
+func IsValidTenantID(id string) bool {
+ if id == "" || len(id) > MaxTenantIDLength {
+ return false
+ }
+
+ return validTenantIDPattern.MatchString(id)
+}
diff --git a/commons/tenant-manager/internal/eviction/lru.go b/commons/tenant-manager/internal/eviction/lru.go
new file mode 100644
index 00000000..25703408
--- /dev/null
+++ b/commons/tenant-manager/internal/eviction/lru.go
@@ -0,0 +1,88 @@
+// Package eviction provides shared LRU eviction logic for multi-tenant
+// connection managers. Each manager (postgres, mongo, rabbitmq) delegates
+// the "find oldest idle candidate" decision to this package and keeps only
+// the technology-specific cleanup (closing the actual connection, removing
+// from manager-specific maps).
+package eviction
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+// DefaultIdleTimeout is the default duration before a tenant connection becomes
+// eligible for eviction. Connections accessed within this window are considered
+// active and will not be evicted, allowing the pool to grow beyond maxConnections.
+const DefaultIdleTimeout = 5 * time.Minute
+
+// FindLRUEvictionCandidate finds the oldest idle connection that exceeds the
+// idle timeout. It returns the ID to evict and true, or an empty string and
+// false if no eviction is needed.
+//
+// The function performs two checks before scanning:
+// 1. If maxConnections <= 0, eviction is disabled (unlimited pool) -- return immediately.
+// 2. If connectionCount < maxConnections, the pool has room -- return immediately.
+//
+// When eviction IS needed, the function iterates lastAccessed and selects the
+// entry with the oldest timestamp that has been idle longer than idleTimeout.
+// If all connections are active (used within the idle timeout), the pool is
+// allowed to grow beyond the soft limit and no eviction occurs.
+func FindLRUEvictionCandidate(
+ connectionCount int,
+ maxConnections int,
+ lastAccessed map[string]time.Time,
+ idleTimeout time.Duration,
+ logger log.Logger,
+) (string, bool) {
+ if maxConnections <= 0 || connectionCount < maxConnections {
+ return "", false
+ }
+
+ if idleTimeout <= 0 {
+ idleTimeout = DefaultIdleTimeout
+ }
+
+ now := time.Now()
+
+ var oldestID string
+
+ var oldestTime time.Time
+
+ for id, t := range lastAccessed {
+ idleDuration := now.Sub(t)
+ if idleDuration < idleTimeout {
+ continue
+ }
+
+ if oldestID == "" || t.Before(oldestTime) {
+ oldestID = id
+ oldestTime = t
+ }
+ }
+
+ if oldestID == "" {
+ if logger != nil {
+ logger.Log(context.Background(), log.LevelWarn,
+ "connection pool at capacity but no idle connections to evict",
+ log.Int("connection_count", connectionCount),
+ log.Int("max_connections", maxConnections),
+ )
+ }
+
+ return "", false
+ }
+
+ if logger != nil {
+ logger.Log(context.Background(), log.LevelInfo,
+ "evicting idle tenant connection",
+ log.String("tenant_id", oldestID),
+ log.String("idle_duration", fmt.Sprintf("%v", now.Sub(oldestTime))),
+ log.String("idle_timeout", fmt.Sprintf("%v", idleTimeout)),
+ )
+ }
+
+ return oldestID, true
+}
diff --git a/commons/tenant-manager/internal/eviction/lru_test.go b/commons/tenant-manager/internal/eviction/lru_test.go
new file mode 100644
index 00000000..17eb4549
--- /dev/null
+++ b/commons/tenant-manager/internal/eviction/lru_test.go
@@ -0,0 +1,450 @@
+package eviction
+
+import (
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestFindLRUEvictionCandidate_EmptyMap(t *testing.T) {
+ t.Parallel()
+
+ id, ok := FindLRUEvictionCandidate(
+ 5, // connectionCount
+ 5, // maxConnections (at capacity)
+ map[string]time.Time{}, // empty lastAccessed
+ time.Minute, // idleTimeout
+ testutil.NewMockLogger(),
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+}
+
+func TestFindLRUEvictionCandidate_SingleEntry(t *testing.T) {
+ t.Parallel()
+
+ t.Run("active entry is not evicted", func(t *testing.T) {
+ t.Parallel()
+
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-10 * time.Second), // recently accessed
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 1, // connectionCount
+ 1, // maxConnections (at capacity)
+ lastAccessed,
+ time.Minute, // idleTimeout = 1 min, entry only 10s old
+ testutil.NewMockLogger(),
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+ })
+
+ t.Run("idle entry is evicted", func(t *testing.T) {
+ t.Parallel()
+
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-10 * time.Minute), // idle for 10 minutes
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 1, // connectionCount
+ 1, // maxConnections (at capacity)
+ lastAccessed,
+ time.Minute, // idleTimeout = 1 min
+ testutil.NewMockLogger(),
+ )
+
+ assert.Equal(t, "tenant-1", id)
+ assert.True(t, ok)
+ })
+}
+
+func TestFindLRUEvictionCandidate_MultipleEntries(t *testing.T) {
+ t.Parallel()
+
+ t.Run("one idle among active entries", func(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now()
+ lastAccessed := map[string]time.Time{
+ "tenant-active-1": now.Add(-10 * time.Second), // active
+ "tenant-idle": now.Add(-10 * time.Minute), // idle
+ "tenant-active-2": now.Add(-30 * time.Second), // active
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 3,
+ 3,
+ lastAccessed,
+ time.Minute,
+ testutil.NewMockLogger(),
+ )
+
+ assert.Equal(t, "tenant-idle", id)
+ assert.True(t, ok)
+ })
+
+ t.Run("all idle returns the oldest", func(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now()
+ lastAccessed := map[string]time.Time{
+ "tenant-recent-idle": now.Add(-5 * time.Minute), // idle 5 min
+ "tenant-oldest-idle": now.Add(-30 * time.Minute), // idle 30 min (LRU)
+ "tenant-medium-idle": now.Add(-15 * time.Minute), // idle 15 min
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 3,
+ 3,
+ lastAccessed,
+ time.Minute,
+ testutil.NewMockLogger(),
+ )
+
+ require.True(t, ok)
+ assert.Equal(t, "tenant-oldest-idle", id)
+ })
+
+ t.Run("none idle allows pool to grow beyond limit", func(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now()
+ lastAccessed := map[string]time.Time{
+ "tenant-1": now.Add(-10 * time.Second),
+ "tenant-2": now.Add(-20 * time.Second),
+ "tenant-3": now.Add(-30 * time.Second),
+ "tenant-4": now.Add(-40 * time.Second),
+ }
+
+ // connectionCount (4) > maxConnections (3), but nothing is idle
+ id, ok := FindLRUEvictionCandidate(
+ 4,
+ 3,
+ lastAccessed,
+ time.Minute,
+ testutil.NewMockLogger(),
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+ })
+}
+
+func TestFindLRUEvictionCandidate_MaxConnectionsZero(t *testing.T) {
+ t.Parallel()
+
+ // maxConnections <= 0 disables eviction entirely (unlimited pool).
+ // Even idle entries should NOT be evicted.
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-1 * time.Hour), // very idle
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 10,
+ 0, // unlimited: no eviction
+ lastAccessed,
+ time.Minute,
+ testutil.NewMockLogger(),
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+}
+
+func TestFindLRUEvictionCandidate_MaxConnectionsNegative(t *testing.T) {
+ t.Parallel()
+
+ // Negative maxConnections is treated the same as zero (unlimited).
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-1 * time.Hour),
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 5,
+ -1,
+ lastAccessed,
+ time.Minute,
+ testutil.NewMockLogger(),
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+}
+
+func TestFindLRUEvictionCandidate_BelowCapacity(t *testing.T) {
+ t.Parallel()
+
+ // When connectionCount < maxConnections the pool has room -- no eviction.
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-1 * time.Hour), // very idle, but pool has room
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 1, // connectionCount
+ 10, // maxConnections (plenty of room)
+ lastAccessed,
+ time.Minute,
+ testutil.NewMockLogger(),
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+}
+
+func TestFindLRUEvictionCandidate_DefaultIdleTimeout(t *testing.T) {
+ t.Parallel()
+
+ // When idleTimeout is 0, the function defaults to DefaultIdleTimeout (5 min).
+ now := time.Now()
+ lastAccessed := map[string]time.Time{
+ "tenant-within-default": now.Add(-3 * time.Minute), // 3 min < 5 min default
+ "tenant-beyond-default": now.Add(-10 * time.Minute), // 10 min > 5 min default
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 2,
+ 2,
+ lastAccessed,
+ 0, // triggers default idle timeout
+ testutil.NewMockLogger(),
+ )
+
+ require.True(t, ok)
+ assert.Equal(t, "tenant-beyond-default", id)
+}
+
+func TestFindLRUEvictionCandidate_NilLogger(t *testing.T) {
+ t.Parallel()
+
+ t.Run("eviction found with nil logger", func(t *testing.T) {
+ t.Parallel()
+
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-10 * time.Minute),
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 1,
+ 1,
+ lastAccessed,
+ time.Minute,
+ nil, // nil logger -- must not panic
+ )
+
+ assert.Equal(t, "tenant-1", id)
+ assert.True(t, ok)
+ })
+
+ t.Run("no eviction candidate with nil logger", func(t *testing.T) {
+ t.Parallel()
+
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-10 * time.Second), // active
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 1,
+ 1,
+ lastAccessed,
+ time.Minute,
+ nil, // nil logger -- must not panic on warn log path
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+ })
+}
+
+func TestFindLRUEvictionCandidate_LogMessages(t *testing.T) {
+ t.Parallel()
+
+ t.Run("logs warning when at capacity but nothing to evict", func(t *testing.T) {
+ t.Parallel()
+
+ logger := testutil.NewCapturingLogger()
+ lastAccessed := map[string]time.Time{
+ "tenant-1": time.Now().Add(-5 * time.Second), // active
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 1,
+ 1,
+ lastAccessed,
+ time.Minute,
+ logger,
+ )
+
+ assert.Empty(t, id)
+ assert.False(t, ok)
+ assert.True(t, logger.ContainsSubstring("no idle connections to evict"),
+ "expected warning about no idle connections, got: %v", logger.GetMessages())
+ })
+
+ t.Run("logs info when evicting", func(t *testing.T) {
+ t.Parallel()
+
+ logger := testutil.NewCapturingLogger()
+ lastAccessed := map[string]time.Time{
+ "tenant-evicted": time.Now().Add(-10 * time.Minute),
+ }
+
+ id, ok := FindLRUEvictionCandidate(
+ 1,
+ 1,
+ lastAccessed,
+ time.Minute,
+ logger,
+ )
+
+ require.True(t, ok)
+ assert.Equal(t, "tenant-evicted", id)
+ assert.True(t, logger.ContainsSubstring("evicting idle tenant connection"),
+ "expected eviction info log, got: %v", logger.GetMessages())
+ assert.True(t, logger.ContainsSubstring("tenant-evicted"),
+ "expected tenant ID in log, got: %v", logger.GetMessages())
+ })
+}
+
+func TestFindLRUEvictionCandidate_TableDriven(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now()
+ idleTimeout := time.Minute
+
+ tests := []struct {
+ name string
+ connectionCount int
+ maxConnections int
+ lastAccessed map[string]time.Time
+ idleTimeout time.Duration
+ expectedID string
+ expectedOK bool
+ }{
+ {
+ name: "empty map at capacity",
+ connectionCount: 5,
+ maxConnections: 5,
+ lastAccessed: map[string]time.Time{},
+ idleTimeout: idleTimeout,
+ expectedID: "",
+ expectedOK: false,
+ },
+ {
+ name: "nil map at capacity",
+ connectionCount: 5,
+ maxConnections: 5,
+ lastAccessed: nil,
+ idleTimeout: idleTimeout,
+ expectedID: "",
+ expectedOK: false,
+ },
+ {
+ name: "below capacity with idle entries",
+ connectionCount: 2,
+ maxConnections: 5,
+ lastAccessed: map[string]time.Time{
+ "t1": now.Add(-10 * time.Minute),
+ },
+ idleTimeout: idleTimeout,
+ expectedID: "",
+ expectedOK: false,
+ },
+ {
+ name: "at capacity single idle",
+ connectionCount: 1,
+ maxConnections: 1,
+ lastAccessed: map[string]time.Time{
+ "t1": now.Add(-5 * time.Minute),
+ },
+ idleTimeout: idleTimeout,
+ expectedID: "t1",
+ expectedOK: true,
+ },
+ {
+ name: "above capacity selects oldest idle",
+ connectionCount: 5,
+ maxConnections: 3,
+ lastAccessed: map[string]time.Time{
+ "recent": now.Add(-2 * time.Minute),
+ "oldest": now.Add(-20 * time.Minute),
+ "middle": now.Add(-10 * time.Minute),
+ "active1": now.Add(-10 * time.Second),
+ "active2": now.Add(-30 * time.Second),
+ },
+ idleTimeout: idleTimeout,
+ expectedID: "oldest",
+ expectedOK: true,
+ },
+ {
+ name: "maxConnections zero disables eviction",
+ connectionCount: 100,
+ maxConnections: 0,
+ lastAccessed: map[string]time.Time{
+ "t1": now.Add(-1 * time.Hour),
+ },
+ idleTimeout: idleTimeout,
+ expectedID: "",
+ expectedOK: false,
+ },
+ {
+ name: "boundary: idle duration well under timeout is not evicted",
+ connectionCount: 1,
+ maxConnections: 1,
+ lastAccessed: map[string]time.Time{
+ // The eviction check uses `idleDuration < idleTimeout` (strictly
+ // less-than), so entries whose idle time equals the timeout ARE
+ // eligible. We place the entry comfortably under the threshold
+ // (30 second buffer) to avoid clock drift between the test's
+ // `now` and FindLRUEvictionCandidate's internal time.Now().
+ "t1": now.Add(-idleTimeout + 30*time.Second),
+ },
+ idleTimeout: idleTimeout,
+ expectedID: "",
+ expectedOK: false,
+ },
+ {
+ name: "boundary: idle duration just past timeout is evicted",
+ connectionCount: 1,
+ maxConnections: 1,
+ lastAccessed: map[string]time.Time{
+ "t1": now.Add(-idleTimeout - 30*time.Second),
+ },
+ idleTimeout: idleTimeout,
+ expectedID: "t1",
+ expectedOK: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ id, ok := FindLRUEvictionCandidate(
+ tt.connectionCount,
+ tt.maxConnections,
+ tt.lastAccessed,
+ tt.idleTimeout,
+ testutil.NewMockLogger(),
+ )
+
+ assert.Equal(t, tt.expectedOK, ok, "eviction decision mismatch")
+ assert.Equal(t, tt.expectedID, id, "evicted tenant ID mismatch")
+ })
+ }
+}
+
+func TestDefaultIdleTimeout(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, 5*time.Minute, DefaultIdleTimeout,
+ "DefaultIdleTimeout should be 5 minutes")
+}
diff --git a/commons/tenant-manager/internal/logcompat/logger.go b/commons/tenant-manager/internal/logcompat/logger.go
new file mode 100644
index 00000000..e9767ec9
--- /dev/null
+++ b/commons/tenant-manager/internal/logcompat/logger.go
@@ -0,0 +1,196 @@
+package logcompat
+
+import (
+ "context"
+ "fmt"
+
+ liblog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ tmlog "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/log"
+)
+
+type Logger struct {
+ base liblog.Logger
+}
+
+func New(logger liblog.Logger) *Logger {
+ if logger == nil {
+ logger = liblog.NewNop()
+ }
+
+ return &Logger{base: tmlog.NewTenantAwareLogger(logger)}
+}
+
+func (l *Logger) WithFields(kv ...any) *Logger {
+ if l == nil || l.base == nil {
+ return New(nil)
+ }
+
+ return &Logger{base: l.base.With(toFields(kv...)...)}
+}
+
+func (l *Logger) enabled(level liblog.Level) bool {
+ return l != nil && l.base != nil && l.base.Enabled(level)
+}
+
+func (l *Logger) log(ctx context.Context, level liblog.Level, msg string) {
+ if l == nil || l.base == nil {
+ return
+ }
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ l.base.Log(ctx, level, msg)
+}
+
+func (l *Logger) InfoCtx(ctx context.Context, args ...any) {
+ if !l.enabled(liblog.LevelInfo) {
+ return
+ }
+
+ l.log(ctx, liblog.LevelInfo, fmt.Sprint(args...))
+}
+
+func (l *Logger) WarnCtx(ctx context.Context, args ...any) {
+ if !l.enabled(liblog.LevelWarn) {
+ return
+ }
+
+ l.log(ctx, liblog.LevelWarn, fmt.Sprint(args...))
+}
+
+func (l *Logger) ErrorCtx(ctx context.Context, args ...any) {
+ if !l.enabled(liblog.LevelError) {
+ return
+ }
+
+ l.log(ctx, liblog.LevelError, fmt.Sprint(args...))
+}
+
+func (l *Logger) InfofCtx(ctx context.Context, f string, args ...any) {
+ if !l.enabled(liblog.LevelInfo) {
+ return
+ }
+
+ l.log(ctx, liblog.LevelInfo, fmt.Sprintf(f, args...))
+}
+
+func (l *Logger) WarnfCtx(ctx context.Context, f string, args ...any) {
+ if !l.enabled(liblog.LevelWarn) {
+ return
+ }
+
+ l.log(ctx, liblog.LevelWarn, fmt.Sprintf(f, args...))
+}
+
+func (l *Logger) ErrorfCtx(ctx context.Context, f string, args ...any) {
+ if !l.enabled(liblog.LevelError) {
+ return
+ }
+
+ l.log(ctx, liblog.LevelError, fmt.Sprintf(f, args...))
+}
+
+func (l *Logger) Info(args ...any) {
+ if !l.enabled(liblog.LevelInfo) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelInfo, fmt.Sprint(args...))
+}
+
+func (l *Logger) Warn(args ...any) {
+ if !l.enabled(liblog.LevelWarn) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelWarn, fmt.Sprint(args...))
+}
+
+func (l *Logger) Error(args ...any) {
+ if !l.enabled(liblog.LevelError) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelError, fmt.Sprint(args...))
+}
+
+func (l *Logger) Debug(args ...any) {
+ if !l.enabled(liblog.LevelDebug) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelDebug, fmt.Sprint(args...))
+}
+
+func (l *Logger) Infof(f string, args ...any) {
+ if !l.enabled(liblog.LevelInfo) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelInfo, fmt.Sprintf(f, args...))
+}
+
+func (l *Logger) Warnf(f string, args ...any) {
+ if !l.enabled(liblog.LevelWarn) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelWarn, fmt.Sprintf(f, args...))
+}
+
+func (l *Logger) Errorf(f string, args ...any) {
+ if !l.enabled(liblog.LevelError) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelError, fmt.Sprintf(f, args...))
+}
+
+func (l *Logger) Debugf(f string, args ...any) {
+ if !l.enabled(liblog.LevelDebug) {
+ return
+ }
+
+ l.log(context.Background(), liblog.LevelDebug, fmt.Sprintf(f, args...))
+}
+
+func (l *Logger) Sync() error {
+ if l == nil || l.base == nil {
+ return nil
+ }
+
+ return l.base.Sync(context.Background())
+}
+
+func (l *Logger) Base() liblog.Logger {
+ if l == nil || l.base == nil {
+ return liblog.NewNop()
+ }
+
+ return l.base
+}
+
+func toFields(kv ...any) []liblog.Field {
+ if len(kv) == 0 {
+ return nil
+ }
+
+ fields := make([]liblog.Field, 0, (len(kv)+1)/2)
+ for i := 0; i < len(kv); i += 2 {
+ key := fmt.Sprintf("arg_%d", i)
+ if ks, ok := kv[i].(string); ok && ks != "" {
+ key = ks
+ }
+
+ if i+1 >= len(kv) {
+ fields = append(fields, liblog.Any(key, nil))
+ continue
+ }
+
+ fields = append(fields, liblog.Any(key, kv[i+1]))
+ }
+
+ return fields
+}
diff --git a/commons/tenant-manager/internal/testutil/logger.go b/commons/tenant-manager/internal/testutil/logger.go
new file mode 100644
index 00000000..1cd7f213
--- /dev/null
+++ b/commons/tenant-manager/internal/testutil/logger.go
@@ -0,0 +1,87 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+// Package testutil provides shared test helpers for the tenant-manager
+// sub-packages, eliminating duplicated mock implementations across test files.
+package testutil
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+)
+
+// NewMockLogger returns a no-op logger that satisfies log.Logger.
+// It delegates to log.NewNop() to avoid duplicating the standard no-op implementation.
+func NewMockLogger() log.Logger {
+ return log.NewNop()
+}
+
+// CapturingLogger implements log.Logger and captures log messages for assertion.
+// This enables verifying log output content in tests (e.g., connection_mode=lazy).
+// Messages are private to prevent unsafe concurrent access; use GetMessages() or
+// ContainsSubstring() for thread-safe reads.
+type CapturingLogger struct {
+ mu sync.Mutex
+ messages []string
+}
+
+func (cl *CapturingLogger) record(msg string) {
+ cl.mu.Lock()
+ defer cl.mu.Unlock()
+
+ cl.messages = append(cl.messages, msg)
+}
+
+// GetMessages returns a thread-safe copy of all captured messages.
+func (cl *CapturingLogger) GetMessages() []string {
+ cl.mu.Lock()
+ defer cl.mu.Unlock()
+
+ copied := make([]string, len(cl.messages))
+ copy(copied, cl.messages)
+
+ return copied
+}
+
+// ContainsSubstring returns true if any captured message contains the given substring.
+func (cl *CapturingLogger) ContainsSubstring(sub string) bool {
+ cl.mu.Lock()
+ defer cl.mu.Unlock()
+
+ for _, msg := range cl.messages {
+ if strings.Contains(msg, sub) {
+ return true
+ }
+ }
+
+ return false
+}
+
+func (cl *CapturingLogger) Log(_ context.Context, _ log.Level, msg string, fields ...log.Field) {
+ if len(fields) == 0 {
+ cl.record(msg)
+
+ return
+ }
+
+ parts := make([]string, 0, len(fields))
+ for _, field := range fields {
+ parts = append(parts, fmt.Sprintf("%s=%v", field.Key, field.Value))
+ }
+
+ cl.record(fmt.Sprintf("%s %s", msg, strings.Join(parts, " ")))
+}
+func (cl *CapturingLogger) With(_ ...log.Field) log.Logger { return cl }
+func (cl *CapturingLogger) WithGroup(_ string) log.Logger { return cl }
+func (cl *CapturingLogger) Enabled(_ log.Level) bool { return true }
+func (cl *CapturingLogger) Sync(_ context.Context) error { return nil }
+
+// NewCapturingLogger returns a new CapturingLogger that records all log messages.
+func NewCapturingLogger() *CapturingLogger {
+ return &CapturingLogger{}
+}
diff --git a/commons/tenant-manager/log/tenant_logger.go b/commons/tenant-manager/log/tenant_logger.go
new file mode 100644
index 00000000..ddb500d4
--- /dev/null
+++ b/commons/tenant-manager/log/tenant_logger.go
@@ -0,0 +1,44 @@
+package log
+
+import (
+ "context"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ tmcore "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+)
+
+type TenantAwareLogger struct {
+ base log.Logger
+}
+
+func NewTenantAwareLogger(base log.Logger) *TenantAwareLogger {
+ return &TenantAwareLogger{base: base}
+}
+
+func (l *TenantAwareLogger) Log(ctx context.Context, level log.Level, msg string, fields ...log.Field) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if tenantID := tmcore.GetTenantIDFromContext(ctx); tenantID != "" {
+ fields = append(fields, log.String("tenant_id", tenantID))
+ }
+
+ l.base.Log(ctx, level, msg, fields...)
+}
+
+func (l *TenantAwareLogger) With(fields ...log.Field) log.Logger {
+ return l.base.With(fields...)
+}
+
+func (l *TenantAwareLogger) WithGroup(name string) log.Logger {
+ return l.base.WithGroup(name)
+}
+
+func (l *TenantAwareLogger) Enabled(level log.Level) bool {
+ return l.base.Enabled(level)
+}
+
+func (l *TenantAwareLogger) Sync(ctx context.Context) error {
+ return l.base.Sync(ctx)
+}
diff --git a/commons/tenant-manager/log/tenant_logger_test.go b/commons/tenant-manager/log/tenant_logger_test.go
new file mode 100644
index 00000000..be713192
--- /dev/null
+++ b/commons/tenant-manager/log/tenant_logger_test.go
@@ -0,0 +1,154 @@
+package log
+
+import (
+ "context"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
+)
+
+func TestTenantAwareLogger_Log(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ t.Run("injects tenant_id when present in context", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+
+ var capturedFields []log.Field
+
+ mockLogger.EXPECT().
+ Log(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
+ DoAndReturn(func(ctx context.Context, level log.Level, msg string, fields ...log.Field) {
+ capturedFields = fields
+ })
+
+ logger := NewTenantAwareLogger(mockLogger)
+ ctx := core.SetTenantIDInContext(context.Background(), "tenant-123")
+
+ logger.Log(ctx, log.LevelInfo, "test message", log.String("key", "value"))
+
+ require.Len(t, capturedFields, 2)
+ assert.Equal(t, "key", capturedFields[0].Key)
+ assert.Equal(t, "value", capturedFields[0].Value)
+ assert.Equal(t, "tenant_id", capturedFields[1].Key)
+ assert.Equal(t, "tenant-123", capturedFields[1].Value)
+ })
+
+ t.Run("works normally when tenant_id is not in context", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+
+ var capturedFields []log.Field
+
+ mockLogger.EXPECT().
+ Log(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
+ DoAndReturn(func(ctx context.Context, level log.Level, msg string, fields ...log.Field) {
+ capturedFields = fields
+ })
+
+ logger := NewTenantAwareLogger(mockLogger)
+ ctx := context.Background()
+
+ logger.Log(ctx, log.LevelInfo, "test message", log.String("key", "value"))
+
+ require.Len(t, capturedFields, 1)
+ assert.Equal(t, "key", capturedFields[0].Key)
+ assert.Equal(t, "value", capturedFields[0].Value)
+ })
+
+ t.Run("does not overwrite caller-provided tenant_id field", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+
+ var capturedFields []log.Field
+
+ mockLogger.EXPECT().
+ Log(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
+ DoAndReturn(func(ctx context.Context, level log.Level, msg string, fields ...log.Field) {
+ capturedFields = fields
+ })
+
+ logger := NewTenantAwareLogger(mockLogger)
+ ctx := core.SetTenantIDInContext(context.Background(), "tenant-123")
+
+ logger.Log(ctx, log.LevelInfo, "test message",
+ log.String("tenant_id", "caller-tenant"),
+ log.String("key", "value"),
+ )
+
+ require.Len(t, capturedFields, 3)
+ assert.Equal(t, "tenant_id", capturedFields[0].Key)
+ assert.Equal(t, "caller-tenant", capturedFields[0].Value)
+ assert.Equal(t, "key", capturedFields[1].Key)
+ assert.Equal(t, "value", capturedFields[1].Value)
+ assert.Equal(t, "tenant_id", capturedFields[2].Key)
+ assert.Equal(t, "tenant-123", capturedFields[2].Value)
+ })
+
+ t.Run("nil context handled gracefully", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+
+ mockLogger.EXPECT().
+ Log(gomock.Any(), gomock.Any(), gomock.Any()).
+ DoAndReturn(func(ctx context.Context, level log.Level, msg string, fields ...log.Field) {
+ assert.NotNil(t, ctx, "base logger should receive non-nil context")
+ })
+
+ logger := NewTenantAwareLogger(mockLogger)
+
+ logger.Log(nil, log.LevelInfo, "test message")
+ })
+}
+
+func TestTenantAwareLogger_OtherMethods(t *testing.T) {
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ t.Run("With delegates to base logger", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+ wrappedLogger := log.NewMockLogger(ctrl)
+
+ mockLogger.EXPECT().With(log.String("key", "value")).Return(wrappedLogger)
+
+ logger := NewTenantAwareLogger(mockLogger)
+ result := logger.With(log.String("key", "value"))
+
+ assert.Equal(t, wrappedLogger, result)
+ })
+
+ t.Run("WithGroup delegates to base logger", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+ wrappedLogger := log.NewMockLogger(ctrl)
+
+ mockLogger.EXPECT().WithGroup("group").Return(wrappedLogger)
+
+ logger := NewTenantAwareLogger(mockLogger)
+ result := logger.WithGroup("group")
+
+ assert.Equal(t, wrappedLogger, result)
+ })
+
+ t.Run("Enabled delegates to base logger", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+
+ mockLogger.EXPECT().Enabled(log.LevelInfo).Return(true)
+
+ logger := NewTenantAwareLogger(mockLogger)
+ result := logger.Enabled(log.LevelInfo)
+
+ assert.True(t, result)
+ })
+
+ t.Run("Sync delegates to base logger", func(t *testing.T) {
+ mockLogger := log.NewMockLogger(ctrl)
+
+ mockLogger.EXPECT().Sync(gomock.Any()).Return(nil)
+
+ logger := NewTenantAwareLogger(mockLogger)
+ err := logger.Sync(context.Background())
+
+ assert.NoError(t, err)
+ })
+}
diff --git a/commons/tenant-manager/middleware/multi_pool.go b/commons/tenant-manager/middleware/multi_pool.go
new file mode 100644
index 00000000..f18cbee7
--- /dev/null
+++ b/commons/tenant-manager/middleware/multi_pool.go
@@ -0,0 +1,470 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+package middleware
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libHTTP "github.com/LerianStudio/lib-commons/v4/commons/net/http"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ tmmongo "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/mongo"
+ tmpostgres "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/postgres"
+ "github.com/gofiber/fiber/v2"
+ "github.com/golang-jwt/jwt/v5"
+ "go.opentelemetry.io/otel/trace"
+)
+
+// ErrorMapper converts tenant-manager errors into Fiber HTTP responses.
+// If nil, the default error mapping is used.
+type ErrorMapper func(c *fiber.Ctx, err error, tenantID string) error
+
+// PoolRoute defines a path-based route to a module's database pools.
+// Each route maps one or more URL path prefixes to a specific module and
+// its associated PostgreSQL and/or MongoDB tenant connection managers.
+type PoolRoute struct {
+ paths []string
+ module string
+ pgPool *tmpostgres.Manager
+ mongoPool *tmmongo.Manager
+}
+
+// MultiPoolOption configures a MultiPoolMiddleware.
+type MultiPoolOption func(*MultiPoolMiddleware)
+
+// MultiPoolMiddleware routes requests to module-specific tenant pools
+// based on URL path matching. It handles JWT extraction, pool resolution,
+// connection injection, and error mapping.
+type MultiPoolMiddleware struct {
+ routes []*PoolRoute
+ defaultRoute *PoolRoute
+ publicPaths []string
+ crossModule bool
+ errorMapper ErrorMapper
+ logger *logcompat.Logger
+ enabled bool
+}
+
+// WithRoute registers a path-based route mapping URL prefixes to a module's
+// database pools. Multiple routes can be registered; the first matching route
+// wins. The paths parameter contains URL path prefixes to match against.
+func WithRoute(paths []string, module string, pgPool *tmpostgres.Manager, mongoPool *tmmongo.Manager) MultiPoolOption {
+ return func(m *MultiPoolMiddleware) {
+ m.routes = append(m.routes, &PoolRoute{
+ paths: paths,
+ module: module,
+ pgPool: pgPool,
+ mongoPool: mongoPool,
+ })
+ }
+}
+
+// WithDefaultRoute registers a fallback route used when no path-based route
+// matches. If no default is set and no route matches, the middleware passes
+// through to the next handler.
+func WithDefaultRoute(module string, pgPool *tmpostgres.Manager, mongoPool *tmmongo.Manager) MultiPoolOption {
+ return func(m *MultiPoolMiddleware) {
+ m.defaultRoute = &PoolRoute{
+ module: module,
+ pgPool: pgPool,
+ mongoPool: mongoPool,
+ }
+ }
+}
+
+// WithPublicPaths registers URL path prefixes that bypass tenant resolution.
+// Requests matching any of the given prefixes skip JWT extraction and proceed
+// directly to the next handler.
+func WithPublicPaths(paths ...string) MultiPoolOption {
+ return func(m *MultiPoolMiddleware) {
+ m.publicPaths = append(m.publicPaths, paths...)
+ }
+}
+
+// WithCrossModuleInjection enables resolution of database connections for all
+// registered routes, not just the matched one. This is useful when a request
+// handler needs access to multiple module databases (e.g., cross-module queries).
+func WithCrossModuleInjection() MultiPoolOption {
+ return func(m *MultiPoolMiddleware) {
+ m.crossModule = true
+ }
+}
+
+// WithErrorMapper sets a custom error mapper function that converts tenant-manager
+// errors into Fiber HTTP responses. When nil (the default), the built-in
+// mapDefaultError is used.
+func WithErrorMapper(fn ErrorMapper) MultiPoolOption {
+ return func(m *MultiPoolMiddleware) {
+ m.errorMapper = fn
+ }
+}
+
+// WithMultiPoolLogger sets the logger for the MultiPoolMiddleware.
+// When not set, the middleware extracts the logger from request context.
+func WithMultiPoolLogger(l log.Logger) MultiPoolOption {
+ return func(m *MultiPoolMiddleware) {
+ m.logger = logcompat.New(l)
+ }
+}
+
+// NewMultiPoolMiddleware creates a new MultiPoolMiddleware with the given options.
+// The middleware is enabled if at least one route has a PG or Mongo pool with
+// IsMultiTenant() == true.
+// By default, health probe paths (/healthz, /readyz, /livez, /health) are public
+// and bypass JWT extraction. Additional paths can be added via WithPublicPaths().
+func NewMultiPoolMiddleware(opts ...MultiPoolOption) *MultiPoolMiddleware {
+ m := &MultiPoolMiddleware{
+ publicPaths: []string{"/healthz", "/readyz", "/livez", "/health"},
+ }
+
+ for _, opt := range opts {
+ opt(m)
+ }
+
+ // Enable if at least one route has a multi-tenant PG or Mongo pool
+ for _, route := range m.routes {
+ if (route.pgPool != nil && route.pgPool.IsMultiTenant()) ||
+ (route.mongoPool != nil && route.mongoPool.IsMultiTenant()) {
+ m.enabled = true
+
+ break
+ }
+ }
+
+ if !m.enabled && m.defaultRoute != nil {
+ if (m.defaultRoute.pgPool != nil && m.defaultRoute.pgPool.IsMultiTenant()) ||
+ (m.defaultRoute.mongoPool != nil && m.defaultRoute.mongoPool.IsMultiTenant()) {
+ m.enabled = true
+ }
+ }
+
+ return m
+}
+
+// WithTenantDB is a Fiber handler that extracts tenant context from JWT,
+// resolves the appropriate database connections based on URL path matching,
+// and stores them in the request context for downstream handlers.
+func (m *MultiPoolMiddleware) WithTenantDB(c *fiber.Ctx) error {
+ // Step 1: Public path check
+ if m.isPublicPath(c.Path()) {
+ return c.Next()
+ }
+
+ // Step 2: Route matching
+ route := m.matchRoute(c.Path())
+ if route == nil {
+ return c.Next()
+ }
+
+ // Step 3: Multi-tenant check — skip only if neither pool is multi-tenant
+ pgEnabled := route.pgPool != nil && route.pgPool.IsMultiTenant()
+ mongoEnabled := route.mongoPool != nil && route.mongoPool.IsMultiTenant()
+
+ if !pgEnabled && !mongoEnabled {
+ return c.Next()
+ }
+
+ // Step 4: Extract context + telemetry
+ ctx := m.initializeTracingContext(c)
+
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "middleware.multi_pool.with_tenant_db")
+ defer span.End()
+
+ // Step 5: Extract tenant ID from JWT
+ tenantID, err := m.extractTenantID(c)
+ if err != nil {
+ logger.ErrorCtx(ctx, fmt.Sprintf("failed to extract tenant ID: %v", err))
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "failed to extract tenant ID", err)
+
+ return m.handleTenantDBError(c, err, "")
+ }
+
+ logger.InfofCtx(ctx, "multi-pool tenant resolved: tenantID=%s, module=%s, path=%s",
+ tenantID, route.module, c.Path())
+
+ // Step 6: Set tenant ID in context
+ ctx = core.ContextWithTenantID(ctx, tenantID)
+
+ // Step 7: Resolve database connections BEFORE triggering consumer.
+ // This ensures the tenant is actually resolvable (not suspended/purged)
+ // before we start consuming messages for it.
+ ctx, err = m.resolveAllConnections(ctx, route, tenantID, pgEnabled, mongoEnabled, logger, span)
+ if err != nil {
+ return m.handleTenantDBError(c, err, tenantID)
+ }
+
+ // Step 8: Update context
+ c.SetUserContext(ctx)
+
+ logger.InfofCtx(ctx, "multi-pool connections injected: tenantID=%s, module=%s", tenantID, route.module)
+
+ return c.Next()
+}
+
+// initializeTracingContext extracts HTTP trace context from the Fiber request,
+// falling back to a background context if neither source provides one.
+func (m *MultiPoolMiddleware) initializeTracingContext(c *fiber.Ctx) context.Context {
+ baseCtx := c.UserContext()
+ if baseCtx == nil {
+ baseCtx = context.Background()
+ }
+
+ ctx := libOpentelemetry.ExtractHTTPContext(baseCtx, c)
+ if ctx == nil {
+ ctx = baseCtx
+ }
+
+ return ctx
+}
+
+// handleTenantDBError dispatches the error through the custom error mapper if
+// configured, otherwise falls back to the default error mapping. For empty
+// tenantID (auth errors), it returns a generic 401 when no mapper is set.
+func (m *MultiPoolMiddleware) handleTenantDBError(c *fiber.Ctx, err error, tenantID string) error {
+ if m.errorMapper != nil {
+ return m.errorMapper(c, err, tenantID)
+ }
+
+ if tenantID == "" {
+ return unauthorizedError(c, "UNAUTHORIZED", "Unauthorized")
+ }
+
+ return m.mapDefaultError(c, err, tenantID)
+}
+
+// resolveAllConnections resolves PG, cross-module, and Mongo connections for the
+// matched route and tenant. It returns the enriched context or the first error.
+func (m *MultiPoolMiddleware) resolveAllConnections(
+ ctx context.Context,
+ route *PoolRoute,
+ tenantID string,
+ pgEnabled, mongoEnabled bool,
+ logger *logcompat.Logger,
+ span trace.Span,
+) (context.Context, error) {
+ var err error
+
+ if pgEnabled {
+ ctx, err = m.resolvePGConnection(ctx, route, tenantID, logger, span)
+ if err != nil {
+ return ctx, err
+ }
+ }
+
+ if m.crossModule {
+ ctx = m.resolveCrossModuleConnections(ctx, route, tenantID, logger)
+ }
+
+ if mongoEnabled {
+ ctx, err = m.resolveMongoConnection(ctx, route, tenantID, logger, span)
+ if err != nil {
+ return ctx, err
+ }
+ }
+
+ return ctx, nil
+}
+
+// matchRoute finds the PoolRoute whose paths match the request path.
+// Returns the defaultRoute if no specific route matches, or nil if no
+// default is configured.
+func (m *MultiPoolMiddleware) matchRoute(path string) *PoolRoute {
+ for _, route := range m.routes {
+ for _, prefix := range route.paths {
+ if path == prefix || strings.HasPrefix(path, prefix+"/") {
+ return route
+ }
+ }
+ }
+
+ return m.defaultRoute
+}
+
+// isPublicPath checks whether the given path matches any registered public
+// path prefix. Public paths bypass all tenant resolution logic.
+func (m *MultiPoolMiddleware) isPublicPath(path string) bool {
+ for _, prefix := range m.publicPaths {
+ if path == prefix || strings.HasPrefix(path, prefix+"/") {
+ return true
+ }
+ }
+
+ return false
+}
+
+// extractTenantID extracts the tenant ID from the JWT token in the
+// Authorization header. Token signature is validated by upstream auth middleware.
+func (m *MultiPoolMiddleware) extractTenantID(c *fiber.Ctx) (string, error) {
+ accessToken := libHTTP.ExtractTokenFromHeader(c)
+ if accessToken == "" {
+ return "", core.ErrAuthorizationTokenRequired
+ }
+
+ token, _, err := new(jwt.Parser).ParseUnverified(accessToken, jwt.MapClaims{})
+ if err != nil {
+ return "", fmt.Errorf("%w: %w", core.ErrInvalidAuthorizationToken, err)
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ return "", core.ErrInvalidTenantClaims
+ }
+
+ tenantID, _ := claims["tenantId"].(string)
+ if tenantID == "" {
+ return "", core.ErrMissingTenantIDClaim
+ }
+
+ if !core.IsValidTenantID(tenantID) {
+ return "", core.ErrInvalidTenantClaims
+ }
+
+ return tenantID, nil
+}
+
+// resolvePGConnection resolves the PostgreSQL connection for the given route
+// and tenant, injecting it into the context using module-scoped context keys.
+func (m *MultiPoolMiddleware) resolvePGConnection(
+ ctx context.Context,
+ route *PoolRoute,
+ tenantID string,
+ logger *logcompat.Logger,
+ span trace.Span,
+) (context.Context, error) {
+ conn, err := route.pgPool.GetConnection(ctx, tenantID)
+ if err != nil {
+ logger.ErrorCtx(ctx, fmt.Sprintf("failed to get tenant PostgreSQL connection: module=%s, tenantID=%s, error=%v", route.module, tenantID, err))
+ libOpentelemetry.HandleSpanError(span, "failed to get tenant PostgreSQL connection", err)
+
+ return ctx, fmt.Errorf("%w: %w", core.ErrConnectionFailed, err)
+ }
+
+ db, err := conn.GetDB()
+ if err != nil {
+ logger.ErrorCtx(ctx, fmt.Sprintf("failed to get database from PostgreSQL connection: module=%s, tenantID=%s, error=%v", route.module, tenantID, err))
+ libOpentelemetry.HandleSpanError(span, "failed to get database from PostgreSQL connection", err)
+
+ return ctx, fmt.Errorf("%w: %w", core.ErrConnectionFailed, err)
+ }
+
+ ctx = core.ContextWithModulePGConnection(ctx, route.module, db)
+
+ return ctx, nil
+}
+
+// resolveCrossModuleConnections resolves PG connections for all routes other
+// than the matched one. Errors are logged but do not block the request.
+func (m *MultiPoolMiddleware) resolveCrossModuleConnections(
+ ctx context.Context,
+ matchedRoute *PoolRoute,
+ tenantID string,
+ logger *logcompat.Logger,
+) context.Context {
+ for _, route := range m.routes {
+ if route == matchedRoute || route.pgPool == nil || !route.pgPool.IsMultiTenant() {
+ continue
+ }
+
+ ctx = m.resolveAndInjectCrossModule(ctx, route, tenantID, logger) //nolint:fatcontext // intentional accumulation of per-module connections into ctx across iterations
+ }
+
+ // Also resolve default route if it differs from matched
+ if m.defaultRoute != nil && m.defaultRoute != matchedRoute &&
+ m.defaultRoute.pgPool != nil && m.defaultRoute.pgPool.IsMultiTenant() {
+ ctx = m.resolveAndInjectCrossModule(ctx, m.defaultRoute, tenantID, logger)
+ }
+
+ return ctx
+}
+
+// crossModuleErrorKey is a context key for storing cross-module resolution errors.
+type crossModuleErrorKey struct{}
+
+// ContextWithCrossModuleError stores a cross-module resolution error in context
+// so downstream handlers can inspect it if needed.
+func ContextWithCrossModuleError(ctx context.Context, err error) context.Context {
+ return context.WithValue(ctx, crossModuleErrorKey{}, err)
+}
+
+// CrossModuleErrorFromContext retrieves the cross-module resolution error, if any.
+func CrossModuleErrorFromContext(ctx context.Context) error {
+ if err, ok := ctx.Value(crossModuleErrorKey{}).(error); ok {
+ return err
+ }
+
+ return nil
+}
+
+// resolveAndInjectCrossModule resolves a single cross-module PG connection and
+// injects it into the context. Errors are logged and stored in context for
+// downstream visibility, but do not block the request.
+func (m *MultiPoolMiddleware) resolveAndInjectCrossModule(
+ ctx context.Context,
+ route *PoolRoute,
+ tenantID string,
+ logger *logcompat.Logger,
+) context.Context {
+ conn, err := route.pgPool.GetConnection(ctx, tenantID)
+ if err != nil {
+ logger.WarnfCtx(ctx, "cross-module PG resolution failed: module=%s, tenantID=%s, error=%v",
+ route.module, tenantID, err)
+
+ return ContextWithCrossModuleError(ctx,
+ fmt.Errorf("cross-module PG resolution failed for module %s: %w", route.module, err))
+ }
+
+ db, err := conn.GetDB()
+ if err != nil {
+ logger.WarnfCtx(ctx, "cross-module PG GetDB failed: module=%s, tenantID=%s, error=%v",
+ route.module, tenantID, err)
+
+ return ContextWithCrossModuleError(ctx,
+ fmt.Errorf("cross-module PG GetDB failed for module %s: %w", route.module, err))
+ }
+
+ return core.ContextWithModulePGConnection(ctx, route.module, db)
+}
+
+// resolveMongoConnection resolves the MongoDB database for the given route
+// and tenant, injecting it into the context.
+func (m *MultiPoolMiddleware) resolveMongoConnection(
+ ctx context.Context,
+ route *PoolRoute,
+ tenantID string,
+ logger *logcompat.Logger,
+ span trace.Span,
+) (context.Context, error) {
+ mongoDB, err := route.mongoPool.GetDatabaseForTenant(ctx, tenantID)
+ if err != nil {
+ logger.ErrorCtx(ctx, fmt.Sprintf("failed to get tenant MongoDB connection: module=%s, tenantID=%s, error=%v", route.module, tenantID, err))
+ libOpentelemetry.HandleSpanError(span, "failed to get tenant MongoDB connection", err)
+
+ return ctx, fmt.Errorf("%w: %w", core.ErrConnectionFailed, err)
+ }
+
+ ctx = core.ContextWithTenantMongo(ctx, mongoDB)
+
+ return ctx, nil
+}
+
+// mapDefaultError delegates to the centralized mapDomainErrorToHTTP function
+// to ensure consistent error-to-HTTP mapping across all middleware types.
+func (m *MultiPoolMiddleware) mapDefaultError(c *fiber.Ctx, err error, tenantID string) error {
+ return mapDomainErrorToHTTP(c, err, tenantID)
+}
+
+// Enabled returns whether the middleware is enabled.
+// The middleware is enabled when at least one route has a multi-tenant PG or Mongo pool.
+func (m *MultiPoolMiddleware) Enabled() bool {
+ return m.enabled
+}
diff --git a/commons/tenant-manager/middleware/multi_pool_test.go b/commons/tenant-manager/middleware/multi_pool_test.go
new file mode 100644
index 00000000..0096a39d
--- /dev/null
+++ b/commons/tenant-manager/middleware/multi_pool_test.go
@@ -0,0 +1,1097 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+package middleware
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ tmmongo "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/mongo"
+ tmpostgres "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/postgres"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// newMultiPoolTestManagers creates postgres and mongo Managers backed by a test
+// client that has a non-nil client (so IsMultiTenant() returns true).
+func newMultiPoolTestManagers(t testing.TB, url string) (*tmpostgres.Manager, *tmmongo.Manager) {
+ t.Helper()
+ c, err := client.NewClient(url, nil, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+ return tmpostgres.NewManager(c, "ledger"), tmmongo.NewManager(c, "ledger")
+}
+
+// newSingleTenantManagers creates managers with a nil client (no tenant manager
+// configured), so IsMultiTenant() returns false.
+func newSingleTenantManagers() (*tmpostgres.Manager, *tmmongo.Manager) {
+ return tmpostgres.NewManager(nil, "ledger"), tmmongo.NewManager(nil, "ledger")
+}
+
+func TestNewMultiPoolMiddleware(t *testing.T) {
+ t.Parallel()
+
+ t.Run("creates disabled middleware when no options provided", func(t *testing.T) {
+ t.Parallel()
+
+ mid := NewMultiPoolMiddleware()
+
+ assert.NotNil(t, mid)
+ assert.False(t, mid.Enabled())
+ assert.Empty(t, mid.routes)
+ assert.Nil(t, mid.defaultRoute)
+ assert.Equal(t, []string{"/healthz", "/readyz", "/livez", "/health"}, mid.publicPaths)
+ assert.False(t, mid.crossModule)
+ assert.Nil(t, mid.errorMapper)
+ assert.Nil(t, mid.logger)
+ })
+
+ t.Run("creates enabled middleware when route has multi-tenant PG pool", func(t *testing.T) {
+ t.Parallel()
+
+ pgPool, mongoPool := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, mongoPool),
+ )
+
+ assert.NotNil(t, mid)
+ assert.True(t, mid.Enabled())
+ assert.Len(t, mid.routes, 1)
+ assert.Equal(t, "transaction", mid.routes[0].module)
+ assert.Equal(t, []string{"/v1/transactions"}, mid.routes[0].paths)
+ })
+
+ t.Run("creates enabled middleware when default route has multi-tenant PG pool", func(t *testing.T) {
+ t.Parallel()
+
+ pgPool, mongoPool := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithDefaultRoute("ledger", pgPool, mongoPool),
+ )
+
+ assert.NotNil(t, mid)
+ assert.True(t, mid.Enabled())
+ assert.NotNil(t, mid.defaultRoute)
+ assert.Equal(t, "ledger", mid.defaultRoute.module)
+ })
+
+ t.Run("creates disabled middleware when all pools are single-tenant", func(t *testing.T) {
+ t.Parallel()
+
+ pgPool, mongoPool := newSingleTenantManagers()
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, mongoPool),
+ WithDefaultRoute("ledger", pgPool, mongoPool),
+ )
+
+ assert.NotNil(t, mid)
+ assert.False(t, mid.Enabled())
+ })
+
+ t.Run("applies all options correctly", func(t *testing.T) {
+ t.Parallel()
+
+ pgPool, mongoPool := newMultiPoolTestManagers(t, "http://localhost:8080")
+ mapper := func(_ *fiber.Ctx, _ error, _ string) error { return nil }
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, mongoPool),
+ WithRoute([]string{"/v1/accounts"}, "account", pgPool, nil),
+ WithDefaultRoute("ledger", pgPool, mongoPool),
+ WithPublicPaths("/health", "/ready"),
+ WithCrossModuleInjection(),
+ WithErrorMapper(mapper),
+ )
+
+ assert.True(t, mid.Enabled())
+ assert.Len(t, mid.routes, 2)
+ assert.NotNil(t, mid.defaultRoute)
+ assert.Equal(t, []string{"/healthz", "/readyz", "/livez", "/health", "/health", "/ready"}, mid.publicPaths)
+ assert.True(t, mid.crossModule)
+ assert.NotNil(t, mid.errorMapper)
+ })
+}
+
+func TestMultiPoolMiddleware_matchRoute(t *testing.T) {
+ t.Parallel()
+
+ pgPool, mongoPool := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions", "/v1/tx"}, "transaction", pgPool, mongoPool),
+ WithRoute([]string{"/v1/accounts"}, "account", pgPool, nil),
+ WithDefaultRoute("ledger", pgPool, mongoPool),
+ )
+
+ tests := []struct {
+ name string
+ path string
+ expectedModule string
+ expectNil bool
+ }{
+ {
+ name: "matches first route by exact prefix",
+ path: "/v1/transactions/123",
+ expectedModule: "transaction",
+ },
+ {
+ name: "matches first route by alternative prefix",
+ path: "/v1/tx/456",
+ expectedModule: "transaction",
+ },
+ {
+ name: "matches second route",
+ path: "/v1/accounts/789",
+ expectedModule: "account",
+ },
+ {
+ name: "falls back to default route",
+ path: "/v1/unknown/path",
+ expectedModule: "ledger",
+ },
+ {
+ name: "matches root path to default",
+ path: "/",
+ expectedModule: "ledger",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ route := mid.matchRoute(tt.path)
+
+ if tt.expectNil {
+ assert.Nil(t, route)
+ } else {
+ require.NotNil(t, route)
+ assert.Equal(t, tt.expectedModule, route.module)
+ }
+ })
+ }
+}
+
+func TestMultiPoolMiddleware_matchRoute_NoDefault(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ )
+
+ route := mid.matchRoute("/v1/unknown")
+ assert.Nil(t, route)
+}
+
+func TestMultiPoolMiddleware_isPublicPath(t *testing.T) {
+ t.Parallel()
+
+ mid := NewMultiPoolMiddleware(
+ WithPublicPaths("/health", "/ready", "/version"),
+ )
+
+ tests := []struct {
+ name string
+ path string
+ expected bool
+ }{
+ {
+ name: "matches health endpoint",
+ path: "/health",
+ expected: true,
+ },
+ {
+ name: "matches ready endpoint",
+ path: "/ready",
+ expected: true,
+ },
+ {
+ name: "matches version endpoint",
+ path: "/version",
+ expected: true,
+ },
+ {
+ name: "matches health sub-path",
+ path: "/health/live",
+ expected: true,
+ },
+ {
+ name: "does not match non-public path",
+ path: "/v1/transactions",
+ expected: false,
+ },
+ {
+ name: "does not match partial prefix",
+ path: "/healthy",
+ expected: false, // boundary-aware: "/healthy" is not "/health" or "/health/..."
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ assert.Equal(t, tt.expected, mid.isPublicPath(tt.path))
+ })
+ }
+}
+
+func TestMultiPoolMiddleware_Enabled(t *testing.T) {
+ t.Parallel()
+
+ t.Run("returns false when no routes configured", func(t *testing.T) {
+ t.Parallel()
+
+ mid := NewMultiPoolMiddleware()
+ assert.False(t, mid.Enabled())
+ })
+
+ t.Run("returns true when route has multi-tenant pool", func(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/test"}, "test", pgPool, nil),
+ )
+
+ assert.True(t, mid.Enabled())
+ })
+
+ t.Run("returns false when route has single-tenant pool", func(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newSingleTenantManagers()
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/test"}, "test", pgPool, nil),
+ )
+
+ assert.False(t, mid.Enabled())
+ })
+
+ t.Run("returns true when route has multi-tenant Mongo pool only", func(t *testing.T) {
+ t.Parallel()
+
+ singlePG, _ := newSingleTenantManagers()
+ _, multiMongo := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/test"}, "test", singlePG, multiMongo),
+ )
+
+ assert.True(t, mid.Enabled())
+ })
+
+ t.Run("returns true when default route has multi-tenant Mongo pool only", func(t *testing.T) {
+ t.Parallel()
+
+ singlePG, _ := newSingleTenantManagers()
+ _, multiMongo := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithDefaultRoute("ledger", singlePG, multiMongo),
+ )
+
+ assert.True(t, mid.Enabled())
+ })
+
+ t.Run("returns true when route has nil PG pool and multi-tenant Mongo pool", func(t *testing.T) {
+ t.Parallel()
+
+ _, multiMongo := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/test"}, "test", nil, multiMongo),
+ )
+
+ assert.True(t, mid.Enabled())
+ })
+
+ t.Run("returns true when only default route is multi-tenant", func(t *testing.T) {
+ t.Parallel()
+
+ singlePG, _ := newSingleTenantManagers()
+ multiPG, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/test"}, "test", singlePG, nil),
+ WithDefaultRoute("ledger", multiPG, nil),
+ )
+
+ assert.True(t, mid.Enabled())
+ })
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_PublicPath(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ WithPublicPaths("/health", "/ready"),
+ )
+
+ nextCalled := false
+
+ app := fiber.New()
+ app.Use(mid.WithTenantDB)
+ app.Get("/health", func(c *fiber.Ctx) error {
+ nextCalled = true
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/health", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.True(t, nextCalled, "public path should bypass tenant resolution")
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_NoMatchingRoute(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ // No default route, so unmatched paths pass through
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ )
+
+ nextCalled := false
+
+ app := fiber.New()
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/unknown", func(c *fiber.Ctx) error {
+ nextCalled = true
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/unknown", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.True(t, nextCalled, "unmatched route should pass through")
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_SingleTenantBypass(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newSingleTenantManagers()
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ )
+
+ nextCalled := false
+
+ app := fiber.New()
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/transactions", func(c *fiber.Ctx) error {
+ nextCalled = true
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/transactions", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.True(t, nextCalled, "single-tenant pool should bypass tenant resolution")
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_MissingToken(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ )
+
+ app := fiber.New()
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/transactions", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/transactions", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Contains(t, string(body), "Unauthorized")
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_InvalidToken(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ )
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/transactions", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/transactions", nil)
+ req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Contains(t, string(body), "Unauthorized")
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_MissingTenantID(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ )
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ "email": "test@example.com",
+ })
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/transactions", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/transactions", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Contains(t, string(body), "Unauthorized")
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_ErrorMapperDelegation(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ customMapperCalled := false
+ customMapper := func(c *fiber.Ctx, _ error, _ string) error {
+ customMapperCalled = true
+
+ return c.Status(http.StatusTeapot).JSON(fiber.Map{
+ "code": "CUSTOM_ERROR",
+ "title": "Custom Error",
+ "message": "handled by custom mapper",
+ })
+ }
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ WithErrorMapper(customMapper),
+ )
+
+ app := fiber.New()
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/transactions", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ // No Authorization header -> triggers error -> should use custom mapper
+ req := httptest.NewRequest(http.MethodGet, "/v1/transactions", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.True(t, customMapperCalled, "custom error mapper should be called")
+ assert.Equal(t, http.StatusTeapot, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Contains(t, string(body), "CUSTOM_ERROR")
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_DefaultRouteMatching(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock Tenant Manager server that returns 404 to trigger an error
+ // response (proves the route was matched and tenant resolution attempted).
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ _, _ = w.Write([]byte(`{"error":"not found"}`))
+ }))
+ defer server.Close()
+
+ pgPool, _ := newMultiPoolTestManagers(t, server.URL)
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ WithDefaultRoute("ledger", pgPool, nil),
+ )
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ "tenantId": "tenant-abc",
+ })
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/unknown", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/unknown", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ // The request should NOT return 200 because the PG connection resolution
+ // will fail with the mock server (proving the default route was matched
+ // and multi-tenant resolution was attempted).
+ assert.NotEqual(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_PGFailureBlocksHandler(t *testing.T) {
+ t.Parallel()
+
+ // The middleware injects the tenant ID into context (step 6 in WithTenantDB)
+ // BEFORE attempting PG resolution (step 8). However, on PG resolution failure
+ // the middleware returns an error without calling c.Next(), so the downstream
+ // handler is never reached and cannot observe the injected tenant ID.
+ //
+ // This test validates the observable behavior: JWT parsing succeeds and the
+ // middleware reaches PG resolution (returning 503), but the handler is NOT
+ // called because the PG connection cannot be established without a real
+ // Tenant Manager backend.
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := &MultiPoolMiddleware{
+ routes: []*PoolRoute{
+ {
+ paths: []string{"/v1/test"},
+ module: "test",
+ pgPool: pgPool,
+ },
+ },
+ enabled: true,
+ }
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ "tenantId": "tenant-xyz",
+ })
+
+ handlerCalled := false
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/test", func(c *fiber.Ctx) error {
+ handlerCalled = true
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ // PG resolution fails (no real Tenant Manager), so the middleware returns
+ // a service-unavailable error and the handler is never reached.
+ assert.NotEqual(t, http.StatusOK, resp.StatusCode,
+ "expected non-200 because PG resolution fails without a real Tenant Manager")
+ assert.False(t, handlerCalled,
+ "handler should not be called when PG resolution fails")
+}
+
+func TestMultiPoolMiddleware_mapDefaultError(t *testing.T) {
+ t.Parallel()
+
+ mid := &MultiPoolMiddleware{}
+
+ tests := []struct {
+ name string
+ err error
+ tenantID string
+ expectedCode int
+ expectedBody string
+ }{
+ {
+ name: "tenant not found returns 404",
+ err: core.ErrTenantNotFound,
+ tenantID: "tenant-123",
+ expectedCode: http.StatusNotFound,
+ expectedBody: "TENANT_NOT_FOUND",
+ },
+ {
+ name: "tenant suspended returns 403",
+ err: &core.TenantSuspendedError{TenantID: "t1", Status: "suspended"},
+ tenantID: "t1",
+ expectedCode: http.StatusForbidden,
+ expectedBody: "Service Suspended",
+ },
+ {
+ name: "manager closed returns 503",
+ err: core.ErrManagerClosed,
+ tenantID: "t1",
+ expectedCode: http.StatusServiceUnavailable,
+ expectedBody: "SERVICE_UNAVAILABLE",
+ },
+ {
+ name: "service not configured returns 503",
+ err: core.ErrServiceNotConfigured,
+ tenantID: "t1",
+ expectedCode: http.StatusServiceUnavailable,
+ expectedBody: "SERVICE_UNAVAILABLE",
+ },
+ {
+ name: "connection error returns 503",
+ err: errors.New("connection refused"),
+ tenantID: "t1",
+ expectedCode: http.StatusInternalServerError,
+ expectedBody: "TENANT_DB_ERROR",
+ },
+ {
+ name: "generic error returns 500",
+ err: errors.New("something unexpected"),
+ tenantID: "t1",
+ expectedCode: http.StatusInternalServerError,
+ expectedBody: "TENANT_DB_ERROR",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return mid.mapDefaultError(c, tt.err, tt.tenantID)
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ assert.Equal(t, tt.expectedCode, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+
+ assert.Contains(t, string(body), tt.expectedBody)
+ })
+ }
+}
+
+func TestMultiPoolMiddleware_extractTenantID(t *testing.T) {
+ t.Parallel()
+
+ mid := &MultiPoolMiddleware{}
+
+ t.Run("returns error when no Authorization header", func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Get("/test", func(c *fiber.Ctx) error {
+ _, err := mid.extractTenantID(c)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "authorization token is required")
+
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+ })
+
+ t.Run("returns error when token is malformed", func(t *testing.T) {
+ t.Parallel()
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Get("/test", func(c *fiber.Ctx) error {
+ _, err := mid.extractTenantID(c)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid authorization token")
+
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer not-valid")
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+ })
+
+ t.Run("returns error when tenantId claim is missing", func(t *testing.T) {
+ t.Parallel()
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ })
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Get("/test", func(c *fiber.Ctx) error {
+ _, err := mid.extractTenantID(c)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "tenantId claim is required")
+
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+ })
+
+ t.Run("returns tenant ID from valid token", func(t *testing.T) {
+ t.Parallel()
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ "tenantId": "tenant-abc",
+ })
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Get("/test", func(c *fiber.Ctx) error {
+ tenantID, err := mid.extractTenantID(c)
+ assert.NoError(t, err)
+ assert.Equal(t, "tenant-abc", tenantID)
+
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+ })
+}
+
+func TestMultiPoolMiddleware_WithTenantDB_CrossModuleInjection(t *testing.T) {
+ t.Parallel()
+
+ // Create a mock Tenant Manager that returns 404 (so PG connection fails).
+ // We verify that even with crossModule enabled, the middleware attempts
+ // resolution for the matched route first. Since PG resolution fails,
+ // we get an error response (proving the route was matched and cross-module
+ // logic was reached or would be reached after primary resolution).
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ _, _ = w.Write([]byte(`{"error":"not found"}`))
+ }))
+ defer server.Close()
+
+ pgPoolA, _ := newMultiPoolTestManagers(t, server.URL)
+ pgPoolB, _ := newMultiPoolTestManagers(t, server.URL)
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPoolA, nil),
+ WithRoute([]string{"/v1/accounts"}, "account", pgPoolB, nil),
+ WithCrossModuleInjection(),
+ )
+
+ assert.True(t, mid.crossModule, "crossModule flag should be set")
+ assert.Len(t, mid.routes, 2)
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ "tenantId": "tenant-abc",
+ })
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(mid.WithTenantDB)
+ app.Get("/v1/transactions", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/v1/transactions", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+
+ defer resp.Body.Close()
+
+ // PG resolution will fail with the mock server, producing an error response.
+ // This confirms the middleware reached the PG resolution step, which happens
+ // before cross-module injection.
+ assert.NotEqual(t, http.StatusOK, resp.StatusCode)
+}
+
+func TestWithRoute(t *testing.T) {
+ t.Parallel()
+
+ pgPool, mongoPool := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := &MultiPoolMiddleware{}
+ opt := WithRoute([]string{"/v1/test", "/v1/test2"}, "test-module", pgPool, mongoPool)
+ opt(mid)
+
+ require.Len(t, mid.routes, 1)
+ assert.Equal(t, "test-module", mid.routes[0].module)
+ assert.Equal(t, []string{"/v1/test", "/v1/test2"}, mid.routes[0].paths)
+ assert.Equal(t, pgPool, mid.routes[0].pgPool)
+ assert.Equal(t, mongoPool, mid.routes[0].mongoPool)
+}
+
+func TestWithDefaultRoute(t *testing.T) {
+ t.Parallel()
+
+ pgPool, mongoPool := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := &MultiPoolMiddleware{}
+ opt := WithDefaultRoute("default-module", pgPool, mongoPool)
+ opt(mid)
+
+ require.NotNil(t, mid.defaultRoute)
+ assert.Equal(t, "default-module", mid.defaultRoute.module)
+ assert.Equal(t, pgPool, mid.defaultRoute.pgPool)
+ assert.Equal(t, mongoPool, mid.defaultRoute.mongoPool)
+ assert.Empty(t, mid.defaultRoute.paths)
+}
+
+func TestWithPublicPaths(t *testing.T) {
+ t.Parallel()
+
+ mid := &MultiPoolMiddleware{}
+ opt := WithPublicPaths("/health", "/ready")
+ opt(mid)
+
+ assert.Equal(t, []string{"/health", "/ready"}, mid.publicPaths)
+
+ // Applying again appends
+ opt2 := WithPublicPaths("/version")
+ opt2(mid)
+
+ assert.Equal(t, []string{"/health", "/ready", "/version"}, mid.publicPaths)
+}
+
+func TestWithCrossModuleInjection(t *testing.T) {
+ t.Parallel()
+
+ mid := &MultiPoolMiddleware{}
+ assert.False(t, mid.crossModule)
+
+ opt := WithCrossModuleInjection()
+ opt(mid)
+
+ assert.True(t, mid.crossModule)
+}
+
+func TestWithErrorMapper(t *testing.T) {
+ t.Parallel()
+
+ mapper := func(_ *fiber.Ctx, _ error, _ string) error { return nil }
+
+ mid := &MultiPoolMiddleware{}
+ assert.Nil(t, mid.errorMapper)
+
+ opt := WithErrorMapper(mapper)
+ opt(mid)
+
+ assert.NotNil(t, mid.errorMapper)
+}
+
+func TestWithMultiPoolLogger(t *testing.T) {
+ t.Parallel()
+
+ mid := &MultiPoolMiddleware{}
+ assert.Nil(t, mid.logger)
+
+ // We just verify the option sets the field. Using nil logger since we
+ // don't have a test logger implementation in scope.
+ opt := WithMultiPoolLogger(nil)
+ opt(mid)
+
+ assert.NotNil(t, mid.logger)
+}
+
+func TestMultiPoolMiddleware_DefaultHealthProbePaths(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ )
+
+ tests := []struct {
+ name string
+ path string
+ expectBypass bool
+ }{
+ {
+ name: "healthz bypasses auth",
+ path: "/healthz",
+ expectBypass: true,
+ },
+ {
+ name: "readyz bypasses auth",
+ path: "/readyz",
+ expectBypass: true,
+ },
+ {
+ name: "livez bypasses auth",
+ path: "/livez",
+ expectBypass: true,
+ },
+ {
+ name: "health bypasses auth",
+ path: "/health",
+ expectBypass: true,
+ },
+ {
+ name: "health sub-path bypasses auth",
+ path: "/health/live",
+ expectBypass: true,
+ },
+ {
+ name: "regular path requires auth",
+ path: "/v1/transactions",
+ expectBypass: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ nextCalled := false
+
+ app := fiber.New()
+ app.Use(mid.WithTenantDB)
+ app.Get(tt.path, func(c *fiber.Ctx) error {
+ nextCalled = true
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, tt.path, nil)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ if tt.expectBypass {
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.True(t, nextCalled, "health probe path should bypass auth")
+ } else {
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+ assert.False(t, nextCalled, "regular path should require auth")
+ }
+ })
+ }
+}
+
+func TestMultiPoolMiddleware_WithPublicPaths_AppendsToDefaults(t *testing.T) {
+ t.Parallel()
+
+ pgPool, _ := newMultiPoolTestManagers(t, "http://localhost:8080")
+
+ mid := NewMultiPoolMiddleware(
+ WithRoute([]string{"/v1/transactions"}, "transaction", pgPool, nil),
+ WithPublicPaths("/metrics", "/version"),
+ )
+
+ assert.Equal(t, []string{"/healthz", "/readyz", "/livez", "/health", "/metrics", "/version"}, mid.publicPaths)
+
+ nextCalled := false
+
+ app := fiber.New()
+ app.Use(mid.WithTenantDB)
+ app.Get("/metrics", func(c *fiber.Ctx) error {
+ nextCalled = true
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.True(t, nextCalled, "custom public path should bypass auth")
+}
diff --git a/commons/tenant-manager/middleware/tenant.go b/commons/tenant-manager/middleware/tenant.go
new file mode 100644
index 00000000..8a3e0999
--- /dev/null
+++ b/commons/tenant-manager/middleware/tenant.go
@@ -0,0 +1,298 @@
+package middleware
+
+import (
+ "context"
+ "errors"
+ "net/http"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ liblog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libHTTP "github.com/LerianStudio/lib-commons/v4/commons/net/http"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ tmmongo "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/mongo"
+ tmpostgres "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/postgres"
+ "github.com/gofiber/fiber/v2"
+ "github.com/golang-jwt/jwt/v5"
+)
+
+// TenantMiddleware extracts tenantId from JWT token and resolves the database connection.
+// It stores the connection in context for downstream handlers and repositories.
+// Supports PostgreSQL only, MongoDB only, or both databases.
+type TenantMiddleware struct {
+ postgres *tmpostgres.Manager // PostgreSQL manager (optional)
+ mongo *tmmongo.Manager // MongoDB manager (optional)
+ enabled bool
+}
+
+// TenantMiddlewareOption configures a TenantMiddleware.
+type TenantMiddlewareOption func(*TenantMiddleware)
+
+// WithPostgresManager sets the PostgreSQL manager for the tenant middleware.
+// When configured, the middleware will resolve PostgreSQL connections for tenants.
+func WithPostgresManager(postgres *tmpostgres.Manager) TenantMiddlewareOption {
+ return func(m *TenantMiddleware) {
+ m.postgres = postgres
+ m.enabled = m.postgres != nil || m.mongo != nil
+ }
+}
+
+// WithMongoManager sets the MongoDB manager for the tenant middleware.
+// When configured, the middleware will resolve MongoDB connections for tenants.
+func WithMongoManager(mongo *tmmongo.Manager) TenantMiddlewareOption {
+ return func(m *TenantMiddleware) {
+ m.mongo = mongo
+ m.enabled = m.postgres != nil || m.mongo != nil
+ }
+}
+
+// NewTenantMiddleware creates a new TenantMiddleware with the given options.
+// Use WithPostgresManager and/or WithMongoManager to configure which databases to use.
+// The middleware is enabled if at least one manager is configured.
+//
+// Usage examples:
+//
+// // PostgreSQL only
+// mid := middleware.NewTenantMiddleware(middleware.WithPostgresManager(pgManager))
+//
+// // MongoDB only
+// mid := middleware.NewTenantMiddleware(middleware.WithMongoManager(mongoManager))
+//
+// // Both PostgreSQL and MongoDB
+// mid := middleware.NewTenantMiddleware(
+// middleware.WithPostgresManager(pgManager),
+// middleware.WithMongoManager(mongoManager),
+// )
+func NewTenantMiddleware(opts ...TenantMiddlewareOption) *TenantMiddleware {
+ m := &TenantMiddleware{}
+
+ for _, opt := range opts {
+ opt(m)
+ }
+
+ // Enable if any manager is configured
+ m.enabled = m.postgres != nil || m.mongo != nil
+
+ return m
+}
+
+// WithTenantDB returns a Fiber handler that extracts tenant context and resolves DB connection.
+// It parses the JWT token to get tenantId and fetches the appropriate connection from Tenant Manager.
+// The connection is stored in the request context for use by repositories.
+//
+// Usage in routes.go:
+//
+// tenantMid := middleware.NewTenantMiddleware(middleware.WithPostgresManager(pgManager))
+// f.Use(tenantMid.WithTenantDB)
+func (m *TenantMiddleware) WithTenantDB(c *fiber.Ctx) error {
+ // If middleware is disabled, pass through
+ if !m.enabled {
+ return c.Next()
+ }
+
+ ctx := c.UserContext()
+
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "middleware.tenant.resolve_db")
+ defer span.End()
+
+ // Extract JWT token from Authorization header
+ accessToken := libHTTP.ExtractTokenFromHeader(c)
+ if accessToken == "" {
+ logger.ErrorCtx(ctx, "no authorization token - multi-tenant mode requires JWT with tenantId")
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "missing authorization token",
+ core.ErrAuthorizationTokenRequired)
+
+ return unauthorizedError(c, "MISSING_TOKEN", "Authorization token is required")
+ }
+
+ // Parse JWT token without signature verification.
+ // Token signature is validated by upstream auth middleware before this point.
+ token, _, err := new(jwt.Parser).ParseUnverified(accessToken, jwt.MapClaims{})
+ if err != nil {
+ logger.Base().Log(ctx, liblog.LevelError, "failed to parse JWT token", liblog.Err(err))
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "failed to parse token", err)
+
+ return unauthorizedError(c, "INVALID_TOKEN", "Failed to parse authorization token")
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ logger.ErrorCtx(ctx, "JWT claims are not in expected format")
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "invalid claims format",
+ core.ErrInvalidTenantClaims)
+
+ return unauthorizedError(c, "INVALID_TOKEN", "JWT claims are not in expected format")
+ }
+
+ // Extract tenantId from claims
+ tenantID, _ := claims["tenantId"].(string)
+ if tenantID == "" {
+ logger.ErrorCtx(ctx, "no tenantId in JWT - multi-tenant mode requires tenantId claim")
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "missing tenantId in JWT",
+ core.ErrMissingTenantIDClaim)
+
+ return unauthorizedError(c, "MISSING_TENANT", "tenantId is required in JWT token")
+ }
+
+ if !core.IsValidTenantID(tenantID) {
+ logger.Base().Log(ctx, liblog.LevelError, "invalid tenantId format in JWT",
+ liblog.String("tenant_id", tenantID))
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "invalid tenantId format",
+ core.ErrInvalidTenantClaims)
+
+ return unauthorizedError(c, "INVALID_TENANT", "tenantId has invalid format")
+ }
+
+ logger.Base().Log(ctx, liblog.LevelInfo, "tenant context resolved",
+ liblog.String("tenant_id", tenantID))
+
+ // Store tenant ID in context
+ ctx = core.ContextWithTenantID(ctx, tenantID)
+
+ // Handle PostgreSQL if manager is configured
+ if m.postgres != nil {
+ conn, err := m.postgres.GetConnection(ctx, tenantID)
+ if err != nil {
+ logger.Base().Log(ctx, liblog.LevelError, "failed to get tenant PostgreSQL connection", liblog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "failed to get tenant PostgreSQL connection", err)
+
+ return mapDomainErrorToHTTP(c, err, tenantID)
+ }
+
+ // Get the database connection from PostgresConnection
+ db, err := conn.GetDB()
+ if err != nil {
+ logger.Base().Log(ctx, liblog.LevelError, "failed to get database from PostgreSQL connection", liblog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "failed to get database from PostgreSQL connection", err)
+
+ return internalServerError(c, "TENANT_DB_ERROR", "Failed to get tenant database connection")
+ }
+
+ // Store PostgreSQL connection in context
+ ctx = core.ContextWithTenantPGConnection(ctx, db)
+ }
+
+ // Handle MongoDB if manager is configured
+ if m.mongo != nil {
+ mongoDB, err := m.mongo.GetDatabaseForTenant(ctx, tenantID)
+ if err != nil {
+ logger.Base().Log(ctx, liblog.LevelError, "failed to get tenant MongoDB connection", liblog.Err(err))
+ libOpentelemetry.HandleSpanError(span, "failed to get tenant MongoDB connection", err)
+
+ return mapDomainErrorToHTTP(c, err, tenantID)
+ }
+
+ ctx = core.ContextWithTenantMongo(ctx, mongoDB)
+ }
+
+ // Update Fiber context
+ c.SetUserContext(ctx)
+
+ return c.Next()
+}
+
+// mapDomainErrorToHTTP is a centralized error-to-HTTP mapping function shared by
+// both TenantMiddleware and MultiPoolMiddleware to ensure consistent status codes
+// for the same domain errors.
+func mapDomainErrorToHTTP(c *fiber.Ctx, err error, tenantID string) error {
+ // Missing token or JWT errors -> 401
+ if errors.Is(err, core.ErrAuthorizationTokenRequired) ||
+ errors.Is(err, core.ErrInvalidAuthorizationToken) ||
+ errors.Is(err, core.ErrInvalidTenantClaims) ||
+ errors.Is(err, core.ErrMissingTenantIDClaim) {
+ return unauthorizedError(c, "UNAUTHORIZED", "Unauthorized")
+ }
+
+ // Tenant not found -> 404
+ if errors.Is(err, core.ErrTenantNotFound) {
+ return c.Status(http.StatusNotFound).JSON(fiber.Map{
+ "code": "TENANT_NOT_FOUND",
+ "title": "Tenant Not Found",
+ "message": "tenant not found: " + tenantID,
+ })
+ }
+
+ // Tenant suspended/purged -> 403
+ var suspErr *core.TenantSuspendedError
+ if errors.As(err, &suspErr) {
+ return forbiddenError(c, "0131", "Service Suspended",
+ "tenant service is "+suspErr.Status)
+ }
+
+ // Generic access denied (403 without parsed status) -> 403
+ if errors.Is(err, core.ErrTenantServiceAccessDenied) {
+ return forbiddenError(c, "0131", "Access Denied",
+ "tenant service access denied")
+ }
+
+ // Manager closed or service not configured -> 503
+ if errors.Is(err, core.ErrManagerClosed) || errors.Is(err, core.ErrServiceNotConfigured) {
+ return c.Status(http.StatusServiceUnavailable).JSON(fiber.Map{
+ "code": "SERVICE_UNAVAILABLE",
+ "title": "Service Unavailable",
+ "message": "Service temporarily unavailable",
+ })
+ }
+
+ // Circuit breaker open -> 503
+ if errors.Is(err, core.ErrCircuitBreakerOpen) {
+ return c.Status(http.StatusServiceUnavailable).JSON(fiber.Map{
+ "code": "SERVICE_UNAVAILABLE",
+ "title": "Service Unavailable",
+ "message": "Service temporarily unavailable",
+ })
+ }
+
+ // Connection errors -> 503
+ if errors.Is(err, core.ErrConnectionFailed) {
+ return c.Status(http.StatusServiceUnavailable).JSON(fiber.Map{
+ "code": "SERVICE_UNAVAILABLE",
+ "title": "Service Unavailable",
+ "message": "Failed to resolve tenant database",
+ })
+ }
+
+ // Default -> 500
+ return internalServerError(c, "TENANT_DB_ERROR", "Failed to resolve tenant database")
+}
+
+// forbiddenError sends an HTTP 403 Forbidden response.
+// Used when the tenant-service association exists but is not active (suspended or purged).
+func forbiddenError(c *fiber.Ctx, code, title, message string) error {
+ return c.Status(http.StatusForbidden).JSON(fiber.Map{
+ "code": code,
+ "title": title,
+ "message": message,
+ })
+}
+
+// internalServerError sends an HTTP 500 Internal Server Error response.
+func internalServerError(c *fiber.Ctx, code, title string) error {
+ return c.Status(http.StatusInternalServerError).JSON(fiber.Map{
+ "code": code,
+ "title": title,
+ "message": "Internal server error",
+ })
+}
+
+// unauthorizedError sends an HTTP 401 Unauthorized response.
+func unauthorizedError(c *fiber.Ctx, code, message string) error {
+ return c.Status(http.StatusUnauthorized).JSON(fiber.Map{
+ "code": code,
+ "title": "Unauthorized",
+ "message": message,
+ })
+}
+
+// Enabled returns whether the middleware is enabled.
+func (m *TenantMiddleware) Enabled() bool {
+ return m.enabled
+}
diff --git a/commons/tenant-manager/middleware/tenant_test.go b/commons/tenant-manager/middleware/tenant_test.go
new file mode 100644
index 00000000..e4c62a7d
--- /dev/null
+++ b/commons/tenant-manager/middleware/tenant_test.go
@@ -0,0 +1,302 @@
+package middleware
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ tmmongo "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/mongo"
+ tmpostgres "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/postgres"
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// newTestManagers creates a postgres and mongo Manager backed by a test client.
+// Centralises the repeated client.NewClient + NewManager boilerplate so each
+// sub-test only declares what is unique to its scenario.
+func newTestManagers(t testing.TB) (*tmpostgres.Manager, *tmmongo.Manager) {
+ t.Helper()
+ c, err := client.NewClient("http://localhost:8080", nil, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+ return tmpostgres.NewManager(c, "ledger"), tmmongo.NewManager(c, "ledger")
+}
+
+func TestNewTenantMiddleware(t *testing.T) {
+ t.Run("creates disabled middleware when no managers are configured", func(t *testing.T) {
+ middleware := NewTenantMiddleware()
+
+ assert.NotNil(t, middleware)
+ assert.False(t, middleware.Enabled())
+ assert.Nil(t, middleware.postgres)
+ assert.Nil(t, middleware.mongo)
+ })
+
+ t.Run("creates enabled middleware with PostgreSQL only", func(t *testing.T) {
+ pgManager, _ := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(WithPostgresManager(pgManager))
+
+ assert.NotNil(t, middleware)
+ assert.True(t, middleware.Enabled())
+ assert.Equal(t, pgManager, middleware.postgres)
+ assert.Nil(t, middleware.mongo)
+ })
+
+ t.Run("creates enabled middleware with MongoDB only", func(t *testing.T) {
+ _, mongoManager := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(WithMongoManager(mongoManager))
+
+ assert.NotNil(t, middleware)
+ assert.True(t, middleware.Enabled())
+ assert.Nil(t, middleware.postgres)
+ assert.Equal(t, mongoManager, middleware.mongo)
+ })
+
+ t.Run("creates middleware with both PostgreSQL and MongoDB managers", func(t *testing.T) {
+ pgManager, mongoManager := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(
+ WithPostgresManager(pgManager),
+ WithMongoManager(mongoManager),
+ )
+
+ assert.NotNil(t, middleware)
+ assert.True(t, middleware.Enabled())
+ assert.Equal(t, pgManager, middleware.postgres)
+ assert.Equal(t, mongoManager, middleware.mongo)
+ })
+}
+
+func TestWithPostgresManager(t *testing.T) {
+ t.Run("sets postgres manager on middleware", func(t *testing.T) {
+ pgManager, _ := newTestManagers(t)
+
+ middleware := NewTenantMiddleware()
+ assert.Nil(t, middleware.postgres)
+ assert.False(t, middleware.Enabled())
+
+ // Apply option manually
+ opt := WithPostgresManager(pgManager)
+ opt(middleware)
+
+ assert.Equal(t, pgManager, middleware.postgres)
+ assert.True(t, middleware.Enabled())
+ })
+
+ t.Run("enables middleware when postgres manager is set", func(t *testing.T) {
+ pgManager, _ := newTestManagers(t)
+
+ middleware := &TenantMiddleware{}
+ assert.False(t, middleware.enabled)
+
+ opt := WithPostgresManager(pgManager)
+ opt(middleware)
+
+ assert.True(t, middleware.enabled)
+ })
+}
+
+func TestWithMongoManager(t *testing.T) {
+ t.Run("sets mongo manager on middleware", func(t *testing.T) {
+ _, mongoManager := newTestManagers(t)
+
+ middleware := NewTenantMiddleware()
+ assert.Nil(t, middleware.mongo)
+ assert.False(t, middleware.Enabled())
+
+ // Apply option manually
+ opt := WithMongoManager(mongoManager)
+ opt(middleware)
+
+ assert.Equal(t, mongoManager, middleware.mongo)
+ assert.True(t, middleware.Enabled())
+ })
+
+ t.Run("enables middleware when mongo manager is set", func(t *testing.T) {
+ _, mongoManager := newTestManagers(t)
+
+ middleware := &TenantMiddleware{}
+ assert.False(t, middleware.enabled)
+
+ opt := WithMongoManager(mongoManager)
+ opt(middleware)
+
+ assert.True(t, middleware.enabled)
+ })
+}
+
+func TestTenantMiddleware_Enabled(t *testing.T) {
+ t.Run("returns false when no managers are configured", func(t *testing.T) {
+ middleware := NewTenantMiddleware()
+ assert.False(t, middleware.Enabled())
+ })
+
+ t.Run("returns true when only PostgreSQL manager is set", func(t *testing.T) {
+ pgManager, _ := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(WithPostgresManager(pgManager))
+ assert.True(t, middleware.Enabled())
+ })
+
+ t.Run("returns true when only MongoDB manager is set", func(t *testing.T) {
+ _, mongoManager := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(WithMongoManager(mongoManager))
+ assert.True(t, middleware.Enabled())
+ })
+
+ t.Run("returns true when both managers are set", func(t *testing.T) {
+ pgManager, mongoManager := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(
+ WithPostgresManager(pgManager),
+ WithMongoManager(mongoManager),
+ )
+ assert.True(t, middleware.Enabled())
+ })
+}
+
+// buildTestJWT constructs a minimal unsigned JWT token string from the given claims.
+// The token is not cryptographically signed (signature is empty), which is acceptable
+// because the middleware uses ParseUnverified (lib-auth already validated the token).
+func buildTestJWT(t testing.TB, claims map[string]any) string {
+ t.Helper()
+ header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
+
+ payload, err := json.Marshal(claims)
+ require.NoError(t, err)
+ encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
+
+ return header + "." + encodedPayload + "."
+}
+
+// simulateAuthMiddleware returns a Fiber handler that sets c.Locals("user_id")
+// to simulate upstream lib-auth middleware having validated the request.
+// hasUpstreamAuthAssertion checks c.Locals("user_id"), not HTTP headers.
+func simulateAuthMiddleware(userID string) fiber.Handler {
+ return func(c *fiber.Ctx) error {
+ c.Locals("user_id", userID)
+ return c.Next()
+ }
+}
+
+func TestTenantMiddleware_WithTenantDB(t *testing.T) {
+ t.Run("no Authorization header returns 401", func(t *testing.T) {
+ pgManager, _ := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(WithPostgresManager(pgManager))
+
+ app := fiber.New()
+ app.Use(middleware.WithTenantDB)
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Contains(t, string(body), "Unauthorized")
+ })
+
+ t.Run("malformed JWT returns 401", func(t *testing.T) {
+ _, mongoManager := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(WithMongoManager(mongoManager))
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(middleware.WithTenantDB)
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer not-a-valid-jwt")
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Contains(t, string(body), "Unauthorized")
+ })
+
+ t.Run("valid JWT missing tenantId claim returns 401", func(t *testing.T) {
+ pgManager, _ := newTestManagers(t)
+
+ middleware := NewTenantMiddleware(WithPostgresManager(pgManager))
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ "email": "test@example.com",
+ })
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(middleware.WithTenantDB)
+ app.Get("/test", func(c *fiber.Ctx) error {
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
+
+ body, err := io.ReadAll(resp.Body)
+ require.NoError(t, err)
+ assert.Contains(t, string(body), "Unauthorized")
+ })
+
+ t.Run("valid JWT with tenantId calls next handler", func(t *testing.T) {
+ // Create an enabled middleware with no real managers configured.
+ // Both postgres and mongo pointers remain nil, so the middleware skips
+ // DB resolution and proceeds to c.Next() after JWT parsing.
+ middleware := &TenantMiddleware{enabled: true}
+
+ token := buildTestJWT(t, map[string]any{
+ "sub": "user-123",
+ "tenantId": "tenant-abc",
+ })
+
+ var capturedTenantID string
+ nextCalled := false
+
+ app := fiber.New()
+ app.Use(simulateAuthMiddleware("user-123"))
+ app.Use(middleware.WithTenantDB)
+ app.Get("/test", func(c *fiber.Ctx) error {
+ nextCalled = true
+ capturedTenantID = core.GetTenantIDFromContext(c.UserContext())
+ return c.SendString("ok")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ req.Header.Set("Authorization", "Bearer "+token)
+ resp, err := app.Test(req, -1)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ assert.Equal(t, http.StatusOK, resp.StatusCode)
+ assert.True(t, nextCalled, "next handler should have been called")
+ assert.Equal(t, "tenant-abc", capturedTenantID, "tenantId should be injected in context")
+ })
+}
diff --git a/commons/tenant-manager/mongo/manager.go b/commons/tenant-manager/mongo/manager.go
new file mode 100644
index 00000000..cc1aea8a
--- /dev/null
+++ b/commons/tenant-manager/mongo/manager.go
@@ -0,0 +1,1034 @@
+// Package mongo provides multi-tenant MongoDB connection management.
+// It fetches tenant-specific database credentials from Tenant Manager service
+// and manages connections per tenant using LRU eviction with idle timeout.
+package mongo
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "fmt"
+ "net/url"
+ "os"
+ "strconv"
+ "sync"
+ "time"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ mongolib "github.com/LerianStudio/lib-commons/v4/commons/mongo"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/eviction"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ "go.mongodb.org/mongo-driver/mongo"
+ "go.mongodb.org/mongo-driver/mongo/options"
+ "go.opentelemetry.io/otel/trace"
+)
+
+// mongoPingTimeout is the maximum duration for MongoDB connection health check pings.
+const mongoPingTimeout = 3 * time.Second
+
+// defaultSettingsCheckInterval is the default interval between periodic
+// connection pool settings revalidation checks. When a cached connection is
+// returned by GetConnection and this interval has elapsed since the last check,
+// fresh config is fetched from the Tenant Manager asynchronously.
+const defaultSettingsCheckInterval = 30 * time.Second
+
+// settingsRevalidationTimeout is the maximum duration for the HTTP call
+// to the Tenant Manager during async settings revalidation.
+const settingsRevalidationTimeout = 5 * time.Second
+
+// DefaultMaxConnections is the default max connections for MongoDB.
+const DefaultMaxConnections uint64 = 100
+
+// defaultIdleTimeout is the default duration before a tenant connection becomes
+// eligible for eviction. Connections accessed within this window are considered
+// active and will not be evicted, allowing the pool to grow beyond maxConnections.
+// Defined centrally in the eviction package; aliased here for local convenience.
+var defaultIdleTimeout = eviction.DefaultIdleTimeout
+
+// Stats contains statistics for the Manager.
+type Stats struct {
+ TotalConnections int `json:"totalConnections"`
+ MaxConnections int `json:"maxConnections"`
+ ActiveConnections int `json:"activeConnections"`
+ TenantIDs []string `json:"tenantIds"`
+ Closed bool `json:"closed"`
+}
+
+// Manager manages MongoDB connections per tenant.
+// Credentials are provided directly by the tenant-manager settings endpoint.
+// When maxConnections is set (> 0), the manager uses LRU eviction with an idle
+// timeout as a soft limit. Connections idle longer than the timeout are eligible
+// for eviction when the pool exceeds maxConnections. If all connections are active
+// (used within the idle timeout), the pool grows beyond the soft limit and
+// naturally shrinks back as tenants become idle.
+type Manager struct {
+ client *client.Client
+ service string
+ module string
+ logger *logcompat.Logger
+
+ mu sync.RWMutex
+ connections map[string]*MongoConnection
+ databaseNames map[string]string // tenantID -> database name (cached from createConnection)
+ closed bool
+ maxConnections int // soft limit for pool size (0 = unlimited)
+ idleTimeout time.Duration // how long before a connection is eligible for eviction
+ lastAccessed map[string]time.Time // LRU tracking per tenant
+
+ lastSettingsCheck map[string]time.Time // tracks per-tenant last settings revalidation time
+ settingsCheckInterval time.Duration // configurable interval between settings revalidation checks
+
+ // revalidateWG tracks in-flight revalidatePoolSettings goroutines so Close()
+ // can wait for them to finish before returning. Without this, goroutines
+ // spawned by GetConnection may access Manager state after Close() returns.
+ revalidateWG sync.WaitGroup
+}
+
+type MongoConnection struct {
+ // Adapter type used by tenant-manager package; keep fields aligned with
+ // tenant-manager migration contract and upstream lib-commons adapter semantics.
+ ConnectionStringSource string
+ Database string
+ Logger log.Logger
+ MaxPoolSize uint64
+ DB *mongo.Client
+
+ // tlsConfig, when non-nil, is applied to the mongo client options via
+ // SetTLSConfig. This is used when separate TLS certificate and key files
+ // are provided (tls.LoadX509KeyPair), since the MongoDB URI parameter
+ // tlsCertificateKeyFile only accepts a single combined PEM file.
+ tlsConfig *tls.Config
+
+ client *mongolib.Client
+}
+
+func (c *MongoConnection) Connect(ctx context.Context) error {
+ if c == nil {
+ return errors.New("mongo connection is nil")
+ }
+
+ // When a custom TLS config is required (e.g., separate cert+key files loaded
+ // via tls.LoadX509KeyPair), connect directly via the mongo driver so we can
+ // call SetTLSConfig on the client options. The mongolib.NewClient path does
+ // not expose this capability.
+ if c.tlsConfig != nil {
+ return c.connectWithTLS(ctx)
+ }
+
+ mongoTenantClient, err := mongolib.NewClient(ctx, mongolib.Config{
+ URI: c.ConnectionStringSource,
+ Database: c.Database,
+ MaxPoolSize: c.MaxPoolSize,
+ Logger: c.Logger,
+ })
+ if err != nil {
+ return err
+ }
+
+ mongoClient, err := mongoTenantClient.Client(ctx)
+ if err != nil {
+ return err
+ }
+
+ c.client = mongoTenantClient
+ c.DB = mongoClient
+
+ return nil
+}
+
+// connectWithTLS creates a MongoDB client using the raw driver, applying the
+// custom TLS configuration via SetTLSConfig. This path is used when separate
+// certificate and key files are provided (not a combined PEM).
+func (c *MongoConnection) connectWithTLS(ctx context.Context) error {
+ clientOptions := options.Client().ApplyURI(c.ConnectionStringSource)
+
+ if c.MaxPoolSize > 0 {
+ clientOptions.SetMaxPoolSize(c.MaxPoolSize)
+ }
+
+ clientOptions.SetTLSConfig(c.tlsConfig)
+
+ mongoClient, err := mongo.Connect(ctx, clientOptions)
+ if err != nil {
+ return fmt.Errorf("mongo connect with TLS failed: %w", err)
+ }
+
+ c.DB = mongoClient
+
+ return nil
+}
+
+// Option configures a Manager.
+type Option func(*Manager)
+
+// WithModule sets the module name for the MongoDB manager.
+func WithModule(module string) Option {
+ return func(p *Manager) {
+ p.module = module
+ }
+}
+
+// WithLogger sets the logger for the MongoDB manager.
+func WithLogger(logger log.Logger) Option {
+ return func(p *Manager) {
+ p.logger = logcompat.New(logger)
+ }
+}
+
+// WithMaxTenantPools sets the soft limit for the number of tenant connections in the pool.
+// When the pool reaches this limit and a new tenant needs a connection, only connections
+// that have been idle longer than the idle timeout are eligible for eviction. If all
+// connections are active (used within the idle timeout), the pool grows beyond this limit.
+// A value of 0 (default) means unlimited.
+func WithMaxTenantPools(maxSize int) Option {
+ return func(p *Manager) {
+ p.maxConnections = maxSize
+ }
+}
+
+// WithSettingsCheckInterval sets the interval between periodic connection pool settings
+// revalidation checks. When GetConnection returns a cached connection and this interval
+// has elapsed since the last check for that tenant, fresh config is fetched from the
+// Tenant Manager asynchronously. For MongoDB, the driver does not support runtime pool
+// resize, but revalidation detects suspended/purged tenants and evicts their connections.
+//
+// If d <= 0, revalidation is DISABLED (settingsCheckInterval is set to 0).
+// When disabled, no async revalidation checks are performed on cache hits.
+// Default: 30 seconds (defaultSettingsCheckInterval).
+func WithSettingsCheckInterval(d time.Duration) Option {
+ return func(p *Manager) {
+ p.settingsCheckInterval = max(d, 0)
+ }
+}
+
+// WithIdleTimeout sets the duration after which an unused tenant connection becomes
+// eligible for eviction. Only connections idle longer than this duration will be evicted
+// when the pool exceeds the soft limit (maxConnections). If all connections are active
+// (used within the idle timeout), the pool is allowed to grow beyond the soft limit.
+// Default: 5 minutes.
+func WithIdleTimeout(d time.Duration) Option {
+ return func(p *Manager) {
+ p.idleTimeout = d
+ }
+}
+
+// NewManager creates a new MongoDB connection manager.
+func NewManager(c *client.Client, service string, opts ...Option) *Manager {
+ p := &Manager{
+ client: c,
+ service: service,
+ logger: logcompat.New(nil),
+ connections: make(map[string]*MongoConnection),
+ databaseNames: make(map[string]string),
+ lastAccessed: make(map[string]time.Time),
+ lastSettingsCheck: make(map[string]time.Time),
+ settingsCheckInterval: defaultSettingsCheckInterval,
+ }
+
+ for _, opt := range opts {
+ opt(p)
+ }
+
+ return p
+}
+
+// GetConnection returns a MongoDB client for the tenant.
+// If a cached client fails a health check (e.g., due to credential rotation
+// after a tenant purge+re-associate), the stale client is evicted and a new
+// one is created with fresh credentials from the Tenant Manager.
+func (p *Manager) GetConnection(ctx context.Context, tenantID string) (*mongo.Client, error) { //nolint:gocognit // complexity from connection lifecycle (ping, revalidate, evict) is inherent
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if tenantID == "" {
+ return nil, errors.New("tenant ID is required")
+ }
+
+ p.mu.RLock()
+
+ if p.closed {
+ p.mu.RUnlock()
+ return nil, core.ErrManagerClosed
+ }
+
+ if conn, ok := p.connections[tenantID]; ok {
+ p.mu.RUnlock()
+
+ // Validate cached connection is still healthy (e.g., credentials may have changed).
+ // Ping is slow I/O, so we intentionally run it outside any lock.
+ if conn.DB != nil {
+ pingCtx, cancel := context.WithTimeout(ctx, mongoPingTimeout)
+ pingErr := conn.DB.Ping(pingCtx, nil)
+
+ cancel()
+
+ if pingErr != nil {
+ if p.logger != nil {
+ p.logger.WarnCtx(ctx, fmt.Sprintf("cached mongo connection unhealthy for tenant %s, reconnecting: %v", tenantID, pingErr))
+ }
+
+ if closeErr := p.CloseConnection(ctx, tenantID); closeErr != nil && p.logger != nil {
+ p.logger.WarnCtx(ctx, fmt.Sprintf("failed to close stale mongo connection for tenant %s: %v", tenantID, closeErr))
+ }
+
+ // Connection was unhealthy and has been evicted; create fresh.
+ return p.createConnection(ctx, tenantID)
+ }
+
+ // Ping succeeded. Re-acquire write lock to update LRU tracking,
+ // but re-check that the connection was not evicted while we were
+ // pinging (another goroutine may have called CloseConnection,
+ // Close, or evictLRU in the meantime).
+ now := time.Now()
+
+ p.mu.Lock()
+ if current, stillExists := p.connections[tenantID]; stillExists && current == conn {
+ p.lastAccessed[tenantID] = now
+
+ shouldRevalidate := p.client != nil && p.settingsCheckInterval > 0 && time.Since(p.lastSettingsCheck[tenantID]) > p.settingsCheckInterval
+ if shouldRevalidate {
+ p.lastSettingsCheck[tenantID] = now
+ p.revalidateWG.Add(1)
+ }
+
+ p.mu.Unlock()
+
+ if shouldRevalidate {
+ go func() { //#nosec G118 -- intentional: revalidatePoolSettings creates its own timeout context; must not use request-scoped context as this outlives the request
+ defer p.revalidateWG.Done()
+
+ p.revalidatePoolSettings(tenantID)
+ }()
+ }
+
+ return conn.DB, nil
+ }
+
+ p.mu.Unlock()
+
+ // Connection was evicted while we were pinging; fall through
+ // to createConnection which will fetch fresh credentials.
+ return p.createConnection(ctx, tenantID)
+ }
+
+ // conn.DB is nil -- cached entry is unusable, create a new connection.
+ return p.createConnection(ctx, tenantID)
+ }
+
+ p.mu.RUnlock()
+
+ return p.createConnection(ctx, tenantID)
+}
+
+// revalidatePoolSettings fetches fresh config from the Tenant Manager and detects
+// whether the tenant has been suspended or purged. For MongoDB, the driver does not
+// support changing pool size after client creation, so this method only checks for
+// tenant status changes and evicts the cached connection if the tenant is suspended.
+// This runs asynchronously (in a goroutine) and must never block GetConnection.
+// If the fetch fails, a warning is logged but the connection remains usable.
+func (p *Manager) revalidatePoolSettings(tenantID string) {
+ // Guard: recover from any panic to avoid crashing the process.
+ // This goroutine runs asynchronously and must never bring down the service.
+ defer func() {
+ if r := recover(); r != nil {
+ if p.logger != nil {
+ p.logger.Warnf("recovered from panic during settings revalidation for tenant %s: %v", tenantID, r)
+ }
+ }
+ }()
+
+ revalidateCtx, cancel := context.WithTimeout(context.Background(), settingsRevalidationTimeout)
+ defer cancel()
+
+ _, err := p.client.GetTenantConfig(revalidateCtx, tenantID, p.service, client.WithSkipCache())
+ if err != nil {
+ // If tenant service was suspended/purged, evict the cached connection immediately.
+ // The next request for this tenant will call createConnection, which fetches fresh
+ // config from the Tenant Manager and receives the 403 error directly.
+ if core.IsTenantSuspendedError(err) {
+ if p.logger != nil {
+ p.logger.Warnf("tenant %s service suspended, evicting cached connection", tenantID)
+ }
+
+ evictCtx, evictCancel := context.WithTimeout(context.Background(), settingsRevalidationTimeout)
+ defer evictCancel()
+
+ _ = p.CloseConnection(evictCtx, tenantID)
+
+ return
+ }
+
+ if p.logger != nil {
+ p.logger.Warnf("failed to revalidate connection settings for tenant %s: %v", tenantID, err)
+ }
+
+ return
+ }
+
+ p.ApplyConnectionSettings(tenantID, nil)
+}
+
+// createConnection fetches config from Tenant Manager and creates a MongoDB client.
+func (p *Manager) createConnection(ctx context.Context, tenantID string) (*mongo.Client, error) {
+ if p.client == nil {
+ return nil, errors.New("tenant manager client is required for multi-tenant connections")
+ }
+
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "mongo.create_connection")
+ defer span.End()
+
+ // Check for a cached connection under the write lock, but perform
+ // network I/O (Ping / Disconnect) outside the lock to avoid blocking
+ // other goroutines on slow network calls.
+ cachedConn, hasCached, err := p.snapshotCachedConnection(tenantID)
+ if err != nil {
+ return nil, err
+ }
+
+ if hasCached {
+ if reusedDB, reused := p.tryReuseCachedConnection(ctx, tenantID, cachedConn); reused {
+ return reusedDB, nil
+ }
+ }
+
+ return p.buildAndCacheNewConnection(ctx, tenantID, logger, span)
+}
+
+// snapshotCachedConnection reads the cached connection for tenantID under a
+// short lock and returns whether the manager is closed.
+func (p *Manager) snapshotCachedConnection(tenantID string) (*MongoConnection, bool, error) {
+ p.mu.Lock()
+ cachedConn, hasCached := p.connections[tenantID]
+ closed := p.closed
+ p.mu.Unlock()
+
+ if closed {
+ return nil, false, core.ErrManagerClosed
+ }
+
+ return cachedConn, hasCached, nil
+}
+
+// tryReuseCachedConnection validates a previously cached connection by pinging it.
+// If the connection is healthy and still in the cache, it updates the LRU timestamp
+// and returns it. If unhealthy or evicted, it cleans up and returns reused=false so
+// the caller falls through to create a new connection.
+func (p *Manager) tryReuseCachedConnection(
+ ctx context.Context,
+ tenantID string,
+ cachedConn *MongoConnection,
+) (*mongo.Client, bool) {
+ if cachedConn == nil || cachedConn.DB == nil {
+ p.removeStaleCacheEntry(tenantID, cachedConn)
+
+ return nil, false
+ }
+
+ pingCtx, cancel := context.WithTimeout(ctx, mongoPingTimeout)
+ pingErr := cachedConn.DB.Ping(pingCtx, nil)
+
+ cancel()
+
+ if pingErr == nil {
+ return p.reuseHealthyConnection(tenantID, cachedConn)
+ }
+
+ p.disconnectUnhealthyConnection(ctx, tenantID, cachedConn, pingErr)
+
+ return nil, false
+}
+
+// reuseHealthyConnection updates the LRU timestamp for a healthy cached connection.
+// Returns (client, true) if the entry still exists in the cache, or (nil, false) if
+// it was evicted while we were pinging.
+func (p *Manager) reuseHealthyConnection(tenantID string, cachedConn *MongoConnection) (*mongo.Client, bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if current, stillExists := p.connections[tenantID]; stillExists && current == cachedConn {
+ p.lastAccessed[tenantID] = time.Now()
+
+ return cachedConn.DB, true
+ }
+
+ return nil, false
+}
+
+// disconnectUnhealthyConnection disconnects a cached connection that failed its
+// health check and removes the stale cache entry.
+func (p *Manager) disconnectUnhealthyConnection(
+ ctx context.Context,
+ tenantID string,
+ cachedConn *MongoConnection,
+ pingErr error,
+) {
+ if p.logger != nil {
+ p.logger.WarnCtx(ctx, fmt.Sprintf("cached mongo connection unhealthy for tenant %s, reconnecting: %v", tenantID, pingErr))
+ }
+
+ discCtx, discCancel := context.WithTimeout(ctx, mongoPingTimeout)
+ if discErr := cachedConn.DB.Disconnect(discCtx); discErr != nil && p.logger != nil {
+ p.logger.WarnCtx(ctx, fmt.Sprintf("failed to disconnect unhealthy mongo connection for tenant %s: %v", tenantID, discErr))
+ }
+
+ discCancel()
+
+ p.removeStaleCacheEntry(tenantID, cachedConn)
+}
+
+// removeStaleCacheEntry removes a cache entry only if it still points to the
+// same connection reference (not replaced by another goroutine).
+func (p *Manager) removeStaleCacheEntry(tenantID string, cachedConn *MongoConnection) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if current, ok := p.connections[tenantID]; ok && current == cachedConn {
+ delete(p.connections, tenantID)
+ delete(p.databaseNames, tenantID)
+ delete(p.lastAccessed, tenantID)
+ delete(p.lastSettingsCheck, tenantID)
+ }
+}
+
+// buildAndCacheNewConnection fetches tenant config, builds a new MongoDB client,
+// and caches it.
+func (p *Manager) buildAndCacheNewConnection(
+ ctx context.Context,
+ tenantID string,
+ logger *logcompat.Logger,
+ span trace.Span,
+) (*mongo.Client, error) {
+ mongoConfig, err := p.getMongoConfigForTenant(ctx, tenantID, logger, span)
+ if err != nil {
+ return nil, err
+ }
+
+ uri, err := buildMongoURI(mongoConfig, logger)
+ if err != nil {
+ return nil, err
+ }
+
+ maxConnections := DefaultMaxConnections
+ if mongoConfig.MaxPoolSize > 0 {
+ maxConnections = mongoConfig.MaxPoolSize
+ }
+
+ conn := &MongoConnection{
+ ConnectionStringSource: uri,
+ Database: mongoConfig.Database,
+ Logger: p.logger.Base(),
+ MaxPoolSize: maxConnections,
+ }
+
+ // When separate TLS certificate and key files are provided, load the
+ // X.509 key pair and build a *tls.Config for the connection. The URI
+ // does not include tlsCertificateKeyFile in this case (see buildMongoQueryParams).
+ if hasSeparateCertAndKey(mongoConfig) {
+ tlsCfg, tlsErr := buildTLSConfigFromFiles(mongoConfig)
+ if tlsErr != nil {
+ logger.ErrorfCtx(ctx, "failed to build TLS config for tenant %s: %v", tenantID, tlsErr)
+ libOpentelemetry.HandleSpanError(span, "failed to build TLS config", tlsErr)
+
+ return nil, fmt.Errorf("failed to build TLS config: %w", tlsErr)
+ }
+
+ conn.tlsConfig = tlsCfg
+ }
+
+ if err := conn.Connect(ctx); err != nil {
+ logger.ErrorfCtx(ctx, "failed to connect to MongoDB for tenant %s: %v", tenantID, err)
+ libOpentelemetry.HandleSpanError(span, "failed to connect to MongoDB", err)
+
+ return nil, fmt.Errorf("failed to connect to MongoDB: %w", err)
+ }
+
+ logger.InfofCtx(ctx, "MongoDB connection created for tenant %s (database: %s)", tenantID, mongoConfig.Database)
+
+ return p.cacheConnection(ctx, tenantID, conn, mongoConfig.Database, logger.Base())
+}
+
+func (p *Manager) getMongoConfigForTenant(
+ ctx context.Context,
+ tenantID string,
+ logger *logcompat.Logger,
+ span trace.Span,
+) (*core.MongoDBConfig, error) {
+ config, err := p.client.GetTenantConfig(ctx, tenantID, p.service)
+ if err != nil {
+ var suspErr *core.TenantSuspendedError
+ if errors.As(err, &suspErr) {
+ logger.WarnfCtx(ctx, "tenant service is %s: tenantID=%s", suspErr.Status, tenantID)
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "tenant service suspended", err)
+
+ return nil, err
+ }
+
+ logger.ErrorfCtx(ctx, "failed to get tenant config: %v", err)
+ libOpentelemetry.HandleSpanError(span, "failed to get tenant config", err)
+
+ return nil, fmt.Errorf("failed to get tenant config: %w", err)
+ }
+
+ mongoConfig := config.GetMongoDBConfig(p.service, p.module)
+ if mongoConfig == nil {
+ logger.ErrorfCtx(ctx, "no MongoDB config for tenant %s service %s module %s", tenantID, p.service, p.module)
+
+ return nil, core.ErrServiceNotConfigured
+ }
+
+ return mongoConfig, nil
+}
+
+func (p *Manager) cacheConnection(
+ ctx context.Context,
+ tenantID string,
+ conn *MongoConnection,
+ databaseName string,
+ baseLogger log.Logger,
+) (*mongo.Client, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.closed {
+ if conn.DB != nil {
+ if discErr := conn.DB.Disconnect(ctx); discErr != nil && p.logger != nil {
+ p.logger.Base().Log(ctx, log.LevelWarn, "failed to disconnect mongo connection on closed manager",
+ log.String("tenant_id", tenantID),
+ log.Err(discErr),
+ )
+ }
+ }
+
+ return nil, core.ErrManagerClosed
+ }
+
+ if cached, ok := p.connections[tenantID]; ok && cached != nil && cached.DB != nil {
+ if conn.DB != nil {
+ if discErr := conn.DB.Disconnect(ctx); discErr != nil && p.logger != nil {
+ p.logger.Base().Log(ctx, log.LevelWarn, "failed to disconnect excess mongo connection",
+ log.String("tenant_id", tenantID),
+ log.Err(discErr),
+ )
+ }
+ }
+
+ p.lastAccessed[tenantID] = time.Now()
+
+ return cached.DB, nil
+ }
+
+ p.evictLRU(ctx, baseLogger)
+
+ p.connections[tenantID] = conn
+ p.databaseNames[tenantID] = databaseName
+ p.lastAccessed[tenantID] = time.Now()
+
+ return conn.DB, nil
+}
+
+// evictLRU removes the least recently used idle connection when the pool reaches the
+// soft limit. Only connections that have been idle longer than the idle timeout are
+// eligible for eviction. If all connections are active (used within the idle timeout),
+// the pool is allowed to grow beyond the soft limit.
+// Caller MUST hold p.mu write lock.
+func (p *Manager) evictLRU(ctx context.Context, logger log.Logger) {
+ candidateID, shouldEvict := eviction.FindLRUEvictionCandidate(
+ len(p.connections), p.maxConnections, p.lastAccessed, p.idleTimeout, logger,
+ )
+ if !shouldEvict {
+ return
+ }
+
+ // Manager-specific cleanup: disconnect the MongoDB client and remove from all maps.
+ if conn, ok := p.connections[candidateID]; ok {
+ if conn.DB != nil {
+ if discErr := conn.DB.Disconnect(ctx); discErr != nil {
+ if logger != nil {
+ logger.Log(ctx, log.LevelWarn,
+ "failed to disconnect evicted mongo connection",
+ log.String("tenant_id", candidateID),
+ log.String("error", discErr.Error()),
+ )
+ }
+ }
+ }
+
+ delete(p.connections, candidateID)
+ delete(p.databaseNames, candidateID)
+ delete(p.lastAccessed, candidateID)
+ delete(p.lastSettingsCheck, candidateID)
+ }
+}
+
+// ApplyConnectionSettings is a no-op for MongoDB. The MongoDB Go driver does not
+// support changing maxPoolSize after client creation. All MongoDB connections use
+// the global default pool size (DefaultMaxConnections or MongoDBConfig.MaxPoolSize).
+// Per-tenant pool sizing is only supported for PostgreSQL via SetMaxOpenConns.
+func (p *Manager) ApplyConnectionSettings(tenantID string, config *core.TenantConfig) {
+ // No-op: MongoDB driver does not support runtime pool resize.
+ // Pool size is determined at connection creation time and remains fixed.
+}
+
+// GetDatabase returns a MongoDB database for the tenant.
+func (p *Manager) GetDatabase(ctx context.Context, tenantID, database string) (*mongo.Database, error) {
+ mongoClient, err := p.GetConnection(ctx, tenantID)
+ if err != nil {
+ return nil, err
+ }
+
+ return mongoClient.Database(database), nil
+}
+
+// GetDatabaseForTenant returns the MongoDB database for a tenant by resolving
+// the database name from the cached mapping populated during createConnection.
+// This avoids a redundant HTTP call to the Tenant Manager since the database
+// name is already known from the initial connection setup.
+func (p *Manager) GetDatabaseForTenant(ctx context.Context, tenantID string) (*mongo.Database, error) {
+ if tenantID == "" {
+ return nil, errors.New("tenant ID is required")
+ }
+
+ // GetConnection handles config fetching and caches both the connection
+ // and the database name (in p.databaseNames).
+ mongoClient, err := p.GetConnection(ctx, tenantID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Look up the database name cached during createConnection.
+ p.mu.RLock()
+ dbName, ok := p.databaseNames[tenantID]
+ p.mu.RUnlock()
+
+ if ok {
+ return mongoClient.Database(dbName), nil
+ }
+
+ // Fallback: database name not cached (e.g., connection was pre-populated
+ // outside createConnection). Fetch config as a last resort.
+ if p.client == nil {
+ return nil, errors.New("tenant manager client is required for multi-tenant connections")
+ }
+
+ config, err := p.client.GetTenantConfig(ctx, tenantID, p.service)
+ if err != nil {
+ // Propagate TenantSuspendedError directly so the middleware can
+ // return a specific 403 response instead of a generic 503.
+ if core.IsTenantSuspendedError(err) {
+ return nil, err
+ }
+
+ return nil, fmt.Errorf("failed to get tenant config: %w", err)
+ }
+
+ mongoConfig := config.GetMongoDBConfig(p.service, p.module)
+ if mongoConfig == nil {
+ return nil, core.ErrServiceNotConfigured
+ }
+
+ // Cache for future calls
+ p.mu.Lock()
+ p.databaseNames[tenantID] = mongoConfig.Database
+ p.mu.Unlock()
+
+ return mongoClient.Database(mongoConfig.Database), nil
+}
+
+// Close closes all MongoDB connections.
+// It waits for any in-flight revalidatePoolSettings goroutines to finish
+// before returning, preventing goroutine leaks and use-after-close races.
+//
+// Uses snapshot-then-cleanup to avoid holding the mutex during network I/O
+// (Disconnect calls), which could block other goroutines on slow networks.
+func (p *Manager) Close(ctx context.Context) error {
+ // Phase 1: Under lock — mark closed, snapshot all connections, clear maps.
+ p.mu.Lock()
+ p.closed = true
+
+ snapshot := make([]*MongoConnection, 0, len(p.connections))
+ for _, conn := range p.connections {
+ snapshot = append(snapshot, conn)
+ }
+
+ // Clear all maps while still under lock.
+ clear(p.connections)
+ clear(p.databaseNames)
+ clear(p.lastAccessed)
+ clear(p.lastSettingsCheck)
+
+ p.mu.Unlock()
+
+ // Phase 2: Outside lock — disconnect each snapshotted connection.
+ var errs []error
+
+ for _, conn := range snapshot {
+ if conn.DB != nil {
+ if err := conn.DB.Disconnect(ctx); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+
+ // Phase 3: Wait for in-flight revalidatePoolSettings goroutines OUTSIDE the lock.
+ // revalidatePoolSettings acquires p.mu internally (via CloseConnection),
+ // so waiting with the lock held would deadlock.
+ p.revalidateWG.Wait()
+
+ return errors.Join(errs...)
+}
+
+// CloseConnection closes the MongoDB client for a specific tenant.
+//
+// Uses snapshot-then-cleanup to avoid holding the mutex during Disconnect,
+// which performs network I/O and could block other goroutines.
+func (p *Manager) CloseConnection(ctx context.Context, tenantID string) error {
+ // Step 1: Under lock — remove entry from maps, capture the connection.
+ p.mu.Lock()
+
+ conn, ok := p.connections[tenantID]
+ if !ok {
+ p.mu.Unlock()
+ return nil
+ }
+
+ delete(p.connections, tenantID)
+ delete(p.databaseNames, tenantID)
+ delete(p.lastAccessed, tenantID)
+ delete(p.lastSettingsCheck, tenantID)
+
+ p.mu.Unlock()
+
+ // Step 2: Outside lock — disconnect the captured connection.
+ if conn.DB != nil {
+ return conn.DB.Disconnect(ctx)
+ }
+
+ return nil
+}
+
+// Stats returns connection statistics.
+func (p *Manager) Stats() Stats {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ tenantIDs := make([]string, 0, len(p.connections))
+
+ activeCount := 0
+ now := time.Now()
+
+ idleTimeout := p.idleTimeout
+ if idleTimeout == 0 {
+ idleTimeout = defaultIdleTimeout
+ }
+
+ for id := range p.connections {
+ tenantIDs = append(tenantIDs, id)
+
+ if t, ok := p.lastAccessed[id]; ok && now.Sub(t) < idleTimeout {
+ activeCount++
+ }
+ }
+
+ return Stats{
+ TotalConnections: len(p.connections),
+ MaxConnections: p.maxConnections,
+ ActiveConnections: activeCount,
+ TenantIDs: tenantIDs,
+ Closed: p.closed,
+ }
+}
+
+// IsMultiTenant returns true if the manager is configured with a Tenant Manager client.
+func (p *Manager) IsMultiTenant() bool {
+ return p.client != nil
+}
+
+// buildMongoURI builds MongoDB connection URI from config.
+//
+// The function uses net/url.URL to construct the URI, which guarantees that
+// all components (credentials, host, database, query parameters) are properly
+// escaped according to RFC 3986. This prevents injection of URI control
+// characters through tenant-supplied configuration values.
+func buildMongoURI(cfg *core.MongoDBConfig, logger *logcompat.Logger) (string, error) {
+ if cfg.URI != "" {
+ return validateAndReturnRawURI(cfg.URI, logger)
+ }
+
+ if err := validateMongoHostPort(cfg); err != nil {
+ return "", err
+ }
+
+ u := buildMongoBaseURL(cfg)
+ query := buildMongoQueryParams(cfg)
+
+ if len(query) > 0 {
+ u.RawQuery = query.Encode()
+ }
+
+ return u.String(), nil
+}
+
+// validateAndReturnRawURI validates and returns a raw MongoDB URI when provided directly.
+func validateAndReturnRawURI(uri string, logger *logcompat.Logger) (string, error) {
+ parsed, err := url.Parse(uri)
+ if err != nil {
+ return "", fmt.Errorf("invalid mongo URI: %w", err)
+ }
+
+ if parsed.Scheme != "mongodb" && parsed.Scheme != "mongodb+srv" {
+ return "", fmt.Errorf("invalid mongo URI scheme %q", parsed.Scheme)
+ }
+
+ if logger != nil {
+ logger.Warn("using raw mongodb URI from tenant configuration")
+ }
+
+ return uri, nil
+}
+
+// validateMongoHostPort validates that host and port are present when no URI is provided.
+func validateMongoHostPort(cfg *core.MongoDBConfig) error {
+ if cfg.Host == "" {
+ return errors.New("mongo host is required when URI is not provided")
+ }
+
+ if cfg.Port == 0 {
+ return errors.New("mongo port is required when URI is not provided")
+ }
+
+ return nil
+}
+
+// buildMongoBaseURL constructs the base MongoDB URL with scheme, host, credentials, and database path.
+func buildMongoBaseURL(cfg *core.MongoDBConfig) *url.URL {
+ u := &url.URL{
+ Scheme: "mongodb",
+ Host: cfg.Host + ":" + strconv.Itoa(cfg.Port),
+ }
+
+ if cfg.Username != "" && cfg.Password != "" {
+ u.User = url.UserPassword(cfg.Username, cfg.Password)
+ }
+
+ if cfg.Database != "" {
+ u.Path = "/" + cfg.Database
+ u.RawPath = "/" + url.PathEscape(cfg.Database)
+ } else {
+ u.Path = "/"
+ }
+
+ return u
+}
+
+// hasSeparateCertAndKey returns true when TLS is enabled and the config provides
+// distinct certificate and key files (not a single combined PEM).
+func hasSeparateCertAndKey(cfg *core.MongoDBConfig) bool {
+ return cfg.TLS && cfg.TLSCertFile != "" && cfg.TLSKeyFile != "" && cfg.TLSCertFile != cfg.TLSKeyFile
+}
+
+// buildTLSConfigFromFiles creates a *tls.Config by loading the X.509 key pair
+// from separate certificate and private-key files. When a CA file is provided
+// it is added to the root CA pool. When TLSSkipVerify is true, both certificate
+// chain validation and hostname verification are skipped.
+func buildTLSConfigFromFiles(cfg *core.MongoDBConfig) (*tls.Config, error) {
+ cert, err := tls.LoadX509KeyPair(cfg.TLSCertFile, cfg.TLSKeyFile)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load TLS certificate key pair: %w", err)
+ }
+
+ tlsCfg := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ MinVersion: tls.VersionTLS12,
+ }
+
+ if cfg.TLSCAFile != "" {
+ caCert, readErr := os.ReadFile(cfg.TLSCAFile)
+ if readErr != nil {
+ return nil, fmt.Errorf("failed to read CA certificate file: %w", readErr)
+ }
+
+ caPool := x509.NewCertPool()
+ if !caPool.AppendCertsFromPEM(caCert) {
+ return nil, fmt.Errorf("failed to parse CA certificate from %s", cfg.TLSCAFile)
+ }
+
+ tlsCfg.RootCAs = caPool
+ }
+
+ if cfg.TLSSkipVerify {
+ tlsCfg.InsecureSkipVerify = true //#nosec G402 -- controlled by explicit config flag
+ }
+
+ return tlsCfg, nil
+}
+
+// buildMongoQueryParams builds the query parameters for the MongoDB URI.
+// Defaults authSource to "admin" when database and credentials are present
+// but no explicit authSource is configured, preserving backward compatibility
+// with deployments where users are created in the "admin" database.
+//
+// When TLS is enabled in the config, the corresponding query parameters are added:
+// - tls=true enables TLS on the connection
+// - tlsCAFile points to the CA certificate (only added when cert+key are NOT separate files)
+// - tlsCertificateKeyFile points to a combined PEM file (only when a single file is provided)
+// - tlsInsecure=true skips server certificate verification (not for production)
+//
+// When both TLSCertFile and TLSKeyFile are provided as distinct files, they are
+// NOT added to the URI; instead, buildTLSConfigFromFiles is used to load the
+// X.509 key pair and the resulting *tls.Config is applied via SetTLSConfig.
+func buildMongoQueryParams(cfg *core.MongoDBConfig) url.Values {
+ query := url.Values{}
+
+ if cfg.AuthSource != "" {
+ query.Set("authSource", cfg.AuthSource)
+ } else if cfg.Database != "" && cfg.Username != "" {
+ query.Set("authSource", "admin")
+ }
+
+ if cfg.DirectConnection {
+ query.Set("directConnection", "true")
+ }
+
+ if cfg.TLS {
+ query.Set("tls", "true")
+
+ // When separate cert+key files are provided, TLS configuration is
+ // handled via tls.LoadX509KeyPair + SetTLSConfig (not URI params).
+ // CA, client cert, and insecure settings are all set programmatically
+ // in that case, so we skip adding them to the URI entirely.
+ if hasSeparateCertAndKey(cfg) {
+ return query
+ }
+
+ if cfg.TLSCAFile != "" {
+ query.Set("tlsCAFile", cfg.TLSCAFile)
+ }
+
+ if cfg.TLSCertFile != "" || cfg.TLSKeyFile != "" {
+ // MongoDB driver uses a single PEM file containing both the client
+ // certificate and the private key via the tlsCertificateKeyFile option.
+ // When only one is provided, we use it directly since it may be a combined PEM.
+ certKeyFile := cfg.TLSCertFile
+ if certKeyFile == "" {
+ certKeyFile = cfg.TLSKeyFile
+ }
+
+ query.Set("tlsCertificateKeyFile", certKeyFile)
+ }
+
+ if cfg.TLSSkipVerify {
+ query.Set("tlsInsecure", "true")
+ }
+ }
+
+ return query
+}
diff --git a/commons/tenant-manager/mongo/manager_test.go b/commons/tenant-manager/mongo/manager_test.go
new file mode 100644
index 00000000..52b7520f
--- /dev/null
+++ b/commons/tenant-manager/mongo/manager_test.go
@@ -0,0 +1,1764 @@
+package mongo
+
+import (
+ "context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/binary"
+ "encoding/pem"
+ "fmt"
+ "io"
+ "math/big"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.mongodb.org/mongo-driver/bson"
+ "go.mongodb.org/mongo-driver/mongo"
+ "go.mongodb.org/mongo-driver/mongo/options"
+ "go.uber.org/goleak"
+)
+
+func TestMain(m *testing.M) {
+ goleak.VerifyTestMain(m,
+ goleak.IgnoreTopFunction("github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/cache.(*InMemoryCache).cleanupLoop"),
+ goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
+ goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"),
+ goleak.IgnoreTopFunction("net/http.(*persistConn).readLoop"),
+ )
+}
+
+func startFakeMongoServer(t *testing.T) (*mongo.Client, func()) {
+ t.Helper()
+
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ go func() {
+ for {
+ conn, acceptErr := ln.Accept()
+ if acceptErr != nil {
+ return
+ }
+
+ go serveFakeMongoConn(conn)
+ }
+ }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ addr := ln.Addr().String()
+
+ mongoClient, err := mongo.Connect(ctx, options.Client().
+ ApplyURI(fmt.Sprintf("mongodb://%s/?directConnection=true", addr)).
+ SetServerSelectionTimeout(2*time.Second))
+ require.NoError(t, err)
+
+ require.NoError(t, mongoClient.Ping(ctx, nil))
+
+ cleanup := func() {
+ _ = mongoClient.Disconnect(context.Background())
+ ln.Close()
+ }
+
+ return mongoClient, cleanup
+}
+
+func serveFakeMongoConn(conn net.Conn) {
+ defer conn.Close()
+
+ for {
+ header := make([]byte, 16)
+ if _, err := io.ReadFull(conn, header); err != nil {
+ return
+ }
+
+ msgLen := int(binary.LittleEndian.Uint32(header[0:4]))
+ reqID := binary.LittleEndian.Uint32(header[4:8])
+
+ body := make([]byte, msgLen-16)
+ if _, err := io.ReadFull(conn, body); err != nil {
+ return
+ }
+
+ resp := bson.D{
+ {Key: "ismaster", Value: true},
+ {Key: "ok", Value: 1.0},
+ {Key: "maxWireVersion", Value: int32(21)},
+ {Key: "minWireVersion", Value: int32(0)},
+ {Key: "maxBsonObjectSize", Value: int32(16777216)},
+ {Key: "maxMessageSizeBytes", Value: int32(48000000)},
+ {Key: "maxWriteBatchSize", Value: int32(100000)},
+ {Key: "localTime", Value: time.Now()},
+ {Key: "connectionId", Value: int32(1)},
+ }
+
+ respBytes, _ := bson.Marshal(resp)
+
+ var payload []byte
+ payload = append(payload, 0, 0, 0, 0)
+ payload = append(payload, 0)
+ payload = append(payload, respBytes...)
+
+ totalLen := uint32(16 + len(payload))
+ respHeader := make([]byte, 16)
+ binary.LittleEndian.PutUint32(respHeader[0:4], totalLen)
+ binary.LittleEndian.PutUint32(respHeader[4:8], reqID+1)
+ binary.LittleEndian.PutUint32(respHeader[8:12], reqID)
+ binary.LittleEndian.PutUint32(respHeader[12:16], 2013)
+
+ _, _ = conn.Write(respHeader)
+ _, _ = conn.Write(payload)
+ }
+}
+
+// mustNewTestClient creates a test client or fails the test immediately.
+// Centralises the repeated client.NewClient + error-check boilerplate.
+// Tests use httptest servers (http://), so WithAllowInsecureHTTP is applied.
+func mustNewTestClient(t testing.TB, baseURL string) *client.Client {
+ t.Helper()
+
+ c, err := client.NewClient(baseURL, testutil.NewMockLogger(), client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+
+ return c
+}
+
+func TestNewManager(t *testing.T) {
+ t.Run("creates manager with client and service", func(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger")
+
+ assert.NotNil(t, manager)
+ assert.Equal(t, "ledger", manager.service)
+ assert.NotNil(t, manager.connections)
+ })
+}
+
+func TestManager_GetConnection_NoTenantID(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger")
+
+ _, err := manager.GetConnection(context.Background(), "")
+
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "tenant ID is required")
+}
+
+func TestManager_GetConnection_ManagerClosed(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger")
+ require.NoError(t, manager.Close(context.Background()))
+
+ _, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ assert.ErrorIs(t, err, core.ErrManagerClosed)
+}
+
+func TestManager_GetDatabaseForTenant_NoTenantID(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger")
+
+ _, err := manager.GetDatabaseForTenant(context.Background(), "")
+
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "tenant ID is required")
+}
+
+func TestManager_GetConnection_NilDBCachedConnection(t *testing.T) {
+ t.Run("returns nil client when cached connection has nil DB", func(t *testing.T) {
+ manager := NewManager(nil, "ledger")
+
+ // Pre-populate cache with a connection that has nil DB
+ cachedConn := &MongoConnection{
+ DB: nil,
+ }
+ manager.connections["tenant-123"] = cachedConn
+
+ // Nil cached DB now triggers a reconnect path. With nil tenant-manager
+ // client configured, this should return a deterministic error instead of panic.
+ result, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "tenant manager client is required")
+ assert.Nil(t, result)
+ })
+}
+
+func TestManager_CloseConnection_EvictsFromCache(t *testing.T) {
+ t.Run("evicts connection from cache on close", func(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger")
+
+ // Pre-populate cache with a connection that has nil DB (to avoid disconnect errors)
+ cachedConn := &MongoConnection{
+ DB: nil,
+ }
+ manager.connections["tenant-123"] = cachedConn
+
+ err := manager.CloseConnection(context.Background(), "tenant-123")
+
+ assert.NoError(t, err)
+
+ manager.mu.RLock()
+ _, exists := manager.connections["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, exists, "connection should have been evicted from cache")
+ })
+}
+
+func TestManager_EvictLRU(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxConnections int
+ idleTimeout time.Duration
+ preloadCount int
+ oldTenantAge time.Duration
+ newTenantAge time.Duration
+ expectEviction bool
+ expectedPoolSize int
+ expectedEvictedID string
+ }{
+ {
+ name: "evicts oldest idle connection when pool is at soft limit",
+ maxConnections: 2,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: true,
+ expectedPoolSize: 1,
+ expectedEvictedID: "tenant-old",
+ },
+ {
+ name: "does not evict when pool is below soft limit",
+ maxConnections: 3,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 2,
+ },
+ {
+ name: "does not evict when maxConnections is zero (unlimited)",
+ maxConnections: 0,
+ preloadCount: 5,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 5,
+ },
+ {
+ name: "does not evict when all connections are active (within idle timeout)",
+ maxConnections: 2,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 2 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 2,
+ },
+ {
+ name: "respects custom idle timeout",
+ maxConnections: 2,
+ idleTimeout: 30 * time.Second,
+ preloadCount: 2,
+ oldTenantAge: 1 * time.Minute,
+ newTenantAge: 10 * time.Second,
+ expectEviction: true,
+ expectedPoolSize: 1,
+ expectedEvictedID: "tenant-old",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ opts := []Option{
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(tt.maxConnections),
+ }
+ if tt.idleTimeout > 0 {
+ opts = append(opts, WithIdleTimeout(tt.idleTimeout))
+ }
+
+ c := &client.Client{}
+ manager := NewManager(c, "ledger", opts...)
+
+ // Pre-populate pool with connections (nil DB to avoid real MongoDB)
+ if tt.preloadCount >= 1 {
+ manager.connections["tenant-old"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-old"] = time.Now().Add(-tt.oldTenantAge)
+ }
+
+ if tt.preloadCount >= 2 {
+ manager.connections["tenant-new"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-new"] = time.Now().Add(-tt.newTenantAge)
+ }
+
+ // For unlimited test, add more connections
+ for i := 2; i < tt.preloadCount; i++ {
+ id := "tenant-extra-" + time.Now().Add(time.Duration(i)*time.Second).Format("150405")
+ manager.connections[id] = &MongoConnection{DB: nil}
+ manager.lastAccessed[id] = time.Now().Add(-time.Duration(i) * time.Minute)
+ }
+
+ // Call evictLRU (caller must hold write lock)
+ manager.mu.Lock()
+ manager.evictLRU(context.Background(), testutil.NewMockLogger())
+ manager.mu.Unlock()
+
+ // Verify pool size
+ assert.Equal(t, tt.expectedPoolSize, len(manager.connections),
+ "pool size mismatch after eviction")
+
+ if tt.expectEviction {
+ // Verify the oldest tenant was evicted
+ _, exists := manager.connections[tt.expectedEvictedID]
+ assert.False(t, exists,
+ "expected tenant %s to be evicted from pool", tt.expectedEvictedID)
+
+ // Verify lastAccessed was also cleaned up
+ _, accessExists := manager.lastAccessed[tt.expectedEvictedID]
+ assert.False(t, accessExists,
+ "expected lastAccessed entry for %s to be removed", tt.expectedEvictedID)
+ }
+ })
+ }
+}
+
+func TestManager_PoolGrowsBeyondSoftLimit_WhenAllActive(t *testing.T) {
+ t.Parallel()
+
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(2),
+ WithIdleTimeout(5*time.Minute),
+ )
+
+ // Pre-populate with 2 connections, both accessed recently (within idle timeout)
+ for _, id := range []string{"tenant-1", "tenant-2"} {
+ manager.connections[id] = &MongoConnection{DB: nil}
+ manager.lastAccessed[id] = time.Now().Add(-1 * time.Minute)
+ }
+
+ // Try to evict - should not evict because all connections are active
+ manager.mu.Lock()
+ manager.evictLRU(context.Background(), testutil.NewMockLogger())
+ manager.mu.Unlock()
+
+ // Pool should remain at 2 (no eviction occurred)
+ assert.Equal(t, 2, len(manager.connections),
+ "pool should not shrink when all connections are active")
+
+ // Simulate adding a third connection (pool grows beyond soft limit)
+ manager.connections["tenant-3"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-3"] = time.Now()
+
+ assert.Equal(t, 3, len(manager.connections),
+ "pool should grow beyond soft limit when all connections are active")
+}
+
+func TestManager_WithIdleTimeout_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ idleTimeout time.Duration
+ expectedTimeout time.Duration
+ }{
+ {
+ name: "sets custom idle timeout",
+ idleTimeout: 10 * time.Minute,
+ expectedTimeout: 10 * time.Minute,
+ },
+ {
+ name: "sets short idle timeout",
+ idleTimeout: 30 * time.Second,
+ expectedTimeout: 30 * time.Second,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithIdleTimeout(tt.idleTimeout),
+ )
+
+ assert.Equal(t, tt.expectedTimeout, manager.idleTimeout)
+ })
+ }
+}
+
+func TestManager_LRU_LastAccessedUpdatedOnCacheHit(t *testing.T) {
+ t.Parallel()
+
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(5),
+ )
+
+ // Pre-populate cache with a connection that has nil DB.
+ cachedConn := &MongoConnection{DB: nil}
+
+ initialTime := time.Now().Add(-5 * time.Minute)
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = initialTime
+
+ // Accessing the connection now follows the reconnect path for nil DB.
+ result, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to get tenant config")
+ assert.Nil(t, result, "nil DB should return nil client")
+
+ // Verify lastAccessed entry was evicted because reconnect path removes stale cache entry.
+ manager.mu.RLock()
+ updatedTime, exists := manager.lastAccessed["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, exists, "lastAccessed entry should be removed on reconnect path")
+ assert.True(t, updatedTime.IsZero(),
+ "lastAccessed should be zero value when entry is removed: initial=%v, updated=%v",
+ initialTime, updatedTime)
+}
+
+func TestManager_CloseConnection_CleansUpLastAccessed(t *testing.T) {
+ t.Parallel()
+
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache with a connection that has nil DB
+ manager.connections["tenant-123"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-123"] = time.Now()
+
+ // Close the specific tenant client
+ err := manager.CloseConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+
+ manager.mu.RLock()
+ _, connExists := manager.connections["tenant-123"]
+ _, accessExists := manager.lastAccessed["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, connExists, "connection should be removed after CloseConnection")
+ assert.False(t, accessExists, "lastAccessed should be removed after CloseConnection")
+}
+
+func TestManager_WithMaxTenantPools_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxConnections int
+ expectedMax int
+ }{
+ {
+ name: "sets max connections via option",
+ maxConnections: 10,
+ expectedMax: 10,
+ },
+ {
+ name: "zero means unlimited",
+ maxConnections: 0,
+ expectedMax: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithMaxTenantPools(tt.maxConnections),
+ )
+
+ assert.Equal(t, tt.expectedMax, manager.maxConnections)
+ })
+ }
+}
+
+func TestManager_ApplyConnectionSettings(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ module string
+ config *core.TenantConfig
+ hasCachedConn bool
+ }{
+ {
+ name: "no-op with top-level connection settings and cached connection",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 30,
+ },
+ },
+ hasCachedConn: true,
+ },
+ {
+ name: "no-op with module-level connection settings and cached connection",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 50,
+ },
+ },
+ },
+ },
+ hasCachedConn: true,
+ },
+ {
+ name: "no-op with connection settings but no cached connection",
+ module: "onboarding",
+ config: &core.TenantConfig{ConnectionSettings: &core.ConnectionSettings{MaxOpenConns: 30}},
+ hasCachedConn: false,
+ },
+ {
+ name: "no-op with config that has no connection settings",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ MongoDB: &core.MongoDBConfig{Host: "localhost"},
+ },
+ },
+ },
+ hasCachedConn: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ logger := testutil.NewCapturingLogger()
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithModule(tt.module),
+ WithLogger(logger),
+ )
+
+ if tt.hasCachedConn {
+ manager.connections["tenant-123"] = &MongoConnection{DB: nil}
+ }
+
+ // ApplyConnectionSettings is a no-op for MongoDB.
+ // The MongoDB driver does not support runtime pool resize.
+ // Verify it does not panic and produces no log output.
+ manager.ApplyConnectionSettings("tenant-123", tt.config)
+
+ assert.Empty(t, logger.GetMessages(),
+ "ApplyConnectionSettings should be a no-op and produce no log output")
+ })
+ }
+}
+
+func TestBuildMongoURI(t *testing.T) {
+ t.Run("rejects empty host when URI not provided", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Port: 27017,
+ Database: "testdb",
+ }
+
+ _, err := buildMongoURI(cfg, nil)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "mongo host is required")
+ })
+
+ t.Run("rejects zero port when URI not provided", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Host: "localhost",
+ Database: "testdb",
+ }
+
+ _, err := buildMongoURI(cfg, nil)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "mongo port is required")
+ })
+
+ t.Run("rejects both empty host and zero port when URI not provided", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Database: "testdb",
+ }
+
+ _, err := buildMongoURI(cfg, nil)
+
+ require.Error(t, err)
+ // Host is checked first
+ assert.Contains(t, err.Error(), "mongo host is required")
+ })
+
+ t.Run("allows empty host and port when URI is provided", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ URI: "mongodb://custom-uri",
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://custom-uri", uri)
+ })
+
+ t.Run("returns URI when provided", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ URI: "mongodb://custom-uri",
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://custom-uri", uri)
+ })
+
+ t.Run("rejects unsupported URI scheme", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{URI: "http://example.com"}
+
+ _, err := buildMongoURI(cfg, nil)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid mongo URI scheme")
+ })
+
+ t.Run("builds URI with credentials", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ Username: "user",
+ Password: "pass",
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://user:pass@localhost:27017/testdb?authSource=admin", uri)
+ })
+
+ t.Run("builds URI without credentials", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+
+ require.NoError(t, err)
+ assert.Equal(t, "mongodb://localhost:27017/testdb", uri)
+ })
+
+ t.Run("defaults authSource to admin with database and credentials", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Host: "mongo.example.com",
+ Port: 27017,
+ Database: "tenantdb",
+ Username: "appuser",
+ Password: "secret",
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+
+ require.NoError(t, err)
+ assert.Contains(t, uri, "authSource=admin")
+ })
+
+ t.Run("explicit authSource overrides default", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Host: "mongo.example.com",
+ Port: 27017,
+ Database: "tenantdb",
+ Username: "appuser",
+ Password: "secret",
+ AuthSource: "customauth",
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+
+ require.NoError(t, err)
+ assert.Contains(t, uri, "authSource=customauth")
+ assert.NotContains(t, uri, "authSource=admin")
+ })
+
+ t.Run("no authSource without credentials", func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Host: "mongo.example.com",
+ Port: 27017,
+ Database: "tenantdb",
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+
+ require.NoError(t, err)
+ assert.NotContains(t, uri, "authSource")
+ })
+
+ t.Run("URL-encodes special characters in credentials", func(t *testing.T) {
+ tests := []struct {
+ name string
+ username string
+ password string
+ expectedUser string
+ expectedPassword string
+ }{
+ {
+ name: "at sign in password",
+ username: "admin",
+ password: "p@ss",
+ expectedUser: "admin",
+ expectedPassword: "p%40ss",
+ },
+ {
+ name: "colon in password",
+ username: "admin",
+ password: "p:ss",
+ expectedUser: "admin",
+ expectedPassword: "p%3Ass",
+ },
+ {
+ name: "slash in password",
+ username: "admin",
+ password: "p/ss",
+ expectedUser: "admin",
+ expectedPassword: "p%2Fss",
+ },
+ {
+ name: "special characters in both username and password",
+ username: "user@domain",
+ password: "p@ss:w/rd",
+ expectedUser: "user%40domain",
+ expectedPassword: "p%40ss%3Aw%2Frd",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ Username: tt.username,
+ Password: tt.password,
+ }
+
+ uri, err := buildMongoURI(cfg, nil)
+ require.NoError(t, err)
+
+ expectedURI := fmt.Sprintf("mongodb://%s:%s@localhost:27017/testdb?authSource=admin",
+ tt.expectedUser, tt.expectedPassword)
+ assert.Equal(t, expectedURI, uri)
+ assert.Contains(t, uri, tt.expectedUser)
+ assert.Contains(t, uri, tt.expectedPassword)
+ })
+ }
+ })
+}
+
+func TestManager_Stats(t *testing.T) {
+ t.Parallel()
+
+ t.Run("returns stats with no connections", func(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithMaxTenantPools(10),
+ )
+
+ stats := manager.Stats()
+
+ assert.Equal(t, 0, stats.TotalConnections)
+ assert.Equal(t, 10, stats.MaxConnections)
+ assert.Equal(t, 0, stats.ActiveConnections)
+ assert.Empty(t, stats.TenantIDs)
+ assert.False(t, stats.Closed)
+ })
+
+ t.Run("returns stats with active and idle connections", func(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger",
+ WithMaxTenantPools(10),
+ WithIdleTimeout(5*time.Minute),
+ )
+
+ // Add an active connection (accessed recently)
+ manager.connections["tenant-active"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-active"] = time.Now().Add(-1 * time.Minute)
+
+ // Add an idle connection (accessed long ago)
+ manager.connections["tenant-idle"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-idle"] = time.Now().Add(-10 * time.Minute)
+
+ stats := manager.Stats()
+
+ assert.Equal(t, 2, stats.TotalConnections)
+ assert.Equal(t, 10, stats.MaxConnections)
+ assert.Equal(t, 1, stats.ActiveConnections)
+ assert.Len(t, stats.TenantIDs, 2)
+ assert.False(t, stats.Closed)
+ })
+
+ t.Run("returns closed status after close", func(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger")
+
+ require.NoError(t, manager.Close(context.Background()))
+
+ stats := manager.Stats()
+
+ assert.True(t, stats.Closed)
+ assert.Equal(t, 0, stats.TotalConnections)
+ })
+}
+
+func TestBuildMongoURI_TLS(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *core.MongoDBConfig
+ contains []string
+ excludes []string
+ }{
+ {
+ name: "adds tls=true when TLS is enabled",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ },
+ contains: []string{"tls=true"},
+ },
+ {
+ name: "does not add tls param when TLS is disabled",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: false,
+ },
+ excludes: []string{"tls="},
+ },
+ {
+ name: "adds tlsCAFile when TLS is enabled with CA file",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ TLSCAFile: "/etc/ssl/ca.pem",
+ },
+ contains: []string{"tls=true", "tlsCAFile=%2Fetc%2Fssl%2Fca.pem"},
+ },
+ {
+ name: "adds tlsCertificateKeyFile when TLS is enabled with cert file",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ TLSCertFile: "/etc/ssl/client.pem",
+ },
+ contains: []string{"tls=true", "tlsCertificateKeyFile=%2Fetc%2Fssl%2Fclient.pem"},
+ },
+ {
+ name: "uses key file as tlsCertificateKeyFile when cert file is empty",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ TLSKeyFile: "/etc/ssl/client-key.pem",
+ },
+ contains: []string{"tls=true", "tlsCertificateKeyFile=%2Fetc%2Fssl%2Fclient-key.pem"},
+ },
+ {
+ name: "omits tlsCertificateKeyFile when cert and key are separate files",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ TLSCertFile: "/etc/ssl/client-cert.pem",
+ TLSKeyFile: "/etc/ssl/client-key.pem",
+ },
+ contains: []string{"tls=true"},
+ excludes: []string{"tlsCertificateKeyFile", "tlsCAFile", "tlsInsecure"},
+ },
+ {
+ name: "uses tlsCertificateKeyFile when cert and key point to the same combined PEM",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ TLSCertFile: "/etc/ssl/client-combined.pem",
+ TLSKeyFile: "/etc/ssl/client-combined.pem",
+ },
+ contains: []string{"tlsCertificateKeyFile=%2Fetc%2Fssl%2Fclient-combined.pem"},
+ },
+ {
+ name: "adds tlsInsecure when TLS skip verify is enabled",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ TLSSkipVerify: true,
+ },
+ contains: []string{"tls=true", "tlsInsecure=true"},
+ },
+ {
+ name: "does not add tlsInsecure when skip verify is false",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: true,
+ TLSSkipVerify: false,
+ },
+ contains: []string{"tls=true"},
+ excludes: []string{"tlsInsecure"},
+ },
+ {
+ name: "does not add TLS params when TLS is disabled even with files set",
+ cfg: &core.MongoDBConfig{
+ Host: "localhost",
+ Port: 27017,
+ Database: "testdb",
+ TLS: false,
+ TLSCAFile: "/etc/ssl/ca.pem",
+ TLSCertFile: "/etc/ssl/client.pem",
+ },
+ excludes: []string{"tls=", "tlsCAFile", "tlsCertificateKeyFile"},
+ },
+ {
+ name: "full TLS config with all options",
+ cfg: &core.MongoDBConfig{
+ Host: "mongo.prod.internal",
+ Port: 27017,
+ Database: "tenantdb",
+ Username: "appuser",
+ Password: "secret",
+ TLS: true,
+ TLSCAFile: "/etc/ssl/ca.pem",
+ TLSCertFile: "/etc/ssl/client.pem",
+ TLSSkipVerify: false,
+ },
+ contains: []string{
+ "tls=true",
+ "tlsCAFile=%2Fetc%2Fssl%2Fca.pem",
+ "tlsCertificateKeyFile=%2Fetc%2Fssl%2Fclient.pem",
+ "authSource=admin",
+ },
+ excludes: []string{"tlsInsecure"},
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ uri, err := buildMongoURI(tt.cfg, nil)
+
+ require.NoError(t, err)
+
+ for _, s := range tt.contains {
+ assert.Contains(t, uri, s, "URI should contain %q", s)
+ }
+
+ for _, s := range tt.excludes {
+ assert.NotContains(t, uri, s, "URI should NOT contain %q", s)
+ }
+ })
+ }
+}
+
+func TestManager_IsMultiTenant(t *testing.T) {
+ t.Parallel()
+
+ t.Run("returns true when client is configured", func(t *testing.T) {
+ c := &client.Client{}
+ manager := NewManager(c, "ledger")
+
+ assert.True(t, manager.IsMultiTenant())
+ })
+
+ t.Run("returns false when client is nil", func(t *testing.T) {
+ manager := NewManager(nil, "ledger")
+
+ assert.False(t, manager.IsMultiTenant())
+ })
+}
+
+func TestHasSeparateCertAndKey(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *core.MongoDBConfig
+ expected bool
+ }{
+ {
+ name: "true when TLS enabled with distinct cert and key files",
+ cfg: &core.MongoDBConfig{TLS: true, TLSCertFile: "/cert.pem", TLSKeyFile: "/key.pem"},
+ expected: true,
+ },
+ {
+ name: "false when TLS disabled",
+ cfg: &core.MongoDBConfig{TLS: false, TLSCertFile: "/cert.pem", TLSKeyFile: "/key.pem"},
+ expected: false,
+ },
+ {
+ name: "false when cert and key are the same file (combined PEM)",
+ cfg: &core.MongoDBConfig{TLS: true, TLSCertFile: "/combined.pem", TLSKeyFile: "/combined.pem"},
+ expected: false,
+ },
+ {
+ name: "false when only cert file is set",
+ cfg: &core.MongoDBConfig{TLS: true, TLSCertFile: "/cert.pem"},
+ expected: false,
+ },
+ {
+ name: "false when only key file is set",
+ cfg: &core.MongoDBConfig{TLS: true, TLSKeyFile: "/key.pem"},
+ expected: false,
+ },
+ {
+ name: "false when neither is set",
+ cfg: &core.MongoDBConfig{TLS: true},
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, tt.expected, hasSeparateCertAndKey(tt.cfg))
+ })
+ }
+}
+
+// generateTestCertAndKey creates a self-signed certificate and private key in
+// the given directory. Returns the paths to the cert and key files.
+func generateTestCertAndKey(t *testing.T, dir string) (certPath, keyPath string) {
+ t.Helper()
+
+ key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ require.NoError(t, err)
+
+ template := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{CommonName: "test"},
+ NotBefore: time.Now().Add(-1 * time.Hour),
+ NotAfter: time.Now().Add(24 * time.Hour),
+ }
+
+ certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
+ require.NoError(t, err)
+
+ certPath = filepath.Join(dir, "cert.pem")
+ certFile, err := os.Create(certPath)
+ require.NoError(t, err)
+ require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}))
+ require.NoError(t, certFile.Close())
+
+ keyDER, err := x509.MarshalECPrivateKey(key)
+ require.NoError(t, err)
+
+ keyPath = filepath.Join(dir, "key.pem")
+ keyFile, err := os.Create(keyPath)
+ require.NoError(t, err)
+ require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}))
+ require.NoError(t, keyFile.Close())
+
+ return certPath, keyPath
+}
+
+func TestBuildTLSConfigFromFiles(t *testing.T) {
+ t.Parallel()
+
+ t.Run("loads separate cert and key files successfully", func(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ certPath, keyPath := generateTestCertAndKey(t, dir)
+
+ cfg := &core.MongoDBConfig{
+ TLS: true,
+ TLSCertFile: certPath,
+ TLSKeyFile: keyPath,
+ }
+
+ tlsCfg, err := buildTLSConfigFromFiles(cfg)
+
+ require.NoError(t, err)
+ require.NotNil(t, tlsCfg)
+ assert.Len(t, tlsCfg.Certificates, 1, "should have loaded one certificate")
+ assert.Nil(t, tlsCfg.RootCAs, "should not have RootCAs when no CA file is set")
+ assert.False(t, tlsCfg.InsecureSkipVerify, "should not skip verify by default")
+ })
+
+ t.Run("loads CA file into RootCAs", func(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ certPath, keyPath := generateTestCertAndKey(t, dir)
+
+ // Write a self-signed CA cert (same cert, for test purposes)
+ caPath := filepath.Join(dir, "ca.pem")
+ certData, err := os.ReadFile(certPath)
+ require.NoError(t, err)
+ require.NoError(t, os.WriteFile(caPath, certData, 0o600))
+
+ cfg := &core.MongoDBConfig{
+ TLS: true,
+ TLSCertFile: certPath,
+ TLSKeyFile: keyPath,
+ TLSCAFile: caPath,
+ }
+
+ tlsCfg, err := buildTLSConfigFromFiles(cfg)
+
+ require.NoError(t, err)
+ require.NotNil(t, tlsCfg)
+ assert.NotNil(t, tlsCfg.RootCAs, "should have RootCAs when CA file is set")
+ })
+
+ t.Run("sets InsecureSkipVerify when TLSSkipVerify is true", func(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ certPath, keyPath := generateTestCertAndKey(t, dir)
+
+ cfg := &core.MongoDBConfig{
+ TLS: true,
+ TLSCertFile: certPath,
+ TLSKeyFile: keyPath,
+ TLSSkipVerify: true,
+ }
+
+ tlsCfg, err := buildTLSConfigFromFiles(cfg)
+
+ require.NoError(t, err)
+ assert.True(t, tlsCfg.InsecureSkipVerify, "should skip verify when TLSSkipVerify is true")
+ })
+
+ t.Run("returns error for invalid cert file", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := &core.MongoDBConfig{
+ TLS: true,
+ TLSCertFile: "/nonexistent/cert.pem",
+ TLSKeyFile: "/nonexistent/key.pem",
+ }
+
+ _, err := buildTLSConfigFromFiles(cfg)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to load TLS certificate key pair")
+ })
+
+ t.Run("returns error for invalid CA file", func(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ certPath, keyPath := generateTestCertAndKey(t, dir)
+
+ cfg := &core.MongoDBConfig{
+ TLS: true,
+ TLSCertFile: certPath,
+ TLSKeyFile: keyPath,
+ TLSCAFile: "/nonexistent/ca.pem",
+ }
+
+ _, err := buildTLSConfigFromFiles(cfg)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to read CA certificate file")
+ })
+
+ t.Run("returns error for unparseable CA PEM", func(t *testing.T) {
+ t.Parallel()
+
+ dir := t.TempDir()
+ certPath, keyPath := generateTestCertAndKey(t, dir)
+
+ badCAPath := filepath.Join(dir, "bad-ca.pem")
+ require.NoError(t, os.WriteFile(badCAPath, []byte("not a PEM"), 0o600))
+
+ cfg := &core.MongoDBConfig{
+ TLS: true,
+ TLSCertFile: certPath,
+ TLSKeyFile: keyPath,
+ TLSCAFile: badCAPath,
+ }
+
+ _, err := buildTLSConfigFromFiles(cfg)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse CA certificate")
+ })
+}
+
+func TestManager_WithSettingsCheckInterval_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ interval time.Duration
+ expectedInterval time.Duration
+ }{
+ {
+ name: "sets custom settings check interval",
+ interval: 1 * time.Minute,
+ expectedInterval: 1 * time.Minute,
+ },
+ {
+ name: "sets short settings check interval",
+ interval: 5 * time.Second,
+ expectedInterval: 5 * time.Second,
+ },
+ {
+ name: "disables revalidation with zero duration",
+ interval: 0,
+ expectedInterval: 0,
+ },
+ {
+ name: "disables revalidation with negative duration",
+ interval: -1 * time.Second,
+ expectedInterval: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithSettingsCheckInterval(tt.interval),
+ )
+
+ assert.Equal(t, tt.expectedInterval, manager.settingsCheckInterval)
+ })
+ }
+}
+
+func TestManager_DefaultSettingsCheckInterval(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ assert.Equal(t, defaultSettingsCheckInterval, manager.settingsCheckInterval,
+ "default settings check interval should be set from named constant")
+ assert.NotNil(t, manager.lastSettingsCheck,
+ "lastSettingsCheck map should be initialized")
+}
+
+func TestManager_GetConnection_RevalidatesSettingsAfterInterval(t *testing.T) {
+ t.Parallel()
+
+ var callCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-123",
+ "tenantSlug": "test-tenant",
+ "databases": {
+ "onboarding": {
+ "mongodb": {"host": "localhost", "port": 27017, "database": "testdb", "username": "user", "password": "pass"}
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ fakeDB, cleanupFake := startFakeMongoServer(t)
+ defer cleanupFake()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("onboarding"),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ cachedConn := &MongoConnection{DB: fakeDB}
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = time.Now()
+ manager.lastSettingsCheck["tenant-123"] = time.Now().Add(-1 * time.Hour)
+
+ db, err := manager.GetConnection(context.Background(), "tenant-123")
+ require.NoError(t, err)
+ assert.Equal(t, fakeDB, db)
+
+ assert.Eventually(t, func() bool {
+ return atomic.LoadInt32(&callCount) > 0
+ }, 500*time.Millisecond, 10*time.Millisecond, "should have fetched fresh config from Tenant Manager")
+
+ manager.mu.RLock()
+ lastCheck := manager.lastSettingsCheck["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, lastCheck.IsZero(), "lastSettingsCheck should have been updated")
+
+ manager.revalidateWG.Wait()
+ require.NoError(t, manager.Close(context.Background()))
+}
+
+func TestManager_GetConnection_DisabledRevalidation_WithZero(t *testing.T) {
+ t.Parallel()
+
+ var callCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-123",
+ "tenantSlug": "test-tenant",
+ "databases": {
+ "onboarding": {
+ "mongodb": {"host": "localhost", "port": 27017, "database": "testdb"}
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ fakeDB, cleanupFake := startFakeMongoServer(t)
+ defer cleanupFake()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("onboarding"),
+ WithSettingsCheckInterval(0),
+ )
+
+ assert.Equal(t, time.Duration(0), manager.settingsCheckInterval)
+
+ cachedConn := &MongoConnection{DB: fakeDB}
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = time.Now()
+ manager.lastSettingsCheck["tenant-123"] = time.Now().Add(-1 * time.Hour)
+
+ db, err := manager.GetConnection(context.Background(), "tenant-123")
+ require.NoError(t, err)
+ assert.Equal(t, fakeDB, db)
+
+ time.Sleep(50 * time.Millisecond)
+
+ assert.Equal(t, int32(0), atomic.LoadInt32(&callCount), "should NOT have fetched config - revalidation is disabled")
+
+ require.NoError(t, manager.Close(context.Background()))
+}
+
+func TestManager_GetConnection_DisabledRevalidation_WithNegative(t *testing.T) {
+ t.Parallel()
+
+ var callCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{}`))
+ }))
+ defer server.Close()
+
+ fakeDB, cleanupFake := startFakeMongoServer(t)
+ defer cleanupFake()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "payment",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("payment"),
+ WithSettingsCheckInterval(-5*time.Second),
+ )
+
+ assert.Equal(t, time.Duration(0), manager.settingsCheckInterval)
+
+ cachedConn := &MongoConnection{DB: fakeDB}
+ manager.connections["tenant-456"] = cachedConn
+ manager.lastAccessed["tenant-456"] = time.Now()
+ manager.lastSettingsCheck["tenant-456"] = time.Now().Add(-1 * time.Hour)
+
+ db, err := manager.GetConnection(context.Background(), "tenant-456")
+ require.NoError(t, err)
+ assert.Equal(t, fakeDB, db)
+
+ time.Sleep(50 * time.Millisecond)
+
+ assert.Equal(t, int32(0), atomic.LoadInt32(&callCount), "should NOT have fetched config - revalidation is disabled via negative interval")
+
+ require.NoError(t, manager.Close(context.Background()))
+}
+
+func TestManager_RevalidateSettings_EvictsSuspendedTenant(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ responseStatus int
+ responseBody string
+ expectEviction bool
+ expectLogSubstring string
+ }{
+ {
+ name: "evicts_cached_connection_when_tenant_is_suspended",
+ responseStatus: http.StatusForbidden,
+ responseBody: `{"code":"TS-SUSPENDED","error":"service suspended","status":"suspended"}`,
+ expectEviction: true,
+ expectLogSubstring: "tenant tenant-suspended service suspended, evicting cached connection",
+ },
+ {
+ name: "evicts_cached_connection_when_tenant_is_purged",
+ responseStatus: http.StatusForbidden,
+ responseBody: `{"code":"TS-SUSPENDED","error":"service purged","status":"purged"}`,
+ expectEviction: true,
+ expectLogSubstring: "tenant tenant-suspended service suspended, evicting cached connection",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(tt.responseStatus)
+ w.Write([]byte(tt.responseBody))
+ }))
+ defer server.Close()
+
+ capLogger := testutil.NewCapturingLogger()
+ tmClient, err := client.NewClient(server.URL, capLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(capLogger),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ // Pre-populate a cached connection for the tenant (nil DB to avoid real MongoDB)
+ manager.connections["tenant-suspended"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-suspended"] = time.Now()
+ manager.lastSettingsCheck["tenant-suspended"] = time.Now()
+
+ // Verify the connection exists before revalidation
+ statsBefore := manager.Stats()
+ assert.Equal(t, 1, statsBefore.TotalConnections,
+ "should have 1 connection before revalidation")
+
+ // Trigger revalidatePoolSettings directly
+ manager.revalidatePoolSettings("tenant-suspended")
+
+ if tt.expectEviction {
+ // Verify the connection was evicted
+ statsAfter := manager.Stats()
+ assert.Equal(t, 0, statsAfter.TotalConnections,
+ "connection should be evicted after suspended tenant detected")
+
+ // Verify lastAccessed and lastSettingsCheck were cleaned up
+ manager.mu.RLock()
+ _, accessExists := manager.lastAccessed["tenant-suspended"]
+ _, settingsExists := manager.lastSettingsCheck["tenant-suspended"]
+ manager.mu.RUnlock()
+
+ assert.False(t, accessExists,
+ "lastAccessed should be removed for evicted tenant")
+ assert.False(t, settingsExists,
+ "lastSettingsCheck should be removed for evicted tenant")
+ }
+
+ // Verify the appropriate log message was produced
+ assert.True(t, capLogger.ContainsSubstring(tt.expectLogSubstring),
+ "expected log message containing %q, got: %v",
+ tt.expectLogSubstring, capLogger.GetMessages())
+ })
+ }
+}
+
+func TestManager_RevalidateSettings_BypassesClientCache(t *testing.T) {
+ t.Parallel()
+
+ // This test verifies that revalidatePoolSettings uses WithSkipCache()
+ // to bypass the client's in-memory cache. Without it, a cached "active"
+ // response would hide a subsequent 403 (suspended/purged) from tenant-manager.
+ //
+ // Setup: The httptest server returns 200 (active) on the first request
+ // and 403 (suspended) on all subsequent requests. We first call
+ // GetTenantConfig directly to populate the client cache, then trigger
+ // revalidatePoolSettings. If WithSkipCache is working, the revalidation
+ // hits the server (gets 403) and evicts the connection. If the cache
+ // were used, it would return the stale 200 and the connection would
+ // remain.
+ var requestCount atomic.Int32
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ count := requestCount.Add(1)
+ w.Header().Set("Content-Type", "application/json")
+
+ if count == 1 {
+ // First request: return active config (populates client cache)
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-cache-test",
+ "tenantSlug": "cached-tenant",
+ "service": "ledger",
+ "status": "active",
+ "databases": {
+ "onboarding": {
+ "mongodb": {"host": "localhost", "port": 27017, "database": "testdb", "username": "user", "password": "pass"}
+ }
+ }
+ }`))
+
+ return
+ }
+
+ // Subsequent requests: return 403 (tenant suspended)
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte(`{"code":"TS-SUSPENDED","error":"service suspended","status":"suspended"}`))
+ }))
+ defer server.Close()
+
+ capLogger := testutil.NewCapturingLogger()
+ tmClient, err := client.NewClient(server.URL, capLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+
+ // Populate the client cache by calling GetTenantConfig directly
+ cfg, err := tmClient.GetTenantConfig(context.Background(), "tenant-cache-test", "ledger")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-cache-test", cfg.ID)
+ assert.Equal(t, int32(1), requestCount.Load(), "should have made exactly 1 HTTP request")
+
+ // Create a manager with a cached connection for this tenant
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(capLogger),
+ WithModule("onboarding"),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ // Pre-populate a cached connection (nil DB to avoid real MongoDB)
+ manager.connections["tenant-cache-test"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-cache-test"] = time.Now()
+ manager.lastSettingsCheck["tenant-cache-test"] = time.Now()
+
+ // Trigger revalidatePoolSettings -- should bypass cache and hit the server
+ manager.revalidatePoolSettings("tenant-cache-test")
+
+ // Verify a second HTTP request was made (cache was bypassed)
+ assert.Equal(t, int32(2), requestCount.Load(),
+ "revalidatePoolSettings should bypass client cache and make a fresh HTTP request")
+
+ // Verify the connection was evicted (server returned 403)
+ statsAfter := manager.Stats()
+ assert.Equal(t, 0, statsAfter.TotalConnections,
+ "connection should be evicted after revalidation detected suspended tenant via cache bypass")
+
+ // Verify lastAccessed and lastSettingsCheck were cleaned up
+ manager.mu.RLock()
+ _, accessExists := manager.lastAccessed["tenant-cache-test"]
+ _, settingsExists := manager.lastSettingsCheck["tenant-cache-test"]
+ manager.mu.RUnlock()
+
+ assert.False(t, accessExists, "lastAccessed should be removed for evicted tenant")
+ assert.False(t, settingsExists, "lastSettingsCheck should be removed for evicted tenant")
+}
+
+func TestManager_RevalidateSettings_FailedDoesNotBreakConnection(t *testing.T) {
+ t.Parallel()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer server.Close()
+
+ capLogger := testutil.NewCapturingLogger()
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(capLogger),
+ WithModule("onboarding"),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ // Pre-populate cache
+ manager.connections["tenant-123"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-123"] = time.Now()
+ manager.lastSettingsCheck["tenant-123"] = time.Now().Add(-1 * time.Hour)
+
+ // Trigger revalidation directly - should fail but not evict
+ manager.revalidatePoolSettings("tenant-123")
+
+ // Connection should still exist (not evicted on transient failure)
+ stats := manager.Stats()
+ assert.Equal(t, 1, stats.TotalConnections,
+ "connection should NOT be evicted after transient revalidation failure")
+
+ // Verify a warning was logged
+ assert.True(t, capLogger.ContainsSubstring("failed to revalidate connection settings"),
+ "should log a warning when revalidation fails")
+}
+
+func TestManager_RevalidateSettings_RecoverFromPanic(t *testing.T) {
+ t.Parallel()
+
+ capLogger := testutil.NewCapturingLogger()
+
+ // Create a manager with nil client to trigger a panic path
+ manager := &Manager{
+ logger: logcompat.New(capLogger),
+ connections: make(map[string]*MongoConnection),
+ databaseNames: make(map[string]string),
+ lastAccessed: make(map[string]time.Time),
+ lastSettingsCheck: make(map[string]time.Time),
+ settingsCheckInterval: 1 * time.Millisecond,
+ }
+
+ // Should not panic -- the recovery handler should catch it
+ assert.NotPanics(t, func() {
+ manager.revalidatePoolSettings("tenant-panic")
+ })
+}
+
+func TestManager_CloseConnection_CleansUpLastSettingsCheck(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache
+ manager.connections["tenant-123"] = &MongoConnection{DB: nil}
+ manager.lastAccessed["tenant-123"] = time.Now()
+ manager.lastSettingsCheck["tenant-123"] = time.Now()
+
+ err := manager.CloseConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+
+ manager.mu.RLock()
+ _, connExists := manager.connections["tenant-123"]
+ _, accessExists := manager.lastAccessed["tenant-123"]
+ _, settingsCheckExists := manager.lastSettingsCheck["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, connExists, "connection should be removed after CloseConnection")
+ assert.False(t, accessExists, "lastAccessed should be removed after CloseConnection")
+ assert.False(t, settingsCheckExists, "lastSettingsCheck should be removed after CloseConnection")
+}
+
+func TestManager_Close_CleansUpLastSettingsCheck(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache with multiple tenants
+ for _, id := range []string{"tenant-1", "tenant-2"} {
+ manager.connections[id] = &MongoConnection{DB: nil}
+ manager.lastAccessed[id] = time.Now()
+ manager.lastSettingsCheck[id] = time.Now()
+ }
+
+ err := manager.Close(context.Background())
+
+ require.NoError(t, err)
+
+ assert.Empty(t, manager.connections, "all connections should be removed after Close")
+ assert.Empty(t, manager.lastAccessed, "all lastAccessed should be removed after Close")
+ assert.Empty(t, manager.lastSettingsCheck, "all lastSettingsCheck should be removed after Close")
+}
+
+func TestManager_Close_WaitsForRevalidateSettings(t *testing.T) {
+ t.Parallel()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ time.Sleep(300 * time.Millisecond)
+
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-slow",
+ "tenantSlug": "slow-tenant",
+ "databases": {
+ "onboarding": {
+ "mongodb": {"host": "localhost", "port": 27017, "database": "testdb", "username": "user", "password": "pass"}
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ fakeDB, cleanupFake := startFakeMongoServer(t)
+ defer cleanupFake()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "test-service",
+ WithLogger(testutil.NewMockLogger()),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ cachedConn := &MongoConnection{DB: fakeDB}
+ manager.connections["tenant-slow"] = cachedConn
+ manager.lastAccessed["tenant-slow"] = time.Now()
+ manager.lastSettingsCheck["tenant-slow"] = time.Time{}
+
+ _, err := manager.GetConnection(context.Background(), "tenant-slow")
+ require.NoError(t, err)
+
+ err = manager.Close(context.Background())
+ require.NoError(t, err)
+
+ assert.True(t, manager.closed, "manager should be closed")
+ assert.Empty(t, manager.connections, "connections should be cleared after Close")
+}
diff --git a/commons/tenant-manager/postgres/goroutine_leak_test.go b/commons/tenant-manager/postgres/goroutine_leak_test.go
new file mode 100644
index 00000000..c1ed840e
--- /dev/null
+++ b/commons/tenant-manager/postgres/goroutine_leak_test.go
@@ -0,0 +1,104 @@
+package postgres
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/bxcodec/dbresolver/v2"
+ "go.uber.org/goleak"
+)
+
+// TestManager_Close_WaitsForRevalidateSettings proves that Close() waits for
+// active revalidatePoolSettings goroutines to finish before returning. Without the
+// WaitGroup fix, Close() would return immediately while the goroutine is still
+// running, causing a goroutine leak.
+func TestManager_Close_WaitsForRevalidateSettings(t *testing.T) {
+ logger := testutil.NewMockLogger()
+
+ // Create a slow HTTP server that simulates a Tenant Manager responding
+ // after a delay. The revalidatePoolSettings goroutine will block on this.
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ time.Sleep(500 * time.Millisecond)
+
+ config := core.TenantConfig{
+ ID: "tenant-slow",
+ TenantSlug: "slow-tenant",
+ Service: "test-service",
+ Status: "active",
+ IsolationMode: "database",
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Database: "test_db",
+ Username: "user",
+ Password: "pass",
+ SSLMode: "disable",
+ },
+ },
+ },
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 20,
+ MaxIdleConns: 5,
+ },
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+
+ if err := json.NewEncoder(w).Encode(config); err != nil {
+ http.Error(w, "encode error", http.StatusInternalServerError)
+ }
+ }))
+ defer server.Close()
+
+ tmClient, err := client.NewClient(server.URL, logger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ if err != nil {
+ t.Fatalf("NewClient() returned unexpected error: %v", err)
+ }
+
+ manager := NewManager(tmClient, "test-service",
+ WithLogger(logger),
+ WithSettingsCheckInterval(1*time.Millisecond), // Trigger revalidation immediately
+ )
+
+ // Pre-populate the connections map with a dummy connection so GetConnection
+ // returns from cache and triggers the revalidation goroutine.
+ dummyDB := &pingableDB{pingErr: nil}
+ var db dbresolver.DB = dummyDB
+
+ manager.connections["tenant-slow"] = &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.lastAccessed["tenant-slow"] = time.Now()
+ // Set lastSettingsCheck to zero time so revalidation is triggered immediately
+ manager.lastSettingsCheck["tenant-slow"] = time.Time{}
+
+ // GetConnection will hit cache, see that settingsCheckInterval has elapsed,
+ // and spawn a revalidatePoolSettings goroutine that blocks for 500ms on the server.
+ _, err = manager.GetConnection(context.Background(), "tenant-slow")
+ if err != nil {
+ t.Fatalf("GetConnection() returned unexpected error: %v", err)
+ }
+
+ // Close immediately — the revalidation goroutine is still blocked on the
+ // slow HTTP server. With the fix, Close() waits for it to finish.
+ if closeErr := manager.Close(context.Background()); closeErr != nil {
+ t.Fatalf("Close() returned unexpected error: %v", closeErr)
+ }
+
+ // If Close() properly waited, no goroutines should be leaked.
+ goleak.VerifyNone(t,
+ goleak.IgnoreTopFunction("github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/cache.(*InMemoryCache).cleanupLoop"),
+ goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
+ goleak.IgnoreTopFunction("net/http.(*persistConn).writeLoop"),
+ goleak.IgnoreTopFunction("net/http.(*persistConn).readLoop"),
+ )
+}
diff --git a/commons/tenant-manager/postgres/manager.go b/commons/tenant-manager/postgres/manager.go
new file mode 100644
index 00000000..6c3aca80
--- /dev/null
+++ b/commons/tenant-manager/postgres/manager.go
@@ -0,0 +1,1003 @@
+// Package postgres provides multi-tenant PostgreSQL connection management.
+// It fetches credentials from Tenant Manager and caches connections per tenant.
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/url"
+ "regexp"
+ "sync"
+ "time"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ libLog "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ libPostgres "github.com/LerianStudio/lib-commons/v4/commons/postgres"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/eviction"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ "github.com/bxcodec/dbresolver/v2"
+ _ "github.com/jackc/pgx/v5/stdlib"
+ "go.opentelemetry.io/otel/trace"
+)
+
+// pingTimeout is the maximum duration for connection health check pings.
+// Kept short to avoid blocking requests when a cached connection is stale.
+const pingTimeout = 3 * time.Second
+
+// defaultSettingsCheckInterval is the default interval between periodic
+// connection pool settings revalidation checks. When a cached connection is
+// returned by GetConnection and this interval has elapsed since the last check,
+// fresh config is fetched from the Tenant Manager asynchronously.
+const defaultSettingsCheckInterval = 30 * time.Second
+
+// settingsRevalidationTimeout is the maximum duration for the HTTP call
+// to the Tenant Manager during async settings revalidation.
+const settingsRevalidationTimeout = 5 * time.Second
+
+// IsolationMode constants define the tenant isolation strategies.
+const (
+ // IsolationModeIsolated indicates each tenant has a dedicated database.
+ IsolationModeIsolated = "isolated"
+ // IsolationModeSchema indicates tenants share a database but have separate schemas.
+ IsolationModeSchema = "schema"
+)
+
+// fallbackMaxOpenConns is the default maximum number of open connections per tenant
+// database pool. Used when per-tenant connectionSettings are absent from the Tenant
+// Manager /connections response (i.e., the tenant has no explicit pool configuration),
+// or when no Tenant Manager client is configured. Can be overridden per-manager via
+// WithMaxOpenConns.
+const fallbackMaxOpenConns = 25
+
+// fallbackMaxIdleConns is the default maximum number of idle connections per tenant
+// database pool. Used when per-tenant connectionSettings are absent from the Tenant
+// Manager /connections response, or when no Tenant Manager client is configured.
+// Can be overridden per-manager via WithMaxIdleConns.
+const fallbackMaxIdleConns = 5
+
+const defaultMaxAllowedOpenConns = 200
+
+const defaultMaxAllowedIdleConns = 50
+
+// defaultIdleTimeout is the default duration before a tenant connection becomes
+// eligible for eviction. Connections accessed within this window are considered
+// active and will not be evicted, allowing the pool to grow beyond maxConnections.
+// Defined centrally in the eviction package; aliased here for local convenience.
+var defaultIdleTimeout = eviction.DefaultIdleTimeout
+
+// Manager manages PostgreSQL database connections per tenant.
+// It fetches credentials from Tenant Manager and caches connections.
+// Credentials are provided directly by the tenant-manager settings endpoint.
+// When maxConnections is set (> 0), the manager uses LRU eviction with an idle
+// timeout as a soft limit. Connections idle longer than the timeout are eligible
+// for eviction when the pool exceeds maxConnections. If all connections are active
+// (used within the idle timeout), the pool grows beyond the soft limit and
+// naturally shrinks back as tenants become idle.
+type Manager struct {
+ client *client.Client
+ service string
+ module string
+ logger *logcompat.Logger
+
+ mu sync.RWMutex
+ connections map[string]*PostgresConnection
+ closed bool
+
+ maxOpenConns int
+ maxIdleConns int
+ maxAllowedOpenConns int
+ maxAllowedIdleConns int
+ maxConnections int // soft limit for pool size (0 = unlimited)
+ idleTimeout time.Duration // how long before a connection is eligible for eviction
+ lastAccessed map[string]time.Time // LRU tracking per tenant
+
+ lastSettingsCheck map[string]time.Time // tracks per-tenant last settings revalidation time
+ settingsCheckInterval time.Duration // configurable interval between settings revalidation checks
+
+ // revalidateWG tracks in-flight revalidatePoolSettings goroutines so Close()
+ // can wait for them to finish before returning. Without this, goroutines
+ // spawned by GetConnection may access Manager state after Close() returns.
+ revalidateWG sync.WaitGroup
+
+ defaultConn *PostgresConnection
+}
+
+type PostgresConnection struct {
+ // Adapter type used by tenant-manager package; keep fields aligned with
+ // tenant-manager migration contract and upstream lib-commons adapter semantics.
+ ConnectionStringPrimary string `json:"-"` // contains credentials, must not be serialized
+ ConnectionStringReplica string `json:"-"` // contains credentials, must not be serialized
+ PrimaryDBName string `json:"primaryDBName,omitempty"`
+ ReplicaDBName string `json:"replicaDBName,omitempty"`
+ MaxOpenConnections int `json:"maxOpenConnections,omitempty"`
+ MaxIdleConnections int `json:"maxIdleConnections,omitempty"`
+ SkipMigrations bool `json:"skipMigrations,omitempty"`
+ Logger libLog.Logger `json:"-"`
+ ConnectionDB *dbresolver.DB `json:"-"`
+
+ client *libPostgres.Client
+}
+
+func (c *PostgresConnection) Connect(ctx context.Context) error {
+ if c == nil {
+ return errors.New("postgres connection is nil")
+ }
+
+ pgClient, err := libPostgres.New(libPostgres.Config{
+ PrimaryDSN: c.ConnectionStringPrimary,
+ ReplicaDSN: c.ConnectionStringReplica,
+ Logger: c.Logger,
+ MaxOpenConnections: c.MaxOpenConnections,
+ MaxIdleConnections: c.MaxIdleConnections,
+ })
+ if err != nil {
+ return err
+ }
+
+ if err := pgClient.Connect(ctx); err != nil {
+ return err
+ }
+
+ resolver, err := pgClient.Resolver(ctx)
+ if err != nil {
+ return err
+ }
+
+ c.client = pgClient
+ c.ConnectionDB = &resolver
+
+ return nil
+}
+
+func (c *PostgresConnection) GetDB() (dbresolver.DB, error) {
+ if c == nil || c.ConnectionDB == nil {
+ return nil, errors.New("postgres resolver not initialized")
+ }
+
+ return *c.ConnectionDB, nil
+}
+
+// Stats contains statistics for the Manager.
+type Stats struct {
+ TotalConnections int `json:"totalConnections"`
+ ActiveConnections int `json:"activeConnections"`
+ MaxConnections int `json:"maxConnections"`
+ TenantIDs []string `json:"tenantIds"`
+ Closed bool `json:"closed"`
+}
+
+// Option configures a Manager.
+type Option func(*Manager)
+
+// WithLogger sets the logger for the Manager.
+func WithLogger(logger libLog.Logger) Option {
+ return func(p *Manager) {
+ p.logger = logcompat.New(logger)
+ }
+}
+
+// WithMaxOpenConns sets max open connections per tenant.
+func WithMaxOpenConns(n int) Option {
+ return func(p *Manager) {
+ p.maxOpenConns = n
+ }
+}
+
+// WithMaxIdleConns sets max idle connections per tenant.
+func WithMaxIdleConns(n int) Option {
+ return func(p *Manager) {
+ p.maxIdleConns = n
+ }
+}
+
+// WithConnectionLimitCaps sets hard maximums for per-tenant pool settings
+// received from Tenant Manager.
+func WithConnectionLimitCaps(maxOpen, maxIdle int) Option {
+ return func(p *Manager) {
+ if maxOpen > 0 {
+ p.maxAllowedOpenConns = maxOpen
+ }
+
+ if maxIdle > 0 {
+ p.maxAllowedIdleConns = maxIdle
+ }
+ }
+}
+
+// WithModule sets the module name for the Manager (e.g., "onboarding", "transaction").
+func WithModule(module string) Option {
+ return func(p *Manager) {
+ p.module = module
+ }
+}
+
+// WithMaxTenantPools sets the soft limit for the number of tenant connections in the pool.
+// When the pool reaches this limit and a new tenant needs a connection, only connections
+// that have been idle longer than the idle timeout are eligible for eviction. If all
+// connections are active (used within the idle timeout), the pool grows beyond this limit.
+// A value of 0 (default) means unlimited.
+func WithMaxTenantPools(maxSize int) Option {
+ return func(p *Manager) {
+ p.maxConnections = maxSize
+ }
+}
+
+// WithSettingsCheckInterval sets the interval between periodic connection pool settings
+// revalidation checks. When GetConnection returns a cached connection and this interval
+// has elapsed since the last check for that tenant, fresh config is fetched from the
+// Tenant Manager asynchronously and pool settings are updated without recreating the connection.
+//
+// If d <= 0, revalidation is DISABLED (settingsCheckInterval is set to 0).
+// When disabled, no async revalidation checks are performed on cache hits.
+// Default: 30 seconds (defaultSettingsCheckInterval).
+func WithSettingsCheckInterval(d time.Duration) Option {
+ return func(p *Manager) {
+ p.settingsCheckInterval = max(d, 0)
+ }
+}
+
+// WithIdleTimeout sets the duration after which an unused tenant connection becomes
+// eligible for eviction. Only connections idle longer than this duration will be
+// evicted when the pool exceeds the soft limit (maxConnections). If all connections
+// are active (used within the idle timeout), the pool is allowed to grow beyond the
+// soft limit and naturally shrinks back as tenants become idle.
+// Default: 5 minutes.
+func WithIdleTimeout(d time.Duration) Option {
+ return func(p *Manager) {
+ p.idleTimeout = d
+ }
+}
+
+// NewManager creates a new PostgreSQL connection manager.
+func NewManager(c *client.Client, service string, opts ...Option) *Manager {
+ p := &Manager{
+ client: c,
+ service: service,
+ logger: logcompat.New(nil),
+ connections: make(map[string]*PostgresConnection),
+ lastAccessed: make(map[string]time.Time),
+ lastSettingsCheck: make(map[string]time.Time),
+ settingsCheckInterval: defaultSettingsCheckInterval,
+ maxOpenConns: fallbackMaxOpenConns,
+ maxIdleConns: fallbackMaxIdleConns,
+ maxAllowedOpenConns: defaultMaxAllowedOpenConns,
+ maxAllowedIdleConns: defaultMaxAllowedIdleConns,
+ }
+
+ for _, opt := range opts {
+ opt(p)
+ }
+
+ return p
+}
+
+// GetConnection returns a database connection for the tenant.
+// Creates a new connection if one doesn't exist.
+// If a cached connection fails a health check (e.g., due to credential rotation
+// after a tenant purge+re-associate), the stale connection is evicted and a new
+// one is created with fresh credentials from the Tenant Manager.
+func (p *Manager) GetConnection(ctx context.Context, tenantID string) (*PostgresConnection, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if tenantID == "" {
+ return nil, errors.New("tenant ID is required")
+ }
+
+ p.mu.RLock()
+
+ if p.closed {
+ p.mu.RUnlock()
+ return nil, core.ErrManagerClosed
+ }
+
+ if conn, ok := p.connections[tenantID]; ok {
+ p.mu.RUnlock()
+
+ // Validate cached connection is still healthy (e.g., credentials may have changed)
+ if conn.ConnectionDB != nil {
+ pingCtx, cancel := context.WithTimeout(ctx, pingTimeout)
+
+ pingErr := (*conn.ConnectionDB).PingContext(pingCtx)
+
+ cancel() // Release timer immediately; we no longer need the ping context.
+
+ if pingErr != nil {
+ if p.logger != nil {
+ p.logger.WarnCtx(ctx, fmt.Sprintf("cached postgres connection unhealthy for tenant %s, reconnecting: %v", tenantID, pingErr))
+ }
+
+ if closeErr := p.CloseConnection(ctx, tenantID); closeErr != nil && p.logger != nil {
+ p.logger.WarnCtx(ctx, fmt.Sprintf("failed to close stale postgres connection for tenant %s: %v", tenantID, closeErr))
+ }
+
+ // Fall through to create a new connection with fresh credentials
+ return p.createConnection(ctx, tenantID)
+ }
+ }
+
+ // Update LRU tracking on cache hit and check if settings revalidation is due
+ now := time.Now()
+
+ p.mu.Lock()
+
+ // TOCTOU re-check: connection may have been evicted while we were pinging.
+ if _, stillExists := p.connections[tenantID]; !stillExists {
+ p.mu.Unlock()
+ // Connection was evicted while we were pinging; create fresh.
+ return p.createConnection(ctx, tenantID)
+ }
+
+ p.lastAccessed[tenantID] = now
+
+ // Only revalidate if settingsCheckInterval > 0 (means revalidation is enabled)
+ shouldRevalidate := p.client != nil && p.settingsCheckInterval > 0 && time.Since(p.lastSettingsCheck[tenantID]) > p.settingsCheckInterval
+ if shouldRevalidate {
+ // Update timestamp BEFORE spawning goroutine to prevent multiple
+ // concurrent revalidation checks for the same tenant.
+ p.lastSettingsCheck[tenantID] = now
+ }
+
+ p.mu.Unlock()
+
+ if shouldRevalidate {
+ p.revalidateWG.Go(func() { //#nosec G118 -- intentional: revalidatePoolSettings creates its own timeout context; must not use request-scoped context as this outlives the request
+ p.revalidatePoolSettings(tenantID)
+ })
+ }
+
+ return conn, nil
+ }
+
+ p.mu.RUnlock()
+
+ return p.createConnection(ctx, tenantID)
+}
+
+// revalidatePoolSettings fetches fresh config from the Tenant Manager and applies
+// updated connection pool settings to the cached connection for the given tenant.
+// This runs asynchronously (in a goroutine) and must never block GetConnection.
+// If the fetch fails, a warning is logged but the connection remains usable.
+func (p *Manager) revalidatePoolSettings(tenantID string) {
+ // Guard: recover from any panic to avoid crashing the process.
+ // This goroutine runs asynchronously and must never bring down the service.
+ defer func() {
+ if r := recover(); r != nil {
+ if p.logger != nil {
+ p.logger.Warnf("recovered from panic during settings revalidation for tenant %s: %v", tenantID, r)
+ }
+ }
+ }()
+
+ revalidateCtx, cancel := context.WithTimeout(context.Background(), settingsRevalidationTimeout)
+ defer cancel()
+
+ config, err := p.client.GetTenantConfig(revalidateCtx, tenantID, p.service, client.WithSkipCache())
+ if err != nil {
+ // If tenant service was suspended/purged, evict the cached connection immediately.
+ // The next request for this tenant will call createConnection, which fetches fresh
+ // config from the Tenant Manager and receives the 403 error directly.
+ if core.IsTenantSuspendedError(err) {
+ if p.logger != nil {
+ p.logger.Warnf("tenant %s service suspended, evicting cached connection", tenantID)
+ }
+
+ _ = p.CloseConnection(context.Background(), tenantID)
+
+ return
+ }
+
+ if p.logger != nil {
+ p.logger.Warnf("failed to revalidate connection settings for tenant %s: %v", tenantID, err)
+ }
+
+ return
+ }
+
+ p.ApplyConnectionSettings(tenantID, config)
+}
+
+// createConnection fetches config from Tenant Manager and creates a connection.
+func (p *Manager) createConnection(ctx context.Context, tenantID string) (*PostgresConnection, error) {
+ if p.client == nil {
+ return nil, errors.New("tenant manager client is required for multi-tenant connections")
+ }
+
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "postgres.create_connection")
+ defer span.End()
+
+ p.mu.Lock()
+ if cachedConn, ok := p.tryReuseOrEvictCachedConnectionLocked(ctx, tenantID, logger); ok {
+ p.mu.Unlock()
+
+ return cachedConn, nil
+ }
+
+ if p.closed {
+ p.mu.Unlock()
+
+ return nil, core.ErrManagerClosed
+ }
+
+ p.mu.Unlock()
+
+ config, pgConfig, err := p.getPostgresConfigForTenant(ctx, tenantID, logger, span)
+ if err != nil {
+ return nil, err
+ }
+
+ conn, err := p.buildTenantPostgresConnection(ctx, tenantID, config, pgConfig, logger, span)
+ if err != nil {
+ return nil, err
+ }
+
+ return p.cacheConnection(ctx, tenantID, conn, logger, config.IsolationMode)
+}
+
+func (p *Manager) tryReuseOrEvictCachedConnectionLocked(
+ ctx context.Context,
+ tenantID string,
+ logger *logcompat.Logger,
+) (*PostgresConnection, bool) {
+ conn, ok := p.connections[tenantID]
+ if !ok {
+ return nil, false
+ }
+
+ if conn != nil && conn.ConnectionDB != nil {
+ pingCtx, cancel := context.WithTimeout(ctx, pingTimeout)
+ pingErr := (*conn.ConnectionDB).PingContext(pingCtx)
+
+ cancel()
+
+ if pingErr == nil {
+ return conn, true
+ }
+
+ logger.WarnCtx(ctx, fmt.Sprintf("cached postgres connection unhealthy for tenant %s after lock, reconnecting: %v", tenantID, pingErr))
+
+ _ = (*conn.ConnectionDB).Close()
+ }
+
+ delete(p.connections, tenantID)
+ delete(p.lastAccessed, tenantID)
+ delete(p.lastSettingsCheck, tenantID)
+
+ return nil, false
+}
+
+func (p *Manager) getPostgresConfigForTenant(
+ ctx context.Context,
+ tenantID string,
+ logger *logcompat.Logger,
+ span trace.Span,
+) (*core.TenantConfig, *core.PostgreSQLConfig, error) {
+ config, err := p.client.GetTenantConfig(ctx, tenantID, p.service)
+ if err != nil {
+ var suspErr *core.TenantSuspendedError
+ if errors.As(err, &suspErr) {
+ logger.WarnCtx(ctx, fmt.Sprintf("tenant service is %s: tenantID=%s", suspErr.Status, tenantID))
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "tenant service suspended", err)
+
+ return nil, nil, err
+ }
+
+ logger.ErrorCtx(ctx, fmt.Sprintf("failed to get tenant config: %v", err))
+ libOpentelemetry.HandleSpanError(span, "failed to get tenant config", err)
+
+ return nil, nil, fmt.Errorf("failed to get tenant config: %w", err)
+ }
+
+ pgConfig := config.GetPostgreSQLConfig(p.service, p.module)
+ if pgConfig == nil {
+ logger.ErrorCtx(ctx, fmt.Sprintf("no PostgreSQL config for tenant %s service %s module %s", tenantID, p.service, p.module))
+
+ return nil, nil, core.ErrServiceNotConfigured
+ }
+
+ return config, pgConfig, nil
+}
+
+func (p *Manager) buildTenantPostgresConnection(
+ ctx context.Context,
+ tenantID string,
+ config *core.TenantConfig,
+ pgConfig *core.PostgreSQLConfig,
+ logger *logcompat.Logger,
+ span trace.Span,
+) (*PostgresConnection, error) {
+ primaryConnStr, err := buildConnectionString(pgConfig)
+ if err != nil {
+ logger.ErrorCtx(ctx, fmt.Sprintf("invalid connection string for tenant %s: %v", tenantID, err))
+ libOpentelemetry.HandleSpanError(span, "invalid connection string", err)
+
+ return nil, fmt.Errorf("invalid connection string for tenant %s: %w", tenantID, err)
+ }
+
+ replicaConnStr, replicaDBName, err := p.resolveReplicaConnection(config, pgConfig, primaryConnStr, tenantID, logger)
+ if err != nil {
+ libOpentelemetry.HandleSpanError(span, "invalid replica connection string", err)
+
+ return nil, fmt.Errorf("invalid replica connection string for tenant %s: %w", tenantID, err)
+ }
+
+ maxOpen, maxIdle := p.resolveConnectionPoolSettings(config, tenantID, logger)
+
+ conn := &PostgresConnection{
+ ConnectionStringPrimary: primaryConnStr,
+ ConnectionStringReplica: replicaConnStr,
+ PrimaryDBName: pgConfig.Database,
+ ReplicaDBName: replicaDBName,
+ MaxOpenConnections: maxOpen,
+ MaxIdleConnections: maxIdle,
+ SkipMigrations: p.IsMultiTenant(),
+ }
+
+ if p.logger != nil {
+ conn.Logger = p.logger.Base()
+ }
+
+ if config.IsSchemaMode() && pgConfig.Schema == "" {
+ logger.ErrorCtx(ctx, "schema mode requires schema in config for tenant "+tenantID)
+
+ return nil, fmt.Errorf("schema mode requires schema in config for tenant %s", tenantID)
+ }
+
+ if err := conn.Connect(ctx); err != nil {
+ logger.ErrorCtx(ctx, fmt.Sprintf("failed to connect to tenant database: %v", err))
+ libOpentelemetry.HandleSpanError(span, "failed to connect", err)
+
+ return nil, fmt.Errorf("failed to connect to tenant database: %w", err)
+ }
+
+ if pgConfig.Schema != "" {
+ logger.InfoCtx(ctx, fmt.Sprintf("connection configured with search_path=%s for tenant %s (mode: %s)", pgConfig.Schema, tenantID, config.IsolationMode))
+ }
+
+ return conn, nil
+}
+
+func (p *Manager) cacheConnection(
+ ctx context.Context,
+ tenantID string,
+ conn *PostgresConnection,
+ logger *logcompat.Logger,
+ isolationMode string,
+) (*PostgresConnection, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.closed {
+ if conn.ConnectionDB != nil {
+ _ = (*conn.ConnectionDB).Close()
+ }
+
+ return nil, core.ErrManagerClosed
+ }
+
+ if cached, ok := p.connections[tenantID]; ok && cached != nil && cached.ConnectionDB != nil {
+ if conn.ConnectionDB != nil {
+ _ = (*conn.ConnectionDB).Close()
+ }
+
+ p.lastAccessed[tenantID] = time.Now()
+
+ return cached, nil
+ }
+
+ p.evictLRU(ctx, logger.Base())
+
+ p.connections[tenantID] = conn
+ p.lastAccessed[tenantID] = time.Now()
+
+ logger.InfoCtx(ctx, fmt.Sprintf("created connection for tenant %s (mode: %s)", tenantID, isolationMode))
+
+ return conn, nil
+}
+
+// resolveReplicaConnection resolves the replica connection string and database name.
+// If a dedicated replica config exists for the service/module, it builds a separate
+// connection string; otherwise it falls back to the primary connection string and database.
+func (p *Manager) resolveReplicaConnection(
+ config *core.TenantConfig,
+ pgConfig *core.PostgreSQLConfig,
+ primaryConnStr string,
+ tenantID string,
+ logger *logcompat.Logger,
+) (connStr string, dbName string, err error) {
+ pgReplicaConfig := config.GetPostgreSQLReplicaConfig(p.service, p.module)
+ if pgReplicaConfig == nil {
+ return primaryConnStr, pgConfig.Database, nil
+ }
+
+ replicaConnStr, buildErr := buildConnectionString(pgReplicaConfig)
+ if buildErr != nil {
+ logger.Errorf("invalid replica connection string for tenant %s: %v", tenantID, buildErr)
+ return "", "", buildErr
+ }
+
+ logger.Infof("using separate replica connection for tenant %s (replica host: %s)", tenantID, pgReplicaConfig.Host)
+
+ return replicaConnStr, pgReplicaConfig.Database, nil
+}
+
+// resolveConnectionSettingsFromConfig extracts connection settings from the tenant config,
+// checking module-level settings first, then top-level for backward compatibility.
+func (p *Manager) resolveConnectionSettingsFromConfig(config *core.TenantConfig) *core.ConnectionSettings {
+ if config == nil {
+ return nil
+ }
+
+ if p.module != "" && config.Databases != nil {
+ if db, ok := config.Databases[p.module]; ok && db.ConnectionSettings != nil {
+ return db.ConnectionSettings
+ }
+ }
+
+ return config.ConnectionSettings
+}
+
+// clampPoolSettings enforces connection pool limits set by WithConnectionLimitCaps.
+func (p *Manager) clampPoolSettings(maxOpen, maxIdle int, tenantID string, logger *logcompat.Logger) (int, int) {
+ if p.maxAllowedOpenConns > 0 && maxOpen > p.maxAllowedOpenConns {
+ if logger != nil {
+ logger.Warnf("clamping maxOpenConns for tenant %s module %s from %d to %d", tenantID, p.module, maxOpen, p.maxAllowedOpenConns)
+ }
+
+ maxOpen = p.maxAllowedOpenConns
+ }
+
+ if p.maxAllowedIdleConns > 0 && maxIdle > p.maxAllowedIdleConns {
+ if logger != nil {
+ logger.Warnf("clamping maxIdleConns for tenant %s module %s from %d to %d", tenantID, p.module, maxIdle, p.maxAllowedIdleConns)
+ }
+
+ maxIdle = p.maxAllowedIdleConns
+ }
+
+ return maxOpen, maxIdle
+}
+
+// resolveConnectionPoolSettings determines the effective maxOpen and maxIdle connection
+// settings for a tenant. It checks module-level settings first (new format), then falls
+// back to top-level settings (legacy), and finally uses global defaults.
+func (p *Manager) resolveConnectionPoolSettings(config *core.TenantConfig, tenantID string, logger *logcompat.Logger) (maxOpen, maxIdle int) {
+ maxOpen = p.maxOpenConns
+ maxIdle = p.maxIdleConns
+
+ connSettings := p.resolveConnectionSettingsFromConfig(config)
+
+ if connSettings != nil {
+ if connSettings.MaxOpenConns > 0 {
+ maxOpen = connSettings.MaxOpenConns
+ logger.Infof("applying per-module maxOpenConns=%d for tenant %s module %s (global default: %d)", maxOpen, tenantID, p.module, p.maxOpenConns)
+ } else {
+ // connectionSettings present but MaxOpenConns is zero: restore manager default
+ maxOpen = p.maxOpenConns
+ }
+
+ if connSettings.MaxIdleConns > 0 {
+ maxIdle = connSettings.MaxIdleConns
+ logger.Infof("applying per-module maxIdleConns=%d for tenant %s module %s (global default: %d)", maxIdle, tenantID, p.module, p.maxIdleConns)
+ } else {
+ // connectionSettings present but MaxIdleConns is zero: restore manager default
+ maxIdle = p.maxIdleConns
+ }
+ }
+
+ maxOpen, maxIdle = p.clampPoolSettings(maxOpen, maxIdle, tenantID, logger)
+
+ return maxOpen, maxIdle
+}
+
+// evictLRU removes the least recently used idle connection when the pool reaches the
+// soft limit. Only connections that have been idle longer than the idle timeout are
+// eligible for eviction. If all connections are active (used within the idle timeout),
+// the pool is allowed to grow beyond the soft limit.
+// Caller MUST hold p.mu write lock.
+func (p *Manager) evictLRU(_ context.Context, logger libLog.Logger) {
+ candidateID, shouldEvict := eviction.FindLRUEvictionCandidate(
+ len(p.connections), p.maxConnections, p.lastAccessed, p.idleTimeout, logger,
+ )
+ if !shouldEvict {
+ return
+ }
+
+ // Manager-specific cleanup: close the postgres connection and remove from all maps.
+ if conn, ok := p.connections[candidateID]; ok {
+ if conn.ConnectionDB != nil {
+ _ = (*conn.ConnectionDB).Close()
+ }
+
+ delete(p.connections, candidateID)
+ delete(p.lastAccessed, candidateID)
+ delete(p.lastSettingsCheck, candidateID)
+ }
+}
+
+// GetDB returns a dbresolver.DB for the tenant.
+func (p *Manager) GetDB(ctx context.Context, tenantID string) (dbresolver.DB, error) {
+ conn, err := p.GetConnection(ctx, tenantID)
+ if err != nil {
+ return nil, err
+ }
+
+ return conn.GetDB()
+}
+
+// Close closes all connections and marks the manager as closed.
+// It waits for any in-flight revalidatePoolSettings goroutines to finish
+// before returning, preventing goroutine leaks and use-after-close races.
+func (p *Manager) Close(_ context.Context) error {
+ // Phase 1: Under lock, mark closed and close all connections.
+ p.mu.Lock()
+
+ p.closed = true
+
+ var errs []error
+
+ for tenantID, conn := range p.connections {
+ if conn.ConnectionDB != nil {
+ if err := (*conn.ConnectionDB).Close(); err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ delete(p.connections, tenantID)
+ delete(p.lastAccessed, tenantID)
+ delete(p.lastSettingsCheck, tenantID)
+ }
+
+ p.mu.Unlock()
+
+ // Phase 2: Wait for in-flight revalidatePoolSettings goroutines OUTSIDE the lock.
+ // revalidatePoolSettings acquires p.mu internally (via CloseConnection and
+ // ApplyConnectionSettings), so waiting with the lock held would deadlock.
+ p.revalidateWG.Wait()
+
+ return errors.Join(errs...)
+}
+
+// CloseConnection closes the connection for a specific tenant.
+func (p *Manager) CloseConnection(_ context.Context, tenantID string) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ conn, ok := p.connections[tenantID]
+ if !ok {
+ return nil
+ }
+
+ var err error
+ if conn.ConnectionDB != nil {
+ err = (*conn.ConnectionDB).Close()
+ }
+
+ delete(p.connections, tenantID)
+ delete(p.lastAccessed, tenantID)
+ delete(p.lastSettingsCheck, tenantID)
+
+ return err
+}
+
+// Stats returns connection statistics.
+func (p *Manager) Stats() Stats {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ tenantIDs := make([]string, 0, len(p.connections))
+ for id := range p.connections {
+ tenantIDs = append(tenantIDs, id)
+ }
+
+ totalConns := len(p.connections)
+
+ now := time.Now()
+
+ idleTimeout := p.idleTimeout
+ if idleTimeout == 0 {
+ idleTimeout = defaultIdleTimeout
+ }
+
+ activeCount := 0
+
+ for id := range p.connections {
+ if t, ok := p.lastAccessed[id]; ok && now.Sub(t) < idleTimeout {
+ activeCount++
+ }
+ }
+
+ return Stats{
+ TotalConnections: totalConns,
+ ActiveConnections: activeCount,
+ MaxConnections: p.maxConnections,
+ TenantIDs: tenantIDs,
+ Closed: p.closed,
+ }
+}
+
+// validSchemaPattern validates PostgreSQL schema names to prevent injection
+// in the options=-csearch_path= connection string parameter.
+var validSchemaPattern = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
+
+func buildConnectionString(cfg *core.PostgreSQLConfig) (string, error) {
+ if cfg == nil {
+ return "", fmt.Errorf("postgres.buildConnectionString: %w", core.ErrNilConfig)
+ }
+
+ sslmode := cfg.SSLMode
+ if sslmode == "" {
+ // Default is "disable" for local development compatibility.
+ // Production deployments should set SSLMode explicitly in PostgreSQLConfig.
+ sslmode = "disable"
+ }
+
+ // Reject contradictory configuration: SSL is disabled but certificate
+ // paths are provided. This likely indicates a misconfiguration that would
+ // silently ignore the supplied certificates.
+ if sslmode == "disable" && (cfg.SSLRootCert != "" || cfg.SSLCert != "" || cfg.SSLKey != "") {
+ return "", fmt.Errorf("sslmode is %q but SSL certificate parameters are set (sslrootcert=%q, sslcert=%q, sslkey=%q); "+
+ "either remove the certificate paths or use a TLS-enabled sslmode", sslmode, cfg.SSLRootCert, cfg.SSLCert, cfg.SSLKey)
+ }
+
+ connURL := &url.URL{
+ Scheme: "postgres",
+ Host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
+ Path: "/" + cfg.Database,
+ }
+
+ if cfg.Username != "" {
+ connURL.User = url.UserPassword(cfg.Username, cfg.Password)
+ }
+
+ values := url.Values{}
+ values.Set("sslmode", sslmode)
+
+ if cfg.SSLRootCert != "" {
+ values.Set("sslrootcert", cfg.SSLRootCert)
+ }
+
+ if cfg.SSLCert != "" {
+ values.Set("sslcert", cfg.SSLCert)
+ }
+
+ if cfg.SSLKey != "" {
+ values.Set("sslkey", cfg.SSLKey)
+ }
+
+ if cfg.Schema != "" {
+ if !validSchemaPattern.MatchString(cfg.Schema) {
+ return "", fmt.Errorf("invalid schema name %q: must match %s", cfg.Schema, validSchemaPattern.String())
+ }
+
+ values.Set("options", "-csearch_path="+cfg.Schema)
+ }
+
+ connURL.RawQuery = values.Encode()
+
+ return connURL.String(), nil
+}
+
+// ApplyConnectionSettings applies updated connection pool settings to an existing
+// cached connection for the given tenant without recreating the connection.
+// This is called during the sync loop to revalidate settings that may have changed
+// in the Tenant Manager (e.g., maxOpenConns adjusted from 10 to 30).
+//
+// Go's sql.DB.SetMaxOpenConns and SetMaxIdleConns are thread-safe and take effect
+// immediately for new connections from the pool. Existing idle connections above the
+// new limit are closed gradually.
+//
+// For MongoDB, the driver does not support changing pool size after client creation,
+// so this method only applies to PostgreSQL connections.
+func (p *Manager) ApplyConnectionSettings(tenantID string, config *core.TenantConfig) {
+ p.mu.RLock()
+
+ conn, ok := p.connections[tenantID]
+ if !ok || conn == nil || conn.ConnectionDB == nil {
+ p.mu.RUnlock()
+ return // no cached connection, settings will be applied on next creation
+ }
+
+ connSettings := p.resolveConnectionSettingsFromConfig(config)
+
+ // Determine effective settings: per-tenant if present, otherwise manager defaults.
+ var maxOpen, maxIdle int
+ if connSettings != nil {
+ maxOpen = connSettings.MaxOpenConns
+ maxIdle = connSettings.MaxIdleConns
+ }
+
+ // Fallback to manager defaults for absent/zero values
+ if maxOpen <= 0 {
+ maxOpen = p.maxOpenConns
+ }
+
+ if maxIdle <= 0 {
+ maxIdle = p.maxIdleConns
+ }
+
+ db := *conn.ConnectionDB
+
+ p.mu.RUnlock() // Release before thread-safe sql.DB operations
+
+ compatLogger := logcompat.New(p.logger.Base())
+ maxOpen, maxIdle = p.clampPoolSettings(maxOpen, maxIdle, tenantID, compatLogger)
+
+ compatLogger.Infof("applying connection settings for tenant %s module %s: maxOpenConns=%d, maxIdleConns=%d",
+ tenantID, p.module, maxOpen, maxIdle)
+
+ db.SetMaxOpenConns(maxOpen)
+ db.SetMaxIdleConns(maxIdle)
+}
+
+// WithConnectionLimits sets the default per-tenant connection limits.
+func WithConnectionLimits(maxOpen, maxIdle int) Option {
+ return func(p *Manager) {
+ p.maxOpenConns = maxOpen
+ p.maxIdleConns = maxIdle
+ }
+}
+
+// WithDefaultConnection sets a default connection used in single-tenant mode.
+func WithDefaultConnection(conn *PostgresConnection) Option {
+ return func(p *Manager) {
+ p.defaultConn = conn
+ }
+}
+
+// Deprecated: prefer NewManager(..., WithConnectionLimits(...)).
+func (p *Manager) WithConnectionLimits(maxOpen, maxIdle int) *Manager {
+ WithConnectionLimits(maxOpen, maxIdle)(p)
+ return p
+}
+
+// Deprecated: prefer NewManager(..., WithDefaultConnection(...)).
+func (p *Manager) WithDefaultConnection(conn *PostgresConnection) *Manager {
+ WithDefaultConnection(conn)(p)
+ return p
+}
+
+// GetDefaultConnection returns the default connection configured for single-tenant mode.
+func (p *Manager) GetDefaultConnection() *PostgresConnection {
+ return p.defaultConn
+}
+
+// IsMultiTenant returns true if the manager is configured with a Tenant Manager client.
+func (p *Manager) IsMultiTenant() bool {
+ return p.client != nil
+}
+
+// CreateDirectConnection creates a direct database connection from config.
+// Useful when you have config but don't need full connection management.
+// Returns an error if cfg is nil.
+func CreateDirectConnection(ctx context.Context, cfg *core.PostgreSQLConfig) (*sql.DB, error) {
+ if cfg == nil {
+ return nil, fmt.Errorf("postgres.CreateDirectConnection: %w", core.ErrNilConfig)
+ }
+
+ connStr, err := buildConnectionString(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("invalid connection config: %w", err)
+ }
+
+ db, err := sql.Open("pgx", connStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open connection: %w", err)
+ }
+
+ if err := db.PingContext(ctx); err != nil {
+ _ = db.Close()
+ return nil, fmt.Errorf("failed to ping database: %w", err)
+ }
+
+ return db, nil
+}
diff --git a/commons/tenant-manager/postgres/manager_test.go b/commons/tenant-manager/postgres/manager_test.go
new file mode 100644
index 00000000..07badc52
--- /dev/null
+++ b/commons/tenant-manager/postgres/manager_test.go
@@ -0,0 +1,1823 @@
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/bxcodec/dbresolver/v2"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// mustNewTestClient creates a test client or fails the test immediately.
+// Centralises the repeated client.NewClient + error-check boilerplate.
+// Tests use httptest servers (http://), so WithAllowInsecureHTTP is applied.
+func mustNewTestClient(t testing.TB, baseURL string) *client.Client {
+ t.Helper()
+ c, err := client.NewClient(baseURL, testutil.NewMockLogger(), client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+ return c
+}
+
+// pingableDB implements dbresolver.DB with configurable PingContext behavior
+// for testing connection health check logic.
+type pingableDB struct {
+ pingErr error
+ closed bool
+}
+
+var _ dbresolver.DB = (*pingableDB)(nil)
+
+func (m *pingableDB) Begin() (dbresolver.Tx, error) { return nil, nil }
+func (m *pingableDB) BeginTx(_ context.Context, _ *sql.TxOptions) (dbresolver.Tx, error) {
+ return nil, nil
+}
+func (m *pingableDB) Close() error { m.closed = true; return nil }
+func (m *pingableDB) Conn(_ context.Context) (dbresolver.Conn, error) { return nil, nil }
+func (m *pingableDB) Driver() driver.Driver { return nil }
+func (m *pingableDB) Exec(_ string, _ ...interface{}) (sql.Result, error) { return nil, nil }
+func (m *pingableDB) ExecContext(_ context.Context, _ string, _ ...interface{}) (sql.Result, error) {
+ return nil, nil
+}
+func (m *pingableDB) Ping() error { return m.pingErr }
+func (m *pingableDB) PingContext(_ context.Context) error { return m.pingErr }
+func (m *pingableDB) Prepare(_ string) (dbresolver.Stmt, error) { return nil, nil }
+func (m *pingableDB) PrepareContext(_ context.Context, _ string) (dbresolver.Stmt, error) {
+ return nil, nil
+}
+func (m *pingableDB) Query(_ string, _ ...interface{}) (*sql.Rows, error) { return nil, nil }
+func (m *pingableDB) QueryContext(_ context.Context, _ string, _ ...interface{}) (*sql.Rows, error) {
+ return nil, nil
+}
+func (m *pingableDB) QueryRow(_ string, _ ...interface{}) *sql.Row { return nil }
+func (m *pingableDB) QueryRowContext(_ context.Context, _ string, _ ...interface{}) *sql.Row {
+ return nil
+}
+func (m *pingableDB) SetConnMaxIdleTime(_ time.Duration) {}
+func (m *pingableDB) SetConnMaxLifetime(_ time.Duration) {}
+func (m *pingableDB) SetMaxIdleConns(_ int) {}
+func (m *pingableDB) SetMaxOpenConns(_ int) {}
+func (m *pingableDB) PrimaryDBs() []*sql.DB { return nil }
+func (m *pingableDB) ReplicaDBs() []*sql.DB { return nil }
+func (m *pingableDB) Stats() sql.DBStats { return sql.DBStats{} }
+
+// trackingDB extends pingableDB to track SetMaxOpenConns/SetMaxIdleConns calls.
+// Fields use int32 with atomic operations to avoid data races when written
+// by async goroutines (revalidatePoolSettings) and read by test assertions.
+type trackingDB struct {
+ pingableDB
+ maxOpenConns int32
+ maxIdleConns int32
+}
+
+func (t *trackingDB) SetMaxOpenConns(n int) { atomic.StoreInt32(&t.maxOpenConns, int32(n)) }
+func (t *trackingDB) SetMaxIdleConns(n int) { atomic.StoreInt32(&t.maxIdleConns, int32(n)) }
+func (t *trackingDB) MaxOpenConns() int32 { return atomic.LoadInt32(&t.maxOpenConns) }
+func (t *trackingDB) MaxIdleConns() int32 { return atomic.LoadInt32(&t.maxIdleConns) }
+
+func TestNewManager(t *testing.T) {
+ t.Run("creates manager with client and service", func(t *testing.T) {
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ assert.NotNil(t, manager)
+ assert.Equal(t, "ledger", manager.service)
+ assert.NotNil(t, manager.connections)
+ })
+}
+
+func TestManager_GetConnection_NoTenantID(t *testing.T) {
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ _, err := manager.GetConnection(context.Background(), "")
+
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "tenant ID is required")
+}
+
+func TestManager_Close(t *testing.T) {
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ err := manager.Close(context.Background())
+
+ assert.NoError(t, err)
+ assert.True(t, manager.closed)
+}
+
+func TestManager_GetConnection_ManagerClosed(t *testing.T) {
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+ manager.Close(context.Background())
+
+ _, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.Error(t, err)
+ assert.ErrorIs(t, err, core.ErrManagerClosed)
+}
+
+func TestIsolationModeConstants(t *testing.T) {
+ t.Run("isolation mode constants have expected values", func(t *testing.T) {
+ assert.Equal(t, "isolated", IsolationModeIsolated)
+ assert.Equal(t, "schema", IsolationModeSchema)
+ })
+}
+
+func TestBuildConnectionString(t *testing.T) {
+ tests := []struct {
+ name string
+ cfg *core.PostgreSQLConfig
+ expected string
+ }{
+ {
+ name: "builds connection string without schema",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ },
+ expected: "postgres://user:pass@localhost:5432/testdb?sslmode=disable",
+ },
+ {
+ name: "builds connection string with schema in options",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ Schema: "tenant_abc",
+ },
+ expected: "postgres://user:pass@localhost:5432/testdb?options=-csearch_path%3Dtenant_abc&sslmode=disable",
+ },
+ {
+ name: "defaults sslmode to disable when empty",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ },
+ expected: "postgres://user:pass@localhost:5432/testdb?sslmode=disable",
+ },
+ {
+ name: "uses provided sslmode",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "require",
+ },
+ expected: "postgres://user:pass@localhost:5432/testdb?sslmode=require",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := buildConnectionString(tt.cfg)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestBuildConnectionString_SSLCertificates(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *core.PostgreSQLConfig
+ contains []string
+ excludes []string
+ }{
+ {
+ name: "adds sslrootcert when SSLRootCert is set",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "verify-full",
+ SSLRootCert: "/etc/ssl/ca.pem",
+ },
+ contains: []string{"sslmode=verify-full", "sslrootcert=%2Fetc%2Fssl%2Fca.pem"},
+ },
+ {
+ name: "adds sslcert and sslkey when both are set",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "verify-full",
+ SSLRootCert: "/etc/ssl/ca.pem",
+ SSLCert: "/etc/ssl/client-cert.pem",
+ SSLKey: "/etc/ssl/client-key.pem",
+ },
+ contains: []string{
+ "sslmode=verify-full",
+ "sslrootcert=%2Fetc%2Fssl%2Fca.pem",
+ "sslcert=%2Fetc%2Fssl%2Fclient-cert.pem",
+ "sslkey=%2Fetc%2Fssl%2Fclient-key.pem",
+ },
+ },
+ {
+ name: "does not add ssl cert params when not set",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "require",
+ },
+ contains: []string{"sslmode=require"},
+ excludes: []string{"sslrootcert", "sslcert=", "sslkey="},
+ },
+ {
+ name: "adds only sslrootcert without client certs",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "verify-ca",
+ SSLRootCert: "/etc/ssl/ca.pem",
+ },
+ contains: []string{"sslmode=verify-ca", "sslrootcert=%2Fetc%2Fssl%2Fca.pem"},
+ excludes: []string{"sslcert=", "sslkey="},
+ },
+ {
+ name: "ssl cert params work with schema mode",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "verify-full",
+ SSLRootCert: "/etc/ssl/ca.pem",
+ Schema: "tenant_abc",
+ },
+ contains: []string{
+ "sslmode=verify-full",
+ "sslrootcert=%2Fetc%2Fssl%2Fca.pem",
+ "options=-csearch_path%3Dtenant_abc",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := buildConnectionString(tt.cfg)
+
+ require.NoError(t, err)
+
+ for _, s := range tt.contains {
+ assert.Contains(t, result, s, "connection string should contain %q", s)
+ }
+
+ for _, s := range tt.excludes {
+ assert.NotContains(t, result, s, "connection string should NOT contain %q", s)
+ }
+ })
+ }
+}
+
+func TestBuildConnectionString_SSLModeDisableWithCerts(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *core.PostgreSQLConfig
+ }{
+ {
+ name: "rejects sslmode=disable with SSLRootCert set",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ SSLRootCert: "/etc/ssl/ca.pem",
+ },
+ },
+ {
+ name: "rejects sslmode=disable with SSLCert set",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ SSLCert: "/etc/ssl/client-cert.pem",
+ },
+ },
+ {
+ name: "rejects sslmode=disable with SSLKey set",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ SSLKey: "/etc/ssl/client-key.pem",
+ },
+ },
+ {
+ name: "rejects sslmode=disable with all SSL cert fields set",
+ cfg: &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ SSLRootCert: "/etc/ssl/ca.pem",
+ SSLCert: "/etc/ssl/client-cert.pem",
+ SSLKey: "/etc/ssl/client-key.pem",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := buildConnectionString(tt.cfg)
+
+ require.Error(t, err)
+ assert.Empty(t, result)
+ assert.Contains(t, err.Error(), "sslmode is \"disable\" but SSL certificate parameters are set")
+ })
+ }
+
+ t.Run("allows sslmode=disable without cert fields", func(t *testing.T) {
+ t.Parallel()
+
+ cfg := &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ }
+
+ result, err := buildConnectionString(cfg)
+
+ require.NoError(t, err)
+ assert.Contains(t, result, "sslmode=disable")
+ })
+}
+
+func TestBuildConnectionString_InvalidSchema(t *testing.T) {
+ tests := []struct {
+ name string
+ schema string
+ }{
+ {
+ name: "rejects schema with SQL injection attempt",
+ schema: "public; DROP TABLE users--",
+ },
+ {
+ name: "rejects schema with spaces",
+ schema: "my schema",
+ },
+ {
+ name: "rejects schema with special characters",
+ schema: "tenant-abc",
+ },
+ {
+ name: "rejects schema starting with a digit",
+ schema: "1tenant",
+ },
+ {
+ name: "rejects schema with double quotes",
+ schema: `"public"`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &core.PostgreSQLConfig{
+ Host: "localhost",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ Schema: tt.schema,
+ }
+
+ result, err := buildConnectionString(cfg)
+
+ require.Error(t, err)
+ assert.Empty(t, result)
+ assert.Contains(t, err.Error(), "invalid schema name")
+ })
+ }
+}
+
+func TestBuildConnectionStrings_PrimaryAndReplica(t *testing.T) {
+ t.Run("builds separate connection strings for primary and replica", func(t *testing.T) {
+ primaryConfig := &core.PostgreSQLConfig{
+ Host: "primary-host",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ }
+ replicaConfig := &core.PostgreSQLConfig{
+ Host: "replica-host",
+ Port: 5433,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ SSLMode: "disable",
+ }
+
+ primaryConnStr, err := buildConnectionString(primaryConfig)
+ require.NoError(t, err)
+ replicaConnStr, err := buildConnectionString(replicaConfig)
+ require.NoError(t, err)
+
+ assert.Contains(t, primaryConnStr, "postgres://user:pass@primary-host:5432/")
+ assert.Contains(t, replicaConnStr, "postgres://user:pass@replica-host:5433/")
+ assert.NotEqual(t, primaryConnStr, replicaConnStr)
+ })
+
+ t.Run("fallback to primary when replica not configured", func(t *testing.T) {
+ config := &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &core.PostgreSQLConfig{
+ Host: "primary-host",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ },
+ // No PostgreSQLReplica configured
+ },
+ },
+ }
+
+ pgConfig := config.GetPostgreSQLConfig("ledger", "onboarding")
+ pgReplicaConfig := config.GetPostgreSQLReplicaConfig("ledger", "onboarding")
+
+ assert.NotNil(t, pgConfig)
+ assert.Nil(t, pgReplicaConfig)
+
+ // When replica is nil, system should use primary connection string
+ primaryConnStr, err := buildConnectionString(pgConfig)
+ require.NoError(t, err)
+
+ replicaConnStr := primaryConnStr
+ if pgReplicaConfig != nil {
+ var replicaErr error
+ replicaConnStr, replicaErr = buildConnectionString(pgReplicaConfig)
+ require.NoError(t, replicaErr)
+ }
+
+ assert.Equal(t, primaryConnStr, replicaConnStr)
+ })
+
+ t.Run("uses replica config when available", func(t *testing.T) {
+ config := &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &core.PostgreSQLConfig{
+ Host: "primary-host",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ },
+ PostgreSQLReplica: &core.PostgreSQLConfig{
+ Host: "replica-host",
+ Port: 5433,
+ Username: "user",
+ Password: "pass",
+ Database: "testdb",
+ },
+ },
+ },
+ }
+
+ pgConfig := config.GetPostgreSQLConfig("ledger", "onboarding")
+ pgReplicaConfig := config.GetPostgreSQLReplicaConfig("ledger", "onboarding")
+
+ assert.NotNil(t, pgConfig)
+ assert.NotNil(t, pgReplicaConfig)
+
+ primaryConnStr, err := buildConnectionString(pgConfig)
+ require.NoError(t, err)
+
+ replicaConnStr := primaryConnStr
+ if pgReplicaConfig != nil {
+ var replicaErr error
+ replicaConnStr, replicaErr = buildConnectionString(pgReplicaConfig)
+ require.NoError(t, replicaErr)
+ }
+
+ assert.NotEqual(t, primaryConnStr, replicaConnStr)
+ assert.Contains(t, primaryConnStr, "postgres://user:pass@primary-host:5432/")
+ assert.Contains(t, replicaConnStr, "postgres://user:pass@replica-host:5433/")
+ })
+
+ t.Run("handles replica with different database name", func(t *testing.T) {
+ config := &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &core.PostgreSQLConfig{
+ Host: "primary-host",
+ Port: 5432,
+ Username: "user",
+ Password: "pass",
+ Database: "primary_db",
+ },
+ PostgreSQLReplica: &core.PostgreSQLConfig{
+ Host: "replica-host",
+ Port: 5433,
+ Username: "user",
+ Password: "pass",
+ Database: "replica_db",
+ },
+ },
+ },
+ }
+
+ pgConfig := config.GetPostgreSQLConfig("ledger", "onboarding")
+ pgReplicaConfig := config.GetPostgreSQLReplicaConfig("ledger", "onboarding")
+
+ assert.Equal(t, "primary_db", pgConfig.Database)
+ assert.Equal(t, "replica_db", pgReplicaConfig.Database)
+ })
+}
+
+func TestManager_GetConnection_HealthyCache(t *testing.T) {
+ t.Run("returns cached connection when ping succeeds", func(t *testing.T) {
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ // Pre-populate cache with a healthy connection
+ healthyDB := &pingableDB{pingErr: nil}
+ var db dbresolver.DB = healthyDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.connections["tenant-123"] = cachedConn
+
+ conn, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+ assert.Equal(t, cachedConn, conn)
+ })
+}
+
+func TestManager_GetConnection_UnhealthyCacheEvicts(t *testing.T) {
+ t.Run("evicts cached connection when ping fails", func(t *testing.T) {
+ // Set up a mock Tenant Manager that returns 500 to simulate unavailability
+ // after eviction. The key assertion is that the stale connection is evicted.
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer server.Close()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger", WithLogger(testutil.NewMockLogger()))
+
+ // Pre-populate cache with an unhealthy connection (simulates auth failure after credential rotation)
+ unhealthyDB := &pingableDB{pingErr: errors.New("FATAL: password authentication failed (SQLSTATE 28P01)")}
+ var db dbresolver.DB = unhealthyDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.connections["tenant-123"] = cachedConn
+
+ // GetConnection will try to ping, fail, evict, then call createConnection.
+ // createConnection will fail because mock Tenant Manager returns 500,
+ // but the important thing is the stale connection was evicted.
+ _, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ // Expect an error because createConnection cannot get config from Tenant Manager
+ assert.Error(t, err)
+
+ // Verify the stale connection was evicted from cache
+ manager.mu.RLock()
+ _, exists := manager.connections["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, exists, "stale connection should have been evicted from cache")
+ assert.True(t, unhealthyDB.closed, "stale connection's DB should have been closed")
+ })
+}
+
+func TestManager_GetConnection_SuspendedTenant(t *testing.T) {
+ t.Run("propagates TenantSuspendedError from client", func(t *testing.T) {
+ // Set up a mock Tenant Manager that returns 403 Forbidden for suspended tenants
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte(`{"code":"TS-SUSPENDED","error":"service ledger is suspended for this tenant","status":"suspended"}`))
+ }))
+ defer server.Close()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger", WithLogger(testutil.NewMockLogger()))
+
+ _, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.Error(t, err)
+ assert.True(t, core.IsTenantSuspendedError(err), "expected TenantSuspendedError, got: %T", err)
+
+ var suspErr *core.TenantSuspendedError
+ require.ErrorAs(t, err, &suspErr)
+ assert.Equal(t, "suspended", suspErr.Status)
+ assert.Equal(t, "tenant-123", suspErr.TenantID)
+ })
+}
+
+func TestManager_GetConnection_NilConnectionDB(t *testing.T) {
+ t.Run("returns cached connection when ConnectionDB is nil without ping", func(t *testing.T) {
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ // Pre-populate cache with a connection that has nil ConnectionDB
+ cachedConn := &PostgresConnection{
+ ConnectionDB: nil,
+ }
+ manager.connections["tenant-123"] = cachedConn
+
+ conn, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+ assert.Equal(t, cachedConn, conn)
+ })
+}
+
+func TestManager_EvictLRU(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxConnections int
+ idleTimeout time.Duration
+ preloadCount int
+ oldTenantAge time.Duration // how long ago tenant-old was accessed
+ newTenantAge time.Duration // how long ago tenant-new was accessed
+ expectEviction bool
+ expectedPoolSize int
+ expectedEvictedID string
+ expectedEvictClosed bool
+ }{
+ {
+ name: "evicts oldest idle connection when pool is at soft limit",
+ maxConnections: 2,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: true,
+ expectedPoolSize: 1,
+ expectedEvictedID: "tenant-old",
+ expectedEvictClosed: true,
+ },
+ {
+ name: "does not evict when pool is below soft limit",
+ maxConnections: 3,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 2,
+ },
+ {
+ name: "does not evict when maxConnections is zero (unlimited)",
+ maxConnections: 0,
+ preloadCount: 5,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 5,
+ },
+ {
+ name: "does not evict when all connections are active (within idle timeout)",
+ maxConnections: 2,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 2 * time.Minute, // within 5min idle timeout
+ newTenantAge: 1 * time.Minute, // within 5min idle timeout
+ expectEviction: false,
+ expectedPoolSize: 2,
+ },
+ {
+ name: "respects custom idle timeout",
+ maxConnections: 2,
+ idleTimeout: 30 * time.Second,
+ preloadCount: 2,
+ oldTenantAge: 1 * time.Minute, // beyond 30s idle timeout
+ newTenantAge: 10 * time.Second, // within 30s idle timeout
+ expectEviction: true,
+ expectedPoolSize: 1,
+ expectedEvictedID: "tenant-old",
+ expectedEvictClosed: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ opts := []Option{
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(tt.maxConnections),
+ }
+ if tt.idleTimeout > 0 {
+ opts = append(opts, WithIdleTimeout(tt.idleTimeout))
+ }
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger", opts...)
+
+ // Pre-populate pool with connections
+ if tt.preloadCount >= 1 {
+ oldDB := &pingableDB{}
+ var oldDBIface dbresolver.DB = oldDB
+
+ manager.connections["tenant-old"] = &PostgresConnection{
+ ConnectionDB: &oldDBIface,
+ }
+ manager.lastAccessed["tenant-old"] = time.Now().Add(-tt.oldTenantAge)
+ }
+
+ if tt.preloadCount >= 2 {
+ newDB := &pingableDB{}
+ var newDBIface dbresolver.DB = newDB
+
+ manager.connections["tenant-new"] = &PostgresConnection{
+ ConnectionDB: &newDBIface,
+ }
+ manager.lastAccessed["tenant-new"] = time.Now().Add(-tt.newTenantAge)
+ }
+
+ // For unlimited test, add more connections
+ for i := 2; i < tt.preloadCount; i++ {
+ db := &pingableDB{}
+ var dbIface dbresolver.DB = db
+
+ id := "tenant-extra-" + time.Now().Add(time.Duration(i)*time.Second).Format("150405")
+ manager.connections[id] = &PostgresConnection{
+ ConnectionDB: &dbIface,
+ }
+ manager.lastAccessed[id] = time.Now().Add(-time.Duration(i) * time.Minute)
+ }
+
+ // Call evictLRU (caller must hold write lock)
+ manager.mu.Lock()
+ manager.evictLRU(context.Background(), testutil.NewMockLogger())
+ manager.mu.Unlock()
+
+ // Verify pool size
+ assert.Equal(t, tt.expectedPoolSize, len(manager.connections),
+ "pool size mismatch after eviction")
+
+ if tt.expectEviction {
+ // Verify the oldest tenant was evicted
+ _, exists := manager.connections[tt.expectedEvictedID]
+ assert.False(t, exists,
+ "expected tenant %s to be evicted from pool", tt.expectedEvictedID)
+
+ // Verify lastAccessed was also cleaned up
+ _, accessExists := manager.lastAccessed[tt.expectedEvictedID]
+ assert.False(t, accessExists,
+ "expected lastAccessed entry for %s to be removed", tt.expectedEvictedID)
+ }
+ })
+ }
+}
+
+func TestManager_PoolGrowsBeyondSoftLimit_WhenAllActive(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(2),
+ WithIdleTimeout(5*time.Minute),
+ )
+
+ // Pre-populate with 2 connections, both accessed recently (within idle timeout)
+ for _, id := range []string{"tenant-1", "tenant-2"} {
+ db := &pingableDB{}
+ var dbIface dbresolver.DB = db
+
+ manager.connections[id] = &PostgresConnection{
+ ConnectionDB: &dbIface,
+ }
+ manager.lastAccessed[id] = time.Now().Add(-1 * time.Minute)
+ }
+
+ // Try to evict - should not evict because all connections are active
+ manager.mu.Lock()
+ manager.evictLRU(context.Background(), testutil.NewMockLogger())
+ manager.mu.Unlock()
+
+ // Pool should remain at 2 (no eviction occurred)
+ assert.Equal(t, 2, len(manager.connections),
+ "pool should not shrink when all connections are active")
+
+ // Simulate adding a third connection (pool grows beyond soft limit)
+ db := &pingableDB{}
+ var dbIface dbresolver.DB = db
+
+ manager.connections["tenant-3"] = &PostgresConnection{
+ ConnectionDB: &dbIface,
+ }
+ manager.lastAccessed["tenant-3"] = time.Now()
+
+ assert.Equal(t, 3, len(manager.connections),
+ "pool should grow beyond soft limit when all connections are active")
+}
+
+func TestManager_WithIdleTimeout_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ idleTimeout time.Duration
+ expectedTimeout time.Duration
+ }{
+ {
+ name: "sets custom idle timeout",
+ idleTimeout: 10 * time.Minute,
+ expectedTimeout: 10 * time.Minute,
+ },
+ {
+ name: "sets short idle timeout",
+ idleTimeout: 30 * time.Second,
+ expectedTimeout: 30 * time.Second,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithIdleTimeout(tt.idleTimeout),
+ )
+
+ assert.Equal(t, tt.expectedTimeout, manager.idleTimeout)
+ })
+ }
+}
+
+func TestManager_LRU_LastAccessedUpdatedOnCacheHit(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(5),
+ )
+
+ // Pre-populate cache with a healthy connection
+ healthyDB := &pingableDB{pingErr: nil}
+ var db dbresolver.DB = healthyDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+
+ initialTime := time.Now().Add(-5 * time.Minute)
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = initialTime
+
+ // Access the connection (cache hit)
+ conn, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+ assert.Equal(t, cachedConn, conn)
+
+ // Verify lastAccessed was updated to a more recent time
+ manager.mu.RLock()
+ updatedTime := manager.lastAccessed["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.True(t, updatedTime.After(initialTime),
+ "lastAccessed should be updated after cache hit: initial=%v, updated=%v",
+ initialTime, updatedTime)
+}
+
+func TestManager_CloseConnection_CleansUpLastAccessed(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache
+ healthyDB := &pingableDB{pingErr: nil}
+ var db dbresolver.DB = healthyDB
+
+ manager.connections["tenant-123"] = &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.lastAccessed["tenant-123"] = time.Now()
+
+ // Close the specific tenant connection
+ err := manager.CloseConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+
+ manager.mu.RLock()
+ _, connExists := manager.connections["tenant-123"]
+ _, accessExists := manager.lastAccessed["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, connExists, "connection should be removed after CloseConnection")
+ assert.False(t, accessExists, "lastAccessed should be removed after CloseConnection")
+}
+
+func TestManager_WithMaxTenantPools_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxConnections int
+ expectedMax int
+ }{
+ {
+ name: "sets max connections via option",
+ maxConnections: 10,
+ expectedMax: 10,
+ },
+ {
+ name: "zero means unlimited",
+ maxConnections: 0,
+ expectedMax: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithMaxTenantPools(tt.maxConnections),
+ )
+
+ assert.Equal(t, tt.expectedMax, manager.maxConnections)
+ })
+ }
+}
+
+func TestManager_Stats_IncludesMaxConnections(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithMaxTenantPools(50),
+ )
+
+ stats := manager.Stats()
+
+ assert.Equal(t, 50, stats.MaxConnections)
+ assert.Equal(t, 0, stats.TotalConnections)
+ assert.Equal(t, 0, stats.ActiveConnections)
+}
+
+func TestManager_WithSettingsCheckInterval_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ interval time.Duration
+ expectedInterval time.Duration
+ }{
+ {
+ name: "sets custom settings check interval",
+ interval: 1 * time.Minute,
+ expectedInterval: 1 * time.Minute,
+ },
+ {
+ name: "sets short settings check interval",
+ interval: 5 * time.Second,
+ expectedInterval: 5 * time.Second,
+ },
+ {
+ name: "disables revalidation with zero duration",
+ interval: 0,
+ expectedInterval: 0,
+ },
+ {
+ name: "disables revalidation with negative duration",
+ interval: -1 * time.Second,
+ expectedInterval: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithSettingsCheckInterval(tt.interval),
+ )
+
+ assert.Equal(t, tt.expectedInterval, manager.settingsCheckInterval)
+ })
+ }
+}
+
+func TestManager_DefaultSettingsCheckInterval(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ assert.Equal(t, defaultSettingsCheckInterval, manager.settingsCheckInterval,
+ "default settings check interval should be set from named constant")
+ assert.NotNil(t, manager.lastSettingsCheck,
+ "lastSettingsCheck map should be initialized")
+}
+
+func TestManager_GetConnection_RevalidatesSettingsAfterInterval(t *testing.T) {
+ t.Parallel()
+
+ // Set up a mock Tenant Manager that returns updated connection settings
+ var callCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ // Return config with updated connection settings (maxOpenConns changed to 50)
+ w.Write([]byte(`{
+ "id": "tenant-123",
+ "tenantSlug": "test-tenant",
+ "databases": {
+ "onboarding": {
+ "postgresql": {"host": "localhost", "port": 5432, "database": "testdb", "username": "user", "password": "pass"},
+ "connectionSettings": {"maxOpenConns": 50, "maxIdleConns": 15}
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("onboarding"),
+ // Use a very short interval so the test triggers revalidation immediately
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ // Pre-populate cache with a healthy connection and an old settings check time
+ tDB := &trackingDB{}
+ var db dbresolver.DB = tDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = time.Now()
+ // Set lastSettingsCheck to a time well in the past so revalidation triggers
+ manager.lastSettingsCheck["tenant-123"] = time.Now().Add(-1 * time.Hour)
+
+ // Call GetConnection - should return cached conn AND trigger async revalidation
+ conn, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+ assert.Equal(t, cachedConn, conn, "should return the cached connection")
+
+ assert.Eventually(t, func() bool {
+ return atomic.LoadInt32(&callCount) > 0
+ }, 500*time.Millisecond, 20*time.Millisecond, "should have fetched fresh config from Tenant Manager")
+
+ assert.Eventually(t, func() bool {
+ return tDB.MaxOpenConns() == int32(50) && tDB.MaxIdleConns() == int32(15)
+ }, 500*time.Millisecond, 20*time.Millisecond, "connection settings should be updated from async revalidation")
+}
+
+func TestManager_GetConnection_DoesNotRevalidateBeforeInterval(t *testing.T) {
+ t.Parallel()
+
+ var callCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-123",
+ "tenantSlug": "test-tenant",
+ "databases": {
+ "onboarding": {
+ "connectionSettings": {"maxOpenConns": 50, "maxIdleConns": 15}
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("onboarding"),
+ // Use a very long interval so revalidation does NOT trigger
+ WithSettingsCheckInterval(1*time.Hour),
+ )
+
+ // Pre-populate cache with a healthy connection and a recent settings check time
+ tDB := &trackingDB{}
+ var db dbresolver.DB = tDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = time.Now()
+ // Set lastSettingsCheck to now - should NOT trigger revalidation
+ manager.lastSettingsCheck["tenant-123"] = time.Now()
+
+ // Call GetConnection - should return cached conn without revalidation
+ conn, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+ assert.Equal(t, cachedConn, conn)
+
+ assert.Never(t, func() bool {
+ return atomic.LoadInt32(&callCount) > 0
+ }, 200*time.Millisecond, 20*time.Millisecond, "should NOT have fetched config - interval not elapsed")
+
+ // Verify that connection settings were NOT changed
+ assert.Equal(t, int32(0), tDB.MaxOpenConns(), "maxOpenConns should NOT be changed")
+ assert.Equal(t, int32(0), tDB.MaxIdleConns(), "maxIdleConns should NOT be changed")
+}
+
+func TestManager_GetConnection_FailedRevalidationDoesNotBreakConnection(t *testing.T) {
+ t.Parallel()
+
+ // Set up a mock Tenant Manager that returns 500 (simulates unavailability)
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ }))
+ defer server.Close()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("onboarding"),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ // Pre-populate cache with a healthy connection
+ tDB := &trackingDB{}
+ var db dbresolver.DB = tDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = time.Now()
+ // Set lastSettingsCheck to the past so revalidation triggers
+ manager.lastSettingsCheck["tenant-123"] = time.Now().Add(-1 * time.Hour)
+
+ // Call GetConnection - should return cached conn even though revalidation will fail
+ conn, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err, "GetConnection should NOT fail when revalidation fails")
+ assert.Equal(t, cachedConn, conn, "should still return the cached connection")
+
+ // Wait for the async goroutine to complete (and fail)
+ time.Sleep(200 * time.Millisecond)
+
+ // Verify that connection settings were NOT changed (fetch failed)
+ assert.Equal(t, int32(0), tDB.MaxOpenConns(), "maxOpenConns should NOT be changed on failed revalidation")
+ assert.Equal(t, int32(0), tDB.MaxIdleConns(), "maxIdleConns should NOT be changed on failed revalidation")
+}
+
+func TestManager_CloseConnection_CleansUpLastSettingsCheck(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache
+ healthyDB := &pingableDB{pingErr: nil}
+ var db dbresolver.DB = healthyDB
+
+ manager.connections["tenant-123"] = &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.lastAccessed["tenant-123"] = time.Now()
+ manager.lastSettingsCheck["tenant-123"] = time.Now()
+
+ // Close the specific tenant connection
+ err := manager.CloseConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+
+ manager.mu.RLock()
+ _, connExists := manager.connections["tenant-123"]
+ _, accessExists := manager.lastAccessed["tenant-123"]
+ _, settingsCheckExists := manager.lastSettingsCheck["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, connExists, "connection should be removed after CloseConnection")
+ assert.False(t, accessExists, "lastAccessed should be removed after CloseConnection")
+ assert.False(t, settingsCheckExists, "lastSettingsCheck should be removed after CloseConnection")
+}
+
+func TestManager_Close_CleansUpLastSettingsCheck(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache with multiple tenants
+ for _, id := range []string{"tenant-1", "tenant-2"} {
+ db := &pingableDB{}
+ var dbIface dbresolver.DB = db
+
+ manager.connections[id] = &PostgresConnection{
+ ConnectionDB: &dbIface,
+ }
+ manager.lastAccessed[id] = time.Now()
+ manager.lastSettingsCheck[id] = time.Now()
+ }
+
+ err := manager.Close(context.Background())
+
+ require.NoError(t, err)
+
+ assert.Empty(t, manager.connections, "all connections should be removed after Close")
+ assert.Empty(t, manager.lastAccessed, "all lastAccessed should be removed after Close")
+ assert.Empty(t, manager.lastSettingsCheck, "all lastSettingsCheck should be removed after Close")
+}
+
+func TestManager_ApplyConnectionSettings_LogsValues(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+
+ // Use a capturing logger to verify that ApplyConnectionSettings logs when it applies values
+ capLogger := testutil.NewCapturingLogger()
+ manager := NewManager(c, "ledger",
+ WithModule("onboarding"),
+ WithLogger(capLogger),
+ )
+
+ tDB := &trackingDB{}
+ var db dbresolver.DB = tDB
+
+ manager.connections["tenant-123"] = &PostgresConnection{
+ ConnectionDB: &db,
+ }
+
+ config := &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 30,
+ MaxIdleConns: 10,
+ },
+ },
+ },
+ }
+
+ manager.ApplyConnectionSettings("tenant-123", config)
+
+ assert.Equal(t, int32(30), tDB.MaxOpenConns())
+ assert.Equal(t, int32(10), tDB.MaxIdleConns())
+ assert.True(t, capLogger.ContainsSubstring("applying connection settings"),
+ "ApplyConnectionSettings should log when applying values")
+}
+
+func TestManager_GetConnection_DisabledRevalidation_WithZero(t *testing.T) {
+ t.Parallel()
+
+ var callCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-123",
+ "tenantSlug": "test-tenant",
+ "databases": {
+ "onboarding": {
+ "postgresql": {"host": "localhost", "port": 5432, "database": "testdb", "username": "user", "password": "pass"},
+ "connectionSettings": {"maxOpenConns": 50, "maxIdleConns": 15}
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("onboarding"),
+ // Disable revalidation with zero duration
+ WithSettingsCheckInterval(0),
+ )
+
+ // Pre-populate cache with a healthy connection and an old settings check time
+ tDB := &trackingDB{}
+ var db dbresolver.DB = tDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.connections["tenant-123"] = cachedConn
+ manager.lastAccessed["tenant-123"] = time.Now()
+ // Set lastSettingsCheck to the past - but should NOT trigger revalidation since disabled
+ manager.lastSettingsCheck["tenant-123"] = time.Now().Add(-1 * time.Hour)
+
+ // Call GetConnection multiple times - should NOT spawn any goroutines
+ for i := 0; i < 5; i++ {
+ conn, err := manager.GetConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+ assert.Equal(t, cachedConn, conn, "should return the cached connection")
+ }
+
+ // Wait to ensure no async goroutine fires
+ time.Sleep(200 * time.Millisecond)
+
+ // Verify that Tenant Manager was NEVER called (no revalidation)
+ assert.Equal(t, int32(0), atomic.LoadInt32(&callCount), "should NOT have fetched config - revalidation is disabled")
+
+ // Verify that connection settings were NOT changed
+ assert.Equal(t, int32(0), tDB.MaxOpenConns(), "maxOpenConns should NOT be changed")
+ assert.Equal(t, int32(0), tDB.MaxIdleConns(), "maxIdleConns should NOT be changed")
+}
+
+func TestManager_GetConnection_DisabledRevalidation_WithNegative(t *testing.T) {
+ t.Parallel()
+
+ var callCount int32
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ atomic.AddInt32(&callCount, 1)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-456",
+ "tenantSlug": "test-tenant",
+ "databases": {
+ "payment": {
+ "postgresql": {"host": "localhost", "port": 5432, "database": "testdb", "username": "user", "password": "pass"},
+ "connectionSettings": {"maxOpenConns": 40, "maxIdleConns": 12}
+ }
+ }
+ }`))
+ }))
+ defer server.Close()
+
+ tmClient := mustNewTestClient(t, server.URL)
+ manager := NewManager(tmClient, "payment",
+ WithLogger(testutil.NewMockLogger()),
+ WithModule("payment"),
+ // Disable revalidation with negative duration
+ WithSettingsCheckInterval(-5*time.Second),
+ )
+
+ // Pre-populate cache with a healthy connection
+ tDB := &trackingDB{}
+ var db dbresolver.DB = tDB
+
+ cachedConn := &PostgresConnection{
+ ConnectionDB: &db,
+ }
+ manager.connections["tenant-456"] = cachedConn
+ manager.lastAccessed["tenant-456"] = time.Now()
+ // Set lastSettingsCheck to the past
+ manager.lastSettingsCheck["tenant-456"] = time.Now().Add(-1 * time.Hour)
+
+ // Call GetConnection - should NOT trigger revalidation
+ conn, err := manager.GetConnection(context.Background(), "tenant-456")
+
+ require.NoError(t, err)
+ assert.Equal(t, cachedConn, conn)
+
+ // Wait to ensure no async goroutine fires
+ time.Sleep(100 * time.Millisecond)
+
+ // Verify that Tenant Manager was NOT called
+ assert.Equal(t, int32(0), atomic.LoadInt32(&callCount), "should NOT have fetched config - revalidation is disabled via negative interval")
+
+ // Verify that connection settings were NOT changed
+ assert.Equal(t, int32(0), tDB.MaxOpenConns(), "maxOpenConns should NOT be changed")
+ assert.Equal(t, int32(0), tDB.MaxIdleConns(), "maxIdleConns should NOT be changed")
+}
+
+func TestManager_ApplyConnectionSettings(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ module string
+ config *core.TenantConfig
+ hasCachedConn bool
+ hasConnectionDB bool
+ expectMaxOpen int
+ expectMaxIdle int
+ expectNoChange bool
+ }{
+ {
+ name: "applies module-level settings",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 30,
+ MaxIdleConns: 10,
+ },
+ },
+ },
+ },
+ hasCachedConn: true,
+ hasConnectionDB: true,
+ expectMaxOpen: 30,
+ expectMaxIdle: 10,
+ },
+ {
+ name: "applies top-level settings as fallback",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 20,
+ MaxIdleConns: 8,
+ },
+ },
+ hasCachedConn: true,
+ hasConnectionDB: true,
+ expectMaxOpen: 20,
+ expectMaxIdle: 8,
+ },
+ {
+ name: "module-level takes precedence over top-level",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 50,
+ MaxIdleConns: 15,
+ },
+ },
+ },
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 20,
+ MaxIdleConns: 8,
+ },
+ },
+ hasCachedConn: true,
+ hasConnectionDB: true,
+ expectMaxOpen: 50,
+ expectMaxIdle: 15,
+ },
+ {
+ name: "no-op when no cached connection exists",
+ module: "onboarding",
+ config: &core.TenantConfig{},
+ hasCachedConn: false,
+ expectNoChange: true,
+ },
+ {
+ name: "no-op when ConnectionDB is nil",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 30,
+ },
+ },
+ hasCachedConn: true,
+ hasConnectionDB: false,
+ expectNoChange: true,
+ },
+ {
+ name: "applies manager defaults when config has no connection settings",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ PostgreSQL: &core.PostgreSQLConfig{Host: "localhost"},
+ },
+ },
+ },
+ hasCachedConn: true,
+ hasConnectionDB: true,
+ expectMaxOpen: fallbackMaxOpenConns, // manager default when no settings present
+ expectMaxIdle: fallbackMaxIdleConns, // manager default when no settings present
+ },
+ {
+ name: "falls back to manager default idle conns when maxIdleConns is zero",
+ module: "onboarding",
+ config: &core.TenantConfig{
+ Databases: map[string]core.DatabaseConfig{
+ "onboarding": {
+ ConnectionSettings: &core.ConnectionSettings{
+ MaxOpenConns: 40,
+ MaxIdleConns: 0,
+ },
+ },
+ },
+ },
+ hasCachedConn: true,
+ hasConnectionDB: true,
+ expectMaxOpen: 40,
+ expectMaxIdle: fallbackMaxIdleConns, // falls back to manager default
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger",
+ WithModule(tt.module),
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ tDB := &trackingDB{}
+
+ if tt.hasCachedConn {
+ conn := &PostgresConnection{}
+ if tt.hasConnectionDB {
+ var db dbresolver.DB = tDB
+ conn.ConnectionDB = &db
+ }
+ manager.connections["tenant-123"] = conn
+ }
+
+ manager.ApplyConnectionSettings("tenant-123", tt.config)
+
+ if tt.expectNoChange {
+ assert.Equal(t, int32(0), tDB.MaxOpenConns(),
+ "maxOpenConns should not be changed")
+ assert.Equal(t, int32(0), tDB.MaxIdleConns(),
+ "maxIdleConns should not be changed")
+ } else {
+ assert.Equal(t, int32(tt.expectMaxOpen), tDB.MaxOpenConns(),
+ "maxOpenConns mismatch")
+ assert.Equal(t, int32(tt.expectMaxIdle), tDB.MaxIdleConns(),
+ "maxIdleConns mismatch")
+ }
+ })
+ }
+}
+
+func TestManager_Stats_ActiveConnections(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t, "http://localhost:8080")
+ manager := NewManager(c, "ledger")
+
+ // Pre-populate with connections and mark them as recently accessed
+ now := time.Now()
+ for _, id := range []string{"tenant-1", "tenant-2", "tenant-3"} {
+ db := &pingableDB{}
+ var dbIface dbresolver.DB = db
+
+ manager.connections[id] = &PostgresConnection{
+ ConnectionDB: &dbIface,
+ }
+ manager.lastAccessed[id] = now
+ }
+
+ stats := manager.Stats()
+
+ assert.Equal(t, 3, stats.TotalConnections)
+ assert.Equal(t, 3, stats.ActiveConnections,
+ "ActiveConnections should equal TotalConnections for postgres")
+}
+
+func TestManager_RevalidateSettings_EvictsSuspendedTenant(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ responseStatus int
+ responseBody string
+ expectEviction bool
+ expectLogSubstring string
+ }{
+ {
+ name: "evicts_cached_connection_when_tenant_is_suspended",
+ responseStatus: http.StatusForbidden,
+ responseBody: `{"code":"TS-SUSPENDED","error":"service suspended","status":"suspended"}`,
+ expectEviction: true,
+ expectLogSubstring: "tenant tenant-suspended service suspended, evicting cached connection",
+ },
+ {
+ name: "evicts_cached_connection_when_tenant_is_purged",
+ responseStatus: http.StatusForbidden,
+ responseBody: `{"code":"TS-SUSPENDED","error":"service purged","status":"purged"}`,
+ expectEviction: true,
+ expectLogSubstring: "tenant tenant-suspended service suspended, evicting cached connection",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Set up a mock Tenant Manager that returns 403 with TenantSuspendedError body
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(tt.responseStatus)
+ w.Write([]byte(tt.responseBody))
+ }))
+ defer server.Close()
+
+ capLogger := testutil.NewCapturingLogger()
+ tmClient, err := client.NewClient(server.URL, capLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(capLogger),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ // Pre-populate a cached connection for the tenant
+ mockDB := &pingableDB{}
+ var dbIface dbresolver.DB = mockDB
+
+ manager.connections["tenant-suspended"] = &PostgresConnection{
+ ConnectionDB: &dbIface,
+ }
+ manager.lastAccessed["tenant-suspended"] = time.Now()
+ manager.lastSettingsCheck["tenant-suspended"] = time.Now()
+
+ // Verify the connection exists before revalidation
+ statsBefore := manager.Stats()
+ assert.Equal(t, 1, statsBefore.TotalConnections,
+ "should have 1 connection before revalidation")
+
+ // Trigger revalidatePoolSettings directly
+ manager.revalidatePoolSettings("tenant-suspended")
+
+ if tt.expectEviction {
+ // Verify the connection was evicted
+ statsAfter := manager.Stats()
+ assert.Equal(t, 0, statsAfter.TotalConnections,
+ "connection should be evicted after suspended tenant detected")
+
+ // Verify the DB was closed
+ assert.True(t, mockDB.closed,
+ "cached connection's DB should have been closed")
+
+ // Verify lastAccessed and lastSettingsCheck were cleaned up
+ manager.mu.RLock()
+ _, accessExists := manager.lastAccessed["tenant-suspended"]
+ _, settingsExists := manager.lastSettingsCheck["tenant-suspended"]
+ manager.mu.RUnlock()
+
+ assert.False(t, accessExists,
+ "lastAccessed should be removed for evicted tenant")
+ assert.False(t, settingsExists,
+ "lastSettingsCheck should be removed for evicted tenant")
+ }
+
+ // Verify the appropriate log message was produced
+ assert.True(t, capLogger.ContainsSubstring(tt.expectLogSubstring),
+ "expected log message containing %q, got: %v",
+ tt.expectLogSubstring, capLogger.GetMessages())
+ })
+ }
+}
+
+func TestManager_RevalidateSettings_BypassesClientCache(t *testing.T) {
+ t.Parallel()
+
+ // This test verifies that revalidatePoolSettings uses WithSkipCache()
+ // to bypass the client's in-memory cache. Without it, a cached "active"
+ // response would hide a subsequent 403 (suspended/purged) from tenant-manager.
+ //
+ // Setup: The httptest server returns 200 (active) on the first request
+ // and 403 (suspended) on all subsequent requests. We first call
+ // GetTenantConfig directly to populate the client cache, then trigger
+ // revalidatePoolSettings. If WithSkipCache is working, the revalidation
+ // hits the server (gets 403) and evicts the connection. If the cache
+ // were used, it would return the stale 200 and the connection would
+ // remain.
+ var requestCount atomic.Int32
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ count := requestCount.Add(1)
+ w.Header().Set("Content-Type", "application/json")
+
+ if count == 1 {
+ // First request: return active config (populates client cache)
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(`{
+ "id": "tenant-cache-test",
+ "tenantSlug": "cached-tenant",
+ "service": "ledger",
+ "status": "active",
+ "databases": {
+ "onboarding": {
+ "postgresql": {"host": "localhost", "port": 5432, "database": "testdb", "username": "user", "password": "pass"}
+ }
+ }
+ }`))
+
+ return
+ }
+
+ // Subsequent requests: return 403 (tenant suspended)
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte(`{"code":"TS-SUSPENDED","error":"service suspended","status":"suspended"}`))
+ }))
+ defer server.Close()
+
+ capLogger := testutil.NewCapturingLogger()
+ tmClient, err := client.NewClient(server.URL, capLogger, client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+
+ // Populate the client cache by calling GetTenantConfig directly
+ cfg, err := tmClient.GetTenantConfig(context.Background(), "tenant-cache-test", "ledger")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant-cache-test", cfg.ID)
+ assert.Equal(t, int32(1), requestCount.Load(), "should have made exactly 1 HTTP request")
+
+ // Create a manager with a cached connection for this tenant
+ manager := NewManager(tmClient, "ledger",
+ WithLogger(capLogger),
+ WithModule("onboarding"),
+ WithSettingsCheckInterval(1*time.Millisecond),
+ )
+
+ mockDB := &pingableDB{}
+ var dbIface dbresolver.DB = mockDB
+
+ manager.connections["tenant-cache-test"] = &PostgresConnection{ConnectionDB: &dbIface}
+ manager.lastAccessed["tenant-cache-test"] = time.Now()
+ manager.lastSettingsCheck["tenant-cache-test"] = time.Now()
+
+ // Trigger revalidatePoolSettings -- should bypass cache and hit the server
+ manager.revalidatePoolSettings("tenant-cache-test")
+
+ // Verify a second HTTP request was made (cache was bypassed)
+ assert.Equal(t, int32(2), requestCount.Load(),
+ "revalidatePoolSettings should bypass client cache and make a fresh HTTP request")
+
+ // Verify the connection was evicted (server returned 403)
+ statsAfter := manager.Stats()
+ assert.Equal(t, 0, statsAfter.TotalConnections,
+ "connection should be evicted after revalidation detected suspended tenant via cache bypass")
+
+ // Verify the DB was closed
+ assert.True(t, mockDB.closed,
+ "cached connection's DB should have been closed on eviction")
+}
diff --git a/commons/tenant-manager/rabbitmq/manager.go b/commons/tenant-manager/rabbitmq/manager.go
new file mode 100644
index 00000000..49e11056
--- /dev/null
+++ b/commons/tenant-manager/rabbitmq/manager.go
@@ -0,0 +1,471 @@
+package rabbitmq
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "fmt"
+ "net/url"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ libCommons "github.com/LerianStudio/lib-commons/v4/commons"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ libOpentelemetry "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/eviction"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/logcompat"
+ amqp "github.com/rabbitmq/amqp091-go"
+)
+
+// Manager manages RabbitMQ connections per tenant.
+// Each tenant has a dedicated vhost, user, and credentials stored in Tenant Manager.
+// When maxConnections is set (> 0), the manager uses LRU eviction with an idle
+// timeout as a soft limit. Connections idle longer than the timeout are eligible
+// for eviction when the pool exceeds maxConnections. If all connections are active
+// (used within the idle timeout), the pool grows beyond the soft limit and
+// naturally shrinks back as tenants become idle.
+type Manager struct {
+ client *client.Client
+ service string
+ module string
+ logger *logcompat.Logger
+
+ mu sync.RWMutex
+ connections map[string]*amqp.Connection
+ closed bool
+ maxConnections int // soft limit for pool size (0 = unlimited)
+ idleTimeout time.Duration // how long before a connection is eligible for eviction
+ lastAccessed map[string]time.Time // LRU tracking per tenant
+ useTLS bool // use amqps:// scheme instead of amqp://
+}
+
+// Option configures a Manager.
+type Option func(*Manager)
+
+// WithModule sets the module name for the RabbitMQ manager.
+func WithModule(module string) Option {
+ return func(p *Manager) {
+ p.module = module
+ }
+}
+
+// WithLogger sets the logger for the RabbitMQ manager.
+func WithLogger(logger log.Logger) Option {
+ return func(p *Manager) {
+ p.logger = logcompat.New(logger)
+ }
+}
+
+// WithMaxTenantPools sets the soft limit for the number of tenant connections in the pool.
+// When the pool reaches this limit and a new tenant needs a connection, only connections
+// that have been idle longer than the idle timeout are eligible for eviction. If all
+// connections are active (used within the idle timeout), the pool grows beyond this limit.
+// A value of 0 (default) means unlimited.
+func WithMaxTenantPools(maxSize int) Option {
+ return func(p *Manager) {
+ p.maxConnections = maxSize
+ }
+}
+
+// WithIdleTimeout sets the duration after which an unused tenant connection becomes
+// eligible for eviction. Only connections idle longer than this duration will be evicted
+// when the pool exceeds the soft limit (maxConnections). If all connections are active
+// (used within the idle timeout), the pool is allowed to grow beyond the soft limit.
+// Default: 5 minutes.
+func WithIdleTimeout(d time.Duration) Option {
+ return func(p *Manager) {
+ p.idleTimeout = d
+ }
+}
+
+// WithTLS enables TLS connections (amqps:// scheme) instead of the default
+// plaintext amqp://. Use this for production deployments where RabbitMQ is
+// configured with TLS certificates.
+func WithTLS() Option {
+ return func(p *Manager) {
+ p.useTLS = true
+ }
+}
+
+// NewManager creates a new RabbitMQ connection manager.
+// Parameters:
+// - c: The Tenant Manager client for fetching tenant configurations
+// - service: The service name (e.g., "ledger")
+// - opts: Optional configuration options
+func NewManager(c *client.Client, service string, opts ...Option) *Manager {
+ p := &Manager{
+ client: c,
+ service: service,
+ logger: logcompat.New(nil),
+ connections: make(map[string]*amqp.Connection),
+ lastAccessed: make(map[string]time.Time),
+ }
+
+ for _, opt := range opts {
+ opt(p)
+ }
+
+ return p
+}
+
+// GetConnection returns a RabbitMQ connection for the tenant.
+// Creates a new connection if one doesn't exist or the existing one is closed.
+func (p *Manager) GetConnection(ctx context.Context, tenantID string) (*amqp.Connection, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ if tenantID == "" {
+ return nil, errors.New("tenant ID is required")
+ }
+
+ p.mu.RLock()
+
+ if p.closed {
+ p.mu.RUnlock()
+ return nil, core.ErrManagerClosed
+ }
+
+ if conn, ok := p.connections[tenantID]; ok && !conn.IsClosed() {
+ p.mu.RUnlock()
+
+ // Update LRU tracking on cache hit
+ p.mu.Lock()
+ // Re-read connection from map (may have been evicted and closed between locks)
+ if refreshedConn, still := p.connections[tenantID]; still && !refreshedConn.IsClosed() {
+ p.lastAccessed[tenantID] = time.Now()
+ p.mu.Unlock()
+
+ return refreshedConn, nil
+ }
+
+ p.mu.Unlock()
+
+ // Connection was evicted between RUnlock and Lock; create a new one
+ _ = conn // original reference is now potentially stale; discard it
+
+ return p.createConnection(ctx, tenantID)
+ }
+
+ p.mu.RUnlock()
+
+ return p.createConnection(ctx, tenantID)
+}
+
+// createConnection fetches config from Tenant Manager and creates a RabbitMQ connection.
+//
+// Network I/O (GetTenantConfig, amqp.Dial) is performed outside the mutex to
+// avoid blocking other goroutines on slow network calls. The pattern is:
+// 1. Under lock: double-check cache, check closed state
+// 2. Outside lock: fetch config and dial
+// 3. Re-acquire lock: evict LRU, cache new connection (with race-loss handling)
+func (p *Manager) createConnection(ctx context.Context, tenantID string) (*amqp.Connection, error) {
+ if p.client == nil {
+ return nil, errors.New("tenant manager client is required for multi-tenant connections")
+ }
+
+ baseLogger, tracer, _, _ := libCommons.NewTrackingFromContext(ctx)
+ logger := logcompat.New(baseLogger)
+
+ ctx, span := tracer.Start(ctx, "rabbitmq.create_connection")
+ defer span.End()
+
+ if p.logger != nil {
+ logger = p.logger
+ }
+
+ // Step 1: Under lock — double-check if connection exists or manager is closed.
+ p.mu.Lock()
+
+ if conn, ok := p.connections[tenantID]; ok && !conn.IsClosed() {
+ p.mu.Unlock()
+ return conn, nil
+ }
+
+ if p.closed {
+ p.mu.Unlock()
+ return nil, core.ErrManagerClosed
+ }
+
+ p.mu.Unlock()
+
+ // Step 2: Outside lock — perform network I/O (HTTP call + TCP dial).
+ config, err := p.client.GetTenantConfig(ctx, tenantID, p.service)
+ if err != nil {
+ logger.Errorf("failed to get tenant config: %v", err)
+ libOpentelemetry.HandleSpanError(span, "failed to get tenant config", err)
+
+ return nil, fmt.Errorf("failed to get tenant config: %w", err)
+ }
+
+ rabbitConfig := config.GetRabbitMQConfig()
+ if rabbitConfig == nil {
+ logger.Errorf("RabbitMQ not configured for tenant: %s", tenantID)
+ libOpentelemetry.HandleSpanBusinessErrorEvent(span, "RabbitMQ not configured", core.ErrServiceNotConfigured)
+
+ return nil, core.ErrServiceNotConfigured
+ }
+
+ // Resolve TLS: per-tenant config takes precedence over global WithTLS() setting.
+ useTLS := p.resolveTLS(rabbitConfig)
+ uri := buildRabbitMQURI(rabbitConfig, useTLS)
+
+ logger.Infof("connecting to RabbitMQ vhost: tenant=%s, vhost=%s, tls=%v", tenantID, rabbitConfig.VHost, useTLS)
+
+ conn, err := p.dialRabbitMQ(uri, useTLS, rabbitConfig.TLSCAFile)
+ if err != nil {
+ logger.Errorf("failed to connect to RabbitMQ: %v", err)
+ libOpentelemetry.HandleSpanError(span, "failed to connect to RabbitMQ", err)
+
+ return nil, fmt.Errorf("failed to connect to RabbitMQ: %w", err)
+ }
+
+ // Step 3: Re-acquire lock — evict LRU, cache connection (with race-loss check).
+ p.mu.Lock()
+
+ // If manager was closed while we were dialing, discard the new connection.
+ if p.closed {
+ p.mu.Unlock()
+
+ if closeErr := conn.Close(); closeErr != nil {
+ logger.Errorf("failed to close RabbitMQ connection on closed manager: %v", closeErr)
+ }
+
+ return nil, core.ErrManagerClosed
+ }
+
+ // If another goroutine cached a connection for this tenant while we were
+ // dialing, use the cached one and discard ours.
+ if cached, ok := p.connections[tenantID]; ok && !cached.IsClosed() {
+ p.lastAccessed[tenantID] = time.Now()
+ p.mu.Unlock()
+
+ if closeErr := conn.Close(); closeErr != nil {
+ logger.Errorf("failed to close excess RabbitMQ connection for tenant %s: %v", tenantID, closeErr)
+ }
+
+ return cached, nil
+ }
+
+ // Evict least recently used connection if pool is full
+ p.evictLRU(logger.Base())
+
+ // Cache our new connection
+ p.connections[tenantID] = conn
+ p.lastAccessed[tenantID] = time.Now()
+
+ p.mu.Unlock()
+
+ logger.Infof("RabbitMQ connection created: tenant=%s, vhost=%s", tenantID, rabbitConfig.VHost)
+
+ return conn, nil
+}
+
+// evictLRU removes the least recently used idle connection when the pool reaches the
+// soft limit. Only connections that have been idle longer than the idle timeout are
+// eligible for eviction. If all connections are active (used within the idle timeout),
+// the pool is allowed to grow beyond the soft limit.
+// Caller MUST hold p.mu write lock.
+func (p *Manager) evictLRU(logger log.Logger) {
+ candidateID, shouldEvict := eviction.FindLRUEvictionCandidate(
+ len(p.connections), p.maxConnections, p.lastAccessed, p.idleTimeout, logger,
+ )
+ if !shouldEvict {
+ return
+ }
+
+ // Manager-specific cleanup: close the AMQP connection and remove from maps.
+ if conn, ok := p.connections[candidateID]; ok {
+ if conn != nil && !conn.IsClosed() {
+ if err := conn.Close(); err != nil && logger != nil {
+ logger.Log(context.Background(), log.LevelWarn, "failed to close evicted rabbitmq connection",
+ log.String("tenant_id", candidateID),
+ log.Err(err),
+ )
+ }
+ }
+
+ delete(p.connections, candidateID)
+ delete(p.lastAccessed, candidateID)
+ }
+}
+
+// GetChannel returns a RabbitMQ channel for the tenant.
+// Creates a new connection if one doesn't exist.
+//
+// Channel ownership: The caller is responsible for closing the returned channel
+// when it is no longer needed. Failing to close channels will leak resources
+// on both the client and the RabbitMQ server.
+func (p *Manager) GetChannel(ctx context.Context, tenantID string) (*amqp.Channel, error) {
+ conn, err := p.GetConnection(ctx, tenantID)
+ if err != nil {
+ return nil, err
+ }
+
+ channel, err := conn.Channel()
+ if err != nil {
+ return nil, fmt.Errorf("failed to open channel: %w", err)
+ }
+
+ return channel, nil
+}
+
+// Close closes all RabbitMQ connections.
+func (p *Manager) Close(_ context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.closed = true
+
+ var errs []error
+
+ for tenantID, conn := range p.connections {
+ if conn != nil && !conn.IsClosed() {
+ if err := conn.Close(); err != nil {
+ errs = append(errs, err)
+ }
+ }
+
+ delete(p.connections, tenantID)
+ delete(p.lastAccessed, tenantID)
+ }
+
+ return errors.Join(errs...)
+}
+
+// CloseConnection closes the RabbitMQ connection for a specific tenant.
+func (p *Manager) CloseConnection(_ context.Context, tenantID string) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ conn, ok := p.connections[tenantID]
+ if !ok {
+ return nil
+ }
+
+ var err error
+ if conn != nil && !conn.IsClosed() {
+ err = conn.Close()
+ }
+
+ delete(p.connections, tenantID)
+ delete(p.lastAccessed, tenantID)
+
+ return err
+}
+
+// ApplyConnectionSettings is a no-op for RabbitMQ connections.
+// RabbitMQ does not support dynamic connection pool settings like databases do.
+// This method exists to satisfy a common manager interface.
+func (p *Manager) ApplyConnectionSettings(_ string, _ *core.TenantConfig) {
+ // no-op: RabbitMQ connections do not have adjustable pool settings.
+}
+
+// Stats returns connection statistics.
+//
+// ActiveConnections counts connections that are not closed.
+// Unlike Postgres/Mongo which use recency-based idle timeout to determine
+// whether a connection is "active", RabbitMQ checks actual connection liveness
+// because AMQP connections are long-lived and do not have a meaningful
+// "last accessed" recency signal for activity classification.
+func (p *Manager) Stats() Stats {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ tenantIDs := make([]string, 0, len(p.connections))
+ activeConnections := 0
+
+ for id, conn := range p.connections {
+ tenantIDs = append(tenantIDs, id)
+
+ if conn != nil && !conn.IsClosed() {
+ activeConnections++
+ }
+ }
+
+ return Stats{
+ TotalConnections: len(p.connections),
+ MaxConnections: p.maxConnections,
+ ActiveConnections: activeConnections,
+ TenantIDs: tenantIDs,
+ Closed: p.closed,
+ }
+}
+
+// Stats contains statistics for the RabbitMQ manager.
+type Stats struct {
+ TotalConnections int `json:"totalConnections"`
+ MaxConnections int `json:"maxConnections"`
+ ActiveConnections int `json:"activeConnections"`
+ TenantIDs []string `json:"tenantIds"`
+ Closed bool `json:"closed"`
+}
+
+// resolveTLS determines whether TLS should be used for a tenant connection.
+// Per-tenant TLS configuration (RabbitMQConfig.TLS) takes precedence over the
+// global WithTLS() setting. When the per-tenant value is nil (not configured),
+// the global useTLS flag is used as a fallback.
+func (p *Manager) resolveTLS(cfg *core.RabbitMQConfig) bool {
+ if cfg.TLS != nil {
+ return *cfg.TLS
+ }
+
+ return p.useTLS
+}
+
+// dialRabbitMQ connects to RabbitMQ, using TLS when enabled.
+// When a custom CA file is specified, it is loaded into the TLS config's RootCAs
+// to allow verification against private certificate authorities.
+func (p *Manager) dialRabbitMQ(uri string, useTLS bool, tlsCAFile string) (*amqp.Connection, error) {
+ if !useTLS || tlsCAFile == "" {
+ return amqp.Dial(uri)
+ }
+
+ // Load custom CA certificate for TLS verification.
+ caCert, err := os.ReadFile(tlsCAFile) // #nosec G304 -- path from tenant config
+ if err != nil {
+ return nil, fmt.Errorf("failed to read TLS CA file %q: %w", tlsCAFile, err)
+ }
+
+ certPool := x509.NewCertPool()
+ if !certPool.AppendCertsFromPEM(caCert) {
+ return nil, fmt.Errorf("failed to parse CA certificate from %q", tlsCAFile)
+ }
+
+ tlsCfg := &tls.Config{
+ RootCAs: certPool,
+ MinVersion: tls.VersionTLS12,
+ }
+
+ return amqp.DialTLS(uri, tlsCfg)
+}
+
+// buildRabbitMQURI builds RabbitMQ connection URI from config.
+// Credentials and vhost are percent-encoded to handle special characters (e.g., @, :, /).
+// Uses QueryEscape with '+' replaced by '%20' because QueryEscape encodes spaces as '+'
+// which is only valid in query strings, not in userinfo or path segments of a URI.
+// When useTLS is true, the amqps:// scheme is used instead of amqp://.
+func buildRabbitMQURI(cfg *core.RabbitMQConfig, useTLS bool) string {
+ escapedUsername := strings.ReplaceAll(url.QueryEscape(cfg.Username), "+", "%20")
+ escapedPassword := strings.ReplaceAll(url.QueryEscape(cfg.Password), "+", "%20")
+ escapedVHost := strings.ReplaceAll(url.QueryEscape(cfg.VHost), "+", "%20")
+
+ scheme := "amqp"
+ if useTLS {
+ scheme = "amqps"
+ }
+
+ return fmt.Sprintf("%s://%s:%s@%s:%d/%s",
+ scheme, escapedUsername, escapedPassword,
+ cfg.Host, cfg.Port, escapedVHost)
+}
+
+// IsMultiTenant returns true if the manager is configured with a Tenant Manager client.
+func (p *Manager) IsMultiTenant() bool {
+ return p.client != nil
+}
diff --git a/commons/tenant-manager/rabbitmq/manager_test.go b/commons/tenant-manager/rabbitmq/manager_test.go
new file mode 100644
index 00000000..c9a2146a
--- /dev/null
+++ b/commons/tenant-manager/rabbitmq/manager_test.go
@@ -0,0 +1,510 @@
+package rabbitmq
+
+import (
+ "context"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/client"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/internal/testutil"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func mustNewTestClient(t *testing.T) *client.Client {
+ t.Helper()
+
+ c, err := client.NewClient("http://localhost:8080", testutil.NewMockLogger(), client.WithAllowInsecureHTTP(), client.WithServiceAPIKey("test-key"))
+ require.NoError(t, err)
+
+ return c
+}
+
+func TestNewManager(t *testing.T) {
+ t.Run("creates manager with client and service", func(t *testing.T) {
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger")
+
+ assert.NotNil(t, manager)
+ assert.Equal(t, "ledger", manager.service)
+ assert.NotNil(t, manager.connections)
+ assert.NotNil(t, manager.lastAccessed)
+ })
+}
+
+func TestManager_EvictLRU(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxConnections int
+ idleTimeout time.Duration
+ preloadCount int
+ oldTenantAge time.Duration
+ newTenantAge time.Duration
+ expectEviction bool
+ expectedPoolSize int
+ expectedEvictedID string
+ }{
+ {
+ name: "evicts oldest idle connection when pool is at soft limit",
+ maxConnections: 2,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: true,
+ expectedPoolSize: 1,
+ expectedEvictedID: "tenant-old",
+ },
+ {
+ name: "does not evict when pool is below soft limit",
+ maxConnections: 3,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 2,
+ },
+ {
+ name: "does not evict when maxConnections is zero (unlimited)",
+ maxConnections: 0,
+ preloadCount: 5,
+ oldTenantAge: 10 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 5,
+ },
+ {
+ name: "does not evict when all connections are active (within idle timeout)",
+ maxConnections: 2,
+ idleTimeout: 5 * time.Minute,
+ preloadCount: 2,
+ oldTenantAge: 2 * time.Minute,
+ newTenantAge: 1 * time.Minute,
+ expectEviction: false,
+ expectedPoolSize: 2,
+ },
+ {
+ name: "respects custom idle timeout",
+ maxConnections: 2,
+ idleTimeout: 30 * time.Second,
+ preloadCount: 2,
+ oldTenantAge: 1 * time.Minute,
+ newTenantAge: 10 * time.Second,
+ expectEviction: true,
+ expectedPoolSize: 1,
+ expectedEvictedID: "tenant-old",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ opts := []Option{
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(tt.maxConnections),
+ }
+ if tt.idleTimeout > 0 {
+ opts = append(opts, WithIdleTimeout(tt.idleTimeout))
+ }
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger", opts...)
+
+ // Pre-populate pool with nil connections (cannot create real amqp.Connection in unit test)
+ // evictLRU checks conn != nil && !conn.IsClosed() before closing,
+ // so nil connections are safe for testing the eviction logic.
+ if tt.preloadCount >= 1 {
+ manager.connections["tenant-old"] = nil
+ manager.lastAccessed["tenant-old"] = time.Now().Add(-tt.oldTenantAge)
+ }
+
+ if tt.preloadCount >= 2 {
+ manager.connections["tenant-new"] = nil
+ manager.lastAccessed["tenant-new"] = time.Now().Add(-tt.newTenantAge)
+ }
+
+ // For unlimited test, add more connections
+ for i := 2; i < tt.preloadCount; i++ {
+ id := "tenant-extra-" + time.Now().Add(time.Duration(i)*time.Second).Format("150405")
+ manager.connections[id] = nil
+ manager.lastAccessed[id] = time.Now().Add(-time.Duration(i) * time.Minute)
+ }
+
+ // Call evictLRU (caller must hold write lock)
+ manager.mu.Lock()
+ manager.evictLRU(testutil.NewMockLogger())
+ manager.mu.Unlock()
+
+ // Verify pool size
+ assert.Equal(t, tt.expectedPoolSize, len(manager.connections),
+ "pool size mismatch after eviction")
+
+ if tt.expectEviction {
+ // Verify the oldest tenant was evicted
+ _, exists := manager.connections[tt.expectedEvictedID]
+ assert.False(t, exists,
+ "expected tenant %s to be evicted from pool", tt.expectedEvictedID)
+
+ // Verify lastAccessed was also cleaned up
+ _, accessExists := manager.lastAccessed[tt.expectedEvictedID]
+ assert.False(t, accessExists,
+ "expected lastAccessed entry for %s to be removed", tt.expectedEvictedID)
+ }
+ })
+ }
+}
+
+func TestManager_PoolGrowsBeyondSoftLimit_WhenAllActive(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ WithMaxTenantPools(2),
+ WithIdleTimeout(5*time.Minute),
+ )
+
+ // Pre-populate with 2 nil connections, both accessed recently (within idle timeout)
+ for _, id := range []string{"tenant-1", "tenant-2"} {
+ manager.connections[id] = nil
+ manager.lastAccessed[id] = time.Now().Add(-1 * time.Minute)
+ }
+
+ // Try to evict - should not evict because all connections are active
+ manager.mu.Lock()
+ manager.evictLRU(testutil.NewMockLogger())
+ manager.mu.Unlock()
+
+ // Pool should remain at 2 (no eviction occurred)
+ assert.Equal(t, 2, len(manager.connections),
+ "pool should not shrink when all connections are active")
+
+ // Simulate adding a third connection (pool grows beyond soft limit)
+ manager.connections["tenant-3"] = nil
+ manager.lastAccessed["tenant-3"] = time.Now()
+
+ assert.Equal(t, 3, len(manager.connections),
+ "pool should grow beyond soft limit when all connections are active")
+}
+
+func TestManager_WithIdleTimeout_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ idleTimeout time.Duration
+ expectedTimeout time.Duration
+ }{
+ {
+ name: "sets custom idle timeout",
+ idleTimeout: 10 * time.Minute,
+ expectedTimeout: 10 * time.Minute,
+ },
+ {
+ name: "sets short idle timeout",
+ idleTimeout: 30 * time.Second,
+ expectedTimeout: 30 * time.Second,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger",
+ WithIdleTimeout(tt.idleTimeout),
+ )
+
+ assert.Equal(t, tt.expectedTimeout, manager.idleTimeout)
+ })
+ }
+}
+
+func TestManager_CloseConnection_CleansUpLastAccessed(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache with a nil connection (avoids needing real AMQP)
+ manager.connections["tenant-123"] = nil
+ manager.lastAccessed["tenant-123"] = time.Now()
+
+ // Close the specific tenant connection
+ err := manager.CloseConnection(context.Background(), "tenant-123")
+
+ require.NoError(t, err)
+
+ manager.mu.RLock()
+ _, connExists := manager.connections["tenant-123"]
+ _, accessExists := manager.lastAccessed["tenant-123"]
+ manager.mu.RUnlock()
+
+ assert.False(t, connExists, "connection should be removed after CloseConnection")
+ assert.False(t, accessExists, "lastAccessed should be removed after CloseConnection")
+}
+
+func TestManager_WithMaxTenantPools_Option(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ maxConnections int
+ expectedMax int
+ }{
+ {
+ name: "sets max connections via option",
+ maxConnections: 10,
+ expectedMax: 10,
+ },
+ {
+ name: "zero means unlimited",
+ maxConnections: 0,
+ expectedMax: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger",
+ WithMaxTenantPools(tt.maxConnections),
+ )
+
+ assert.Equal(t, tt.expectedMax, manager.maxConnections)
+ })
+ }
+}
+
+func TestManager_Stats_IncludesMaxConnections(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger",
+ WithMaxTenantPools(50),
+ )
+
+ stats := manager.Stats()
+
+ assert.Equal(t, 50, stats.MaxConnections)
+ assert.Equal(t, 0, stats.TotalConnections)
+}
+
+func TestManager_Close_CleansUpLastAccessed(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger",
+ WithLogger(testutil.NewMockLogger()),
+ )
+
+ // Pre-populate cache with nil connections
+ manager.connections["tenant-1"] = nil
+ manager.lastAccessed["tenant-1"] = time.Now()
+ manager.connections["tenant-2"] = nil
+ manager.lastAccessed["tenant-2"] = time.Now()
+
+ err := manager.Close(context.Background())
+
+ require.NoError(t, err)
+ assert.True(t, manager.closed)
+ assert.Empty(t, manager.connections, "all connections should be removed after Close")
+ assert.Empty(t, manager.lastAccessed, "all lastAccessed entries should be removed after Close")
+}
+
+func TestBuildRabbitMQURI(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *core.RabbitMQConfig
+ useTLS bool
+ expected string
+ }{
+ {
+ name: "builds URI with all fields",
+ cfg: &core.RabbitMQConfig{
+ Host: "localhost",
+ Port: 5672,
+ Username: "guest",
+ Password: "guest",
+ VHost: "tenant-abc",
+ },
+ useTLS: false,
+ expected: "amqp://guest:guest@localhost:5672/tenant-abc",
+ },
+ {
+ name: "builds URI with custom port",
+ cfg: &core.RabbitMQConfig{
+ Host: "rabbitmq.internal",
+ Port: 5673,
+ Username: "admin",
+ Password: "secret",
+ VHost: "/",
+ },
+ useTLS: false,
+ expected: "amqp://admin:secret@rabbitmq.internal:5673/%2F",
+ },
+ {
+ name: "builds TLS URI with amqps scheme",
+ cfg: &core.RabbitMQConfig{
+ Host: "rabbitmq.prod.internal",
+ Port: 5671,
+ Username: "admin",
+ Password: "secret",
+ VHost: "tenant-xyz",
+ },
+ useTLS: true,
+ expected: "amqps://admin:secret@rabbitmq.prod.internal:5671/tenant-xyz",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ uri := buildRabbitMQURI(tt.cfg, tt.useTLS)
+ assert.Equal(t, tt.expected, uri)
+ })
+ }
+}
+
+func TestManager_ResolveTLS(t *testing.T) {
+ t.Parallel()
+
+ boolPtr := func(b bool) *bool { return &b }
+
+ tests := []struct {
+ name string
+ globalTLS bool
+ tenantTLS *bool
+ expected bool
+ }{
+ {
+ name: "uses global TLS when tenant TLS is nil",
+ globalTLS: true,
+ tenantTLS: nil,
+ expected: true,
+ },
+ {
+ name: "uses global false when tenant TLS is nil",
+ globalTLS: false,
+ tenantTLS: nil,
+ expected: false,
+ },
+ {
+ name: "per-tenant true overrides global false",
+ globalTLS: false,
+ tenantTLS: boolPtr(true),
+ expected: true,
+ },
+ {
+ name: "per-tenant false overrides global true",
+ globalTLS: true,
+ tenantTLS: boolPtr(false),
+ expected: false,
+ },
+ {
+ name: "per-tenant true with global true",
+ globalTLS: true,
+ tenantTLS: boolPtr(true),
+ expected: true,
+ },
+ {
+ name: "per-tenant false with global false",
+ globalTLS: false,
+ tenantTLS: boolPtr(false),
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+
+ var opts []Option
+ if tt.globalTLS {
+ opts = append(opts, WithTLS())
+ }
+
+ manager := NewManager(c, "ledger", opts...)
+
+ cfg := &core.RabbitMQConfig{
+ Host: "localhost",
+ Port: 5672,
+ Username: "guest",
+ Password: "guest",
+ VHost: "test",
+ TLS: tt.tenantTLS,
+ }
+
+ result := manager.resolveTLS(cfg)
+
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestManager_DialRabbitMQ_InvalidCAFile(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger")
+
+ // Attempt to dial with a non-existent CA file
+ _, err := manager.dialRabbitMQ("amqps://guest:guest@localhost:5671/test", true, "/nonexistent/ca.pem")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to read TLS CA file")
+}
+
+func TestManager_DialRabbitMQ_InvalidCACert(t *testing.T) {
+ t.Parallel()
+
+ // Create a temp file with invalid PEM content
+ tmpFile, err := os.CreateTemp("", "invalid-ca-*.pem")
+ require.NoError(t, err)
+ defer os.Remove(tmpFile.Name())
+
+ _, err = tmpFile.WriteString("this is not a valid PEM certificate")
+ require.NoError(t, err)
+ require.NoError(t, tmpFile.Close())
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger")
+
+ _, err = manager.dialRabbitMQ("amqps://guest:guest@localhost:5671/test", true, tmpFile.Name())
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "failed to parse CA certificate")
+}
+
+func TestManager_ApplyConnectionSettings_IsNoOp(t *testing.T) {
+ t.Parallel()
+
+ c := mustNewTestClient(t)
+ manager := NewManager(c, "ledger")
+
+ // Should not panic or error - it's a no-op
+ manager.ApplyConnectionSettings("tenant-123", &core.TenantConfig{
+ ID: "tenant-123",
+ })
+}
diff --git a/commons/tenant-manager/s3/objectstorage.go b/commons/tenant-manager/s3/objectstorage.go
new file mode 100644
index 00000000..29fd782b
--- /dev/null
+++ b/commons/tenant-manager/s3/objectstorage.go
@@ -0,0 +1,88 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+package s3
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+)
+
+// GetObjectStorageKey returns a tenant-prefixed object storage key: "{tenantID}/{key}".
+// If tenantID is empty, returns the key with leading slashes stripped (normalized).
+// Leading slashes are always stripped from the key to ensure clean path construction,
+// regardless of whether tenantID is present.
+// Returns an error if tenantID contains the path delimiter "/" which would create
+// ambiguous object storage paths or enable path traversal.
+func GetObjectStorageKey(tenantID, key string) (string, error) {
+ key = strings.TrimLeft(key, "/")
+
+ if tenantID == "" {
+ return key, nil
+ }
+
+ tenantID = strings.Trim(tenantID, "/")
+
+ if tenantID == "" {
+ return key, nil
+ }
+
+ if strings.Contains(tenantID, "/") {
+ return "", fmt.Errorf("tenantID must not contain path delimiter '/': %q", tenantID)
+ }
+
+ return tenantID + "/" + key, nil
+}
+
+// GetObjectStorageKeyForTenant returns a tenant-prefixed object storage key
+// using the tenantID from context.
+//
+// In multi-tenant mode (tenantID in context): "{tenantId}/{key}"
+// In single-tenant mode (no tenant in context): "{key}" (normalized, leading slashes stripped)
+//
+// If ctx is nil, behaves as single-tenant mode (no prefix).
+// Returns an error if the tenantID from context contains the path delimiter "/".
+//
+// Usage:
+//
+// key, err := s3.GetObjectStorageKeyForTenant(ctx, "reports/templateID/reportID.html")
+// // Multi-tenant: "org_01ABC.../reports/templateID/reportID.html"
+// // Single-tenant: "reports/templateID/reportID.html"
+// storage.Upload(ctx, key, reader, contentType)
+func GetObjectStorageKeyForTenant(ctx context.Context, key string) (string, error) {
+ if ctx == nil {
+ return GetObjectStorageKey("", key)
+ }
+
+ tenantID := core.GetTenantIDFromContext(ctx)
+
+ return GetObjectStorageKey(tenantID, key)
+}
+
+// StripObjectStoragePrefix removes the tenant prefix from an object storage key,
+// returning the original key. If the key doesn't have the expected prefix,
+// returns the key unchanged.
+// Returns an error if tenantID contains the path delimiter "/".
+func StripObjectStoragePrefix(tenantID, prefixedKey string) (string, error) {
+ if tenantID == "" {
+ return prefixedKey, nil
+ }
+
+ tenantID = strings.Trim(tenantID, "/")
+
+ if tenantID == "" {
+ return prefixedKey, nil
+ }
+
+ if strings.Contains(tenantID, "/") {
+ return "", fmt.Errorf("tenantID must not contain path delimiter '/': %q", tenantID)
+ }
+
+ prefix := tenantID + "/"
+
+ return strings.TrimPrefix(prefixedKey, prefix), nil
+}
diff --git a/commons/tenant-manager/s3/objectstorage_test.go b/commons/tenant-manager/s3/objectstorage_test.go
new file mode 100644
index 00000000..b11533fa
--- /dev/null
+++ b/commons/tenant-manager/s3/objectstorage_test.go
@@ -0,0 +1,268 @@
+package s3
+
+import (
+ "context"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetObjectStorageKey(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ key string
+ expected string
+ }{
+ {
+ name: "prefixes key with tenant ID",
+ tenantID: "org_01ABC",
+ key: "reports/templateID/reportID.html",
+ expected: "org_01ABC/reports/templateID/reportID.html",
+ },
+ {
+ name: "returns key unchanged when tenant ID is empty",
+ tenantID: "",
+ key: "reports/templateID/reportID.html",
+ expected: "reports/templateID/reportID.html",
+ },
+ {
+ name: "handles empty key with tenant ID",
+ tenantID: "org_01ABC",
+ key: "",
+ expected: "org_01ABC/",
+ },
+ {
+ name: "handles empty key without tenant ID",
+ tenantID: "",
+ key: "",
+ expected: "",
+ },
+ {
+ name: "strips leading slash from key before prefixing",
+ tenantID: "org_01ABC",
+ key: "/reports/templateID/reportID.html",
+ expected: "org_01ABC/reports/templateID/reportID.html",
+ },
+ {
+ name: "strips leading slash from key without tenant ID",
+ tenantID: "",
+ key: "/reports/templateID/reportID.html",
+ expected: "reports/templateID/reportID.html",
+ },
+ {
+ name: "handles key with multiple leading slashes",
+ tenantID: "org_01ABC",
+ key: "///reports/file.html",
+ expected: "org_01ABC/reports/file.html",
+ },
+ {
+ name: "preserves nested path structure",
+ tenantID: "tenant-456",
+ key: "a/b/c/d/file.pdf",
+ expected: "tenant-456/a/b/c/d/file.pdf",
+ },
+ {
+ name: "handles key that is just a filename",
+ tenantID: "org_01ABC",
+ key: "file.html",
+ expected: "org_01ABC/file.html",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := GetObjectStorageKey(tt.tenantID, tt.key)
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestGetObjectStorageKey_RejectsDelimiterInTenantID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ }{
+ {name: "slash in middle", tenantID: "tenant/123"},
+ {name: "multiple slashes", tenantID: "a/b/c"},
+ {name: "slash in middle after trim", tenantID: "/tenant/123/"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := GetObjectStorageKey(tt.tenantID, "reports/file.html")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "must not contain path delimiter '/'")
+ assert.Empty(t, result)
+ })
+ }
+}
+
+func TestGetObjectStorageKey_TrimsLeadingTrailingSlashesFromTenantID(t *testing.T) {
+ t.Parallel()
+
+ // Leading/trailing slashes are trimmed, so a tenantID that is ONLY slashes
+ // becomes empty and is treated as single-tenant mode.
+ result, err := GetObjectStorageKey("/", "reports/file.html")
+
+ require.NoError(t, err)
+ assert.Equal(t, "reports/file.html", result)
+}
+
+func TestGetObjectStorageKeyForTenant(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ key string
+ expected string
+ }{
+ {
+ name: "prefixes key with tenant ID from context",
+ tenantID: "org_01ABC",
+ key: "reports/templateID/reportID.html",
+ expected: "org_01ABC/reports/templateID/reportID.html",
+ },
+ {
+ name: "returns key unchanged when no tenant in context",
+ tenantID: "",
+ key: "reports/templateID/reportID.html",
+ expected: "reports/templateID/reportID.html",
+ },
+ {
+ name: "handles empty key with tenant in context",
+ tenantID: "org_01ABC",
+ key: "",
+ expected: "org_01ABC/",
+ },
+ {
+ name: "handles empty key without tenant in context",
+ tenantID: "",
+ key: "",
+ expected: "",
+ },
+ {
+ name: "strips leading slash from key",
+ tenantID: "org_01ABC",
+ key: "/reports/templateID/reportID.html",
+ expected: "org_01ABC/reports/templateID/reportID.html",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ if tt.tenantID != "" {
+ ctx = core.SetTenantIDInContext(ctx, tt.tenantID)
+ }
+
+ result, err := GetObjectStorageKeyForTenant(ctx, tt.key)
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestGetObjectStorageKeyForTenant_NilContext(t *testing.T) {
+ t.Parallel()
+
+ result, err := GetObjectStorageKeyForTenant(nil, "reports/templateID/reportID.html")
+
+ require.NoError(t, err)
+ assert.Equal(t, "reports/templateID/reportID.html", result)
+}
+
+func TestGetObjectStorageKeyForTenant_UsesSameTenantID(t *testing.T) {
+ t.Parallel()
+
+ ctx := context.Background()
+ tenantID := "org_consistency_check"
+
+ ctx = core.SetTenantIDInContext(ctx, tenantID)
+
+ extractedID := core.GetTenantID(ctx)
+
+ result, err := GetObjectStorageKeyForTenant(ctx, "test-key")
+
+ require.NoError(t, err)
+ assert.Equal(t, tenantID, extractedID)
+ assert.Equal(t, extractedID+"/test-key", result)
+}
+
+func TestStripObjectStoragePrefix(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ prefixedKey string
+ expected string
+ }{
+ {
+ name: "strips tenant prefix from key",
+ tenantID: "org_01ABC",
+ prefixedKey: "org_01ABC/reports/templateID/reportID.html",
+ expected: "reports/templateID/reportID.html",
+ },
+ {
+ name: "returns key unchanged when tenant ID is empty",
+ tenantID: "",
+ prefixedKey: "reports/templateID/reportID.html",
+ expected: "reports/templateID/reportID.html",
+ },
+ {
+ name: "returns key unchanged when prefix does not match",
+ tenantID: "org_01ABC",
+ prefixedKey: "other_tenant/reports/file.html",
+ expected: "other_tenant/reports/file.html",
+ },
+ {
+ name: "handles key that is just the prefix",
+ tenantID: "org_01ABC",
+ prefixedKey: "org_01ABC/",
+ expected: "",
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := StripObjectStoragePrefix(tt.tenantID, tt.prefixedKey)
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestStripObjectStoragePrefix_RejectsDelimiterInTenantID(t *testing.T) {
+ t.Parallel()
+
+ result, err := StripObjectStoragePrefix("tenant/123", "tenant/123/reports/file.html")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "must not contain path delimiter '/'")
+ assert.Empty(t, result)
+}
diff --git a/commons/tenant-manager/valkey/keys.go b/commons/tenant-manager/valkey/keys.go
new file mode 100644
index 00000000..a4848206
--- /dev/null
+++ b/commons/tenant-manager/valkey/keys.go
@@ -0,0 +1,91 @@
+// Copyright (c) 2026 Lerian Studio. All rights reserved.
+// Use of this source code is governed by the Elastic License 2.0
+// that can be found in the LICENSE file.
+
+package valkey
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+)
+
+const TenantKeyPrefix = "tenant"
+
+// GetKey returns tenant-prefixed key: "tenant:{tenantID}:{key}"
+// If tenantID is empty, returns the key unchanged.
+// Returns an error if tenantID contains the delimiter character ":"
+// which would corrupt the key namespace structure.
+func GetKey(tenantID, key string) (string, error) {
+ if tenantID == "" {
+ return key, nil
+ }
+
+ if strings.Contains(tenantID, ":") {
+ return "", fmt.Errorf("tenantID must not contain delimiter character ':': %q", tenantID)
+ }
+
+ return fmt.Sprintf("%s:%s:%s", TenantKeyPrefix, tenantID, key), nil
+}
+
+// GetKeyFromContext returns tenant-prefixed key using tenantID from context.
+// If no tenantID in context, returns the key unchanged.
+// If ctx is nil, returns the key unchanged (no tenant prefix).
+// Returns an error if the tenantID from context contains the delimiter character ":".
+func GetKeyFromContext(ctx context.Context, key string) (string, error) {
+ if ctx == nil {
+ return GetKey("", key)
+ }
+
+ tenantID := core.GetTenantIDFromContext(ctx)
+
+ return GetKey(tenantID, key)
+}
+
+// GetPattern returns pattern for scanning tenant keys: "tenant:{tenantID}:{pattern}"
+// If tenantID is empty, returns the pattern unchanged.
+// Returns an error if tenantID contains the delimiter character ":".
+func GetPattern(tenantID, pattern string) (string, error) {
+ if tenantID == "" {
+ return pattern, nil
+ }
+
+ if strings.Contains(tenantID, ":") {
+ return "", fmt.Errorf("tenantID must not contain delimiter character ':': %q", tenantID)
+ }
+
+ return fmt.Sprintf("%s:%s:%s", TenantKeyPrefix, tenantID, pattern), nil
+}
+
+// GetPatternFromContext returns pattern using tenantID from context.
+// If no tenantID in context, returns the pattern unchanged.
+// If ctx is nil, returns the pattern unchanged (no tenant prefix).
+// Returns an error if the tenantID from context contains the delimiter character ":".
+func GetPatternFromContext(ctx context.Context, pattern string) (string, error) {
+ if ctx == nil {
+ return GetPattern("", pattern)
+ }
+
+ tenantID := core.GetTenantIDFromContext(ctx)
+
+ return GetPattern(tenantID, pattern)
+}
+
+// StripTenantPrefix removes tenant prefix from key, returns original key.
+// If key doesn't have the expected prefix, returns the key unchanged.
+// Returns an error if tenantID contains the delimiter character ":".
+func StripTenantPrefix(tenantID, prefixedKey string) (string, error) {
+ if tenantID == "" {
+ return prefixedKey, nil
+ }
+
+ if strings.Contains(tenantID, ":") {
+ return "", fmt.Errorf("tenantID must not contain delimiter character ':': %q", tenantID)
+ }
+
+ prefix := fmt.Sprintf("%s:%s:", TenantKeyPrefix, tenantID)
+
+ return strings.TrimPrefix(prefixedKey, prefix), nil
+}
diff --git a/commons/tenant-manager/valkey/keys_test.go b/commons/tenant-manager/valkey/keys_test.go
new file mode 100644
index 00000000..825bf343
--- /dev/null
+++ b/commons/tenant-manager/valkey/keys_test.go
@@ -0,0 +1,148 @@
+package valkey
+
+import (
+ "context"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/tenant-manager/core"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetKey(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ key string
+ expected string
+ }{
+ {name: "prefixes key with tenant", tenantID: "tenant-123", key: "orders", expected: "tenant:tenant-123:orders"},
+ {name: "returns key unchanged when tenant empty", tenantID: "", key: "orders", expected: "orders"},
+ {name: "handles empty key", tenantID: "tenant-123", key: "", expected: "tenant:tenant-123:"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := GetKey(tt.tenantID, tt.key)
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestGetKey_RejectsDelimiterInTenantID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ tenantID string
+ }{
+ {name: "colon in middle", tenantID: "tenant:123"},
+ {name: "colon at start", tenantID: ":tenant"},
+ {name: "colon at end", tenantID: "tenant:"},
+ {name: "multiple colons", tenantID: "a:b:c"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ result, err := GetKey(tt.tenantID, "orders")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "must not contain delimiter character ':'")
+ assert.Empty(t, result)
+ })
+ }
+}
+
+func TestGetKeyFromContext(t *testing.T) {
+ t.Parallel()
+
+ ctx := core.SetTenantIDInContext(context.Background(), "tenant-ctx")
+
+ result, err := GetKeyFromContext(ctx, "orders")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant:tenant-ctx:orders", result)
+
+ result, err = GetKeyFromContext(context.Background(), "orders")
+ require.NoError(t, err)
+ assert.Equal(t, "orders", result)
+
+ result, err = GetKeyFromContext(nil, "orders")
+ require.NoError(t, err)
+ assert.Equal(t, "orders", result)
+}
+
+func TestGetPattern(t *testing.T) {
+ t.Parallel()
+
+ result, err := GetPattern("tenant-123", "orders:*")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant:tenant-123:orders:*", result)
+
+ result, err = GetPattern("", "orders:*")
+ require.NoError(t, err)
+ assert.Equal(t, "orders:*", result)
+}
+
+func TestGetPattern_RejectsDelimiterInTenantID(t *testing.T) {
+ t.Parallel()
+
+ result, err := GetPattern("tenant:123", "orders:*")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "must not contain delimiter character ':'")
+ assert.Empty(t, result)
+}
+
+func TestGetPatternFromContext(t *testing.T) {
+ t.Parallel()
+
+ ctx := core.SetTenantIDInContext(context.Background(), "tenant-ctx")
+
+ result, err := GetPatternFromContext(ctx, "orders:*")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant:tenant-ctx:orders:*", result)
+
+ result, err = GetPatternFromContext(context.Background(), "orders:*")
+ require.NoError(t, err)
+ assert.Equal(t, "orders:*", result)
+
+ result, err = GetPatternFromContext(nil, "orders:*")
+ require.NoError(t, err)
+ assert.Equal(t, "orders:*", result)
+}
+
+func TestStripTenantPrefix(t *testing.T) {
+ t.Parallel()
+
+ result, err := StripTenantPrefix("tenant-123", "tenant:tenant-123:orders:1")
+ require.NoError(t, err)
+ assert.Equal(t, "orders:1", result)
+
+ result, err = StripTenantPrefix("", "orders:1")
+ require.NoError(t, err)
+ assert.Equal(t, "orders:1", result)
+
+ result, err = StripTenantPrefix("tenant-123", "tenant:other:orders:1")
+ require.NoError(t, err)
+ assert.Equal(t, "tenant:other:orders:1", result)
+}
+
+func TestStripTenantPrefix_RejectsDelimiterInTenantID(t *testing.T) {
+ t.Parallel()
+
+ result, err := StripTenantPrefix("tenant:123", "tenant:tenant:123:orders:1")
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "must not contain delimiter character ':'")
+ assert.Empty(t, result)
+}
diff --git a/commons/time.go b/commons/time.go
index 3af481ea..06cf83c7 100644
--- a/commons/time.go
+++ b/commons/time.go
@@ -1,14 +1,14 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package commons
import (
+ "errors"
"fmt"
"time"
)
+// ErrInvalidDateFormat indicates the date string could not be parsed by any known format.
+var ErrInvalidDateFormat = errors.New("invalid date format")
+
// IsValidDate checks if the provided date string is in the format "YYYY-MM-DD".
func IsValidDate(date string) bool {
_, err := time.Parse("2006-01-02", date)
@@ -87,7 +87,7 @@ func ParseDateTime(dateStr string, isEndDate bool) (time.Time, bool, error) {
return t, false, nil
}
- return time.Time{}, false, fmt.Errorf("invalid date format: %s", dateStr)
+ return time.Time{}, false, fmt.Errorf("%w: %s", ErrInvalidDateFormat, dateStr)
}
// IsValidDateTime checks if the provided date string is in the format "YYYY-MM-DD HH:MM:SS".
diff --git a/commons/time_test.go b/commons/time_test.go
index 3cece18c..401a459c 100644
--- a/commons/time_test.go
+++ b/commons/time_test.go
@@ -1,6 +1,4 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package commons
diff --git a/commons/transaction/doc.go b/commons/transaction/doc.go
new file mode 100644
index 00000000..6b6a18ff
--- /dev/null
+++ b/commons/transaction/doc.go
@@ -0,0 +1,9 @@
+// Package transaction provides transaction intent planning and posting validations.
+//
+// Core flow:
+// - BuildIntentPlan validates and expands allocations into postings.
+// - ValidateBalanceEligibility checks source/destination constraints.
+// - ApplyPosting applies operation/status transitions to balances.
+//
+// The package enforces deterministic behavior using typed domain errors.
+package transaction
diff --git a/commons/transaction/error_example_test.go b/commons/transaction/error_example_test.go
new file mode 100644
index 00000000..ea28989d
--- /dev/null
+++ b/commons/transaction/error_example_test.go
@@ -0,0 +1,24 @@
+//go:build unit
+
+package transaction_test
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/transaction"
+)
+
+func ExampleNewDomainError() {
+ err := transaction.NewDomainError(transaction.ErrorInvalidInput, "asset", "asset is required")
+
+ var domainErr transaction.DomainError
+ ok := errors.As(err, &domainErr)
+
+ fmt.Println(ok)
+ fmt.Println(domainErr.Code, domainErr.Field)
+
+ // Output:
+ // true
+ // 1001 asset
+}
diff --git a/commons/transaction/transaction.go b/commons/transaction/transaction.go
index dcec1bc0..b17f4056 100644
--- a/commons/transaction/transaction.go
+++ b/commons/transaction/transaction.go
@@ -1,202 +1,183 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package transaction
import (
- "strconv"
+ "fmt"
"strings"
"time"
+ constant "github.com/LerianStudio/lib-commons/v4/commons/constants"
"github.com/shopspring/decimal"
)
-// Deprecated: use model from Midaz pkg instead.
-// Balance structure for marshaling/unmarshalling JSON.
-//
-// swagger:model Balance
-// @Description Balance is the struct designed to represent the account balance.
-type Balance struct {
- ID string `json:"id" example:"00000000-0000-0000-0000-000000000000"`
- OrganizationID string `json:"organizationId" example:"00000000-0000-0000-0000-000000000000"`
- LedgerID string `json:"ledgerId" example:"00000000-0000-0000-0000-000000000000"`
- AccountID string `json:"accountId" example:"00000000-0000-0000-0000-000000000000"`
- Alias string `json:"alias" example:"@person1"`
- Key string `json:"key" example:"asset-freeze"`
- AssetCode string `json:"assetCode" example:"BRL"`
- Available decimal.Decimal `json:"available" example:"1500"`
- OnHold decimal.Decimal `json:"onHold" example:"500"`
- Version int64 `json:"version" example:"1"`
- AccountType string `json:"accountType" example:"creditCard"`
- AllowSending bool `json:"allowSending" example:"true"`
- AllowReceiving bool `json:"allowReceiving" example:"true"`
- CreatedAt time.Time `json:"createdAt" example:"2021-01-01T00:00:00Z"`
- UpdatedAt time.Time `json:"updatedAt" example:"2021-01-01T00:00:00Z"`
- DeletedAt *time.Time `json:"deletedAt" example:"2021-01-01T00:00:00Z"`
- Metadata map[string]any `json:"metadata,omitempty"`
-} // @name Balance
-
-// Deprecated: use model from Midaz pkg instead.
-type Responses struct {
- Total decimal.Decimal
- Asset string
- From map[string]Amount
- To map[string]Amount
- Sources []string
- Destinations []string
- Aliases []string
- Pending bool
- TransactionRoute string
- OperationRoutesFrom map[string]string
- OperationRoutesTo map[string]string
-}
+// Operation represents the posting operation applied to a balance.
+type Operation string
+
+const (
+ // OperationDebit decreases available balance from a source.
+ OperationDebit Operation = Operation(constant.DEBIT)
+ // OperationCredit increases available balance on a destination.
+ OperationCredit Operation = Operation(constant.CREDIT)
+ // OperationOnHold moves value from available to on-hold.
+ OperationOnHold Operation = Operation(constant.ONHOLD)
+ // OperationRelease moves value from on-hold back to available.
+ OperationRelease Operation = Operation(constant.RELEASE)
+)
-// Deprecated: use model from Midaz pkg instead.
-// Metadata structure for marshaling/unmarshalling JSON.
-//
-// swagger:model Metadata
-// @Description Metadata is the struct designed to store metadata.
-type Metadata struct {
- Key string `json:"key,omitempty"`
- Value any `json:"value,omitempty"`
-} // @name Metadata
-
-// Deprecated: use model from Midaz pkg instead.
-// Amount structure for marshaling/unmarshalling JSON.
-//
-// swagger:model Amount
-// @Description Amount is the struct designed to represent the amount of an operation.
-type Amount struct {
- Asset string `json:"asset,omitempty" validate:"required" example:"BRL"`
- Value decimal.Decimal `json:"value,omitempty" validate:"required" example:"1000"`
- Operation string `json:"operation,omitempty"`
- TransactionType string `json:"transactionType,omitempty"`
-} // @name Amount
-
-// Deprecated: use model from Midaz pkg instead.
-// Share structure for marshaling/unmarshalling JSON.
-//
-// swagger:model Share
-// @Description Share is the struct designed to represent the sharing fields of an operation.
-type Share struct {
- Percentage int64 `json:"percentage,omitempty" validate:"required"`
- PercentageOfPercentage int64 `json:"percentageOfPercentage,omitempty"`
-} // @name Share
-
-// Deprecated: use model from Midaz pkg instead.
-// Send structure for marshaling/unmarshalling JSON.
+// TransactionStatus represents the lifecycle state of a transaction intent.
//
-// swagger:model Send
-// @Description Send is the struct designed to represent the sending fields of an operation.
-type Send struct {
- Asset string `json:"asset,omitempty" validate:"required" example:"BRL"`
- Value decimal.Decimal `json:"value,omitempty" validate:"required" example:"1000"`
- Source Source `json:"source,omitempty" validate:"required"`
- Distribute Distribute `json:"distribute,omitempty" validate:"required"`
-} // @name Send
-
-// Deprecated: use model from Midaz pkg instead.
-// Source structure for marshaling/unmarshalling JSON.
+// Semantics:
+// - CREATED: intent recorded but not yet submitted for processing.
+// - APPROVED: intent approved for execution but not yet applied.
+// - PENDING: intent currently being processed (balance updates in flight).
+// - CANCELED: intent rejected or rolled back; terminal state.
//
-// swagger:model Source
-// @Description Source is the struct designed to represent the source fields of an operation.
-type Source struct {
- Remaining string `json:"remaining,omitempty" example:"remaining"`
- From []FromTo `json:"from,omitempty" validate:"singletransactiontype,required,dive"`
-} // @name Source
-
-// Deprecated: use model from Midaz pkg instead.
-// Rate structure for marshaling/unmarshalling JSON.
+// Typical transitions:
//
-// swagger:model Rate
-// @Description Rate is the struct designed to represent the rate fields of an operation.
-type Rate struct {
- From string `json:"from" validate:"required" example:"BRL"`
- To string `json:"to" validate:"required" example:"USDe"`
- Value decimal.Decimal `json:"value" validate:"required" example:"1000"`
- ExternalID string `json:"externalId" validate:"uuid,required" example:"00000000-0000-0000-0000-000000000000"`
-} // @name Rate
-
-// Deprecated: use IsEmpty method from Midaz pkg instead.
-// IsEmpty method that set empty or nil in fields
-func (r Rate) IsEmpty() bool {
- return r.ExternalID == "" && r.From == "" && r.To == "" && r.Value.IsZero()
+// CREATED → APPROVED | CANCELED
+// APPROVED → PENDING | CANCELED
+// PENDING → (terminal; see associated Posting status for settlement)
+type TransactionStatus string
+
+const (
+ // StatusCreated marks an intent as recorded but not yet approved.
+ StatusCreated TransactionStatus = TransactionStatus(constant.CREATED)
+ // StatusApproved marks an intent as approved for processing.
+ StatusApproved TransactionStatus = TransactionStatus(constant.APPROVED)
+ // StatusPending marks an intent as currently being processed.
+ StatusPending TransactionStatus = TransactionStatus(constant.PENDING)
+ // StatusCanceled marks an intent as rejected or rolled back.
+ StatusCanceled TransactionStatus = TransactionStatus(constant.CANCELED)
+)
+
+// AccountType classifies balances by ownership boundary.
+type AccountType string
+
+const (
+ // AccountTypeInternal identifies balances owned within the platform.
+ AccountTypeInternal AccountType = "internal"
+ // AccountTypeExternal identifies balances owned outside the platform.
+ AccountTypeExternal AccountType = "external"
+)
+
+// ErrorCode is a domain error code used by transaction validations.
+type ErrorCode string
+
+const (
+ // ErrorInsufficientFunds indicates the source balance cannot cover the amount.
+ ErrorInsufficientFunds ErrorCode = ErrorCode(constant.CodeInsufficientFunds)
+ // ErrorAccountIneligibility indicates the account cannot participate in the transaction.
+ ErrorAccountIneligibility ErrorCode = ErrorCode(constant.CodeAccountIneligibility)
+ // ErrorAccountStatusTransactionRestriction indicates account status blocks this transaction.
+ ErrorAccountStatusTransactionRestriction ErrorCode = ErrorCode(constant.CodeAccountStatusTransactionRestriction)
+ // ErrorAssetCodeNotFound indicates the requested asset was not found.
+ ErrorAssetCodeNotFound ErrorCode = ErrorCode(constant.CodeAssetCodeNotFound)
+ // ErrorTransactionValueMismatch indicates allocations do not match transaction total.
+ ErrorTransactionValueMismatch ErrorCode = ErrorCode(constant.CodeTransactionValueMismatch)
+ // ErrorTransactionAmbiguous indicates transaction routing cannot be determined uniquely.
+ ErrorTransactionAmbiguous ErrorCode = ErrorCode(constant.CodeTransactionAmbiguous)
+ // ErrorOnHoldExternalAccount indicates on-hold operations are not allowed for external accounts.
+ ErrorOnHoldExternalAccount ErrorCode = ErrorCode(constant.CodeOnHoldExternalAccount)
+ // ErrorDataCorruption indicates persisted transaction data is inconsistent.
+ ErrorDataCorruption ErrorCode = "0099"
+ // ErrorInvalidInput indicates request payload validation failed.
+ ErrorInvalidInput ErrorCode = "1001"
+ // ErrorInvalidStateTransition indicates an invalid transaction state transition was requested.
+ ErrorInvalidStateTransition ErrorCode = "1002"
+ // ErrorCrossScope indicates balances from different organizations or ledgers are mixed.
+ ErrorCrossScope ErrorCode = "1003"
+)
+
+// DomainError represents a structured transaction domain validation error.
+type DomainError struct {
+ Code ErrorCode
+ Field string
+ Message string
}
-// Deprecated: use model from Midaz pkg instead.
-// FromTo structure for marshaling/unmarshalling JSON.
-//
-// swagger:model FromTo
-// @Description FromTo is the struct designed to represent the from/to fields of an operation.
-type FromTo struct {
- AccountAlias string `json:"accountAlias,omitempty" example:"@person1"`
- BalanceKey string `json:"balanceKey,omitempty" example:"asset-freeze"`
- Amount *Amount `json:"amount,omitempty"`
- Share *Share `json:"share,omitempty"`
- Remaining string `json:"remaining,omitempty" example:"remaining"`
- Rate *Rate `json:"rate,omitempty"`
- Description string `json:"description,omitempty" example:"description"`
- ChartOfAccounts string `json:"chartOfAccounts" example:"1000"`
- Metadata map[string]any `json:"metadata" validate:"dive,keys,keymax=100,endkeys,nonested,valuemax=2000"`
- IsFrom bool `json:"isFrom,omitempty" example:"true"`
- Route string `json:"route,omitempty" validate:"omitempty,max=250" example:"00000000-0000-0000-0000-000000000000"`
-} // @name FromTo
-
-// Deprecated: use SplitAlias method from Midaz pkg instead.
-// SplitAlias function to split alias with index.
-func (ft FromTo) SplitAlias() string {
- if strings.Contains(ft.AccountAlias, "#") {
- return strings.Split(ft.AccountAlias, "#")[1]
+// Error returns the formatted domain error string.
+func (e DomainError) Error() string {
+ if e.Field == "" {
+ return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
- return ft.AccountAlias
+ return fmt.Sprintf("%s: %s (%s)", e.Code, e.Message, e.Field)
}
-// Deprecated: use SplitAliasWithKey method from Midaz pkg instead.
-// SplitAliasWithKey extracts the substring after the '#' character from the provided alias or returns the alias if '#' is not present.
-func SplitAliasWithKey(alias string) string {
- if idx := strings.Index(alias, "#"); idx != -1 {
- return alias[idx+1:]
+// NewDomainError creates a domain error with code, field, and message.
+func NewDomainError(code ErrorCode, field, message string) error {
+ return DomainError{Code: code, Field: field, Message: message}
+}
+
+// Balance contains the balance state used during intent planning and posting.
+type Balance struct {
+ ID string `json:"id"`
+ OrganizationID string `json:"organizationId"`
+ LedgerID string `json:"ledgerId"`
+ AccountID string `json:"accountId"`
+ Asset string `json:"asset"`
+ Available decimal.Decimal `json:"available"`
+ OnHold decimal.Decimal `json:"onHold"`
+ Version int64 `json:"version"`
+ AccountType AccountType `json:"accountType"`
+ AllowSending bool `json:"allowSending"`
+ AllowReceiving bool `json:"allowReceiving"`
+ CreatedAt time.Time `json:"createdAt"`
+ UpdatedAt time.Time `json:"updatedAt"`
+ DeletedAt *time.Time `json:"deletedAt"`
+ Metadata map[string]any `json:"metadata,omitempty"`
+}
+
+// LedgerTarget identifies the account and balance affected by a posting.
+type LedgerTarget struct {
+ AccountID string `json:"accountId"`
+ BalanceID string `json:"balanceId"`
+}
+
+func (t LedgerTarget) validate(field string) error {
+ if strings.TrimSpace(t.AccountID) == "" {
+ return NewDomainError(ErrorInvalidInput, field+".accountId", "accountId is required")
}
- return alias
+ if strings.TrimSpace(t.BalanceID) == "" {
+ return NewDomainError(ErrorInvalidInput, field+".balanceId", "balanceId is required")
+ }
+
+ return nil
}
-// Deprecated: use ConcatAlias method from Midaz pkg instead.
-// ConcatAlias function to concat alias with index.
-func (ft FromTo) ConcatAlias(i int) string {
- return strconv.Itoa(i) + "#" + ft.AccountAlias + "#" + ft.BalanceKey
+// Allocation defines how part of the transaction total is assigned.
+type Allocation struct {
+ Target LedgerTarget `json:"target"`
+ Amount *decimal.Decimal `json:"amount,omitempty"`
+ Share *decimal.Decimal `json:"share,omitempty"`
+ Remainder bool `json:"remainder"`
+ Route string `json:"route,omitempty"`
}
-// Deprecated: use model from Midaz pkg instead.
-// Distribute structure for marshaling/unmarshalling JSON.
-//
-// swagger:model Distribute
-// @Description Distribute is the struct designed to represent the distribution fields of an operation.
-type Distribute struct {
- Remaining string `json:"remaining,omitempty"`
- To []FromTo `json:"to,omitempty" validate:"singletransactiontype,required,dive"`
-} // @name Distribute
-
-// Deprecated: use model from Midaz pkg instead.
-// Transaction structure for marshaling/unmarshalling JSON.
-//
-// swagger:model Transaction
-// @Description Transaction is a struct designed to store transaction data.
-type Transaction struct {
- ChartOfAccountsGroupName string `json:"chartOfAccountsGroupName,omitempty" example:"1000"`
- Description string `json:"description,omitempty" example:"Description"`
- Code string `json:"code,omitempty" example:"00000000-0000-0000-0000-000000000000"`
- Pending bool `json:"pending,omitempty" example:"false"`
- Metadata map[string]any `json:"metadata,omitempty" validate:"dive,keys,keymax=100,endkeys,nonested,valuemax=2000"`
- Route string `json:"route,omitempty" validate:"omitempty,max=250" example:"00000000-0000-0000-0000-000000000000"`
- TransactionDate time.Time `json:"transactionDate,omitempty" example:"2021-01-01T00:00:00Z"`
- Send Send `json:"send" validate:"required"`
-} // @name Transaction
-
-// Deprecated: use IsEmpty method from Midaz pkg instead.
-// IsEmpty is a func that validate if transaction is Empty.
-func (t Transaction) IsEmpty() bool {
- return t.Send.Asset == "" && t.Send.Value.IsZero()
+// TransactionIntentInput is the user input used to build a deterministic plan.
+type TransactionIntentInput struct {
+ Asset string `json:"asset"`
+ Total decimal.Decimal `json:"total"`
+ Pending bool `json:"pending"`
+ Sources []Allocation `json:"sources"`
+ Destinations []Allocation `json:"destinations"`
+}
+
+// Posting is a concrete operation to apply against a target balance.
+type Posting struct {
+ Target LedgerTarget `json:"target"`
+ Asset string `json:"asset"`
+ Amount decimal.Decimal `json:"amount"`
+ Operation Operation `json:"operation"`
+ Status TransactionStatus `json:"status"`
+ Route string `json:"route,omitempty"`
+}
+
+// IntentPlan is the validated and expanded representation of a transaction intent.
+type IntentPlan struct {
+ Asset string `json:"asset"`
+ Total decimal.Decimal `json:"total"`
+ Pending bool `json:"pending"`
+ Sources []Posting `json:"sources"`
+ Destinations []Posting `json:"destinations"`
}
diff --git a/commons/transaction/transaction_test.go b/commons/transaction/transaction_test.go
index 8b05c666..dd1ff744 100644
--- a/commons/transaction/transaction_test.go
+++ b/commons/transaction/transaction_test.go
@@ -1,279 +1,1242 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package transaction
import (
+ "errors"
+ "sync"
"testing"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
-func TestBalance_IsEmpty(t *testing.T) {
+// ---------------------------------------------------------------------------
+// ResolveOperation -- exhaustive state matrix
+// ---------------------------------------------------------------------------
+
+func TestResolveOperation(t *testing.T) {
tests := []struct {
- name string
- rate Rate
- want bool
+ name string
+ pending bool
+ isSource bool
+ status TransactionStatus
+ expected Operation
+ errorCode ErrorCode
}{
- {
- name: "Empty rate",
- rate: Rate{},
- want: true,
- },
- {
- name: "Non-empty rate",
- rate: Rate{
- From: "BRL",
- To: "USD",
- Value: decimal.NewFromInt(100),
- ExternalID: "00000000-0000-0000-0000-000000000000",
- },
- want: false,
- },
+ // Pending transactions
+ {name: "pending source PENDING", pending: true, isSource: true, status: StatusPending, expected: OperationOnHold},
+ {name: "pending destination PENDING", pending: true, isSource: false, status: StatusPending, expected: OperationCredit},
+ {name: "pending source CANCELED", pending: true, isSource: true, status: StatusCanceled, expected: OperationRelease},
+ {name: "pending destination CANCELED", pending: true, isSource: false, status: StatusCanceled, expected: OperationDebit},
+ {name: "pending source APPROVED", pending: true, isSource: true, status: StatusApproved, expected: OperationDebit},
+ {name: "pending destination APPROVED", pending: true, isSource: false, status: StatusApproved, expected: OperationCredit},
+
+ // Non-pending transactions
+ {name: "non-pending source CREATED", pending: false, isSource: true, status: StatusCreated, expected: OperationDebit},
+ {name: "non-pending destination CREATED", pending: false, isSource: false, status: StatusCreated, expected: OperationCredit},
+
+ // Invalid statuses
+ {name: "non-pending source APPROVED", pending: false, isSource: true, status: StatusApproved, errorCode: ErrorInvalidStateTransition},
+ {name: "non-pending destination APPROVED", pending: false, isSource: false, status: StatusApproved, errorCode: ErrorInvalidStateTransition},
+ {name: "non-pending source PENDING", pending: false, isSource: true, status: StatusPending, errorCode: ErrorInvalidStateTransition},
+ {name: "non-pending destination CANCELED", pending: false, isSource: false, status: StatusCanceled, errorCode: ErrorInvalidStateTransition},
+ {name: "pending source CREATED", pending: true, isSource: true, status: StatusCreated, errorCode: ErrorInvalidStateTransition},
+ {name: "pending destination CREATED", pending: true, isSource: false, status: StatusCreated, errorCode: ErrorInvalidStateTransition},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got := tt.rate.IsEmpty()
- assert.Equal(t, tt.want, got)
+ t.Parallel()
+
+ got, err := ResolveOperation(tt.pending, tt.isSource, tt.status)
+
+ if tt.errorCode != "" {
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr))
+ assert.Equal(t, tt.errorCode, domainErr.Code)
+ assert.Equal(t, "status", domainErr.Field)
+
+ return
+ }
+
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected, got)
})
}
}
-func TestFromTo_SplitAlias(t *testing.T) {
+// ---------------------------------------------------------------------------
+// ApplyPosting -- happy path operations
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting(t *testing.T) {
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ OnHold: decimal.NewFromInt(20),
+ Version: 7,
+ AllowSending: true,
+ AllowReceiving: true,
+ }
+
tests := []struct {
- name string
- accountAlias string
- want string
+ name string
+ posting Posting
+ expected Balance
+ errorCode ErrorCode
}{
{
- name: "Alias without index",
- accountAlias: "@person1",
- want: "@person1",
+ name: "ON_HOLD moves available to onHold",
+ posting: Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(30),
+ Operation: OperationOnHold,
+ Status: StatusPending,
+ },
+ expected: Balance{Available: decimal.NewFromInt(70), OnHold: decimal.NewFromInt(50), Version: 8},
+ },
+ {
+ name: "RELEASE moves onHold to available",
+ posting: Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationRelease,
+ Status: StatusCanceled,
+ },
+ expected: Balance{Available: decimal.NewFromInt(110), OnHold: decimal.NewFromInt(10), Version: 8},
},
{
- name: "Alias with index",
- accountAlias: "1#@person1",
- want: "@person1",
+ name: "DEBIT APPROVED deducts from onHold",
+ posting: Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: StatusApproved,
+ },
+ expected: Balance{Available: decimal.NewFromInt(100), OnHold: decimal.NewFromInt(10), Version: 8},
},
{
- name: "Alias with index and balance key",
- accountAlias: "1#@person1#savings",
- want: "@person1",
+ name: "DEBIT CREATED deducts from available",
+ posting: Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(50),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ },
+ expected: Balance{Available: decimal.NewFromInt(50), OnHold: decimal.NewFromInt(20), Version: 8},
},
{
- name: "Alias with index and empty balance key",
- accountAlias: "0#@external#",
- want: "@external",
+ name: "CREDIT CREATED adds to available",
+ posting: Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(40),
+ Operation: OperationCredit,
+ Status: StatusCreated,
+ },
+ expected: Balance{Available: decimal.NewFromInt(140), OnHold: decimal.NewFromInt(20), Version: 8},
},
{
- name: "Alias with index and default balance key",
- accountAlias: "2#@account#default",
- want: "@account",
+ name: "CREDIT APPROVED adds to available",
+ posting: Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(25),
+ Operation: OperationCredit,
+ Status: StatusApproved,
+ },
+ expected: Balance{Available: decimal.NewFromInt(125), OnHold: decimal.NewFromInt(20), Version: 8},
},
{
- name: "Complex alias with index and balance key",
- accountAlias: "5#@external/BRL#checking",
- want: "@external/BRL",
+ name: "CREDIT PENDING adds to available",
+ posting: Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(40),
+ Operation: OperationCredit,
+ Status: StatusPending,
+ },
+ expected: Balance{Available: decimal.NewFromInt(140), OnHold: decimal.NewFromInt(20), Version: 8},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ft := FromTo{
- AccountAlias: tt.accountAlias,
+ t.Parallel()
+
+ got, err := ApplyPosting(balance, tt.posting)
+
+ if tt.errorCode != "" {
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr))
+ assert.Equal(t, tt.errorCode, domainErr.Code)
+
+ return
}
- got := ft.SplitAlias()
- assert.Equal(t, tt.want, got)
+
+ require.NoError(t, err)
+ assert.Equal(t, balance.ID, got.ID)
+ assert.Equal(t, balance.AccountID, got.AccountID)
+ assert.Equal(t, balance.Asset, got.Asset)
+ assert.True(t, tt.expected.Available.Equal(got.Available),
+ "available: want=%s got=%s", tt.expected.Available, got.Available)
+ assert.True(t, tt.expected.OnHold.Equal(got.OnHold),
+ "onHold: want=%s got=%s", tt.expected.OnHold, got.OnHold)
+ assert.Equal(t, tt.expected.Version, got.Version)
})
}
}
-func TestFromTo_ConcatAlias(t *testing.T) {
- tests := []struct {
- name string
- accountAlias string
- balanceKey string
- index int
- want string
- }{
- {
- name: "Concat index with alias and balance key",
- accountAlias: "@person1",
- balanceKey: "savings",
- index: 1,
- want: "1#@person1#savings",
- },
- {
- name: "Concat index with alias and empty balance key",
- accountAlias: "@person2",
- balanceKey: "",
- index: 0,
- want: "0#@person2#",
- },
- {
- name: "Concat index with alias and default balance key",
- accountAlias: "@person3",
- balanceKey: "default",
- index: 2,
- want: "2#@person3#default",
+// ---------------------------------------------------------------------------
+// ApplyPosting -- validation errors
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_MissingTargetAccountID(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{ID: "b1", AccountID: "a1", Asset: "USD", Available: decimal.NewFromInt(100)}
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidInput, de.Code)
+ assert.Contains(t, de.Field, "accountId")
+}
+
+func TestApplyPosting_MissingTargetBalanceID(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{ID: "b1", AccountID: "a1", Asset: "USD", Available: decimal.NewFromInt(100)}
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: ""},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidInput, de.Code)
+ assert.Contains(t, de.Field, "balanceId")
+}
+
+func TestApplyPosting_RejectsMismatchedBalanceID(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{ID: "balance-A", AccountID: "account-1", Asset: "USD", Available: decimal.NewFromInt(100)}
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-B"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorAccountIneligibility, de.Code)
+ assert.Contains(t, de.Field, "balanceId")
+}
+
+func TestApplyPosting_RejectsMismatchedAccountID(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ OnHold: decimal.Zero,
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{
+ AccountID: "account-2",
+ BalanceID: "balance-1",
},
+ Asset: "USD",
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ Amount: decimal.NewFromInt(10),
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ft := FromTo{
- AccountAlias: tt.accountAlias,
- BalanceKey: tt.balanceKey,
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr))
+ assert.Equal(t, ErrorAccountIneligibility, domainErr.Code)
+ assert.Contains(t, domainErr.Field, "accountId")
+}
+
+func TestApplyPosting_RejectsAssetMismatch(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "EUR",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorAssetCodeNotFound, de.Code)
+ assert.Equal(t, "posting.asset", de.Field)
+}
+
+func TestApplyPosting_RejectsZeroAmount(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.Zero,
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidInput, de.Code)
+ assert.Contains(t, de.Message, "posting amount must be greater than zero")
+}
+
+func TestApplyPosting_RejectsNegativeAmount(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(-5),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidInput, de.Code)
+}
+
+func TestApplyPosting_RejectsUnsupportedOperation(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: Operation("UNKNOWN_OP"),
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidInput, de.Code)
+ assert.Equal(t, "posting.operation", de.Field)
+}
+
+// ---------------------------------------------------------------------------
+// ApplyPosting -- invalid state transitions
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_OnHold_RequiresPendingStatus(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ }
+
+ invalidStatuses := []TransactionStatus{StatusCreated, StatusApproved, StatusCanceled}
+ for _, status := range invalidStatuses {
+ t.Run(string(status), func(t *testing.T) {
+ t.Parallel()
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationOnHold,
+ Status: status,
}
- got := ft.ConcatAlias(tt.index)
- assert.Equal(t, tt.want, got)
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidStateTransition, de.Code)
+ assert.Contains(t, de.Message, "ON_HOLD requires PENDING status")
})
}
}
-// TestFromTo_ConcatSplitAlias_Compatibility verifies that SplitAlias can correctly parse strings generated by ConcatAlias
-func TestFromTo_ConcatSplitAlias_Compatibility(t *testing.T) {
- tests := []struct {
- name string
- accountAlias string
- balanceKey string
- index int
- }{
- {
- name: "Standard alias with balance key",
- accountAlias: "@person1",
- balanceKey: "savings",
- index: 1,
- },
- {
- name: "External alias with empty balance key",
- accountAlias: "@external/BRL",
- balanceKey: "",
- index: 0,
- },
- {
- name: "Complex alias with default balance key",
- accountAlias: "@company/accounts/primary",
- balanceKey: "default",
- index: 5,
- },
- {
- name: "Simple alias with special balance key",
- accountAlias: "@test",
- balanceKey: "checking-account",
- index: 999,
- },
+func TestApplyPosting_Release_RequiresCanceledStatus(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(50),
+ OnHold: decimal.NewFromInt(50),
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ft := FromTo{
- AccountAlias: tt.accountAlias,
- BalanceKey: tt.balanceKey,
+ invalidStatuses := []TransactionStatus{StatusCreated, StatusApproved, StatusPending}
+ for _, status := range invalidStatuses {
+ t.Run(string(status), func(t *testing.T) {
+ t.Parallel()
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationRelease,
+ Status: status,
}
- // Generate concatenated string using ConcatAlias
- concatenated := ft.ConcatAlias(tt.index)
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidStateTransition, de.Code)
+ assert.Contains(t, de.Message, "RELEASE requires CANCELED status")
+ })
+ }
+}
- // Create new FromTo with the concatenated string as AccountAlias
- ftWithConcatenated := FromTo{
- AccountAlias: concatenated,
+func TestApplyPosting_Debit_InvalidStatus(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ OnHold: decimal.NewFromInt(100),
+ }
+
+ // Only PENDING is invalid for DEBIT now; CANCELED is valid for pending destination cancellations.
+ invalidStatuses := []TransactionStatus{StatusPending}
+ for _, status := range invalidStatuses {
+ t.Run(string(status), func(t *testing.T) {
+ t.Parallel()
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: status,
}
- // Extract alias using SplitAlias
- extractedAlias := ftWithConcatenated.SplitAlias()
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
- // Verify that extracted alias matches the original
- assert.Equal(t, tt.accountAlias, extractedAlias,
- "SplitAlias should extract the original alias from ConcatAlias output")
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidStateTransition, de.Code)
+ assert.Contains(t, de.Message, "DEBIT only supports CREATED, APPROVED, or CANCELED status")
})
}
}
-func TestFromTo_ConcatSplitAlias_AliasContainsHash(t *testing.T) {
- ft := FromTo{
- AccountAlias: "@person#vip",
- BalanceKey: "savings",
+func TestApplyPosting_Debit_Canceled_PendingDestinationCancellation(t *testing.T) {
+ t.Parallel()
+
+ // When a pending transaction is canceled, the destination that received a CREDIT
+ // during PENDING must have that credit reversed via DEBIT+CANCELED.
+ balance := Balance{
+ ID: "dst-bal",
+ AccountID: "dst-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(50),
+ OnHold: decimal.Zero,
+ Version: 1,
}
- concatenated := ft.ConcatAlias(1)
- ftWithConcatenated := FromTo{AccountAlias: concatenated}
- extractedAlias := ftWithConcatenated.SplitAlias()
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "dst-acc", BalanceID: "dst-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(30),
+ Operation: OperationDebit,
+ Status: StatusCanceled,
+ }
- // Current behavior (documented): alias gets truncated at '#' due to ambiguous delimiter usage.
- // This test documents the ambiguity; consider changing serialization or adding escaping.
- assert.Equal(t, "@person", extractedAlias)
+ result, err := ApplyPosting(balance, posting)
+ require.NoError(t, err)
+ assert.True(t, result.Available.Equal(decimal.NewFromInt(20)),
+ "expected 20 after cancellation debit, got %s", result.Available)
+ assert.Equal(t, int64(2), result.Version)
}
-func TestFromTo_ConcatSplitAlias_BalanceKeyContainsHash(t *testing.T) {
- ft := FromTo{
- AccountAlias: "@person",
- BalanceKey: "sav#ings",
+func TestApplyPosting_Credit_InvalidStatus(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
}
- concatenated := ft.ConcatAlias(2)
- ftWithConcatenated := FromTo{AccountAlias: concatenated}
- extractedAlias := ftWithConcatenated.SplitAlias()
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationCredit,
+ Status: StatusCanceled,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
- // Alias should still be extracted correctly regardless of balanceKey content
- assert.Equal(t, ft.AccountAlias, extractedAlias)
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidStateTransition, de.Code)
+ assert.Contains(t, de.Message, "CREDIT only supports CREATED, APPROVED, or PENDING status")
}
-func TestTransaction_IsEmpty(t *testing.T) {
- tests := []struct {
- name string
- transaction Transaction
- want bool
- }{
- {
- name: "Empty transaction",
- transaction: Transaction{
- Send: Send{
- Asset: "",
- Value: decimal.NewFromInt(0),
- },
- },
- want: true,
+// ---------------------------------------------------------------------------
+// ApplyPosting -- insufficient funds (negative result guards)
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_RejectsNegativeResultingBalances(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(50),
+ OnHold: decimal.NewFromInt(5),
+ }
+
+ t.Run("debit over available", func(t *testing.T) {
+ t.Parallel()
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(100),
+ Status: StatusCreated,
+ Operation: OperationDebit,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr))
+ assert.Equal(t, ErrorInsufficientFunds, domainErr.Code)
+ assert.Contains(t, domainErr.Message, "negative available balance")
+ })
+
+ t.Run("release over on hold", func(t *testing.T) {
+ t.Parallel()
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Status: StatusCanceled,
+ Operation: OperationRelease,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr))
+ assert.Equal(t, ErrorInsufficientFunds, domainErr.Code)
+ assert.Contains(t, domainErr.Message, "negative on-hold balance")
+ })
+
+ t.Run("on hold over available", func(t *testing.T) {
+ t.Parallel()
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(51),
+ Status: StatusPending,
+ Operation: OperationOnHold,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr))
+ assert.Equal(t, ErrorInsufficientFunds, domainErr.Code)
+ assert.Contains(t, domainErr.Message, "negative available balance")
+ })
+
+ t.Run("debit approved over onHold", func(t *testing.T) {
+ t.Parallel()
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(6),
+ Status: StatusApproved,
+ Operation: OperationDebit,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr))
+ assert.Equal(t, ErrorInsufficientFunds, domainErr.Code)
+ assert.Contains(t, domainErr.Message, "negative on-hold balance")
+ })
+}
+
+func TestApplyPosting_AllowsPendingCredit(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "balance-1",
+ AccountID: "account-1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(0),
+ OnHold: decimal.NewFromInt(0),
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "account-1", BalanceID: "balance-1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(25),
+ Status: StatusPending,
+ Operation: OperationCredit,
+ }
+
+ updated, err := ApplyPosting(balance, posting)
+ require.NoError(t, err)
+ assert.True(t, updated.Available.Equal(decimal.NewFromInt(25)))
+ assert.Equal(t, int64(1), updated.Version)
+}
+
+// ---------------------------------------------------------------------------
+// ApplyPosting -- idempotency / immutability
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_DoesNotMutateInput(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ OnHold: decimal.NewFromInt(50),
+ Version: 3,
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(30),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ // Save copies of the original values.
+ origAvailable := balance.Available
+ origOnHold := balance.OnHold
+ origVersion := balance.Version
+
+ result, err := ApplyPosting(balance, posting)
+ require.NoError(t, err)
+
+ // Original balance must not be mutated.
+ assert.True(t, balance.Available.Equal(origAvailable),
+ "input balance available mutated from %s to %s", origAvailable, balance.Available)
+ assert.True(t, balance.OnHold.Equal(origOnHold),
+ "input balance onHold mutated from %s to %s", origOnHold, balance.OnHold)
+ assert.Equal(t, origVersion, balance.Version,
+ "input balance version mutated from %d to %d", origVersion, balance.Version)
+
+ // Result should reflect the operation.
+ assert.True(t, result.Available.Equal(decimal.NewFromInt(70)))
+ assert.Equal(t, int64(4), result.Version)
+}
+
+func TestApplyPosting_DoesNotMutateInputOnError(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ OnHold: decimal.NewFromInt(50),
+ Version: 3,
+ }
+
+ // Save copies of the original values.
+ origAvailable := balance.Available
+ origOnHold := balance.OnHold
+ origVersion := balance.Version
+
+ // Asset mismatch should cause an error.
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "EUR",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ _, err := ApplyPosting(balance, posting)
+ require.Error(t, err)
+
+ // Original balance must not be mutated even on error.
+ assert.True(t, balance.Available.Equal(origAvailable),
+ "input balance available mutated from %s to %s", origAvailable, balance.Available)
+ assert.True(t, balance.OnHold.Equal(origOnHold),
+ "input balance onHold mutated from %s to %s", origOnHold, balance.OnHold)
+ assert.Equal(t, origVersion, balance.Version,
+ "input balance version mutated from %d to %d", origVersion, balance.Version)
+}
+
+func TestApplyPosting_SequentialPostings_VersionIncrements(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(1000),
+ OnHold: decimal.Zero,
+ Version: 0,
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ // Apply 10 sequential postings.
+ current := balance
+ for i := 0; i < 10; i++ {
+ var err error
+ current, err = ApplyPosting(current, posting)
+ require.NoError(t, err)
+ assert.Equal(t, int64(i+1), current.Version)
+ }
+
+ // After 10 debits of 10 from 1000, available should be 900.
+ assert.True(t, current.Available.Equal(decimal.NewFromInt(900)))
+}
+
+// ---------------------------------------------------------------------------
+// ApplyPosting -- decimal precision in operations
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_DecimalPrecision(t *testing.T) {
+ t.Parallel()
+
+ d, _ := decimal.NewFromString("100.005")
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "BTC",
+ Available: d,
+ OnHold: decimal.Zero,
+ Version: 0,
+ }
+
+ amt, _ := decimal.NewFromString("0.001")
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "BTC",
+ Amount: amt,
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ result, err := ApplyPosting(balance, posting)
+ require.NoError(t, err)
+
+ expected, _ := decimal.NewFromString("100.004")
+ assert.True(t, result.Available.Equal(expected),
+ "expected %s, got %s", expected, result.Available)
+}
+
+func TestApplyPosting_VerySmallAmount(t *testing.T) {
+ t.Parallel()
+
+ avail, _ := decimal.NewFromString("0.000000000000000002")
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "ETH",
+ Available: avail,
+ OnHold: decimal.Zero,
+ }
+
+ amt, _ := decimal.NewFromString("0.000000000000000001")
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "ETH",
+ Amount: amt,
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ result, err := ApplyPosting(balance, posting)
+ require.NoError(t, err)
+
+ expected, _ := decimal.NewFromString("0.000000000000000001")
+ assert.True(t, result.Available.Equal(expected))
+}
+
+func TestApplyPosting_LargeAmount(t *testing.T) {
+ t.Parallel()
+
+ avail, _ := decimal.NewFromString("999999999999999.99")
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: avail,
+ OnHold: decimal.Zero,
+ }
+
+ amt, _ := decimal.NewFromString("0.01")
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: amt,
+ Operation: OperationCredit,
+ Status: StatusCreated,
+ }
+
+ result, err := ApplyPosting(balance, posting)
+ require.NoError(t, err)
+
+ expected, _ := decimal.NewFromString("1000000000000000.00")
+ assert.True(t, result.Available.Equal(expected),
+ "expected %s, got %s", expected, result.Available)
+}
+
+// ---------------------------------------------------------------------------
+// ApplyPosting -- full lifecycle: Created -> OnHold -> Release (cancel)
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_FullPendingLifecycle_Approved(t *testing.T) {
+ t.Parallel()
+
+ // Start with source balance: 100 available, 0 on hold.
+ source := Balance{
+ ID: "src",
+ AccountID: "src-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ OnHold: decimal.Zero,
+ Version: 0,
+ }
+
+ // Step 1: ON_HOLD (PENDING) - source holds 30.
+ afterHold, err := ApplyPosting(source, Posting{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(30),
+ Operation: OperationOnHold,
+ Status: StatusPending,
+ })
+ require.NoError(t, err)
+ assert.True(t, afterHold.Available.Equal(decimal.NewFromInt(70)))
+ assert.True(t, afterHold.OnHold.Equal(decimal.NewFromInt(30)))
+ assert.Equal(t, int64(1), afterHold.Version)
+
+ // Step 2: DEBIT (APPROVED) - settlement moves from on-hold.
+ afterDebit, err := ApplyPosting(afterHold, Posting{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(30),
+ Operation: OperationDebit,
+ Status: StatusApproved,
+ })
+ require.NoError(t, err)
+ assert.True(t, afterDebit.Available.Equal(decimal.NewFromInt(70)))
+ assert.True(t, afterDebit.OnHold.Equal(decimal.Zero))
+ assert.Equal(t, int64(2), afterDebit.Version)
+}
+
+func TestApplyPosting_FullPendingLifecycle_Canceled(t *testing.T) {
+ t.Parallel()
+
+ source := Balance{
+ ID: "src",
+ AccountID: "src-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(100),
+ OnHold: decimal.Zero,
+ Version: 0,
+ }
+
+ // Step 1: ON_HOLD (PENDING).
+ afterHold, err := ApplyPosting(source, Posting{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(30),
+ Operation: OperationOnHold,
+ Status: StatusPending,
+ })
+ require.NoError(t, err)
+ assert.True(t, afterHold.Available.Equal(decimal.NewFromInt(70)))
+ assert.True(t, afterHold.OnHold.Equal(decimal.NewFromInt(30)))
+
+ // Step 2: RELEASE (CANCELED) - funds return to available.
+ afterRelease, err := ApplyPosting(afterHold, Posting{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(30),
+ Operation: OperationRelease,
+ Status: StatusCanceled,
+ })
+ require.NoError(t, err)
+ assert.True(t, afterRelease.Available.Equal(decimal.NewFromInt(100)),
+ "expected 100, got %s", afterRelease.Available)
+ assert.True(t, afterRelease.OnHold.Equal(decimal.Zero))
+ assert.Equal(t, int64(2), afterRelease.Version)
+}
+
+// ---------------------------------------------------------------------------
+// ApplyPosting -- debit exactly to zero
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_DebitToExactlyZero(t *testing.T) {
+ t.Parallel()
+
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(42),
+ OnHold: decimal.Zero,
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(42),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }
+
+ result, err := ApplyPosting(balance, posting)
+ require.NoError(t, err)
+ assert.True(t, result.Available.Equal(decimal.Zero),
+ "expected zero, got %s", result.Available)
+}
+
+// ---------------------------------------------------------------------------
+// Concurrent posting safety
+// ---------------------------------------------------------------------------
+
+func TestApplyPosting_ConcurrentSafety(t *testing.T) {
+ t.Parallel()
+
+ // ApplyPosting is a pure function (takes value, returns value),
+ // so concurrent calls should never interfere with each other.
+ balance := Balance{
+ ID: "b1",
+ AccountID: "a1",
+ Asset: "USD",
+ Available: decimal.NewFromInt(1000),
+ OnHold: decimal.Zero,
+ Version: 0,
+ }
+
+ posting := Posting{
+ Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(10),
+ Operation: OperationCredit,
+ Status: StatusCreated,
+ }
+
+ const goroutines = 100
+
+ var wg sync.WaitGroup
+
+ wg.Add(goroutines)
+
+ results := make([]Balance, goroutines)
+ errs := make([]error, goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ go func(idx int) {
+ defer wg.Done()
+ results[idx], errs[idx] = ApplyPosting(balance, posting)
+ }(i)
+ }
+
+ wg.Wait()
+
+ // Every goroutine should succeed and produce the same deterministic result.
+ for i := 0; i < goroutines; i++ {
+ require.NoError(t, errs[i], "goroutine %d failed", i)
+ assert.True(t, results[i].Available.Equal(decimal.NewFromInt(1010)),
+ "goroutine %d: expected 1010, got %s", i, results[i].Available)
+ assert.Equal(t, int64(1), results[i].Version)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// End-to-end: Build plan, validate eligibility, apply postings
+// ---------------------------------------------------------------------------
+
+func TestEndToEnd_FullTransactionFlow(t *testing.T) {
+ t.Parallel()
+
+ total := decimal.NewFromInt(200)
+ amount := decimal.NewFromInt(200)
+
+ // 1. Build intent plan.
+ input := TransactionIntentInput{
+ Asset: "BRL",
+ Total: total,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "alice-acc", BalanceID: "alice-bal"}, Amount: &amount},
},
- {
- name: "Non-empty transaction with asset",
- transaction: Transaction{
- Send: Send{
- Asset: "BRL",
- Value: decimal.NewFromInt(0),
- },
- },
- want: false,
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "bob-acc", BalanceID: "bob-bal"}, Amount: &amount},
},
- {
- name: "Non-empty transaction with value",
- transaction: Transaction{
- Send: Send{
- Asset: "",
- Value: decimal.NewFromInt(100),
- },
- },
- want: false,
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+
+ // 2. Validate eligibility.
+ balances := map[string]Balance{
+ "alice-bal": {
+ ID: "alice-bal",
+ AccountID: "alice-acc",
+ Asset: "BRL",
+ Available: decimal.NewFromInt(500),
+ OnHold: decimal.Zero,
+ AllowSending: true,
+ AccountType: AccountTypeInternal,
},
- {
- name: "Complete non-empty transaction",
- transaction: Transaction{
- Send: Send{
- Asset: "BRL",
- Value: decimal.NewFromInt(100),
- },
- },
- want: false,
+ "bob-bal": {
+ ID: "bob-bal",
+ AccountID: "bob-acc",
+ Asset: "BRL",
+ Available: decimal.NewFromInt(100),
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := tt.transaction.IsEmpty()
- assert.Equal(t, tt.want, got)
- })
+ err = ValidateBalanceEligibility(plan, balances)
+ require.NoError(t, err)
+
+ // 3. Apply source debit.
+ aliceBalance := balances["alice-bal"]
+ aliceAfter, err := ApplyPosting(aliceBalance, plan.Sources[0])
+ require.NoError(t, err)
+ assert.True(t, aliceAfter.Available.Equal(decimal.NewFromInt(300)),
+ "expected 300, got %s", aliceAfter.Available)
+
+ // 4. Apply destination credit.
+ bobBalance := balances["bob-bal"]
+ bobAfter, err := ApplyPosting(bobBalance, plan.Destinations[0])
+ require.NoError(t, err)
+ assert.True(t, bobAfter.Available.Equal(decimal.NewFromInt(300)),
+ "expected 300, got %s", bobAfter.Available)
+}
+
+func TestEndToEnd_PendingTransactionFlow(t *testing.T) {
+ t.Parallel()
+
+ total := decimal.NewFromInt(50)
+ amount := decimal.NewFromInt(50)
+
+ // 1. Build pending plan.
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: total,
+ Pending: true,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src-bal"}, Amount: &amount},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "dst-acc", BalanceID: "dst-bal"}, Amount: &amount},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusPending)
+ require.NoError(t, err)
+ assert.Equal(t, OperationOnHold, plan.Sources[0].Operation)
+ assert.Equal(t, OperationCredit, plan.Destinations[0].Operation)
+
+ // 2. Validate eligibility.
+ srcBal := Balance{
+ ID: "src-bal",
+ AccountID: "src-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(200),
+ OnHold: decimal.Zero,
+ AllowSending: true,
+ AccountType: AccountTypeInternal,
+ }
+
+ dstBal := Balance{
+ ID: "dst-bal",
+ AccountID: "dst-acc",
+ Asset: "USD",
+ Available: decimal.Zero,
+ AllowReceiving: true,
+ AccountType: AccountTypeExternal,
}
+
+ balances := map[string]Balance{
+ "src-bal": srcBal,
+ "dst-bal": dstBal,
+ }
+
+ err = ValidateBalanceEligibility(plan, balances)
+ require.NoError(t, err)
+
+ // 3. Apply source ON_HOLD.
+ srcAfterHold, err := ApplyPosting(srcBal, plan.Sources[0])
+ require.NoError(t, err)
+ assert.True(t, srcAfterHold.Available.Equal(decimal.NewFromInt(150)))
+ assert.True(t, srcAfterHold.OnHold.Equal(decimal.NewFromInt(50)))
+
+ // 4. Apply destination CREDIT.
+ dstAfterCredit, err := ApplyPosting(dstBal, plan.Destinations[0])
+ require.NoError(t, err)
+ assert.True(t, dstAfterCredit.Available.Equal(decimal.NewFromInt(50)))
+
+ // 5. Now approve: source gets DEBIT from onHold.
+ approvePosting := Posting{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(50),
+ Operation: OperationDebit,
+ Status: StatusApproved,
+ }
+
+ srcAfterApproval, err := ApplyPosting(srcAfterHold, approvePosting)
+ require.NoError(t, err)
+ assert.True(t, srcAfterApproval.Available.Equal(decimal.NewFromInt(150)))
+ assert.True(t, srcAfterApproval.OnHold.Equal(decimal.Zero))
+}
+
+// ---------------------------------------------------------------------------
+// sumPostings
+// ---------------------------------------------------------------------------
+
+func TestSumPostings(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty", func(t *testing.T) {
+ t.Parallel()
+
+ result := sumPostings(nil)
+ assert.True(t, result.Equal(decimal.Zero))
+ })
+
+ t.Run("single", func(t *testing.T) {
+ t.Parallel()
+
+ postings := []Posting{{Amount: decimal.NewFromInt(42)}}
+ result := sumPostings(postings)
+ assert.True(t, result.Equal(decimal.NewFromInt(42)))
+ })
+
+ t.Run("multiple", func(t *testing.T) {
+ t.Parallel()
+
+ d1, _ := decimal.NewFromString("33.33")
+ d2, _ := decimal.NewFromString("33.33")
+ d3, _ := decimal.NewFromString("33.34")
+
+ postings := []Posting{{Amount: d1}, {Amount: d2}, {Amount: d3}}
+ result := sumPostings(postings)
+ assert.True(t, result.Equal(decimal.NewFromInt(100)),
+ "expected 100, got %s", result)
+ })
}
diff --git a/commons/transaction/validations.go b/commons/transaction/validations.go
index 7339f6fb..f0726af4 100644
--- a/commons/transaction/validations.go
+++ b/commons/transaction/validations.go
@@ -1,50 +1,98 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package transaction
import (
- "context"
- "strconv"
+ "fmt"
"strings"
- "github.com/LerianStudio/lib-commons/v2/commons"
- constant "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/LerianStudio/lib-commons/v2/commons/opentelemetry"
"github.com/shopspring/decimal"
)
-// Deprecated: use ValidateBalancesRules method from Midaz pkg instead.
-// ValidateBalancesRules function with some validates in accounts and DSL operations
-func ValidateBalancesRules(ctx context.Context, transaction Transaction, validate Responses, balances []*Balance) error {
- logger, tracer, _, _ := commons.NewTrackingFromContext(ctx)
+var oneHundred = decimal.NewFromInt(100)
- _, spanValidateBalances := tracer.Start(ctx, "validations.validate_balances_rules")
- defer spanValidateBalances.End()
+// BuildIntentPlan validates input allocations and builds a normalized intent plan.
+func BuildIntentPlan(input TransactionIntentInput, status TransactionStatus) (IntentPlan, error) {
+ if strings.TrimSpace(input.Asset) == "" {
+ return IntentPlan{}, NewDomainError(ErrorInvalidInput, "asset", "asset is required")
+ }
- if len(balances) != (len(validate.From) + len(validate.To)) {
- err := commons.ValidateBusinessError(constant.ErrAccountIneligibility, "ValidateAccounts")
+ if !input.Total.IsPositive() {
+ return IntentPlan{}, NewDomainError(ErrorInvalidInput, "total", "total must be greater than zero")
+ }
- opentelemetry.HandleSpanBusinessErrorEvent(&spanValidateBalances, "validations.validate_balances_rules", err)
+ if len(input.Sources) == 0 {
+ return IntentPlan{}, NewDomainError(ErrorInvalidInput, "sources", "at least one source is required")
+ }
- return err
+ if len(input.Destinations) == 0 {
+ return IntentPlan{}, NewDomainError(ErrorInvalidInput, "destinations", "at least one destination is required")
}
- for _, balance := range balances {
- if err := validateFromBalances(balance, validate.From, validate.Asset, validate.Pending); err != nil {
- opentelemetry.HandleSpanBusinessErrorEvent(&spanValidateBalances, "validations.validate_from_balances_", err)
+ sources, err := buildPostings(input.Asset, input.Total, input.Pending, status, input.Sources, true)
+ if err != nil {
+ return IntentPlan{}, err
+ }
- logger.Errorf("validations.validate_from_balances_err: %s", err)
+ destinations, err := buildPostings(input.Asset, input.Total, input.Pending, status, input.Destinations, false)
+ if err != nil {
+ return IntentPlan{}, err
+ }
- return err
+ sourceTotal := sumPostings(sources)
+
+ destinationTotal := sumPostings(destinations)
+ if !sourceTotal.Equal(input.Total) || !destinationTotal.Equal(input.Total) {
+ return IntentPlan{}, NewDomainError(
+ ErrorTransactionValueMismatch,
+ "total",
+ fmt.Sprintf("source total=%s destination total=%s expected=%s", sourceTotal, destinationTotal, input.Total),
+ )
+ }
+
+ sourceIDs := make(map[string]struct{}, len(sources))
+ for _, source := range sources {
+ sourceIDs[source.Target.BalanceID] = struct{}{}
+ }
+
+ for _, destination := range destinations {
+ if _, exists := sourceIDs[destination.Target.BalanceID]; exists {
+ return IntentPlan{}, NewDomainError(ErrorTransactionAmbiguous, "destinations", "balance appears as source and destination")
}
+ }
+
+ return IntentPlan{
+ Asset: input.Asset,
+ Total: input.Total,
+ Pending: input.Pending,
+ Sources: sources,
+ Destinations: destinations,
+ }, nil
+}
- if err := validateToBalances(balance, validate.To, validate.Asset); err != nil {
- opentelemetry.HandleSpanBusinessErrorEvent(&spanValidateBalances, "validations.validate_to_balances_", err)
+// ValidateBalanceEligibility checks whether balances can participate in a plan.
+// It validates:
+// - All referenced balances exist in the catalog
+// - Asset codes match the transaction asset
+// - Sending/receiving permissions are met
+// - Source balances have sufficient available funds for the posting amount
+// - Account ownership: posting target AccountID matches balance AccountID
+// - All balances share the same OrganizationID and LedgerID (no cross-scope mixing)
+// - External account constraints (pending holds, zero-balance destinations)
+func ValidateBalanceEligibility(plan IntentPlan, balances map[string]Balance) error {
+ if len(balances) == 0 {
+ return NewDomainError(ErrorAccountIneligibility, "balances", "balance catalog is empty")
+ }
- logger.Errorf("validations.validate_to_balances_err: %s", err)
+ // Cross-scope validation: all balances must belong to the same org and ledger.
+ var refOrgID, refLedgerID string
+
+ for _, posting := range plan.Sources {
+ if err := validateSourcePosting(plan, posting, balances, &refOrgID, &refLedgerID); err != nil {
+ return err
+ }
+ }
+ for _, posting := range plan.Destinations {
+ if err := validateDestinationPosting(plan, posting, balances, &refOrgID, &refLedgerID); err != nil {
return err
}
}
@@ -52,41 +100,111 @@ func ValidateBalancesRules(ctx context.Context, transaction Transaction, validat
return nil
}
-func validateFromBalances(balance *Balance, from map[string]Amount, asset string, pending bool) error {
- for key := range from {
- balanceAliasKey := AliasKey(balance.Alias, balance.Key)
- if key == balance.ID || SplitAliasWithKey(key) == balanceAliasKey {
- if balance.AssetCode != asset {
- return commons.ValidateBusinessError(constant.ErrAssetCodeNotFound, "validateFromAccounts")
- }
+// validateSourcePosting validates a single source posting against its balance.
+func validateSourcePosting(plan IntentPlan, posting Posting, balances map[string]Balance, refOrgID, refLedgerID *string) error {
+ balance, ok := balances[posting.Target.BalanceID]
+ if !ok {
+ return NewDomainError(ErrorAccountIneligibility, "sources", "source balance not found")
+ }
- if !balance.AllowSending {
- return commons.ValidateBusinessError(constant.ErrAccountStatusTransactionRestriction, "validateFromAccounts")
- }
+ // Account ownership validation
+ if balance.AccountID != posting.Target.AccountID {
+ return NewDomainError(ErrorAccountIneligibility, "sources", "source posting accountId does not match balance accountId")
+ }
- if pending && balance.AccountType == constant.ExternalAccountType {
- return commons.ValidateBusinessError(constant.ErrOnHoldExternalAccount, "validateBalance", balance.Alias)
- }
- }
+ // Cross-scope check
+ if err := validateScope(refOrgID, refLedgerID, balance, "sources"); err != nil {
+ return err
+ }
+
+ if balance.Asset != plan.Asset {
+ return NewDomainError(ErrorAssetCodeNotFound, "sources", "source asset does not match transaction asset")
+ }
+
+ if !balance.AllowSending {
+ return NewDomainError(ErrorAccountStatusTransactionRestriction, "sources", "source balance is not allowed to send")
+ }
+
+ // Amount sufficiency check: source must have enough available funds
+ if balance.AllowSending && balance.Available.LessThan(posting.Amount) {
+ return NewDomainError(ErrorInsufficientFunds, "sources",
+ fmt.Sprintf("source balance available %s is less than posting amount %s", balance.Available, posting.Amount))
+ }
+
+ if plan.Pending && balance.AccountType == AccountTypeExternal {
+ return NewDomainError(ErrorOnHoldExternalAccount, "sources", "external source cannot be put on hold")
}
return nil
}
-func validateToBalances(balance *Balance, to map[string]Amount, asset string) error {
- balanceAliasKey := AliasKey(balance.Alias, balance.Key)
- for key := range to {
- if key == balance.ID || SplitAliasWithKey(key) == balanceAliasKey {
- if balance.AssetCode != asset {
- return commons.ValidateBusinessError(constant.ErrAssetCodeNotFound, "validateToAccounts")
- }
+// validateDestinationPosting validates a single destination posting against its balance.
+func validateDestinationPosting(plan IntentPlan, posting Posting, balances map[string]Balance, refOrgID, refLedgerID *string) error {
+ balance, ok := balances[posting.Target.BalanceID]
+ if !ok {
+ return NewDomainError(ErrorAccountIneligibility, "destinations", "destination balance not found")
+ }
+
+ // Account ownership validation
+ if balance.AccountID != posting.Target.AccountID {
+ return NewDomainError(ErrorAccountIneligibility, "destinations", "destination posting accountId does not match balance accountId")
+ }
+
+ // Cross-scope check
+ if err := validateScope(refOrgID, refLedgerID, balance, "destinations"); err != nil {
+ return err
+ }
+
+ if balance.Asset != plan.Asset {
+ return NewDomainError(ErrorAssetCodeNotFound, "destinations", "destination asset does not match transaction asset")
+ }
+
+ if !balance.AllowReceiving {
+ return NewDomainError(ErrorAccountStatusTransactionRestriction, "destinations", "destination balance is not allowed to receive")
+ }
+
+ if err := validateExternalDestinationBalance(balance); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// validateExternalDestinationBalance enforces constraints on external destination accounts:
+// they must have exactly zero available balance (negative indicates data corruption).
+func validateExternalDestinationBalance(balance Balance) error {
+ if balance.AccountType != AccountTypeExternal {
+ return nil
+ }
+
+ if balance.Available.IsNegative() {
+ return NewDomainError(ErrorDataCorruption, "balance", "external destination account has negative balance, indicating data corruption")
+ }
+
+ if balance.Available.IsPositive() {
+ return NewDomainError(ErrorInsufficientFunds, "destinations", "external destination must have zero available balance")
+ }
- if !balance.AllowReceiving {
- return commons.ValidateBusinessError(constant.ErrAccountStatusTransactionRestriction, "validateToAccounts")
+ return nil
+}
+
+// validateScope ensures all balances in a plan share the same OrganizationID and LedgerID.
+// On first call (ref values are empty), it captures the reference values.
+// On subsequent calls, it compares against the captured reference.
+func validateScope(refOrgID, refLedgerID *string, balance Balance, field string) error {
+ if balance.OrganizationID != "" || balance.LedgerID != "" {
+ if *refOrgID == "" && *refLedgerID == "" {
+ *refOrgID = balance.OrganizationID
+ *refLedgerID = balance.LedgerID
+ } else {
+ if balance.OrganizationID != *refOrgID {
+ return NewDomainError(ErrorCrossScope, field,
+ fmt.Sprintf("balance organizationId %q does not match expected %q", balance.OrganizationID, *refOrgID))
}
- if balance.Available.IsPositive() && balance.AccountType == constant.ExternalAccountType {
- return commons.ValidateBusinessError(constant.ErrInsufficientFunds, "validateToAccounts", balance.Alias)
+ if balance.LedgerID != *refLedgerID {
+ return NewDomainError(ErrorCrossScope, field,
+ fmt.Sprintf("balance ledgerId %q does not match expected %q", balance.LedgerID, *refLedgerID))
}
}
}
@@ -94,299 +212,359 @@ func validateToBalances(balance *Balance, to map[string]Amount, asset string) er
return nil
}
-// Deprecated: use ValidateFromToOperation method from Midaz pkg instead.
-// ValidateFromToOperation func that validate operate balance
-func ValidateFromToOperation(ft FromTo, validate Responses, balance *Balance) (Amount, Balance, error) {
- if ft.IsFrom {
- ba, err := OperateBalances(validate.From[ft.AccountAlias], *balance)
- if err != nil {
- return Amount{}, Balance{}, err
- }
+// ApplyPosting applies a posting transition to a balance and returns the new state.
+func ApplyPosting(balance Balance, posting Posting) (Balance, error) {
+ if err := validatePostingAgainstBalance(balance, posting); err != nil {
+ return Balance{}, err
+ }
- if ba.Available.IsNegative() && balance.AccountType != constant.ExternalAccountType {
- return Amount{}, Balance{}, commons.ValidateBusinessError(constant.ErrInsufficientFunds, "ValidateFromToOperation", balance.Alias)
- }
+ result := balance
- return validate.From[ft.AccountAlias], ba, nil
- } else {
- ba, err := OperateBalances(validate.To[ft.AccountAlias], *balance)
- if err != nil {
- return Amount{}, Balance{}, err
- }
+ updated, err := applyPostingOperation(result, posting)
+ if err != nil {
+ return Balance{}, err
+ }
- return validate.To[ft.AccountAlias], ba, nil
+ if err := validatePostingResult(updated); err != nil {
+ return Balance{}, err
}
+
+ updated.Version++
+
+ return updated, nil
}
-// Deprecated: use AliasKey method from Midaz pkg instead.
-// AliasKey function to concatenate alias with balance key
-func AliasKey(alias, balanceKey string) string {
- if balanceKey == "" {
- balanceKey = "default"
+func validatePostingAgainstBalance(balance Balance, posting Posting) error {
+ if err := posting.Target.validate("posting.target"); err != nil {
+ return err
}
- return alias + "#" + balanceKey
-}
+ if balance.ID != posting.Target.BalanceID {
+ return NewDomainError(ErrorAccountIneligibility, "posting.target.balanceId", "posting does not belong to the provided balance")
+ }
-// Deprecated: use SplitAlias method from Midaz pkg instead.
-// SplitAlias function to split alias with index
-func SplitAlias(alias string) string {
- if strings.Contains(alias, "#") {
- return strings.Split(alias, "#")[1]
+ if balance.AccountID != posting.Target.AccountID {
+ return NewDomainError(ErrorAccountIneligibility, "posting.target.accountId", "posting account does not match balance account")
}
- return alias
-}
+ if balance.Asset != posting.Asset {
+ return NewDomainError(ErrorAssetCodeNotFound, "posting.asset", "posting asset does not match balance asset")
+ }
-// Deprecated: use ConcatAlias method from Midaz pkg instead.
-// ConcatAlias function to concat alias with index
-func ConcatAlias(i int, alias string) string {
- return strconv.Itoa(i) + "#" + alias
+ if !posting.Amount.IsPositive() {
+ return NewDomainError(ErrorInvalidInput, "posting.amount", "posting amount must be greater than zero")
+ }
+
+ return nil
}
-// Deprecated: use OperateBalances method from Midaz pkg instead.
-// OperateBalances Function to sum or sub two balances and Normalize the scale
-func OperateBalances(amount Amount, balance Balance) (Balance, error) {
- var (
- total decimal.Decimal
- totalOnHold decimal.Decimal
- totalVersion int64
- )
-
- total = balance.Available
- totalOnHold = balance.OnHold
-
- switch {
- case amount.Operation == constant.ONHOLD && amount.TransactionType == constant.PENDING:
- total = balance.Available.Sub(amount.Value)
- totalOnHold = balance.OnHold.Add(amount.Value)
- case amount.Operation == constant.RELEASE && amount.TransactionType == constant.CANCELED:
- totalOnHold = balance.OnHold.Sub(amount.Value)
- total = balance.Available.Add(amount.Value)
- case amount.Operation == constant.DEBIT && amount.TransactionType == constant.APPROVED:
- totalOnHold = balance.OnHold.Sub(amount.Value)
- case amount.Operation == constant.CREDIT && amount.TransactionType == constant.APPROVED:
- total = balance.Available.Add(amount.Value)
- case amount.Operation == constant.DEBIT && amount.TransactionType == constant.CREATED:
- total = balance.Available.Sub(amount.Value)
- case amount.Operation == constant.CREDIT && amount.TransactionType == constant.CREATED:
- total = balance.Available.Add(amount.Value)
+func applyPostingOperation(balance Balance, posting Posting) (Balance, error) {
+ result := balance
+
+ switch posting.Operation {
+ case OperationOnHold:
+ return applyOnHold(result, posting)
+ case OperationRelease:
+ return applyRelease(result, posting)
+ case OperationDebit:
+ return applyDebit(result, posting)
+ case OperationCredit:
+ return applyCredit(result, posting)
default:
- // For unknown operations, return the original balance without changing the version.
- return balance, nil
+ return Balance{}, NewDomainError(ErrorInvalidInput, "posting.operation", "unsupported operation")
}
+}
- totalVersion = balance.Version + 1
+func applyOnHold(balance Balance, posting Posting) (Balance, error) {
+ if posting.Status != StatusPending {
+ return Balance{}, NewDomainError(ErrorInvalidStateTransition, "posting.status", "ON_HOLD requires PENDING status")
+ }
- return Balance{
- Available: total,
- OnHold: totalOnHold,
- Version: totalVersion,
- }, nil
+ balance.Available = balance.Available.Sub(posting.Amount)
+ balance.OnHold = balance.OnHold.Add(posting.Amount)
+
+ return balance, nil
}
-// Deprecated: use DetermineOperation method from Midaz pkg instead.
-// DetermineOperation Function to determine the operation
-func DetermineOperation(isPending bool, isFrom bool, transactionType string) string {
- switch {
- case isPending && transactionType == constant.PENDING:
- switch {
- case isFrom:
- return constant.ONHOLD
- default:
- return constant.CREDIT
- }
- case isPending && isFrom && transactionType == constant.CANCELED:
- return constant.RELEASE
- case isPending && transactionType == constant.APPROVED:
- switch {
- case isFrom:
- return constant.DEBIT
- default:
- return constant.CREDIT
- }
- case !isPending:
- switch {
- case isFrom:
- return constant.DEBIT
- default:
- return constant.CREDIT
- }
- default:
- return constant.CREDIT
+func applyRelease(balance Balance, posting Posting) (Balance, error) {
+ if posting.Status != StatusCanceled {
+ return Balance{}, NewDomainError(ErrorInvalidStateTransition, "posting.status", "RELEASE requires CANCELED status")
}
-}
-// Deprecated: use CalculateTotal method from Midaz pkg instead.
-// CalculateTotal Calculate total for sources/destinations based on shares, amounts and remains
-func CalculateTotal(fromTos []FromTo, transaction Transaction, transactionType string, t chan decimal.Decimal, ft chan map[string]Amount, sd chan []string, or chan map[string]string) {
- fmto := make(map[string]Amount)
- scdt := make([]string, 0)
+ balance.OnHold = balance.OnHold.Sub(posting.Amount)
+ balance.Available = balance.Available.Add(posting.Amount)
- total := decimal.NewFromInt(0)
+ return balance, nil
+}
- remaining := Amount{
- Asset: transaction.Send.Asset,
- Value: transaction.Send.Value,
- TransactionType: transactionType,
+func applyDebit(balance Balance, posting Posting) (Balance, error) {
+ switch posting.Status {
+ case StatusApproved:
+ balance.OnHold = balance.OnHold.Sub(posting.Amount)
+ case StatusCreated:
+ balance.Available = balance.Available.Sub(posting.Amount)
+ case StatusCanceled:
+ // Pending destination cancellation: the debit reverses the original credit
+ // that was applied when the destination received funds during the PENDING phase.
+ // ResolveOperation(pending=true, isSource=false, StatusCanceled) → OperationDebit.
+ balance.Available = balance.Available.Sub(posting.Amount)
+ default:
+ return Balance{}, NewDomainError(
+ ErrorInvalidStateTransition,
+ "posting.status",
+ "DEBIT only supports CREATED, APPROVED, or CANCELED status",
+ )
}
- operationRoute := make(map[string]string)
+ return balance, nil
+}
+
+func applyCredit(balance Balance, posting Posting) (Balance, error) {
+ switch posting.Status {
+ case StatusCreated, StatusApproved, StatusPending:
+ balance.Available = balance.Available.Add(posting.Amount)
+ default:
+ return Balance{}, NewDomainError(
+ ErrorInvalidStateTransition,
+ "posting.status",
+ "CREDIT only supports CREATED, APPROVED, or PENDING status",
+ )
+ }
- for i := range fromTos {
- operationRoute[fromTos[i].AccountAlias] = fromTos[i].Route
+ return balance, nil
+}
- operation := DetermineOperation(transaction.Pending, fromTos[i].IsFrom, transactionType)
+func validatePostingResult(balance Balance) error {
+ if balance.Available.IsNegative() {
+ return NewDomainError(ErrorInsufficientFunds, "posting.amount", "operation would result in negative available balance")
+ }
- if fromTos[i].Share != nil && fromTos[i].Share.Percentage != 0 {
- oneHundred := decimal.NewFromInt(100)
+ if balance.OnHold.IsNegative() {
+ return NewDomainError(ErrorInsufficientFunds, "posting.amount", "operation would result in negative on-hold balance")
+ }
- percentage := decimal.NewFromInt(fromTos[i].Share.Percentage)
+ return nil
+}
- percentageOfPercentage := decimal.NewFromInt(fromTos[i].Share.PercentageOfPercentage)
- if percentageOfPercentage.IsZero() {
- percentageOfPercentage = oneHundred
+// ResolveOperation resolves the posting operation from pending/source/status semantics.
+func ResolveOperation(pending bool, isSource bool, status TransactionStatus) (Operation, error) {
+ if pending {
+ switch status {
+ case StatusPending:
+ if isSource {
+ return OperationOnHold, nil
}
- firstPart := percentage.Div(oneHundred)
- secondPart := percentageOfPercentage.Div(oneHundred)
- shareValue := transaction.Send.Value.Mul(firstPart).Mul(secondPart)
+ return OperationCredit, nil
+ case StatusCanceled:
+ if isSource {
+ return OperationRelease, nil
+ }
- fmto[fromTos[i].AccountAlias] = Amount{
- Asset: transaction.Send.Asset,
- Value: shareValue,
- Operation: operation,
- TransactionType: transactionType,
+ return OperationDebit, nil
+ case StatusApproved:
+ if isSource {
+ return OperationDebit, nil
}
- total = total.Add(shareValue)
- remaining.Value = remaining.Value.Sub(shareValue)
+ return OperationCredit, nil
+ default:
+ return "", NewDomainError(ErrorInvalidStateTransition, "status", "pending transactions only support PENDING, APPROVED, or CANCELED status")
+ }
+ }
+
+ switch status {
+ case StatusCreated:
+ if isSource {
+ return OperationDebit, nil
}
- if fromTos[i].Amount != nil && fromTos[i].Amount.Value.IsPositive() {
- amount := Amount{
- Asset: fromTos[i].Amount.Asset,
- Value: fromTos[i].Amount.Value,
- Operation: operation,
- TransactionType: transactionType,
- }
+ return OperationCredit, nil
+ default:
+ return "", NewDomainError(ErrorInvalidStateTransition, "status", "non-pending transactions only support CREATED status")
+ }
+}
- fmto[fromTos[i].AccountAlias] = amount
- total = total.Add(amount.Value)
+func buildPostings(asset string, total decimal.Decimal, pending bool, status TransactionStatus, allocations []Allocation, isSource bool) ([]Posting, error) {
+ postings := make([]Posting, len(allocations))
+ allocated := decimal.Zero
+ remainderIndex := -1
- remaining.Value = remaining.Value.Sub(amount.Value)
+ side := "destinations"
+ if isSource {
+ side = "sources"
+ }
+
+ for i, allocation := range allocations {
+ field := fmt.Sprintf("%s[%d]", side, i)
+
+ posting, amount, usesRemainder, err := buildPostingFromAllocation(
+ asset,
+ total,
+ pending,
+ status,
+ isSource,
+ allocation,
+ field,
+ )
+ if err != nil {
+ return nil, err
}
- if !commons.IsNilOrEmpty(&fromTos[i].Remaining) {
- total = total.Add(remaining.Value)
+ postings[i] = posting
+
+ if usesRemainder {
+ if remainderIndex >= 0 {
+ return nil, NewDomainError(ErrorInvalidInput, field+".remainder", "only one remainder allocation is allowed")
+ }
+
+ remainderIndex = i
+
+ continue
+ }
- remaining.Operation = operation
+ allocated = allocated.Add(amount)
+ }
- fmto[fromTos[i].AccountAlias] = remaining
- fromTos[i].Amount = &remaining
+ if remainderIndex >= 0 {
+ remainder, err := computeRemainderAllocation(total, allocated)
+ if err != nil {
+ return nil, err
}
- scdt = append(scdt, AliasKey(fromTos[i].SplitAlias(), fromTos[i].BalanceKey))
+ postings[remainderIndex].Amount = remainder
+ allocated = allocated.Add(remainder)
+ }
+
+ if err := validateAllocatedTotal(allocated, total); err != nil {
+ return nil, err
+ }
+
+ return postings, nil
+}
+
+func buildPostingFromAllocation(
+ asset string,
+ total decimal.Decimal,
+ pending bool,
+ status TransactionStatus,
+ isSource bool,
+ allocation Allocation,
+ field string,
+) (Posting, decimal.Decimal, bool, error) {
+ if err := allocation.Target.validate(field + ".target"); err != nil {
+ return Posting{}, decimal.Zero, false, err
+ }
+
+ if err := validateAllocationStrategy(allocation, field); err != nil {
+ return Posting{}, decimal.Zero, false, err
}
- t <- total
+ operation, err := ResolveOperation(pending, isSource, status)
+ if err != nil {
+ return Posting{}, decimal.Zero, false, err
+ }
+
+ posting := Posting{
+ Target: allocation.Target,
+ Asset: asset,
+ Operation: operation,
+ Status: status,
+ Route: allocation.Route,
+ }
+
+ amount, usesRemainder, err := resolveAllocationAmount(total, allocation, field)
+ if err != nil {
+ return Posting{}, decimal.Zero, false, err
+ }
- ft <- fmto
+ if usesRemainder {
+ return posting, decimal.Zero, true, nil
+ }
- sd <- scdt
+ posting.Amount = amount
- or <- operationRoute
+ return posting, amount, false, nil
}
-// Deprecated: use AppendIfNotExist method from Midaz pkg instead.
-// AppendIfNotExist Append if not exist
-func AppendIfNotExist(slice []string, s []string) []string {
- for _, v := range s {
- if !commons.Contains(slice, v) {
- slice = append(slice, v)
- }
+func validateAllocationStrategy(allocation Allocation, field string) error {
+ strategyCount := 0
+ if allocation.Amount != nil {
+ strategyCount++
+ }
+
+ if allocation.Share != nil {
+ strategyCount++
+ }
+
+ if allocation.Remainder {
+ strategyCount++
+ }
+
+ if strategyCount != 1 {
+ return NewDomainError(ErrorInvalidInput, field, "allocation must define exactly one strategy: amount, share, or remainder")
}
- return slice
+ return nil
}
-// Deprecated: use ValidateSendSourceAndDistribute method from Midaz pkg instead.
-// ValidateSendSourceAndDistribute Validate send and distribute totals
-func ValidateSendSourceAndDistribute(ctx context.Context, transaction Transaction, transactionType string) (*Responses, error) {
- var (
- sourcesTotal decimal.Decimal
- destinationsTotal decimal.Decimal
- )
-
- logger, tracer, _, _ := commons.NewTrackingFromContext(ctx)
-
- _, span := tracer.Start(ctx, "commons.transaction.ValidateSendSourceAndDistribute")
- defer span.End()
-
- sizeFrom := len(transaction.Send.Source.From)
- sizeTo := len(transaction.Send.Distribute.To)
-
- response := &Responses{
- Total: transaction.Send.Value,
- Asset: transaction.Send.Asset,
- From: make(map[string]Amount, sizeFrom),
- To: make(map[string]Amount, sizeTo),
- Sources: make([]string, 0, sizeFrom),
- Destinations: make([]string, 0, sizeTo),
- Aliases: make([]string, 0, sizeFrom+sizeTo),
- Pending: transaction.Pending,
- TransactionRoute: transaction.Route,
- OperationRoutesFrom: make(map[string]string, sizeFrom),
- OperationRoutesTo: make(map[string]string, sizeTo),
- }
-
- tFrom := make(chan decimal.Decimal, sizeFrom)
- ftFrom := make(chan map[string]Amount, sizeFrom)
- sdFrom := make(chan []string, sizeFrom)
- orFrom := make(chan map[string]string, sizeFrom)
-
- go CalculateTotal(transaction.Send.Source.From, transaction, transactionType, tFrom, ftFrom, sdFrom, orFrom)
-
- sourcesTotal = <-tFrom
- response.From = <-ftFrom
- response.Sources = <-sdFrom
- response.OperationRoutesFrom = <-orFrom
- response.Aliases = AppendIfNotExist(response.Aliases, response.Sources)
-
- tTo := make(chan decimal.Decimal, sizeTo)
- ftTo := make(chan map[string]Amount, sizeTo)
- sdTo := make(chan []string, sizeTo)
- orTo := make(chan map[string]string, sizeTo)
-
- go CalculateTotal(transaction.Send.Distribute.To, transaction, transactionType, tTo, ftTo, sdTo, orTo)
-
- destinationsTotal = <-tTo
- response.To = <-ftTo
- response.Destinations = <-sdTo
- response.OperationRoutesTo = <-orTo
- response.Aliases = AppendIfNotExist(response.Aliases, response.Destinations)
-
- for i, source := range response.Sources {
- if _, ok := response.To[ConcatAlias(i, source)]; ok {
- logger.Errorf("ValidateSendSourceAndDistribute: Ambiguous transaction source and destination")
-
- return nil, commons.ValidateBusinessError(constant.ErrTransactionAmbiguous, "ValidateSendSourceAndDistribute")
+func resolveAllocationAmount(total decimal.Decimal, allocation Allocation, field string) (decimal.Decimal, bool, error) {
+ if allocation.Amount != nil {
+ if !allocation.Amount.IsPositive() {
+ return decimal.Zero, false, NewDomainError(ErrorInvalidInput, field+".amount", "amount must be greater than zero")
}
+
+ return *allocation.Amount, false, nil
}
- for i, destination := range response.Destinations {
- if _, ok := response.From[ConcatAlias(i, destination)]; ok {
- logger.Errorf("ValidateSendSourceAndDistribute: Ambiguous transaction source and destination")
+ if allocation.Share != nil {
+ share := *allocation.Share
+ if !share.IsPositive() || share.GreaterThan(oneHundred) {
+ return decimal.Zero, false, NewDomainError(ErrorInvalidInput, field+".share", "share must be greater than 0 and at most 100")
+ }
- return nil, commons.ValidateBusinessError(constant.ErrTransactionAmbiguous, "ValidateSendSourceAndDistribute")
+ amount := total.Mul(share.Div(oneHundred))
+ if !amount.IsPositive() {
+ return decimal.Zero, false, NewDomainError(ErrorInvalidInput, field+".share", "share produces a non-positive amount")
}
+
+ return amount, false, nil
+ }
+
+ if allocation.Remainder {
+ return decimal.Zero, true, nil
+ }
+
+ return decimal.Zero, false, NewDomainError(ErrorInvalidInput, field, "allocation must define exactly one strategy: amount, share, or remainder")
+}
+
+func computeRemainderAllocation(total decimal.Decimal, allocated decimal.Decimal) (decimal.Decimal, error) {
+ remainder := total.Sub(allocated)
+ if !remainder.IsPositive() {
+ return decimal.Zero, NewDomainError(ErrorTransactionValueMismatch, "allocations", "remainder is zero or negative")
+ }
+
+ return remainder, nil
+}
+
+func validateAllocatedTotal(allocated decimal.Decimal, total decimal.Decimal) error {
+ if !allocated.Equal(total) {
+ return NewDomainError(
+ ErrorTransactionValueMismatch,
+ "allocations",
+ fmt.Sprintf("allocated=%s expected=%s", allocated, total),
+ )
}
- if !sourcesTotal.Equal(destinationsTotal) || !destinationsTotal.Equal(response.Total) {
- logger.Errorf("ValidateSendSourceAndDistribute: Transaction value mismatch")
+ return nil
+}
+
+func sumPostings(postings []Posting) decimal.Decimal {
+ total := decimal.Zero
- return nil, commons.ValidateBusinessError(constant.ErrTransactionValueMismatch, "ValidateSendSourceAndDistribute")
+ for _, posting := range postings {
+ total = total.Add(posting.Amount)
}
- return response, nil
+ return total
}
diff --git a/commons/transaction/validations_example_test.go b/commons/transaction/validations_example_test.go
new file mode 100644
index 00000000..a69830e8
--- /dev/null
+++ b/commons/transaction/validations_example_test.go
@@ -0,0 +1,37 @@
+//go:build unit
+
+package transaction_test
+
+import (
+ "fmt"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/transaction"
+ "github.com/shopspring/decimal"
+)
+
+func ExampleBuildIntentPlan() {
+ total := decimal.NewFromInt(100)
+
+ input := transaction.TransactionIntentInput{
+ Asset: "USD",
+ Total: total,
+ Pending: false,
+ Sources: []transaction.Allocation{{
+ Target: transaction.LedgerTarget{AccountID: "acc-src", BalanceID: "bal-src"},
+ Amount: &total,
+ }},
+ Destinations: []transaction.Allocation{{
+ Target: transaction.LedgerTarget{AccountID: "acc-dst", BalanceID: "bal-dst"},
+ Amount: &total,
+ }},
+ }
+
+ plan, err := transaction.BuildIntentPlan(input, transaction.StatusCreated)
+
+ fmt.Println(err == nil)
+ fmt.Println(plan.Sources[0].Operation, plan.Destinations[0].Operation)
+
+ // Output:
+ // true
+ // DEBIT CREDIT
+}
diff --git a/commons/transaction/validations_test.go b/commons/transaction/validations_test.go
index 78d4e23e..a22a896f 100644
--- a/commons/transaction/validations_test.go
+++ b/commons/transaction/validations_test.go
@@ -1,944 +1,1681 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package transaction
import (
- "context"
+ "encoding/json"
+ "errors"
+ "strings"
"testing"
- "github.com/LerianStudio/lib-commons/v2/commons"
- constant "github.com/LerianStudio/lib-commons/v2/commons/constants"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
- "go.opentelemetry.io/otel"
+ "github.com/stretchr/testify/require"
)
-func TestValidateBalancesRules(t *testing.T) {
- // Create a context with logger and tracer
- ctx := context.Background()
- logger := &log.GoLogger{Level: log.InfoLevel}
- ctx = commons.ContextWithLogger(ctx, logger)
- tracer := otel.Tracer("test")
- ctx = commons.ContextWithTracer(ctx, tracer)
+// ---------------------------------------------------------------------------
+// Helper functions
+// ---------------------------------------------------------------------------
+
+// decPtr returns a pointer to a decimal value parsed from a string.
+func decPtr(t *testing.T, s string) *decimal.Decimal {
+ t.Helper()
+ d, err := decimal.NewFromString(s)
+ require.NoError(t, err, "decPtr: invalid decimal string %q", s)
+ return &d
+}
+
+// intDecPtr returns a pointer to a decimal created from an int64.
+func intDecPtr(v int64) *decimal.Decimal {
+ d := decimal.NewFromInt(v)
+ return &d
+}
+
+// assertDomainError extracts a DomainError from err, verifies the error code,
+// and returns it for additional assertions.
+func assertDomainError(t *testing.T, err error, expectedCode ErrorCode) DomainError {
+ t.Helper()
+
+ require.Error(t, err)
+
+ var domainErr DomainError
+ require.True(t, errors.As(err, &domainErr), "expected DomainError, got %T: %v", err, err)
+ assert.Equal(t, expectedCode, domainErr.Code)
+
+ return domainErr
+}
+
+// simplePlan creates a valid IntentPlan with the given asset, total, and
+// single source/destination postings using the provided status.
+func simplePlan(asset string, total decimal.Decimal, status TransactionStatus) IntentPlan {
+ op := OperationDebit
+ dstOp := OperationCredit
+
+ return IntentPlan{
+ Asset: asset,
+ Total: total,
+ Sources: []Posting{{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src-bal"},
+ Asset: asset,
+ Amount: total,
+ Operation: op,
+ Status: status,
+ }},
+ Destinations: []Posting{{
+ Target: LedgerTarget{AccountID: "dst-acc", BalanceID: "dst-bal"},
+ Asset: asset,
+ Amount: total,
+ Operation: dstOp,
+ Status: status,
+ }},
+ }
+}
+
+// ---------------------------------------------------------------------------
+// DomainError type tests
+// ---------------------------------------------------------------------------
+
+func TestDomainError_ErrorString(t *testing.T) {
+ t.Parallel()
+
+ t.Run("with field", func(t *testing.T) {
+ t.Parallel()
+
+ de := DomainError{Code: ErrorInvalidInput, Field: "total", Message: "must be positive"}
+ assert.Equal(t, "1001: must be positive (total)", de.Error())
+ })
+
+ t.Run("without field", func(t *testing.T) {
+ t.Parallel()
+
+ de := DomainError{Code: ErrorInsufficientFunds, Message: "not enough funds"}
+ assert.Equal(t, "0018: not enough funds", de.Error())
+ })
+}
+
+func TestNewDomainError_Implements_error(t *testing.T) {
+ t.Parallel()
+
+ err := NewDomainError(ErrorInvalidInput, "field", "message")
+ require.Error(t, err)
+
+ var de DomainError
+ require.True(t, errors.As(err, &de))
+ assert.Equal(t, ErrorInvalidInput, de.Code)
+ assert.Equal(t, "field", de.Field)
+ assert.Equal(t, "message", de.Message)
+}
+
+// ---------------------------------------------------------------------------
+// LedgerTarget.validate
+// ---------------------------------------------------------------------------
+
+func TestLedgerTarget_Validate(t *testing.T) {
+ t.Parallel()
tests := []struct {
- name string
- transaction Transaction
- validate Responses
- balances []*Balance
- expectError bool
- errorCode string
+ name string
+ target LedgerTarget
+ expectErr bool
+ field string
}{
- {
- name: "valid balances - simple transfer",
- transaction: Transaction{
- Send: Send{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- Source: Source{
- From: []FromTo{
- {AccountAlias: "@account1"},
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {AccountAlias: "@account2"},
- },
- },
- },
- },
- validate: Responses{
- Asset: "USD",
- From: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(100), Operation: constant.DEBIT, TransactionType: constant.CREATED},
- },
- To: map[string]Amount{
- "0#@account2#default": {Value: decimal.NewFromInt(100), Operation: constant.CREDIT, TransactionType: constant.CREATED},
- },
- },
- balances: []*Balance{
- {
- ID: "123",
- Alias: "@account1",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(200),
- OnHold: decimal.NewFromInt(0),
- AllowSending: true,
- AllowReceiving: true,
- AccountType: "internal",
- },
- {
- ID: "456",
- Alias: "@account2",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(50),
- OnHold: decimal.NewFromInt(0),
- AllowSending: true,
- AllowReceiving: true,
- AccountType: "internal",
- },
- },
- expectError: false,
- },
- {
- name: "invalid - wrong number of balances",
- transaction: Transaction{},
- validate: Responses{
- From: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(100), Operation: constant.DEBIT, TransactionType: constant.CREATED},
- },
- To: map[string]Amount{
- "0#@account2#default": {Value: decimal.NewFromInt(100), Operation: constant.CREDIT, TransactionType: constant.CREATED},
- },
- },
- balances: []*Balance{}, // Empty balances
- expectError: true,
- errorCode: "0019", // ErrAccountIneligibility
- },
+ {name: "valid", target: LedgerTarget{AccountID: "a", BalanceID: "b"}, expectErr: false},
+ {name: "empty accountId", target: LedgerTarget{AccountID: "", BalanceID: "b"}, expectErr: true, field: "t.accountId"},
+ {name: "whitespace accountId", target: LedgerTarget{AccountID: " ", BalanceID: "b"}, expectErr: true, field: "t.accountId"},
+ {name: "empty balanceId", target: LedgerTarget{AccountID: "a", BalanceID: ""}, expectErr: true, field: "t.balanceId"},
+ {name: "whitespace balanceId", target: LedgerTarget{AccountID: "a", BalanceID: " "}, expectErr: true, field: "t.balanceId"},
+ {name: "both empty", target: LedgerTarget{}, expectErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- err := ValidateBalancesRules(ctx, tt.transaction, tt.validate, tt.balances)
-
- if tt.expectError {
- assert.Error(t, err)
- if tt.errorCode != "" {
- // Check if the error is a Response type and contains the error code
- if respErr, ok := err.(commons.Response); ok {
- assert.Equal(t, tt.errorCode, respErr.Code)
- } else {
- assert.Contains(t, err.Error(), tt.errorCode)
- }
- }
+ t.Parallel()
+
+ err := tt.target.validate("t")
+ if tt.expectErr {
+ require.Error(t, err)
+ assertDomainError(t, err, ErrorInvalidInput)
} else {
- assert.NoError(t, err)
+ require.NoError(t, err)
}
})
}
}
-func TestValidateFromBalances(t *testing.T) {
- tests := []struct {
- name string
- balance *Balance
- from map[string]Amount
- asset string
- expectError bool
- errorCode string
- }{
- {
- name: "valid from balance",
- balance: &Balance{
- ID: "123",
- Alias: "@account1",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(100),
- AllowSending: true,
- AccountType: "internal",
- },
- from: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(50)},
- },
- asset: "USD",
- expectError: false,
+// ---------------------------------------------------------------------------
+// BuildIntentPlan -- Input validation
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan(t *testing.T) {
+ amount30 := decimal.NewFromInt(30)
+ share50 := decimal.NewFromInt(50)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Pending: false,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "acc-1", BalanceID: "bal-1"}, Amount: &amount30},
+ {Target: LedgerTarget{AccountID: "acc-2", BalanceID: "bal-2"}, Remainder: true},
},
- {
- name: "invalid - wrong asset code",
- balance: &Balance{
- ID: "123",
- Alias: "@account1",
- Key: "default",
- AssetCode: "EUR",
- Available: decimal.NewFromInt(100),
- AllowSending: true,
- AccountType: "internal",
- },
- from: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(50)},
- },
- asset: "USD",
- expectError: true,
- errorCode: "0034", // ErrAssetCodeNotFound
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "acc-3", BalanceID: "bal-3"}, Share: &share50},
+ {Target: LedgerTarget{AccountID: "acc-4", BalanceID: "bal-4"}, Share: &share50},
},
- {
- name: "invalid - sending not allowed",
- balance: &Balance{
- ID: "123",
- Alias: "@account1",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(100),
- AllowSending: false,
- AccountType: "internal",
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ assert.NoError(t, err)
+ assert.Equal(t, decimal.NewFromInt(100), plan.Total)
+ assert.Equal(t, "USD", plan.Asset)
+ assert.Len(t, plan.Sources, 2)
+ assert.Len(t, plan.Destinations, 2)
+ assert.Equal(t, decimal.NewFromInt(30), plan.Sources[0].Amount)
+ assert.Equal(t, decimal.NewFromInt(70), plan.Sources[1].Amount)
+ assert.Equal(t, OperationDebit, plan.Sources[0].Operation)
+ assert.Equal(t, OperationCredit, plan.Destinations[0].Operation)
+}
+
+func TestBuildIntentPlan_EmptyAsset(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ for _, asset := range []string{"", " ", " "} {
+ input := TransactionIntentInput{
+ Asset: asset,
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
},
- from: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(50)},
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
},
- asset: "USD",
- expectError: true,
- errorCode: "0024", // ErrAccountStatusTransactionRestriction
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Equal(t, "asset", de.Field)
+ }
+}
+
+func TestBuildIntentPlan_ZeroTotal(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(0)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.Zero,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
},
- {
- name: "valid - external account with zero balance",
- balance: &Balance{
- ID: "123",
- Alias: "@external",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(0),
- AllowSending: true,
- AccountType: constant.ExternalAccountType,
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Equal(t, "total", de.Field)
+}
+
+func TestBuildIntentPlan_NegativeTotal(t *testing.T) {
+ t.Parallel()
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(-50),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Remainder: true},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Remainder: true},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Equal(t, "total", de.Field)
+}
+
+func TestBuildIntentPlan_EmptySources(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{},
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Equal(t, "sources", de.Field)
+}
+
+func TestBuildIntentPlan_EmptyDestinations(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
+ },
+ Destinations: []Allocation{},
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Equal(t, "destinations", de.Field)
+}
+
+func TestBuildIntentPlan_NilSources(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: nil,
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Equal(t, "sources", de.Field)
+}
+
+// ---------------------------------------------------------------------------
+// Self-referencing (source == destination balance)
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_RejectsAmbiguousSourceDestination(t *testing.T) {
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Pending: false,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "acc-1", BalanceID: "shared"}, Amount: &amount},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "acc-2", BalanceID: "shared"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assert.Error(t, err)
+ var domainErr DomainError
+ assert.ErrorAs(t, err, &domainErr)
+ assert.Equal(t, ErrorTransactionAmbiguous, domainErr.Code)
+}
+
+func TestBuildIntentPlan_SelfReferencing_DifferentAccounts(t *testing.T) {
+ t.Parallel()
+
+ // Even if account IDs differ, same balance ID triggers ambiguity.
+ amount := decimal.NewFromInt(50)
+
+ input := TransactionIntentInput{
+ Asset: "BRL",
+ Total: decimal.NewFromInt(50),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "account-A", BalanceID: "shared-balance"}, Amount: &amount},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "account-B", BalanceID: "shared-balance"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorTransactionAmbiguous)
+ assert.Equal(t, "destinations", de.Field)
+}
+
+// ---------------------------------------------------------------------------
+// Value mismatch
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_RejectsValueMismatch(t *testing.T) {
+ amount90 := decimal.NewFromInt(90)
+ amount100 := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Pending: false,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "acc-1", BalanceID: "bal-1"}, Amount: &amount90},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "acc-2", BalanceID: "bal-2"}, Amount: &amount100},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assert.Error(t, err)
+ var domainErr DomainError
+ assert.ErrorAs(t, err, &domainErr)
+ assert.Equal(t, ErrorTransactionValueMismatch, domainErr.Code)
+}
+
+func TestBuildIntentPlan_SourceTotalDoesNotMatchTransaction(t *testing.T) {
+ t.Parallel()
+
+ amount60 := decimal.NewFromInt(60)
+ amount100 := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b1"}, Amount: &amount60},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d1"}, Amount: &amount100},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assertDomainError(t, err, ErrorTransactionValueMismatch)
+}
+
+// ---------------------------------------------------------------------------
+// Allocation strategy validation
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_NoStrategy(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ // No Amount, Share, or Remainder
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Message, "exactly one strategy")
+}
+
+func TestBuildIntentPlan_MultipleStrategies(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(50)
+ share := decimal.NewFromInt(50)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {
+ Target: LedgerTarget{AccountID: "a", BalanceID: "b"},
+ Amount: &amount,
+ Share: &share,
+ Remainder: false,
},
- from: map[string]Amount{
- "0#@external#default": {Value: decimal.NewFromInt(50)},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Message, "exactly one strategy")
+}
+
+func TestBuildIntentPlan_AmountAndRemainder(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(50)
+
+ input := TransactionIntentInput{
+ Asset: "EUR",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {
+ Target: LedgerTarget{AccountID: "a", BalanceID: "b"},
+ Amount: &amount,
+ Remainder: true,
},
- asset: "USD",
- expectError: false,
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Remainder: true},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assertDomainError(t, err, ErrorInvalidInput)
+}
+
+func TestBuildIntentPlan_DuplicateRemainder(t *testing.T) {
+ t.Parallel()
+
+ input := TransactionIntentInput{
+ Asset: "BRL",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"}, Remainder: true},
+ {Target: LedgerTarget{AccountID: "a2", BalanceID: "b2"}, Remainder: true},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Remainder: true},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Message, "only one remainder")
+}
+
+// ---------------------------------------------------------------------------
+// Zero and negative amount allocations
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_ZeroAmountAllocation(t *testing.T) {
+ t.Parallel()
+
+ zero := decimal.Zero
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &zero},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Message, "amount must be greater than zero")
+}
+
+func TestBuildIntentPlan_NegativeAmountAllocation(t *testing.T) {
+ t.Parallel()
+
+ neg := decimal.NewFromInt(-10)
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &neg},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Field, "amount")
+}
+
+// ---------------------------------------------------------------------------
+// Share validation
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_ShareZero(t *testing.T) {
+ t.Parallel()
+
+ zero := decimal.Zero
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Share: &zero},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Field, "share")
+}
+
+func TestBuildIntentPlan_ShareNegative(t *testing.T) {
+ t.Parallel()
+
+ neg := decimal.NewFromInt(-10)
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Share: &neg},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assertDomainError(t, err, ErrorInvalidInput)
+}
+
+func TestBuildIntentPlan_ShareOver100(t *testing.T) {
+ t.Parallel()
+
+ over := decimal.NewFromInt(101)
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Share: &over},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assertDomainError(t, err, ErrorInvalidInput)
+}
+
+func TestBuildIntentPlan_ShareExactly100(t *testing.T) {
+ t.Parallel()
+
+ share100 := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(500),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Share: &share100},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Share: &share100},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.True(t, plan.Sources[0].Amount.Equal(decimal.NewFromInt(500)))
+ assert.True(t, plan.Destinations[0].Amount.Equal(decimal.NewFromInt(500)))
+}
+
+// ---------------------------------------------------------------------------
+// Remainder edge cases
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_RemainderIsEntireAmount(t *testing.T) {
+ t.Parallel()
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(250),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Remainder: true},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Remainder: true},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.True(t, plan.Sources[0].Amount.Equal(decimal.NewFromInt(250)))
+ assert.True(t, plan.Destinations[0].Amount.Equal(decimal.NewFromInt(250)))
+}
+
+func TestBuildIntentPlan_RemainderBecomeZeroOrNegative(t *testing.T) {
+ t.Parallel()
+
+ // Allocating 100 via amount and having remainder when total is 100 leaves zero remainder.
+ amount := decimal.NewFromInt(100)
+ dstAmt := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"}, Amount: &amount},
+ {Target: LedgerTarget{AccountID: "a2", BalanceID: "b2"}, Remainder: true},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &dstAmt},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assertDomainError(t, err, ErrorTransactionValueMismatch)
+}
+
+func TestBuildIntentPlan_RemainderNegative_OverAllocated(t *testing.T) {
+ t.Parallel()
+
+ // Allocating more than total leaves a negative remainder.
+ amount := decimal.NewFromInt(120)
+ dstAmt := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"}, Amount: &amount},
+ {Target: LedgerTarget{AccountID: "a2", BalanceID: "b2"}, Remainder: true},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &dstAmt},
+ },
+ }
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assertDomainError(t, err, ErrorTransactionValueMismatch)
+}
+
+// ---------------------------------------------------------------------------
+// Decimal precision edge cases
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_HighPrecisionDecimals(t *testing.T) {
+ t.Parallel()
+
+ // 0.001 precision
+ amt := decPtr(t, "0.001")
+
+ input := TransactionIntentInput{
+ Asset: "BTC",
+ Total: *decPtr(t, "0.001"),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: amt},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: amt},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.True(t, plan.Total.Equal(*decPtr(t, "0.001")))
+}
+
+func TestBuildIntentPlan_VeryLargeAmount(t *testing.T) {
+ t.Parallel()
+
+ amt := decPtr(t, "999999999999.99")
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: *decPtr(t, "999999999999.99"),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: amt},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: amt},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.True(t, plan.Sources[0].Amount.Equal(*decPtr(t, "999999999999.99")))
+}
+
+func TestBuildIntentPlan_ManyDecimalPlaces(t *testing.T) {
+ t.Parallel()
+
+ // 18 decimal places - crypto-level precision
+ amt := decPtr(t, "0.000000000000000001")
+
+ input := TransactionIntentInput{
+ Asset: "ETH",
+ Total: *decPtr(t, "0.000000000000000001"),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: amt},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: amt},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.True(t, plan.Sources[0].Amount.Equal(*decPtr(t, "0.000000000000000001")))
+}
+
+func TestBuildIntentPlan_ShareProducesDecimalAmount(t *testing.T) {
+ t.Parallel()
+
+ // 33.33% of 100 = 33.33; remainder picks up the rest.
+ share := *decPtr(t, "33.33")
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"}, Share: &share},
+ {Target: LedgerTarget{AccountID: "a2", BalanceID: "b2"}, Share: &share},
+ {Target: LedgerTarget{AccountID: "a3", BalanceID: "b3"}, Remainder: true},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Remainder: true},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+
+ // Verify total coverage
+ srcTotal := decimal.Zero
+ for _, s := range plan.Sources {
+ srcTotal = srcTotal.Add(s.Amount)
+ }
+
+ assert.True(t, srcTotal.Equal(decimal.NewFromInt(100)),
+ "source total should be exactly 100, got %s", srcTotal)
+}
+
+// ---------------------------------------------------------------------------
+// Single source to multiple destinations
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_SingleSourceMultipleDestinations(t *testing.T) {
+ t.Parallel()
+
+ total := decimal.NewFromInt(300)
+ srcAmt := decimal.NewFromInt(300)
+ dstAmt := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: total,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &srcAmt},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c1", BalanceID: "d1"}, Amount: &dstAmt},
+ {Target: LedgerTarget{AccountID: "c2", BalanceID: "d2"}, Amount: &dstAmt},
+ {Target: LedgerTarget{AccountID: "c3", BalanceID: "d3"}, Amount: &dstAmt},
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := validateFromBalances(tt.balance, tt.from, tt.asset, false)
-
- if tt.expectError {
- assert.Error(t, err)
- if tt.errorCode != "" {
- // Check if the error is a Response type and contains the error code
- if respErr, ok := err.(commons.Response); ok {
- assert.Equal(t, tt.errorCode, respErr.Code)
- } else {
- assert.Contains(t, err.Error(), tt.errorCode)
- }
- }
- } else {
- assert.NoError(t, err)
- }
- })
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.Len(t, plan.Sources, 1)
+ assert.Len(t, plan.Destinations, 3)
+
+ for _, dst := range plan.Destinations {
+ assert.True(t, dst.Amount.Equal(decimal.NewFromInt(100)))
+ assert.Equal(t, OperationCredit, dst.Operation)
}
}
-func TestValidateToBalances(t *testing.T) {
- tests := []struct {
- name string
- balance *Balance
- to map[string]Amount
- asset string
- expectError bool
- errorCode string
- }{
- {
- name: "valid to balance",
- balance: &Balance{
- ID: "123",
- Alias: "@account1",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(100),
- AllowReceiving: true,
- AccountType: "internal",
- },
- to: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(50)},
- },
- asset: "USD",
- expectError: false,
- },
- {
- name: "invalid - wrong asset code",
- balance: &Balance{
- ID: "123",
- Alias: "@account1",
- Key: "default",
- AssetCode: "EUR",
- Available: decimal.NewFromInt(100),
- AllowReceiving: true,
- AccountType: "internal",
- },
- to: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(50)},
- },
- asset: "USD",
- expectError: true,
- errorCode: "0034", // ErrAssetCodeNotFound
- },
- {
- name: "invalid - receiving not allowed",
- balance: &Balance{
- ID: "123",
- Alias: "@account1",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(100),
- AllowReceiving: false,
- AccountType: "internal",
- },
- to: map[string]Amount{
- "0#@account1#default": {Value: decimal.NewFromInt(50)},
- },
- asset: "USD",
- expectError: true,
- errorCode: "0024", // ErrAccountStatusTransactionRestriction
+// ---------------------------------------------------------------------------
+// Multiple sources to single destination
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_MultipleSourcesSingleDestination(t *testing.T) {
+ t.Parallel()
+
+ total := decimal.NewFromInt(300)
+ srcAmt := decimal.NewFromInt(100)
+ dstAmt := decimal.NewFromInt(300)
+
+ input := TransactionIntentInput{
+ Asset: "BRL",
+ Total: total,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a1", BalanceID: "b1"}, Amount: &srcAmt},
+ {Target: LedgerTarget{AccountID: "a2", BalanceID: "b2"}, Amount: &srcAmt},
+ {Target: LedgerTarget{AccountID: "a3", BalanceID: "b3"}, Amount: &srcAmt},
},
- {
- name: "invalid - external account with positive balance",
- balance: &Balance{
- ID: "123",
- Alias: "@external",
- Key: "default",
- AssetCode: "USD",
- Available: decimal.NewFromInt(100),
- AllowReceiving: true,
- AccountType: constant.ExternalAccountType,
- },
- to: map[string]Amount{
- "0#@external#default": {Value: decimal.NewFromInt(50)},
- },
- asset: "USD",
- expectError: true,
- errorCode: "0018", // ErrInsufficientFunds
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &dstAmt},
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- err := validateToBalances(tt.balance, tt.to, tt.asset)
-
- if tt.expectError {
- assert.Error(t, err)
- if tt.errorCode != "" {
- // Check if the error is a Response type and contains the error code
- if respErr, ok := err.(commons.Response); ok {
- assert.Equal(t, tt.errorCode, respErr.Code)
- } else {
- assert.Contains(t, err.Error(), tt.errorCode)
- }
- }
- } else {
- assert.NoError(t, err)
- }
- })
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.Len(t, plan.Sources, 3)
+ assert.Len(t, plan.Destinations, 1)
+
+ for _, src := range plan.Sources {
+ assert.True(t, src.Amount.Equal(decimal.NewFromInt(100)))
+ assert.Equal(t, OperationDebit, src.Operation)
}
}
-func TestOperateBalances(t *testing.T) {
- tests := []struct {
- name string
- amount Amount
- balance Balance
- operation string
- expected Balance
- expectError bool
- }{
- {
- name: "debit operation",
- amount: Amount{
- Value: decimal.NewFromInt(50),
- Operation: constant.DEBIT,
- TransactionType: constant.CREATED,
- },
- balance: Balance{
- Available: decimal.NewFromInt(100),
- OnHold: decimal.NewFromInt(10),
- },
- expected: Balance{
- Available: decimal.NewFromInt(50), // 100 - 50 = 50
- OnHold: decimal.NewFromInt(10),
- },
- expectError: false,
+// ---------------------------------------------------------------------------
+// Allocation target validation
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_SourceMissingAccountID(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "", BalanceID: "b"}, Amount: &amount},
},
- {
- name: "credit operation",
- amount: Amount{
- Value: decimal.NewFromInt(50),
- Operation: constant.CREDIT,
- TransactionType: constant.CREATED,
- },
- balance: Balance{
- Available: decimal.NewFromInt(100),
- OnHold: decimal.NewFromInt(10),
- },
- expected: Balance{
- Available: decimal.NewFromInt(150), // 100 + 50 = 150
- OnHold: decimal.NewFromInt(10),
- },
- expectError: false,
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := OperateBalances(tt.amount, tt.balance)
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Field, "accountId")
+}
- if tt.expectError {
- assert.Error(t, err)
- } else {
- assert.NoError(t, err)
- assert.Equal(t, tt.expected.Available.String(), result.Available.String())
- assert.Equal(t, tt.expected.OnHold.String(), result.OnHold.String())
- }
- })
+func TestBuildIntentPlan_DestinationMissingBalanceID(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: ""}, Amount: &amount},
+ },
}
+
+ _, err := BuildIntentPlan(input, StatusCreated)
+ de := assertDomainError(t, err, ErrorInvalidInput)
+ assert.Contains(t, de.Field, "balanceId")
}
-func TestAliasKey(t *testing.T) {
- tests := []struct {
- name string
- alias string
- balanceKey string
- want string
- }{
- {
- name: "alias with balance key",
- alias: "@person1",
- balanceKey: "savings",
- want: "@person1#savings",
- },
- {
- name: "alias with empty balance key defaults to 'default'",
- alias: "@person1",
- balanceKey: "",
- want: "@person1#default",
+// ---------------------------------------------------------------------------
+// Valid minimum and complex plans
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_MinimumValidTransaction(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(1)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(1),
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
},
- {
- name: "alias with special characters and balance key",
- alias: "@external/BRL",
- balanceKey: "checking",
- want: "@external/BRL#checking",
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
},
- {
- name: "empty alias with balance key",
- alias: "",
- balanceKey: "current",
- want: "#current",
+ }
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.Equal(t, "USD", plan.Asset)
+ assert.True(t, plan.Total.Equal(decimal.NewFromInt(1)))
+ assert.Len(t, plan.Sources, 1)
+ assert.Len(t, plan.Destinations, 1)
+ assert.False(t, plan.Pending)
+}
+
+func TestBuildIntentPlan_ComplexMultiPartyTransaction(t *testing.T) {
+ t.Parallel()
+
+ total := decimal.NewFromInt(1000)
+ share60 := decimal.NewFromInt(60)
+ share40 := decimal.NewFromInt(40)
+ amount200 := decimal.NewFromInt(200)
+ share30 := decimal.NewFromInt(30)
+
+ input := TransactionIntentInput{
+ Asset: "BRL",
+ Total: total,
+ Pending: false,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "src1", BalanceID: "bal-src1"}, Share: &share60},
+ {Target: LedgerTarget{AccountID: "src2", BalanceID: "bal-src2"}, Share: &share40},
},
- {
- name: "empty alias with empty balance key",
- alias: "",
- balanceKey: "",
- want: "#default",
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "dst1", BalanceID: "bal-dst1"}, Amount: &amount200},
+ {Target: LedgerTarget{AccountID: "dst2", BalanceID: "bal-dst2"}, Share: &share30},
+ {Target: LedgerTarget{AccountID: "dst3", BalanceID: "bal-dst3"}, Remainder: true},
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := AliasKey(tt.alias, tt.balanceKey)
- assert.Equal(t, tt.want, got)
- })
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+
+ // Verify source amounts: 60% of 1000 = 600, 40% of 1000 = 400
+ assert.True(t, plan.Sources[0].Amount.Equal(decimal.NewFromInt(600)))
+ assert.True(t, plan.Sources[1].Amount.Equal(decimal.NewFromInt(400)))
+
+ // Verify destination amounts: 200, 30% of 1000=300, remainder=500
+ assert.True(t, plan.Destinations[0].Amount.Equal(decimal.NewFromInt(200)))
+ assert.True(t, plan.Destinations[1].Amount.Equal(decimal.NewFromInt(300)))
+ assert.True(t, plan.Destinations[2].Amount.Equal(decimal.NewFromInt(500)))
+
+ // All operations
+ for _, s := range plan.Sources {
+ assert.Equal(t, OperationDebit, s.Operation)
+ assert.Equal(t, StatusCreated, s.Status)
+ assert.Equal(t, "BRL", s.Asset)
+ }
+
+ for _, d := range plan.Destinations {
+ assert.Equal(t, OperationCredit, d.Operation)
+ assert.Equal(t, StatusCreated, d.Status)
+ assert.Equal(t, "BRL", d.Asset)
}
}
-func TestSplitAlias(t *testing.T) {
- tests := []struct {
- name string
- alias string
- want string
- }{
- {
- name: "alias without index",
- alias: "@person1",
- want: "@person1",
- },
- {
- name: "alias with index",
- alias: "1#@person1",
- want: "@person1",
+// ---------------------------------------------------------------------------
+// Pending transaction plan
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_PendingTransaction(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(75)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(75),
+ Pending: true,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
},
- {
- name: "alias with zero index",
- alias: "0#@person1",
- want: "@person1",
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := SplitAlias(tt.alias)
- assert.Equal(t, tt.want, got)
- })
+ plan, err := BuildIntentPlan(input, StatusPending)
+ require.NoError(t, err)
+ assert.True(t, plan.Pending)
+
+ // Pending source = ON_HOLD, pending destination = CREDIT
+ assert.Equal(t, OperationOnHold, plan.Sources[0].Operation)
+ assert.Equal(t, OperationCredit, plan.Destinations[0].Operation)
+}
+
+// ---------------------------------------------------------------------------
+// Route propagation
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_RoutePropagation(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Allocation{
+ {
+ Target: LedgerTarget{AccountID: "a", BalanceID: "b"},
+ Amount: &amount,
+ Route: "wire-transfer",
+ },
+ },
+ Destinations: []Allocation{
+ {
+ Target: LedgerTarget{AccountID: "c", BalanceID: "d"},
+ Amount: &amount,
+ Route: "ach",
+ },
+ },
}
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.Equal(t, "wire-transfer", plan.Sources[0].Route)
+ assert.Equal(t, "ach", plan.Destinations[0].Route)
}
-func TestConcatAlias(t *testing.T) {
- tests := []struct {
- name string
- index int
- alias string
- want string
- }{
- {
- name: "concat with positive index",
- index: 1,
- alias: "@person1",
- want: "1#@person1",
+// ---------------------------------------------------------------------------
+// Invalid status for non-pending
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_InvalidStatusForNonPending(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Pending: false,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
},
- {
- name: "concat with zero index",
- index: 0,
- alias: "@person2",
- want: "0#@person2",
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
},
- {
- name: "concat with large index",
- index: 999,
- alias: "@person3",
- want: "999#@person3",
+ }
+
+ // Non-pending only supports StatusCreated.
+ _, err := BuildIntentPlan(input, StatusApproved)
+ assertDomainError(t, err, ErrorInvalidStateTransition)
+}
+
+func TestBuildIntentPlan_InvalidStatusForPending(t *testing.T) {
+ t.Parallel()
+
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Pending: true,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "a", BalanceID: "b"}, Amount: &amount},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "c", BalanceID: "d"}, Amount: &amount},
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := ConcatAlias(tt.index, tt.alias)
- assert.Equal(t, tt.want, got)
- })
+ // Pending only supports PENDING, APPROVED, or CANCELED.
+ _, err := BuildIntentPlan(input, StatusCreated)
+ assertDomainError(t, err, ErrorInvalidStateTransition)
+}
+
+// ---------------------------------------------------------------------------
+// Stress test: many allocations
+// ---------------------------------------------------------------------------
+
+func TestBuildIntentPlan_ManyAllocations(t *testing.T) {
+ t.Parallel()
+
+ count := 50
+ share := decimal.NewFromInt(2) // 2% each = 100%
+ total := decimal.NewFromInt(10000)
+
+ sources := make([]Allocation, count)
+ dests := make([]Allocation, count)
+
+ for i := 0; i < count; i++ {
+ s := share
+ sources[i] = Allocation{
+ Target: LedgerTarget{
+ AccountID: "src-acc-" + strings.Repeat("x", 3),
+ BalanceID: "src-bal-" + string(rune('A'+i%26)) + string(rune('0'+i/26)),
+ },
+ Share: &s,
+ }
+
+ d := share
+ dests[i] = Allocation{
+ Target: LedgerTarget{
+ AccountID: "dst-acc-" + strings.Repeat("y", 3),
+ BalanceID: "dst-bal-" + string(rune('A'+i%26)) + string(rune('0'+i/26)),
+ },
+ Share: &d,
+ }
+ }
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: total,
+ Sources: sources,
+ Destinations: dests,
}
+
+ plan, err := BuildIntentPlan(input, StatusCreated)
+ require.NoError(t, err)
+ assert.Len(t, plan.Sources, count)
+ assert.Len(t, plan.Destinations, count)
+
+ // Verify total sums
+ srcSum := decimal.Zero
+ for _, s := range plan.Sources {
+ srcSum = srcSum.Add(s.Amount)
+ }
+
+ assert.True(t, srcSum.Equal(total), "expected source sum %s, got %s", total, srcSum)
}
-func TestAppendIfNotExist(t *testing.T) {
- tests := []struct {
- name string
- slice []string
- s []string
- want []string
- }{
- {
- name: "append new elements",
- slice: []string{"a", "b"},
- s: []string{"c", "d"},
- want: []string{"a", "b", "c", "d"},
- },
- {
- name: "skip existing elements",
- slice: []string{"a", "b"},
- s: []string{"b", "c"},
- want: []string{"a", "b", "c"},
+// ---------------------------------------------------------------------------
+// ValidateBalanceEligibility
+// ---------------------------------------------------------------------------
+
+func TestValidateBalanceEligibility(t *testing.T) {
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: amount,
+ Pending: true,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "source-account", BalanceID: "source-balance"}, Amount: &amount},
},
- {
- name: "all elements exist",
- slice: []string{"a", "b", "c"},
- s: []string{"a", "b"},
- want: []string{"a", "b", "c"},
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "destination-account", BalanceID: "destination-balance"}, Amount: &amount},
},
- {
- name: "empty initial slice",
- slice: []string{},
- s: []string{"a", "b"},
- want: []string{"a", "b"},
+ }
+
+ plan, err := BuildIntentPlan(input, StatusPending)
+ assert.NoError(t, err)
+
+ balances := map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(300),
+ OnHold: decimal.NewFromInt(0),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
- {
- name: "empty append slice",
- slice: []string{"a", "b"},
- s: []string{},
- want: []string{"a", "b"},
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(0),
+ OnHold: decimal.NewFromInt(0),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeExternal,
},
}
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := AppendIfNotExist(tt.slice, tt.s)
- assert.Equal(t, tt.want, got)
- })
- }
+ err = ValidateBalanceEligibility(plan, balances)
+ assert.NoError(t, err)
+}
+
+func TestValidateBalanceEligibility_EmptyBalanceCatalog(t *testing.T) {
+ t.Parallel()
+
+ plan := simplePlan("USD", decimal.NewFromInt(100), StatusCreated)
+ err := ValidateBalanceEligibility(plan, map[string]Balance{})
+ de := assertDomainError(t, err, ErrorAccountIneligibility)
+ assert.Equal(t, "balances", de.Field)
}
-func TestValidateSendSourceAndDistribute(t *testing.T) {
+func TestValidateBalanceEligibility_NilBalanceCatalog(t *testing.T) {
+ t.Parallel()
+
+ plan := simplePlan("USD", decimal.NewFromInt(100), StatusCreated)
+ err := ValidateBalanceEligibility(plan, nil)
+ de := assertDomainError(t, err, ErrorAccountIneligibility)
+ assert.Equal(t, "balances", de.Field)
+}
+
+func TestValidateBalanceEligibility_Errors(t *testing.T) {
+ amount := decimal.NewFromInt(100)
+
+ input := TransactionIntentInput{
+ Asset: "USD",
+ Total: amount,
+ Pending: true,
+ Sources: []Allocation{
+ {Target: LedgerTarget{AccountID: "source-account", BalanceID: "source-balance"}, Amount: &amount},
+ },
+ Destinations: []Allocation{
+ {Target: LedgerTarget{AccountID: "destination-account", BalanceID: "destination-balance"}, Amount: &amount},
+ },
+ }
+
+ plan, err := BuildIntentPlan(input, StatusPending)
+ assert.NoError(t, err)
+
tests := []struct {
- name string
- transaction Transaction
- want *Responses
- expectError bool
- errorCode string
+ name string
+ balances map[string]Balance
+ errorCode ErrorCode
+ field string
}{
{
- name: "valid - simple source and distribute",
- transaction: Transaction{
- Send: Send{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- Source: Source{
- From: []FromTo{
- {
- AccountAlias: "@account1",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- },
- },
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {
- AccountAlias: "@account2",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- },
- },
- },
- },
+ name: "missing source balance",
+ balances: map[string]Balance{
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ AllowReceiving: true,
},
},
- expectError: false, // Now expects success after fixing CalculateTotal
+ errorCode: ErrorAccountIneligibility,
+ field: "sources",
},
{
- name: "valid - multiple sources and distributes",
- transaction: Transaction{
- Send: Send{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- Source: Source{
- From: []FromTo{
- {
- AccountAlias: "@account1",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(50),
- },
- },
- {
- AccountAlias: "@account2",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(50),
- },
- },
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {
- AccountAlias: "@account3",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(60),
- },
- },
- {
- AccountAlias: "@account4",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(40),
- },
- },
- },
- },
+ name: "source asset mismatch",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "EUR",
+ Available: decimal.NewFromInt(300),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
+ },
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
},
- expectError: false, // Now expects success after fixing CalculateTotal
+ errorCode: ErrorAssetCodeNotFound,
+ field: "sources",
},
{
- name: "valid transaction with shares",
- transaction: Transaction{
- Send: Send{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- Source: Source{
- From: []FromTo{
- {
- AccountAlias: "@account1",
- Share: &Share{
- Percentage: 60,
- },
- },
- {
- AccountAlias: "@account2",
- Share: &Share{
- Percentage: 40,
- },
- },
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {
- AccountAlias: "@account3",
- Share: &Share{
- Percentage: 100,
- },
- },
- },
- },
+ name: "source cannot send",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(300),
+ AllowSending: false,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
+ },
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
},
- want: &Responses{
- Asset: "USD",
- From: map[string]Amount{
- "@account1": {Value: decimal.NewFromInt(60)},
- "@account2": {Value: decimal.NewFromInt(40)},
+ errorCode: ErrorAccountStatusTransactionRestriction,
+ field: "sources",
+ },
+ {
+ name: "pending source cannot be external",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(300),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeExternal,
},
- To: map[string]Amount{
- "@account3": {Value: decimal.NewFromInt(100)},
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
},
- expectError: false,
+ errorCode: ErrorOnHoldExternalAccount,
+ field: "sources",
},
{
- name: "valid transaction with remains",
- transaction: Transaction{
- Send: Send{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- Source: Source{
- From: []FromTo{
- {
- AccountAlias: "@account1",
- Share: &Share{
- Percentage: 50,
- },
- IsFrom: true,
- },
- {
- AccountAlias: "@account2",
- Remaining: "remaining",
- IsFrom: true,
- },
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {
- AccountAlias: "@account3",
- Remaining: "remaining",
- },
- },
- },
+ name: "missing destination balance",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(300),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
},
- want: &Responses{
- Asset: "USD",
- From: map[string]Amount{
- "@account1": {Value: decimal.NewFromInt(50)},
- "@account2": {Value: decimal.NewFromInt(50)},
+ errorCode: ErrorAccountIneligibility,
+ field: "destinations",
+ },
+ {
+ name: "destination asset mismatch",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(300),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
- To: map[string]Amount{
- "@account3": {Value: decimal.NewFromInt(100)},
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "GBP",
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
},
- expectError: false,
+ errorCode: ErrorAssetCodeNotFound,
+ field: "destinations",
},
{
- name: "invalid - total mismatch",
- transaction: Transaction{
- Send: Send{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- Source: Source{
- From: []FromTo{
- {
- AccountAlias: "@account1",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(60),
- },
- },
- {
- AccountAlias: "@account2",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(30), // Total is 90, not 100
- },
- },
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {
- AccountAlias: "@account3",
- Amount: &Amount{
- Asset: "USD",
- Value: decimal.NewFromInt(100),
- },
- },
- },
- },
+ name: "destination cannot receive",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(300),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
+ },
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ AllowSending: true,
+ AllowReceiving: false,
+ AccountType: AccountTypeInternal,
},
},
- expectError: true,
- errorCode: "0073", // ErrTransactionValueMismatch
+ errorCode: ErrorAccountStatusTransactionRestriction,
+ field: "destinations",
},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- got, err := ValidateSendSourceAndDistribute(ctx, tt.transaction, constant.CREATED)
-
- if tt.expectError {
- assert.Error(t, err)
- if tt.errorCode != "" {
- // Check if the error is a Response type and contains the error code
- if respErr, ok := err.(commons.Response); ok {
- assert.Equal(t, tt.errorCode, respErr.Code)
- } else {
- assert.Contains(t, err.Error(), tt.errorCode)
- }
- }
- } else {
- assert.NoError(t, err)
- assert.NotNil(t, got)
- if tt.want != nil && got != nil {
- assert.Equal(t, tt.want.Asset, got.Asset)
- assert.Equal(t, len(tt.want.From), len(got.From))
- assert.Equal(t, len(tt.want.To), len(got.To))
- }
- }
- })
- }
-}
-
-func TestValidateTransactionWithPercentageAndRemaining(t *testing.T) {
- tests := []struct {
- name string
- transaction Transaction
- expectError bool
- errorCode string
- }{
{
- name: "valid transaction with percentage and remaining",
- transaction: Transaction{
- ChartOfAccountsGroupName: "PAG_CONTAS_CODE_1",
- Description: "description for the transaction person1 to person2 value of 100 reais",
- Metadata: map[string]interface{}{
- "depositType": "PIX",
- "valor": "100.00",
+ name: "external destination with positive available",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(300),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
},
- Pending: false,
- Route: "00000000-0000-0000-0000-000000000000",
- Send: Send{
- Asset: "BRL",
- Value: decimal.NewFromFloat(100.00),
- Source: Source{
- From: []FromTo{
- {
- AccountAlias: "@external/BRL",
- Remaining: "remaining",
- Description: "Loan payment 1",
- Route: "00000000-0000-0000-0000-000000000000",
- Metadata: map[string]interface{}{
- "1": "m",
- "Cpf": "43049498x",
- },
- },
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {
- AccountAlias: "@mcgregor_0",
- Share: &Share{
- Percentage: 50,
- },
- Route: "00000000-0000-0000-0000-000000000000",
- Metadata: map[string]interface{}{
- "mensagem": "tks",
- },
- },
- {
- AccountAlias: "@mcgregor_1",
- Share: &Share{
- Percentage: 50,
- },
- Description: "regression test",
- Metadata: map[string]interface{}{
- "key": "value",
- },
- },
- },
- },
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(50),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeExternal,
},
},
- expectError: false,
+ errorCode: ErrorInsufficientFunds,
+ field: "destinations",
},
{
- name: "transaction with value mismatch",
- transaction: Transaction{
- ChartOfAccountsGroupName: "PAG_CONTAS_CODE_1",
- Description: "transaction with value mismatch",
- Pending: false,
- Send: Send{
- Asset: "BRL",
- Value: decimal.NewFromFloat(100.00),
- Source: Source{
- From: []FromTo{
- {
- AccountAlias: "@external/BRL",
- Amount: &Amount{
- Asset: "BRL",
- // Source amount doesn't match transaction value
- Value: decimal.NewFromFloat(90.00),
- },
- },
- },
- },
- Distribute: Distribute{
- To: []FromTo{
- {
- AccountAlias: "@mcgregor_0",
- Share: &Share{
- Percentage: 100,
- },
- },
- },
- },
+ name: "source insufficient funds",
+ balances: map[string]Balance{
+ "source-balance": {
+ ID: "source-balance",
+ AccountID: "source-account",
+ Asset: "USD",
+ Available: decimal.NewFromInt(50),
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
+ },
+ "destination-balance": {
+ ID: "destination-balance",
+ AccountID: "destination-account",
+ Asset: "USD",
+ Available: decimal.Zero,
+ AllowSending: true,
+ AllowReceiving: true,
+ AccountType: AccountTypeExternal,
},
},
- expectError: true,
- errorCode: "0073", // ErrTransactionValueMismatch
+ errorCode: ErrorInsufficientFunds,
+ field: "sources",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- // Call ValidateSendSourceAndDistribute to get the responses
- responses, err := ValidateSendSourceAndDistribute(ctx, tt.transaction, constant.CREATED)
-
- if tt.expectError {
- assert.Error(t, err)
- if tt.errorCode != "" {
- errMsg := err.Error()
- assert.Contains(t, errMsg, tt.errorCode, "Error should contain the expected error code")
- }
- return
- }
+ err := ValidateBalanceEligibility(plan, tt.balances)
+ de := assertDomainError(t, err, tt.errorCode)
+ assert.Equal(t, tt.field, de.Field)
+ })
+ }
+}
+
+func TestValidateBalanceEligibility_NonPending_ExternalSourceAllowed(t *testing.T) {
+ t.Parallel()
+
+ // When not pending, external sources ARE allowed (only pending + external is prohibited).
+ plan := IntentPlan{
+ Asset: "USD",
+ Total: decimal.NewFromInt(100),
+ Sources: []Posting{{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(100),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }},
+ Destinations: []Posting{{
+ Target: LedgerTarget{AccountID: "dst-acc", BalanceID: "dst-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(100),
+ Operation: OperationCredit,
+ Status: StatusCreated,
+ }},
+ Pending: false,
+ }
+
+ balances := map[string]Balance{
+ "src-bal": {
+ ID: "src-bal",
+ AccountID: "src-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(500),
+ AllowSending: true,
+ AccountType: AccountTypeExternal,
+ },
+ "dst-bal": {
+ ID: "dst-bal",
+ AccountID: "dst-acc",
+ Asset: "USD",
+ Available: decimal.Zero,
+ AllowReceiving: true,
+ AccountType: AccountTypeInternal,
+ },
+ }
+
+ err := ValidateBalanceEligibility(plan, balances)
+ require.NoError(t, err)
+}
+
+func TestValidateBalanceEligibility_ExternalDestinationWithZeroAvailable(t *testing.T) {
+ t.Parallel()
- assert.NoError(t, err)
- assert.NotNil(t, responses)
+ // External destination with zero available should pass.
+ plan := IntentPlan{
+ Asset: "USD",
+ Total: decimal.NewFromInt(50),
+ Sources: []Posting{{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(50),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }},
+ Destinations: []Posting{{
+ Target: LedgerTarget{AccountID: "dst-acc", BalanceID: "dst-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(50),
+ Operation: OperationCredit,
+ Status: StatusCreated,
+ }},
+ }
- // For successful case, validate response structure
- assert.Equal(t, tt.transaction.Send.Value, responses.Total)
- assert.Equal(t, tt.transaction.Send.Asset, responses.Asset)
+ balances := map[string]Balance{
+ "src-bal": {
+ ID: "src-bal",
+ AccountID: "src-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(200),
+ AllowSending: true,
+ AccountType: AccountTypeInternal,
+ },
+ "dst-bal": {
+ ID: "dst-bal",
+ AccountID: "dst-acc",
+ Asset: "USD",
+ Available: decimal.Zero,
+ AllowReceiving: true,
+ AccountType: AccountTypeExternal,
+ },
+ }
- // Verify the source account is included in the response
- fromKey := "@external/BRL"
- _, exists := responses.From[fromKey]
- assert.True(t, exists, "From account should exist: %s", fromKey)
+ err := ValidateBalanceEligibility(plan, balances)
+ require.NoError(t, err)
+}
- // Verify the destination accounts are included in the response
- toKey1 := "@mcgregor_0"
- _, exists = responses.To[toKey1]
- assert.True(t, exists, "To account should exist: %s", toKey1)
+func TestValidateBalanceEligibility_ExternalDestinationNegativeAvailable(t *testing.T) {
+ t.Parallel()
- toKey2 := "@mcgregor_1"
- _, exists = responses.To[toKey2]
- assert.True(t, exists, "To account should exist: %s", toKey2)
+ // External destination with negative available should now fail (!IsZero returns true for negative).
+ plan := IntentPlan{
+ Asset: "USD",
+ Total: decimal.NewFromInt(50),
+ Sources: []Posting{{
+ Target: LedgerTarget{AccountID: "src-acc", BalanceID: "src-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(50),
+ Operation: OperationDebit,
+ Status: StatusCreated,
+ }},
+ Destinations: []Posting{{
+ Target: LedgerTarget{AccountID: "dst-acc", BalanceID: "dst-bal"},
+ Asset: "USD",
+ Amount: decimal.NewFromInt(50),
+ Operation: OperationCredit,
+ Status: StatusCreated,
+ }},
+ }
- // Verify total amount is correctly distributed
- var total decimal.Decimal
- for _, amount := range responses.To {
- total = total.Add(amount.Value)
- }
- assert.True(t, responses.Total.Equal(total),
- "Total amount (%s) should equal sum of destination amounts (%s)",
- responses.Total.String(), total.String())
- })
+ balances := map[string]Balance{
+ "src-bal": {
+ ID: "src-bal",
+ AccountID: "src-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(200),
+ AllowSending: true,
+ AccountType: AccountTypeInternal,
+ },
+ "dst-bal": {
+ ID: "dst-bal",
+ AccountID: "dst-acc",
+ Asset: "USD",
+ Available: decimal.NewFromInt(-10),
+ AllowReceiving: true,
+ AccountType: AccountTypeExternal,
+ },
}
+
+ err := ValidateBalanceEligibility(plan, balances)
+ de := assertDomainError(t, err, ErrorDataCorruption)
+ assert.Equal(t, "balance", de.Field)
+}
+
+// ---------------------------------------------------------------------------
+// Serialization round-trip (IntentPlan)
+// ---------------------------------------------------------------------------
+
+func TestIntentPlan_JSONRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ original := IntentPlan{
+ Asset: "BRL",
+ Total: *decPtr(t, "1234.56"),
+ Pending: true,
+ Sources: []Posting{{
+ Target: LedgerTarget{AccountID: "a", BalanceID: "b"},
+ Asset: "BRL",
+ Amount: *decPtr(t, "1234.56"),
+ Operation: OperationOnHold,
+ Status: StatusPending,
+ Route: "pix",
+ }},
+ Destinations: []Posting{{
+ Target: LedgerTarget{AccountID: "c", BalanceID: "d"},
+ Asset: "BRL",
+ Amount: *decPtr(t, "1234.56"),
+ Operation: OperationCredit,
+ Status: StatusPending,
+ }},
+ }
+
+ data, err := json.Marshal(original)
+ require.NoError(t, err)
+
+ var restored IntentPlan
+ err = json.Unmarshal(data, &restored)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.Asset, restored.Asset)
+ assert.True(t, original.Total.Equal(restored.Total))
+ assert.Equal(t, original.Pending, restored.Pending)
+ assert.Len(t, restored.Sources, 1)
+ assert.Len(t, restored.Destinations, 1)
+ assert.True(t, original.Sources[0].Amount.Equal(restored.Sources[0].Amount))
+ assert.Equal(t, original.Sources[0].Operation, restored.Sources[0].Operation)
+ assert.Equal(t, original.Sources[0].Route, restored.Sources[0].Route)
+}
+
+func TestBalance_JSONRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ original := Balance{
+ ID: "bal-123",
+ OrganizationID: "org-1",
+ LedgerID: "led-1",
+ AccountID: "acc-1",
+ Asset: "BTC",
+ Available: *decPtr(t, "0.00123456"),
+ OnHold: *decPtr(t, "0.00000001"),
+ Version: 42,
+ AccountType: AccountTypeInternal,
+ AllowSending: true,
+ AllowReceiving: true,
+ Metadata: map[string]any{"key": "value"},
+ }
+
+ data, err := json.Marshal(original)
+ require.NoError(t, err)
+
+ var restored Balance
+ err = json.Unmarshal(data, &restored)
+ require.NoError(t, err)
+
+ assert.Equal(t, original.ID, restored.ID)
+ assert.Equal(t, original.OrganizationID, restored.OrganizationID)
+ assert.True(t, original.Available.Equal(restored.Available))
+ assert.True(t, original.OnHold.Equal(restored.OnHold))
+ assert.Equal(t, original.Version, restored.Version)
+ assert.Equal(t, original.AccountType, restored.AccountType)
+ assert.Equal(t, "value", restored.Metadata["key"])
}
diff --git a/commons/utils.go b/commons/utils.go
index 2db8c879..eb651c0e 100644
--- a/commons/utils.go
+++ b/commons/utils.go
@@ -1,34 +1,25 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package commons
import (
"context"
"encoding/json"
- "errors"
+ "fmt"
"math"
"os/exec"
"reflect"
"regexp"
"slices"
"strconv"
- "strings"
"time"
- "unicode"
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ cn "github.com/LerianStudio/lib-commons/v4/commons/constants"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/opentelemetry/metrics"
"github.com/google/uuid"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
- "go.opentelemetry.io/otel/metric"
)
-const beginningKey = "{"
-const keySeparator = ":"
-const endKey = "}"
-
var internalServicePattern = regexp.MustCompile(`^[\w-]+/[\d.]+\s+LerianStudio$`)
// Contains checks if an item is in a slice. This function uses type parameters to work with any slice type.
@@ -40,12 +31,14 @@ func Contains[T comparable](slice []T, item T) bool {
func CheckMetadataKeyAndValueLength(limit int, metadata map[string]any) error {
for k, v := range metadata {
if len(k) > limit {
- return errors.New("0050")
+ return cn.ErrMetadataKeyLengthExceeded
}
var value string
switch t := v.(type) {
+ case nil:
+ continue // nil values are valid, skip length check
case int:
value = strconv.Itoa(t)
case float64:
@@ -54,103 +47,23 @@ func CheckMetadataKeyAndValueLength(limit int, metadata map[string]any) error {
value = t
case bool:
value = strconv.FormatBool(t)
+ default:
+ value = fmt.Sprintf("%v", t) // convert unknown types to string for length check
}
if len(value) > limit {
- return errors.New("0051")
+ return cn.ErrMetadataValueLengthExceeded
}
}
return nil
}
-// Deprecated: use ValidateCountryAddress method from Midaz pkg instead.
-// ValidateCountryAddress validate if country in object address contains in countries list using ISO 3166-1 alpha-2
-func ValidateCountryAddress(country string) error {
- countries := []string{
- "AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AQ", "AR", "AS", "AT", "AU", "AW", "AX", "AZ",
- "BA", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BL", "BM", "BN", "BO", "BQ", "BR", "BS", "BT", "BV", "BW",
- "BY", "BZ", "CA", "CC", "CD", "CF", "CG", "CH", "CI", "CK", "CL", "CM", "CN", "CO", "CR", "CU", "CV", "CW", "CX",
- "CY", "CZ", "DE", "DJ", "DK", "DM", "DO", "DZ", "EC", "EE", "EG", "EH", "ER", "ES", "ET", "FI", "FJ", "FK", "FM",
- "FO", "FR", "GA", "GB", "GD", "GE", "GF", "GG", "GH", "GI", "GL", "GM", "GN", "GP", "GQ", "GR", "GS", "GT", "GU",
- "GW", "GY", "HK", "HM", "HN", "HR", "HT", "HU", "ID", "IE", "IL", "IM", "IN", "IO", "IQ", "IR", "IS", "IT", "JE",
- "JM", "JO", "JP", "KE", "KG", "KH", "KI", "KM", "KN", "KP", "KR", "KW", "KY", "KZ", "LA", "LB", "LC", "LI", "LK",
- "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "ME", "MF", "MG", "MH", "MK", "ML", "MM", "MN", "MO", "MP",
- "MQ", "MR", "MS", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NC", "NE", "NF", "NG", "NI", "NL", "NO", "NP",
- "NR", "NU", "NZ", "OM", "PA", "PE", "PF", "PG", "PH", "PK", "PL", "PM", "PN", "PR", "PS", "PT", "PW", "PY", "QA",
- "RE", "RO", "RS", "RU", "RW", "SA", "SB", "SC", "SD", "SE", "SG", "SH", "SI", "SJ", "SK", "SL", "SM", "SN", "SO",
- "SR", "SS", "ST", "SV", "SX", "SY", "SZ", "TC", "TD", "TF", "TG", "TH", "TJ", "TK", "TL", "TM", "TN", "TO", "TR",
- "TT", "TV", "TW", "TZ", "UA", "UG", "UM", "US", "UY", "UZ", "VA", "VC", "VE", "VG", "VI", "VN", "VU", "WF", "WS",
- "YE", "YT", "ZA", "ZM", "ZW",
- }
-
- if !slices.Contains(countries, country) {
- return errors.New("0032")
- }
-
- return nil
-}
-
-// Deprecated: use ValidateAccountType method from Midaz pkg instead.
-// ValidateAccountType validate type values of accounts
-func ValidateAccountType(t string) error {
- types := []string{"deposit", "savings", "loans", "marketplace", "creditCard"}
-
- if !slices.Contains(types, t) {
- return errors.New("0066")
- }
-
- return nil
-}
-
-// Deprecated: use ValidateType method from Midaz pkg instead.
-// ValidateType validate type values of currencies
-func ValidateType(t string) error {
- types := []string{"crypto", "currency", "commodity", "others"}
-
- if !slices.Contains(types, t) {
- return errors.New("0040")
- }
-
- return nil
-}
-
-// Deprecated: use ValidateCode method from Midaz pkg instead.
-func ValidateCode(code string) error {
- for _, r := range code {
- if !unicode.IsLetter(r) {
- return errors.New("0033")
- } else if !unicode.IsUpper(r) {
- return errors.New("0004")
- }
- }
-
- return nil
-}
-
-// Deprecated: use ValidateCurrency method from Midaz pkg instead.
-// ValidateCurrency validate if code contains in currencies list using ISO 4217
-func ValidateCurrency(code string) error {
- currencies := []string{
- "AED", "AFN", "ALL", "AMD", "ANG", "AOA", "ARS", "AUD", "AWG", "AZN", "BAM", "BBD", "BDT", "BGN", "BHD", "BIF", "BMD", "BND", "BOB",
- "BOV", "BRL", "BSD", "BTN", "BWP", "BYN", "BZD", "CAD", "CDF", "CHE", "CHF", "CHW", "CLF", "CLP", "CNY", "COP", "COU", "CRC", "CUC",
- "CUP", "CVE", "CZK", "DJF", "DKK", "DOP", "DZD", "EGP", "ERN", "ETB", "EUR", "FJD", "FKP", "GBP", "GEL", "GHS", "GIP", "GMD", "GNF",
- "GTQ", "GYD", "HKD", "HNL", "HTG", "HUF", "IDR", "ILS", "INR", "IQD", "IRR", "ISK", "JMD", "JOD", "JPY", "KES", "KGS", "KHR", "KMF",
- "KPW", "KRW", "KWD", "KYD", "KZT", "LAK", "LBP", "LKR", "LRD", "LSL", "LYD", "MAD", "MDL", "MGA", "MKD", "MMK", "MNT", "MOP", "MRU",
- "MUR", "MVR", "MWK", "MXN", "MXV", "MYR", "MZN", "NAD", "NGN", "NIO", "NOK", "NPR", "NZD", "OMR", "PAB", "PEN", "PGK", "PHP", "PKR",
- "PLN", "PYG", "QAR", "RON", "RSD", "RUB", "RWF", "SAR", "SBD", "SCR", "SDG", "SEK", "SGD", "SHP", "SLE", "SOS", "SRD", "SSP", "STN",
- "SVC", "SYP", "SZL", "THB", "TJS", "TMT", "TND", "TOP", "TRY", "TTD", "TWD", "TZS", "UAH", "UGX", "USD", "USN", "UYI", "UYU", "UZS",
- "VED", "VEF", "VND", "VUV", "WST", "XAF", "XCD", "XDR", "XOF", "XPF", "XSU", "XUA", "YER", "ZAR", "ZMW", "ZWL",
- }
-
- if !slices.Contains(currencies, code) {
- return errors.New("0005")
- }
-
- return nil
-}
-
-// SafeIntToUint64 safe mode to converter int to uint64
+// SafeIntToUint64 converts int to uint64 with safety clamping.
+// Negative values are mapped to 1 (not 0) because this function is typically
+// used where the result serves as a divisor or count, and zero would cause
+// a division-by-zero panic. Using 1 as the safe minimum preserves
+// arithmetic safety while signaling an unexpected input.
func SafeIntToUint64(val int) uint64 {
if val < 0 {
return uint64(1)
@@ -185,7 +98,14 @@ func SafeUintToInt(val uint) int {
func SafeIntToUint32(value int, defaultVal uint32, logger log.Logger, fieldName string) uint32 {
if value < 0 {
if logger != nil {
- logger.Debugf("Invalid %s value %d (negative), using default: %d", fieldName, value, defaultVal)
+ logger.Log(
+ context.Background(),
+ log.LevelDebug,
+ "invalid uint32 source value, using default",
+ log.String("field_name", fieldName),
+ log.Int("value", value),
+ log.Int("default", int(defaultVal)),
+ )
}
return defaultVal
@@ -195,7 +115,15 @@ func SafeIntToUint32(value int, defaultVal uint32, logger log.Logger, fieldName
if uv > uint64(math.MaxUint32) {
if logger != nil {
- logger.Debugf("%s value %d exceeds uint32 max (%d), using default %d", fieldName, value, uint64(math.MaxUint32), defaultVal)
+ logger.Log(
+ context.Background(),
+ log.LevelDebug,
+ "uint32 source value exceeds max, using default",
+ log.String("field_name", fieldName),
+ log.Int("value", value),
+ log.Any("max", uint64(math.MaxUint32)),
+ log.Int("default", int(defaultVal)),
+ )
}
return defaultVal
@@ -211,18 +139,17 @@ func IsUUID(s string) bool {
return err == nil
}
-// GenerateUUIDv7 generate a new uuid v7 using google/uuid package and return it. If an error occurs, it will return the error.
-func GenerateUUIDv7() uuid.UUID {
- u := uuid.Must(uuid.NewV7())
-
- return u
+// GenerateUUIDv7 generates a new UUID v7 using the google/uuid package.
+// Returns the generated UUID or an error if crypto/rand fails.
+func GenerateUUIDv7() (uuid.UUID, error) {
+ return uuid.NewV7()
}
// StructToJSONString convert a struct to json string
func StructToJSONString(s any) (string, error) {
jsonByte, err := json.Marshal(s)
if err != nil {
- return "", err
+ return "", fmt.Errorf("struct to JSON: %w", err)
}
return string(jsonByte), nil
@@ -246,23 +173,32 @@ func MergeMaps(source, target map[string]any) map[string]any {
return target
}
+// SyscmdI abstracts command execution for testing and composition.
type SyscmdI interface {
- ExecCmd(name string, arg ...string) ([]byte, error)
+ ExecCmd(ctx context.Context, name string, arg ...string) ([]byte, error)
}
+// Syscmd is the default SyscmdI implementation backed by os/exec.
type Syscmd struct{}
-func (r *Syscmd) ExecCmd(name string, arg ...string) ([]byte, error) {
- return exec.Command(name, arg...).Output() //#nosec G204 -- Generic command wrapper; caller responsible for safe usage
+// ExecCmd runs a command and returns its stdout bytes.
+func (r *Syscmd) ExecCmd(ctx context.Context, name string, arg ...string) ([]byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ // #nosec G204 -- arguments are passed directly to exec.CommandContext (no shell interpretation); callers are responsible for input validation
+ return exec.CommandContext(ctx, name, arg...).Output()
}
-// GetCPUUsage get the current CPU usage
-func GetCPUUsage(ctx context.Context, cpuGauge metric.Int64Gauge) {
+// GetCPUUsage reads the current CPU usage and records it through the MetricsFactory gauge.
+// If factory is nil, the reading is performed but metric recording is skipped.
+func GetCPUUsage(ctx context.Context, factory *metrics.MetricsFactory) {
logger := NewLoggerFromContext(ctx)
out, err := cpu.Percent(100*time.Millisecond, false)
if err != nil {
- logger.Warnf("Errot to get cpu use: %v", err)
+ logger.Log(ctx, log.LevelWarn, "error getting CPU usage", log.Err(err))
}
var percentageCPU int64 = 0
@@ -270,23 +206,38 @@ func GetCPUUsage(ctx context.Context, cpuGauge metric.Int64Gauge) {
percentageCPU = int64(out[0])
}
- cpuGauge.Record(ctx, percentageCPU)
+ if factory == nil {
+ logger.Log(ctx, log.LevelWarn, "metrics factory is nil, skipping CPU usage recording")
+ return
+ }
+
+ if err := factory.RecordSystemCPUUsage(ctx, percentageCPU); err != nil {
+ logger.Log(ctx, log.LevelWarn, "error recording CPU gauge", log.Err(err))
+ }
}
-// GetMemUsage get the current memory usage
-func GetMemUsage(ctx context.Context, memGauge metric.Int64Gauge) {
+// GetMemUsage reads the current memory usage and records it through the MetricsFactory gauge.
+// If factory is nil, the reading is performed but metric recording is skipped.
+func GetMemUsage(ctx context.Context, factory *metrics.MetricsFactory) {
logger := NewLoggerFromContext(ctx)
var percentageMem int64 = 0
out, err := mem.VirtualMemory()
if err != nil {
- logger.Warnf("Error to get info memory: %v", err)
+ logger.Log(ctx, log.LevelWarn, "error getting memory info", log.Err(err))
} else {
percentageMem = int64(out.UsedPercent)
}
- memGauge.Record(ctx, percentageMem)
+ if factory == nil {
+ logger.Log(ctx, log.LevelWarn, "metrics factory is nil, skipping memory usage recording")
+ return
+ }
+
+ if err := factory.RecordSystemMemUsage(ctx, percentageMem); err != nil {
+ logger.Log(ctx, log.LevelWarn, "error recording memory gauge", log.Err(err))
+ }
}
// GetMapNumKinds get the map of numeric kinds to use in validations and conversions.
@@ -322,61 +273,6 @@ func Reverse[T any](s []T) []T {
return s
}
-// Deprecated: use GenericInternalKey method from Midaz pkg instead.
-// GenericInternalKey returns a key with the following format to be used on redis cluster:
-// "name:{organizationID:ledgerID:key}"
-func GenericInternalKey(name, organizationID, ledgerID, key string) string {
- var builder strings.Builder
-
- builder.WriteString(name)
- builder.WriteString(keySeparator)
- builder.WriteString(beginningKey)
- builder.WriteString(organizationID)
- builder.WriteString(keySeparator)
- builder.WriteString(ledgerID)
- builder.WriteString(keySeparator)
- builder.WriteString(key)
- builder.WriteString(endKey)
-
- return builder.String()
-}
-
-// Deprecated: use TransactionInternalKey method from Midaz pkg instead.
-// TransactionInternalKey returns a key with the following format to be used on redis cluster:
-// "transaction:{organizationID:ledgerID:key}"
-func TransactionInternalKey(organizationID, ledgerID uuid.UUID, key string) string {
- transaction := GenericInternalKey("transaction", organizationID.String(), ledgerID.String(), key)
-
- return transaction
-}
-
-// Deprecated: use IdempotencyInternalKey method from Midaz pkg instead.
-// IdempotencyInternalKey returns a key with the following format to be used on redis cluster:
-// "idempotency:{organizationID:ledgerID:key}"
-func IdempotencyInternalKey(organizationID, ledgerID uuid.UUID, key string) string {
- idempotency := GenericInternalKey("idempotency", organizationID.String(), ledgerID.String(), key)
-
- return idempotency
-}
-
-// Deprecated: use BalanceInternalKey method from Midaz pkg instead.
-// BalanceInternalKey returns a key with the following format to be used on redis cluster:
-// "balance:{organizationID:ledgerID:key}"
-func BalanceInternalKey(organizationID, ledgerID, key string) string {
- balance := GenericInternalKey("balance", organizationID, ledgerID, key)
-
- return balance
-}
-
-// Deprecated: use AccountingRoutesInternalKey method from Midaz pkg instead.
-// AccountingRoutesInternalKey returns a key with the following format to be used on redis cluster:
-// "accounting_routes:{organizationID:ledgerID:key}"
-func AccountingRoutesInternalKey(organizationID, ledgerID, key uuid.UUID) string {
- accountingRoutes := GenericInternalKey("accounting_routes", organizationID.String(), ledgerID.String(), key.String())
-
- return accountingRoutes
-}
-
// UUIDsToStrings converts a slice of UUIDs to a slice of strings.
// It's optimized to minimize allocations and iterations.
func UUIDsToStrings(uuids []uuid.UUID) []string {
@@ -388,6 +284,7 @@ func UUIDsToStrings(uuids []uuid.UUID) []string {
return result
}
+// IsInternalLerianService reports whether a user-agent belongs to a Lerian internal service.
func IsInternalLerianService(userAgent string) bool {
return internalServicePattern.MatchString(userAgent)
}
diff --git a/commons/utils_test.go b/commons/utils_test.go
new file mode 100644
index 00000000..9b4f0e2d
--- /dev/null
+++ b/commons/utils_test.go
@@ -0,0 +1,347 @@
+//go:build unit
+
+package commons
+
+import (
+ "context"
+ "math"
+ "reflect"
+ "testing"
+
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/google/uuid"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestContains(t *testing.T) {
+ t.Parallel()
+
+ t.Run("found", func(t *testing.T) {
+ t.Parallel()
+ assert.True(t, Contains([]string{"a", "b", "c"}, "b"))
+ })
+
+ t.Run("not_found", func(t *testing.T) {
+ t.Parallel()
+ assert.False(t, Contains([]string{"a", "b", "c"}, "z"))
+ })
+
+ t.Run("empty_slice", func(t *testing.T) {
+ t.Parallel()
+ assert.False(t, Contains([]int{}, 1))
+ })
+}
+
+func TestCheckMetadataKeyAndValueLength(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ limit int
+ metadata map[string]any
+ wantErr string
+ }{
+ {
+ name: "key_too_long",
+ limit: 3,
+ metadata: map[string]any{"toolong": "v"},
+ wantErr: "0050",
+ },
+ {
+ name: "int_value",
+ limit: 10,
+ metadata: map[string]any{"k": 42},
+ },
+ {
+ name: "float64_value",
+ limit: 20,
+ metadata: map[string]any{"k": 3.14},
+ },
+ {
+ name: "string_value_within_limit",
+ limit: 10,
+ metadata: map[string]any{"k": "short"},
+ },
+ {
+ name: "string_value_too_long",
+ limit: 3,
+ metadata: map[string]any{"k": "toolong"},
+ wantErr: "0051",
+ },
+ {
+ name: "bool_value",
+ limit: 10,
+ metadata: map[string]any{"k": true},
+ },
+ {
+ name: "nil_value_skipped",
+ limit: 1,
+ metadata: map[string]any{"k": nil},
+ },
+ {
+ name: "unknown_type",
+ limit: 10,
+ metadata: map[string]any{"k": []int{1, 2}},
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ err := CheckMetadataKeyAndValueLength(tc.limit, tc.metadata)
+ if tc.wantErr != "" {
+ require.Error(t, err)
+ assert.Equal(t, tc.wantErr, err.Error())
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestSafeIntToUint64(t *testing.T) {
+ t.Parallel()
+
+ t.Run("negative_returns_1", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, uint64(1), SafeIntToUint64(-5))
+ })
+
+ t.Run("positive", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, uint64(42), SafeIntToUint64(42))
+ })
+
+ t.Run("zero", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, uint64(0), SafeIntToUint64(0))
+ })
+}
+
+func TestSafeInt64ToInt(t *testing.T) {
+ t.Parallel()
+
+ t.Run("normal", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, 100, SafeInt64ToInt(100))
+ })
+
+ t.Run("overflow_max", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, math.MaxInt, SafeInt64ToInt(math.MaxInt64))
+ })
+
+ t.Run("underflow_min", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, math.MinInt, SafeInt64ToInt(math.MinInt64))
+ })
+}
+
+func TestSafeUintToInt(t *testing.T) {
+ t.Parallel()
+
+ t.Run("normal", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, 10, SafeUintToInt(10))
+ })
+
+ t.Run("overflow", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, math.MaxInt, SafeUintToInt(uint(math.MaxUint)))
+ })
+}
+
+func TestSafeIntToUint32(t *testing.T) {
+ t.Parallel()
+
+ t.Run("negative_returns_default", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, uint32(99), SafeIntToUint32(-1, 99, nil, "test"))
+ })
+
+ t.Run("overflow_returns_default", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, uint32(99), SafeIntToUint32(math.MaxInt, 99, nil, "test"))
+ })
+
+ t.Run("normal", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, uint32(42), SafeIntToUint32(42, 0, nil, "test"))
+ })
+
+ t.Run("negative_with_logger", func(t *testing.T) {
+ t.Parallel()
+ logger := &log.NopLogger{}
+ assert.Equal(t, uint32(99), SafeIntToUint32(-1, 99, logger, "field"))
+ })
+
+ t.Run("overflow_with_logger", func(t *testing.T) {
+ t.Parallel()
+ logger := &log.NopLogger{}
+ assert.Equal(t, uint32(99), SafeIntToUint32(math.MaxInt, 99, logger, "field"))
+ })
+}
+
+func TestIsUUID(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid", func(t *testing.T) {
+ t.Parallel()
+ assert.True(t, IsUUID("550e8400-e29b-41d4-a716-446655440000"))
+ })
+
+ t.Run("invalid", func(t *testing.T) {
+ t.Parallel()
+ assert.False(t, IsUUID("not-a-uuid"))
+ })
+}
+
+func TestGenerateUUIDv7(t *testing.T) {
+ t.Parallel()
+
+ id, err := GenerateUUIDv7()
+ require.NoError(t, err)
+ assert.True(t, IsUUID(id.String()))
+}
+
+func TestStructToJSONString(t *testing.T) {
+ t.Parallel()
+
+ t.Run("valid_struct", func(t *testing.T) {
+ t.Parallel()
+
+ s := struct {
+ Name string `json:"name"`
+ }{Name: "test"}
+
+ result, err := StructToJSONString(s)
+ require.NoError(t, err)
+ assert.Equal(t, `{"name":"test"}`, result)
+ })
+
+ t.Run("invalid_value", func(t *testing.T) {
+ t.Parallel()
+
+ _, err := StructToJSONString(make(chan int))
+ assert.Error(t, err)
+ })
+}
+
+func TestMergeMaps(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil_target", func(t *testing.T) {
+ t.Parallel()
+
+ result := MergeMaps(map[string]any{"a": 1}, nil)
+ assert.Equal(t, 1, result["a"])
+ })
+
+ t.Run("nil_value_deletes_key", func(t *testing.T) {
+ t.Parallel()
+
+ target := map[string]any{"a": 1, "b": 2}
+ result := MergeMaps(map[string]any{"a": nil}, target)
+ _, exists := result["a"]
+ assert.False(t, exists)
+ assert.Equal(t, 2, result["b"])
+ })
+
+ t.Run("normal_merge", func(t *testing.T) {
+ t.Parallel()
+
+ target := map[string]any{"a": 1}
+ result := MergeMaps(map[string]any{"b": 2}, target)
+ assert.Equal(t, 1, result["a"])
+ assert.Equal(t, 2, result["b"])
+ })
+}
+
+func TestReverse(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty", func(t *testing.T) {
+ t.Parallel()
+ assert.Empty(t, Reverse([]int{}))
+ })
+
+ t.Run("single", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, []int{1}, Reverse([]int{1}))
+ })
+
+ t.Run("multiple", func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, []int{3, 2, 1}, Reverse([]int{1, 2, 3}))
+ })
+}
+
+func TestUUIDsToStrings(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty", func(t *testing.T) {
+ t.Parallel()
+ assert.Empty(t, UUIDsToStrings([]uuid.UUID{}))
+ })
+
+ t.Run("multiple", func(t *testing.T) {
+ t.Parallel()
+
+ u1 := uuid.MustParse("550e8400-e29b-41d4-a716-446655440000")
+ u2 := uuid.MustParse("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
+
+ result := UUIDsToStrings([]uuid.UUID{u1, u2})
+ assert.Equal(t, []string{u1.String(), u2.String()}, result)
+ })
+}
+
+func TestIsInternalLerianService(t *testing.T) {
+ t.Parallel()
+
+ t.Run("matching", func(t *testing.T) {
+ t.Parallel()
+ assert.True(t, IsInternalLerianService("my-service/1.0.0 LerianStudio"))
+ })
+
+ t.Run("non_matching", func(t *testing.T) {
+ t.Parallel()
+ assert.False(t, IsInternalLerianService("curl/7.68.0"))
+ })
+}
+
+func TestGetCPUUsage_NilFactory(t *testing.T) {
+ t.Parallel()
+
+ // Should not panic when factory is nil; metrics recording is skipped.
+ assert.NotPanics(t, func() {
+ GetCPUUsage(context.Background(), nil)
+ })
+}
+
+func TestGetMemUsage_NilFactory(t *testing.T) {
+ t.Parallel()
+
+ // Should not panic when factory is nil; metrics recording is skipped.
+ assert.NotPanics(t, func() {
+ GetMemUsage(context.Background(), nil)
+ })
+}
+
+func TestGetMapNumKinds(t *testing.T) {
+ t.Parallel()
+
+ kinds := GetMapNumKinds()
+
+ expected := []reflect.Kind{
+ reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Float32, reflect.Float64,
+ }
+
+ assert.Len(t, kinds, len(expected))
+
+ for _, k := range expected {
+ assert.True(t, kinds[k], "expected kind %v to be present", k)
+ }
+}
diff --git a/commons/zap/doc.go b/commons/zap/doc.go
new file mode 100644
index 00000000..104c0f8b
--- /dev/null
+++ b/commons/zap/doc.go
@@ -0,0 +1,5 @@
+// Package zap provides adapters and helpers around zap-based logging.
+//
+// It bridges the commons/log abstraction to zap while preserving structured
+// fields and compatibility with existing middleware/context plumbing.
+package zap
diff --git a/commons/zap/injector.go b/commons/zap/injector.go
index 3521aec8..eb84d0fd 100644
--- a/commons/zap/injector.go
+++ b/commons/zap/injector.go
@@ -1,76 +1,149 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package zap
import (
+ "errors"
"fmt"
- "log"
"os"
+ "strings"
- clog "github.com/LerianStudio/lib-commons/v2/commons/log"
"go.opentelemetry.io/contrib/bridges/otelzap"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
-// InitializeLoggerWithError initializes our log layer and returns it with error handling.
-// Returns an error instead of calling log.Fatalf on failure.
-//
-//nolint:ireturn
-func InitializeLoggerWithError() (clog.Logger, error) {
- var zapCfg zap.Config
-
- if os.Getenv("ENV_NAME") == "production" {
- zapCfg = zap.NewProductionConfig()
- zapCfg.EncoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
- zapCfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
- } else {
- zapCfg = zap.NewDevelopmentConfig()
- zapCfg.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
- zapCfg.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
- }
+const (
+ callerSkipFrames = 1
+ encodingConsole = "console"
+)
- if val, ok := os.LookupEnv("LOG_LEVEL"); ok {
- var lvl zapcore.Level
- if err := lvl.Set(val); err != nil {
- log.Printf("Invalid LOG_LEVEL, fallback to InfoLevel: %v", err)
+// Environment controls the baseline logger profile.
+type Environment string
- lvl = zapcore.InfoLevel
- }
+const (
+ // EnvironmentProduction enables production-safe logging defaults.
+ EnvironmentProduction Environment = "production"
+ // EnvironmentStaging enables staging-safe logging defaults.
+ EnvironmentStaging Environment = "staging"
+ // EnvironmentUAT enables UAT-safe logging defaults.
+ EnvironmentUAT Environment = "uat"
+ // EnvironmentDevelopment enables verbose development logging defaults.
+ EnvironmentDevelopment Environment = "development"
+ // EnvironmentLocal enables verbose local-development logging defaults.
+ EnvironmentLocal Environment = "local"
+)
- zapCfg.Level = zap.NewAtomicLevelAt(lvl)
+// Config contains all required logger initialization inputs.
+type Config struct {
+ Environment Environment
+ Level string
+ OTelLibraryName string
+}
+
+func (c Config) validate() error {
+ if c.OTelLibraryName == "" {
+ return errors.New("OTelLibraryName is required")
}
- zapCfg.DisableStacktrace = true
+ switch c.Environment {
+ case EnvironmentProduction, EnvironmentStaging, EnvironmentUAT, EnvironmentDevelopment, EnvironmentLocal:
+ return nil
+ default:
+ return fmt.Errorf("invalid environment %q", c.Environment)
+ }
+}
- logger, err := zapCfg.Build(zap.AddCallerSkip(2), zap.WrapCore(func(core zapcore.Core) zapcore.Core {
- return zapcore.NewTee(core, otelzap.NewCore(os.Getenv("OTEL_LIBRARY_NAME")))
- }))
+// New creates a structured logger from the given configuration.
+//
+// The returned Logger implements log.Logger and stores the runtime-adjustable
+// level handle internally. Use Logger.Level() to access it.
+func New(cfg Config) (*Logger, error) {
+ if err := cfg.validate(); err != nil {
+ return nil, fmt.Errorf("invalid zap config: %w", err)
+ }
+
+ baseConfig := buildConfigByEnvironment(cfg.Environment)
+
+ level, err := resolveLevel(cfg)
if err != nil {
- return nil, fmt.Errorf("can't initialize zap logger: %w", err)
+ return nil, err
}
- sugarLogger := logger.Sugar()
+ baseConfig.Level = level
+ baseConfig.DisableStacktrace = true
+
+ coreOptions := []zap.Option{
+ zap.AddCallerSkip(callerSkipFrames),
+ zap.WrapCore(func(core zapcore.Core) zapcore.Core {
+ return zapcore.NewTee(core, otelzap.NewCore(cfg.OTelLibraryName))
+ }),
+ }
- sugarLogger.Infof("Log level is (%v)", zapCfg.Level)
- sugarLogger.Infof("Logger is (%T) \n", sugarLogger)
+ built, err := baseConfig.Build(coreOptions...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to build logger: %w", err)
+ }
- return &ZapWithTraceLogger{
- Logger: sugarLogger,
+ return &Logger{
+ logger: built,
+ atomicLevel: level,
+ consoleEncoding: baseConfig.Encoding == encodingConsole,
}, nil
}
-// Deprecated: Use InitializeLoggerWithError for proper error handling.
-// InitializeLogger initializes our log layer and returns it.
-//
-//nolint:ireturn
-func InitializeLogger() clog.Logger {
- logger, err := InitializeLoggerWithError()
- if err != nil {
- log.Fatalf("%v", err)
+func resolveLevel(cfg Config) (zap.AtomicLevel, error) {
+ levelStr := cfg.Level
+ if strings.TrimSpace(levelStr) == "" {
+ levelStr = strings.TrimSpace(os.Getenv("LOG_LEVEL"))
+ }
+
+ if levelStr != "" {
+ var parsed zapcore.Level
+ if err := parsed.Set(levelStr); err != nil {
+ return zap.AtomicLevel{}, fmt.Errorf("invalid level %q: %w", levelStr, err)
+ }
+
+ return zap.NewAtomicLevelAt(parsed), nil
+ }
+
+ if cfg.Environment == EnvironmentDevelopment || cfg.Environment == EnvironmentLocal {
+ return zap.NewAtomicLevelAt(zapcore.DebugLevel), nil
+ }
+
+ return zap.NewAtomicLevelAt(zapcore.InfoLevel), nil
+}
+
+func buildConfigByEnvironment(environment Environment) zap.Config {
+ encoding := resolveEncoding(environment)
+
+ if environment == EnvironmentDevelopment || environment == EnvironmentLocal {
+ cfg := zap.NewDevelopmentConfig()
+ cfg.Encoding = encoding
+ cfg.EncoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
+
+ if encoding == encodingConsole {
+ cfg.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
+ }
+
+ return cfg
+ }
+
+ cfg := zap.NewProductionConfig()
+ cfg.Encoding = encoding
+ cfg.EncoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
+
+ return cfg
+}
+
+func resolveEncoding(environment Environment) string {
+ if enc := strings.TrimSpace(os.Getenv("LOG_ENCODING")); enc != "" {
+ if enc == "json" || enc == encodingConsole {
+ return enc
+ }
+ }
+
+ if environment == EnvironmentDevelopment || environment == EnvironmentLocal {
+ return encodingConsole
}
- return logger
+ return "json"
}
diff --git a/commons/zap/injector_test.go b/commons/zap/injector_test.go
index 40bf475e..50a095b2 100644
--- a/commons/zap/injector_test.go
+++ b/commons/zap/injector_test.go
@@ -1,86 +1,160 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package zap
-// Note on error path testing for InitializeLoggerWithError:
-// The zap logger Build() function only returns an error in cases that are
-// difficult to simulate in unit tests (e.g., invalid output paths, encoder errors).
-// With the default configuration used in InitializeLoggerWithError, the Build()
-// call is very unlikely to fail.
-//
-// The error path IS covered in InitializeLoggerWithError (injector.go):
-// When zap.Build() fails, the function returns a wrapped error via
-// fmt.Errorf("can't initialize zap logger: %w", err)
-// This ensures proper error chaining for callers using errors.Is() or errors.As().
-//
-// To trigger an actual error in Build(), one would need to:
-// - Provide an invalid output path (not possible with current implementation)
-// - Corrupt the zap configuration (not exposed)
-//
-// Therefore, error handling exists and is correct, but cannot be easily tested
-// without modifying the production code to accept external configuration.
-
import (
- "bytes"
- "log"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/zap/zapcore"
)
-func TestInitializeLogger(t *testing.T) {
- t.Setenv("ENV_NAME", "production")
+func TestNewRejectsMissingOTelLibraryName(t *testing.T) {
+ t.Parallel()
+
+ _, err := New(Config{Environment: EnvironmentProduction})
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "OTelLibraryName is required")
+}
+
+func TestNewRejectsInvalidEnvironment(t *testing.T) {
+ t.Parallel()
- logger := InitializeLogger()
- assert.NotNil(t, logger)
+ _, err := New(Config{Environment: Environment("banana"), OTelLibraryName: "svc"})
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid environment")
}
-func TestInitializeLoggerWithError_Success(t *testing.T) {
- t.Setenv("ENV_NAME", "production")
+func TestNewAppliesEnvironmentDefaultLevel(t *testing.T) {
+ t.Parallel()
- logger, err := InitializeLoggerWithError()
+ logger, err := New(Config{Environment: EnvironmentDevelopment, OTelLibraryName: "svc"})
+ require.NoError(t, err)
+ assert.Equal(t, zapcore.DebugLevel, logger.Level().Level())
- assert.NoError(t, err)
- assert.NotNil(t, logger)
+ logger, err = New(Config{Environment: EnvironmentProduction, OTelLibraryName: "svc"})
+ require.NoError(t, err)
+ assert.Equal(t, zapcore.InfoLevel, logger.Level().Level())
}
-func TestInitializeLoggerWithError_Development(t *testing.T) {
- t.Setenv("ENV_NAME", "development")
+func TestNewAppliesCustomLevel(t *testing.T) {
+ t.Parallel()
- logger, err := InitializeLoggerWithError()
+ logger, err := New(Config{Environment: EnvironmentProduction, OTelLibraryName: "svc", Level: "error"})
+ require.NoError(t, err)
+ assert.Equal(t, zapcore.ErrorLevel, logger.Level().Level())
+}
+
+func TestNewRejectsInvalidCustomLevel(t *testing.T) {
+ t.Parallel()
- assert.NoError(t, err)
- assert.NotNil(t, logger)
+ _, err := New(Config{Environment: EnvironmentProduction, OTelLibraryName: "svc", Level: "invalid"})
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid level")
}
-func TestInitializeLoggerWithError_CustomLogLevel(t *testing.T) {
- t.Setenv("ENV_NAME", "production")
- t.Setenv("LOG_LEVEL", "warn")
+func TestCallerAttributionPointsToCallSite(t *testing.T) {
+ t.Parallel()
+
+ // Verify that caller skip is configured so that the logged caller
+ // points to the call site, not the zap wrapper internals.
+ // callerSkipFrames=1 means skip the wrapper's own frame.
+ logger, err := New(Config{
+ Environment: EnvironmentDevelopment,
+ OTelLibraryName: "test-caller",
+ })
+ require.NoError(t, err)
+
+ // The logger should not be nil and should have caller enabled
+ // (development config enables AddCaller by default).
+ raw := logger.Raw()
+ require.NotNil(t, raw, "Raw() should return the underlying zap logger")
+}
+
+func TestNewWithLocalEnvironment(t *testing.T) {
+ t.Parallel()
+
+ logger, err := New(Config{Environment: EnvironmentLocal, OTelLibraryName: "svc"})
+ require.NoError(t, err)
+ require.NotNil(t, logger)
+ assert.Equal(t, zapcore.DebugLevel, logger.Level().Level())
+}
+
+func TestNewWithStagingEnvironment(t *testing.T) {
+ t.Parallel()
+
+ logger, err := New(Config{Environment: EnvironmentStaging, OTelLibraryName: "svc"})
+ require.NoError(t, err)
+ require.NotNil(t, logger)
+ assert.Equal(t, zapcore.InfoLevel, logger.Level().Level())
+}
+
+func TestNewWithUATEnvironment(t *testing.T) {
+ t.Parallel()
+
+ logger, err := New(Config{Environment: EnvironmentUAT, OTelLibraryName: "svc"})
+ require.NoError(t, err)
+ require.NotNil(t, logger)
+ assert.Equal(t, zapcore.InfoLevel, logger.Level().Level())
+}
+
+func TestResolveLevelEmptyForProductionDefaultsToInfo(t *testing.T) {
+ t.Parallel()
+
+ level, err := resolveLevel(Config{Environment: EnvironmentProduction, Level: ""})
+ require.NoError(t, err)
+ assert.Equal(t, zapcore.InfoLevel, level.Level())
+}
+
+func TestResolveLevelEmptyForLocalDefaultsToDebug(t *testing.T) {
+ t.Parallel()
+
+ level, err := resolveLevel(Config{Environment: EnvironmentLocal, Level: ""})
+ require.NoError(t, err)
+ assert.Equal(t, zapcore.DebugLevel, level.Level())
+}
+
+func TestBuildConfigByEnvironmentDev(t *testing.T) {
+ t.Setenv("LOG_ENCODING", "")
- logger, err := InitializeLoggerWithError()
+ cfg := buildConfigByEnvironment(EnvironmentDevelopment)
+ assert.Equal(t, "console", cfg.Encoding)
+ assert.True(t, cfg.Development)
+}
+
+func TestBuildConfigByEnvironmentProd(t *testing.T) {
+ t.Setenv("LOG_ENCODING", "")
- assert.NoError(t, err)
- assert.NotNil(t, logger)
+ cfg := buildConfigByEnvironment(EnvironmentProduction)
+ assert.Equal(t, "json", cfg.Encoding)
+ assert.False(t, cfg.Development)
}
-// This test must not call t.Parallel() because it mutates the global log.Writer
-// via log.SetOutput(&buf) and relies on the defer to restore originalOutput.
-func TestInitializeLoggerWithError_InvalidLogLevel(t *testing.T) {
- t.Setenv("ENV_NAME", "production")
- t.Setenv("LOG_LEVEL", "invalid_level")
+func TestResolveEncodingFromEnvVar(t *testing.T) {
+ t.Setenv("LOG_ENCODING", "json")
+ assert.Equal(t, "json", resolveEncoding(EnvironmentLocal))
+
+ t.Setenv("LOG_ENCODING", "console")
+ assert.Equal(t, "console", resolveEncoding(EnvironmentProduction))
+
+ t.Setenv("LOG_ENCODING", "invalid")
+ assert.Equal(t, "console", resolveEncoding(EnvironmentLocal))
+ assert.Equal(t, "json", resolveEncoding(EnvironmentProduction))
+}
- var buf bytes.Buffer
- originalOutput := log.Writer()
- log.SetOutput(&buf)
+func TestResolveLevelFromEnvVar(t *testing.T) {
+ t.Setenv("LOG_LEVEL", "warn")
- defer log.SetOutput(originalOutput)
+ level, err := resolveLevel(Config{Environment: EnvironmentProduction, Level: ""})
+ require.NoError(t, err)
+ assert.Equal(t, zapcore.WarnLevel, level.Level())
+}
- logger, err := InitializeLoggerWithError()
+func TestResolveLevelConfigOverridesEnvVar(t *testing.T) {
+ t.Setenv("LOG_LEVEL", "warn")
- assert.NoError(t, err)
- assert.NotNil(t, logger)
- assert.Contains(t, buf.String(), "Invalid LOG_LEVEL")
- assert.Contains(t, buf.String(), "fallback to InfoLevel")
+ level, err := resolveLevel(Config{Environment: EnvironmentProduction, Level: "error"})
+ require.NoError(t, err)
+ assert.Equal(t, zapcore.ErrorLevel, level.Level(), "Config.Level should take precedence over LOG_LEVEL env var")
}
diff --git a/commons/zap/zap.go b/commons/zap/zap.go
index a624252b..039ec2fd 100644
--- a/commons/zap/zap.go
+++ b/commons/zap/zap.go
@@ -1,152 +1,296 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
-
package zap
import (
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ logpkg "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/runtime"
+ "github.com/LerianStudio/lib-commons/v4/commons/security"
+ "go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
+ "go.uber.org/zap/zapcore"
)
-// ZapWithTraceLogger is a wrapper of otelzap.SugaredLogger.
+// Field is a typed structured logging field (zap alias kept for convenience methods).
+type Field = zap.Field
+
+// Logger is a strict structured logger that implements log.Logger.
//
-// It implements Logger interface.
-// The shutdown function is used to close the logger provider.
-type ZapWithTraceLogger struct {
- Logger *zap.SugaredLogger
- defaultMessageTemplate string
+// It intentionally does not expose printf/line/fatal helpers.
+type Logger struct {
+ logger *zap.Logger
+ atomicLevel zap.AtomicLevel
+ // consoleEncoding is true when the logger uses console encoding.
+ // When true, messages are sanitized to prevent CWE-117 log injection,
+ // since console encoding does not inherently escape control characters
+ // the way JSON encoding does.
+ consoleEncoding bool
}
-// logWithHydration is a helper method to log messages with hydrated arguments using the default message template.
-func (l *ZapWithTraceLogger) logWithHydration(logFunc func(...any), args ...any) {
- logFunc(hydrateArgs(l.defaultMessageTemplate, args)...)
+// Compile-time assertion: *Logger implements logpkg.Logger.
+var _ logpkg.Logger = (*Logger)(nil)
+
+func (l *Logger) must() *zap.Logger {
+ if l == nil || l.logger == nil {
+ return zap.NewNop()
+ }
+
+ return l.logger
}
-// logfWithHydration is a helper method to log formatted messages with hydrated arguments using the default message template.
-func (l *ZapWithTraceLogger) logfWithHydration(logFunc func(string, ...any), format string, args ...any) {
- logFunc(l.defaultMessageTemplate+format, args...)
+// ---------------------------------------------------------------------------
+// log.Logger interface methods
+// ---------------------------------------------------------------------------
+
+// Log implements log.Logger. It dispatches to the appropriate zap level.
+// If ctx carries an active OpenTelemetry span, trace_id and span_id are
+// automatically appended so logs correlate with distributed traces.
+//
+// Unknown levels are treated as LevelInfo (consistent with GoLogger policy).
+func (l *Logger) Log(ctx context.Context, level logpkg.Level, msg string, fields ...logpkg.Field) {
+ zapFields := logFieldsToZap(fields)
+
+ if ctx != nil {
+ if sc := trace.SpanFromContext(ctx).SpanContext(); sc.IsValid() {
+ zapFields = append(zapFields,
+ zap.String("trace_id", sc.TraceID().String()),
+ zap.String("span_id", sc.SpanID().String()),
+ )
+ }
+ }
+
+ // Sanitize message for console encoding (CWE-117 prevention).
+ // JSON encoding handles this via its built-in escaping.
+ safeMsg := l.sanitizeConsoleMsg(msg)
+
+ switch level {
+ case logpkg.LevelDebug:
+ l.must().Debug(safeMsg, zapFields...)
+ case logpkg.LevelInfo:
+ l.must().Info(safeMsg, zapFields...)
+ case logpkg.LevelWarn:
+ l.must().Warn(safeMsg, zapFields...)
+ case logpkg.LevelError:
+ l.must().Error(safeMsg, zapFields...)
+ default:
+ // Unknown level policy: treat as Info. This is consistent across both
+ // GoLogger and zap backends. See log.Level documentation.
+ l.must().Info(safeMsg, zapFields...)
+ }
}
-// Info implements Info Logger interface function.
-func (l *ZapWithTraceLogger) Info(args ...any) {
- l.logWithHydration(l.Logger.Info, args...)
+// With returns a child logger with additional structured fields.
+//
+//nolint:ireturn
+func (l *Logger) With(fields ...logpkg.Field) logpkg.Logger {
+ if l == nil {
+ return &Logger{logger: zap.NewNop()}
+ }
+
+ return &Logger{
+ logger: l.must().With(logFieldsToZap(fields)...),
+ atomicLevel: l.atomicLevel,
+ consoleEncoding: l.consoleEncoding,
+ }
}
-// Infof implements Infof Logger interface function.
-func (l *ZapWithTraceLogger) Infof(format string, args ...any) {
- l.logfWithHydration(l.Logger.Infof, format, args...)
+// WithGroup returns a child logger that nests subsequent fields under a namespace.
+// Empty group names are silently ignored, consistent with GoLogger behavior.
+//
+//nolint:ireturn
+func (l *Logger) WithGroup(name string) logpkg.Logger {
+ if l == nil {
+ return &Logger{logger: zap.NewNop()}
+ }
+
+ if name == "" {
+ return l
+ }
+
+ return &Logger{
+ logger: l.must().With(zap.Namespace(name)),
+ atomicLevel: l.atomicLevel,
+ consoleEncoding: l.consoleEncoding,
+ }
}
-// Infoln implements Infoln Logger interface function.
-func (l *ZapWithTraceLogger) Infoln(args ...any) {
- l.logWithHydration(l.Logger.Infoln, args...)
+// Enabled reports whether the logger would emit a log at the given level.
+func (l *Logger) Enabled(level logpkg.Level) bool {
+ return l.must().Core().Enabled(logLevelToZap(level))
}
-// Error implements Error Logger interface function.
-func (l *ZapWithTraceLogger) Error(args ...any) {
- l.logWithHydration(l.Logger.Error, args...)
+// Sync flushes buffered logs, respecting context cancellation.
+func (l *Logger) Sync(ctx context.Context) error {
+ if ctx == nil {
+ return l.must().Sync()
+ }
+
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+
+ done := make(chan error, 1)
+
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ runtime.HandlePanicValue(ctx, nil, r, "zap", "sync")
+
+ done <- fmt.Errorf("panic during logger sync: %v", r)
+ }
+ }()
+
+ done <- l.must().Sync()
+ }()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-done:
+ return err
+ }
}
-// Errorf implements Errorf Logger interface function.
-func (l *ZapWithTraceLogger) Errorf(format string, args ...any) {
- l.logfWithHydration(l.Logger.Errorf, format, args...)
+// ---------------------------------------------------------------------------
+// Convenience methods (direct zap.Field access for performance-sensitive code)
+// ---------------------------------------------------------------------------
+
+// WithZapFields returns a child logger with additional zap.Field values.
+// Use this when working directly with zap fields for performance.
+func (l *Logger) WithZapFields(fields ...Field) *Logger {
+ if l == nil {
+ return &Logger{logger: zap.NewNop()}
+ }
+
+ return &Logger{
+ logger: l.must().With(fields...),
+ atomicLevel: l.atomicLevel,
+ consoleEncoding: l.consoleEncoding,
+ }
}
-// Errorln implements Errorln Logger interface function.
-func (l *ZapWithTraceLogger) Errorln(args ...any) {
- l.logWithHydration(l.Logger.Errorln, args...)
+// Debug logs a message with debug severity.
+func (l *Logger) Debug(message string, fields ...Field) {
+ l.must().Debug(message, fields...)
}
-// Warn implements Warn Logger interface function.
-func (l *ZapWithTraceLogger) Warn(args ...any) {
- l.logWithHydration(l.Logger.Warn, args...)
+// Info logs a message with info severity.
+func (l *Logger) Info(message string, fields ...Field) {
+ l.must().Info(message, fields...)
}
-// Warnf implements Warnf Logger interface function.
-func (l *ZapWithTraceLogger) Warnf(format string, args ...any) {
- l.logfWithHydration(l.Logger.Warnf, format, args...)
+// Warn logs a message with warn severity.
+func (l *Logger) Warn(message string, fields ...Field) {
+ l.must().Warn(message, fields...)
}
-// Warnln implements Warnln Logger interface function.
-func (l *ZapWithTraceLogger) Warnln(args ...any) {
- l.logWithHydration(l.Logger.Warnln, args...)
+// Error logs a message with error severity.
+func (l *Logger) Error(message string, fields ...Field) {
+ l.must().Error(message, fields...)
}
-// Debug implements Debug Logger interface function.
-func (l *ZapWithTraceLogger) Debug(args ...any) {
- l.logWithHydration(l.Logger.Debug, args...)
+// Raw returns the underlying zap logger.
+func (l *Logger) Raw() *zap.Logger {
+ return l.must()
}
-// Debugf implements Debugf Logger interface function.
-func (l *ZapWithTraceLogger) Debugf(format string, args ...any) {
- l.logfWithHydration(l.Logger.Debugf, format, args...)
+// Level returns the runtime-adjustable level handle for this logger.
+// On a nil receiver, a default AtomicLevel (info) is returned.
+func (l *Logger) Level() zap.AtomicLevel {
+ if l == nil {
+ return zap.NewAtomicLevel()
+ }
+
+ return l.atomicLevel
}
-// Debugln implements Debugln Logger interface function.
-func (l *ZapWithTraceLogger) Debugln(args ...any) {
- l.logWithHydration(l.Logger.Debugln, args...)
+// Any creates a field with any value.
+func Any(key string, value any) Field {
+ return zap.Any(key, value)
}
-// Fatal implements Fatal Logger interface function.
-func (l *ZapWithTraceLogger) Fatal(args ...any) {
- l.logWithHydration(l.Logger.Fatal, args...)
+// String creates a string field.
+func String(key, value string) Field {
+ return zap.String(key, value)
}
-// Fatalf implements Fatalf Logger interface function.
-func (l *ZapWithTraceLogger) Fatalf(format string, args ...any) {
- l.logfWithHydration(l.Logger.Fatalf, format, args...)
+// Int creates an int field.
+func Int(key string, value int) Field {
+ return zap.Int(key, value)
}
-// Fatalln implements Fatalln Logger interface function.
-func (l *ZapWithTraceLogger) Fatalln(args ...any) {
- l.logWithHydration(l.Logger.Fatalln, args...)
+// Bool creates a bool field.
+func Bool(key string, value bool) Field {
+ return zap.Bool(key, value)
}
-// WithFields adds structured context to the logger. It returns a new logger and leaves the original unchanged.
-//
-//nolint:ireturn
-func (l *ZapWithTraceLogger) WithFields(fields ...any) log.Logger {
- newLogger := l.Logger.With(fields...)
+// Duration creates a duration field.
+func Duration(key string, value time.Duration) Field {
+ return zap.Duration(key, value)
+}
- return &ZapWithTraceLogger{
- Logger: newLogger,
- defaultMessageTemplate: l.defaultMessageTemplate,
- }
+// ErrorField creates an error field.
+func ErrorField(err error) Field {
+ return zap.Error(err)
}
-// Sync implements Sync Logger interface function.
-//
-// Sync calls the underlying Core's Sync method, flushing any buffered log entries as well as closing the logger provider used by open telemetry. Applications should take care to call Sync before exiting.
-//
-//nolint:ireturn
-func (l *ZapWithTraceLogger) Sync() error {
- err := l.Logger.Sync()
- if err != nil {
- return err
- }
+// ---------------------------------------------------------------------------
+// Internal conversion helpers
+// ---------------------------------------------------------------------------
- return nil
+// logLevelToZap converts a log.Level to a zapcore.Level.
+func logLevelToZap(level logpkg.Level) zapcore.Level {
+ switch level {
+ case logpkg.LevelDebug:
+ return zapcore.DebugLevel
+ case logpkg.LevelInfo:
+ return zapcore.InfoLevel
+ case logpkg.LevelWarn:
+ return zapcore.WarnLevel
+ case logpkg.LevelError:
+ return zapcore.ErrorLevel
+ default:
+ return zapcore.InfoLevel
+ }
}
-// WithDefaultMessageTemplate sets the default message template for the logger.
-// Returns a new logger instance without mutating the original.
-//
-//nolint:ireturn
-func (l *ZapWithTraceLogger) WithDefaultMessageTemplate(message string) log.Logger {
- return &ZapWithTraceLogger{
- Logger: l.Logger,
- defaultMessageTemplate: message,
+// redactedValue is the placeholder used for sensitive field values in log output.
+const redactedValue = "[REDACTED]"
+
+// consoleControlCharReplacer neutralizes control characters that can split log
+// lines or forge entries in console-encoded output (CWE-117). JSON encoding
+// handles this automatically via its escaping rules.
+var consoleControlCharReplacer = strings.NewReplacer(
+ "\n", `\n`,
+ "\r", `\r`,
+ "\t", `\t`,
+ "\x00", `\0`,
+)
+
+// sanitizeConsoleMsg escapes control characters in a message string
+// when the logger is configured with console encoding.
+func (l *Logger) sanitizeConsoleMsg(msg string) string {
+ if l != nil && l.consoleEncoding {
+ return consoleControlCharReplacer.Replace(msg)
}
-}
-func hydrateArgs(defaultTemplateMsg string, args []any) []any {
- argsHydration := make([]any, len(args)+1)
- argsHydration[0] = defaultTemplateMsg
+ return msg
+}
- for i, arg := range args {
- argsHydration[i+1] = arg
+// logFieldsToZap converts log.Field values to zap.Field values.
+// Sensitive field keys (matched via security.IsSensitiveField) are redacted.
+func logFieldsToZap(fields []logpkg.Field) []zap.Field {
+ zapFields := make([]zap.Field, len(fields))
+ for i, f := range fields {
+ if security.IsSensitiveField(f.Key) {
+ zapFields[i] = zap.String(f.Key, redactedValue)
+ } else {
+ zapFields[i] = zap.Any(f.Key, f.Value)
+ }
}
- return argsHydration
+ return zapFields
}
diff --git a/commons/zap/zap_test.go b/commons/zap/zap_test.go
index ca249351..1ea25401 100644
--- a/commons/zap/zap_test.go
+++ b/commons/zap/zap_test.go
@@ -1,191 +1,659 @@
-// Copyright (c) 2026 Lerian Studio. All rights reserved.
-// Use of this source code is governed by the Elastic License 2.0
-// that can be found in the LICENSE file.
+//go:build unit
package zap
import (
- "go.uber.org/zap"
+ "context"
+ "errors"
+ "strings"
"testing"
+ "time"
+
+ logpkg "github.com/LerianStudio/lib-commons/v4/commons/log"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "go.opentelemetry.io/otel/trace"
+ "go.uber.org/zap"
+ "go.uber.org/zap/zapcore"
+ "go.uber.org/zap/zaptest/observer"
)
-func TestZap(t *testing.T) {
- t.Run("log with hydration", func(t *testing.T) {
- l := &ZapWithTraceLogger{}
- l.logWithHydration(func(a ...any) {}, "")
- })
+func newObservedLogger(level zapcore.Level) (*Logger, *observer.ObservedLogs) {
+ core, observed := observer.New(level)
- t.Run("logf with hydration", func(t *testing.T) {
- l := &ZapWithTraceLogger{}
- l.logfWithHydration(func(s string, a ...any) {}, "", "")
- })
+ return &Logger{logger: zap.New(core)}, observed
+}
+
+// newBufferedLogger creates a Logger that writes JSON to a buffer for output
+// inspection (e.g., verifying CWE-117 sanitization in serialized output).
+func newBufferedLogger(level zapcore.Level) (*Logger, *strings.Builder) {
+ buf := &strings.Builder{}
+ ws := zapcore.AddSync(buf)
+
+ encoderCfg := zap.NewProductionEncoderConfig()
+ encoderCfg.TimeKey = "" // omit timestamp for deterministic test output
+ core := zapcore.NewCore(
+ zapcore.NewJSONEncoder(encoderCfg),
+ ws,
+ level,
+ )
+
+ return &Logger{logger: zap.New(core)}, buf
+}
- t.Run("ZapWithTraceLogger info", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+func TestLoggerNilReceiverFallsBackToNop(t *testing.T) {
+ var nilLogger *Logger
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Info(func(s string, a ...any) {}, "", "")
+ assert.NotPanics(t, func() {
+ nilLogger.Info("message")
})
+}
- t.Run("ZapWithTraceLogger infof", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+func TestLoggerNilUnderlyingFallsBackToNop(t *testing.T) {
+ logger := &Logger{}
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Infof("", "")
+ assert.NotPanics(t, func() {
+ logger.Info("message")
})
+}
- t.Run("ZapWithTraceLogger infoln", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+func TestStructuredLoggingMethods(t *testing.T) {
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Infoln("", "")
- })
+ logger.Debug("debug message")
+ logger.Info("info message", String("request_id", "req-1"))
+ logger.Warn("warn message")
+ logger.Error("error message", ErrorField(errors.New("boom")))
- t.Run("ZapWithTraceLogger Error", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ entries := observed.All()
+ require.Len(t, entries, 4)
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Error("", "")
- })
+ assert.Equal(t, zapcore.DebugLevel, entries[0].Level)
+ assert.Equal(t, "debug message", entries[0].Message)
- t.Run("ZapWithTraceLogger Errorf", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ assert.Equal(t, zapcore.InfoLevel, entries[1].Level)
+ assert.Equal(t, "info message", entries[1].Message)
+ assert.Equal(t, "req-1", entries[1].ContextMap()["request_id"])
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Errorf("", "")
- })
+ assert.Equal(t, zapcore.WarnLevel, entries[2].Level)
+ assert.Equal(t, "warn message", entries[2].Message)
- t.Run("ZapWithTraceLogger Errorln", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ assert.Equal(t, zapcore.ErrorLevel, entries[3].Level)
+ assert.Equal(t, "error message", entries[3].Message)
+}
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Errorln("", "")
- })
+func TestWithZapFieldsAddsFieldsWithoutMutatingParent(t *testing.T) {
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+ child := logger.WithZapFields(String("tenant_id", "t-1"))
- t.Run("ZapWithTraceLogger Warn", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ logger.Info("parent")
+ child.Info("child")
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Warn("", "")
- })
+ entries := observed.All()
+ require.Len(t, entries, 2)
- t.Run("ZapWithTraceLogger Warnf", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ _, parentHasTenant := entries[0].ContextMap()["tenant_id"]
+ assert.False(t, parentHasTenant)
+ assert.Equal(t, "t-1", entries[1].ContextMap()["tenant_id"])
+}
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Warnf("", "")
- })
+func TestSyncReturnsNoErrorForHealthyLogger(t *testing.T) {
+ logger, _ := newObservedLogger(zapcore.DebugLevel)
- t.Run("ZapWithTraceLogger Warnln", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ require.NoError(t, logger.Sync(context.Background()))
+}
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Warnln("", "")
- })
+// errorSink is a zapcore.WriteSyncer that always returns an error on Sync.
+type errorSink struct{}
- t.Run("ZapWithTraceLogger Debug", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+func (e errorSink) Write(p []byte) (int, error) { return len(p), nil }
+func (e errorSink) Sync() error { return errors.New("simulated sync failure") }
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Debug("", "")
- })
+// panicSink is a zapcore.WriteSyncer that panics on Sync, for testing
+// the panic-recovery branch in Logger.Sync.
+type panicSink struct{}
- t.Run("ZapWithTraceLogger Debugf", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+func (p panicSink) Write(b []byte) (int, error) { return len(b), nil }
+func (p panicSink) Sync() error { panic("boom from sync") }
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Debugf("", "")
- })
+func TestSyncReturnsErrorFromFailingSink(t *testing.T) {
+ core := zapcore.NewCore(
+ zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()),
+ errorSink{},
+ zapcore.DebugLevel,
+ )
+ logger := &Logger{logger: zap.New(core)}
- t.Run("ZapWithTraceLogger Debugln", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ err := logger.Sync(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "simulated sync failure")
+}
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Debugln("", "")
- })
+func TestSyncRecoversPanicFromSink(t *testing.T) {
+ core := zapcore.NewCore(
+ zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()),
+ panicSink{},
+ zapcore.DebugLevel,
+ )
+ logger := &Logger{logger: zap.New(core)}
+
+ err := logger.Sync(context.Background())
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "panic during logger sync")
+}
+
+func TestFieldHelpers(t *testing.T) {
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+ logger.Info(
+ "helpers",
+ String("s", "value"),
+ Int("i", 42),
+ Bool("b", true),
+ Duration("d", 2*time.Second),
+ )
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ ctx := entries[0].ContextMap()
+
+ assert.Equal(t, "value", ctx["s"])
+ assert.Equal(t, int64(42), ctx["i"])
+ assert.Equal(t, true, ctx["b"])
+ assert.Equal(t, 2*time.Second, ctx["d"])
+}
+
+// ===========================================================================
+// CWE-117: Log Injection Prevention for Zap Adapter
+//
+// Zap serializes output as JSON, which inherently escapes control characters
+// in string values. These tests verify that behavior is preserved and that
+// injection attempts cannot split log lines or forge entries.
+// ===========================================================================
+
+// TestCWE117_ZapMessageNewlineInjection verifies that newlines in log messages
+// are properly escaped in JSON output, preventing log line splitting.
+func TestCWE117_ZapMessageNewlineInjection(t *testing.T) {
+ tests := []struct {
+ name string
+ message string
+ }{
+ {
+ name: "LF in message",
+ message: "legitimate\n{\"level\":\"error\",\"msg\":\"forged entry\"}",
+ },
+ {
+ name: "CR in message",
+ message: "legitimate\r{\"level\":\"error\",\"msg\":\"forged entry\"}",
+ },
+ {
+ name: "CRLF in message",
+ message: "legitimate\r\n{\"level\":\"error\",\"msg\":\"forged entry\"}",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+ logger.Info(tt.message)
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ // JSON output from zap should be a single line per entry
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1,
+ "CWE-117: zap JSON output must be a single line, got %d lines:\n%s", len(lines), out)
+
+ // The raw newline characters should not appear in the JSON output
+ // (JSON encoder escapes them as \n, \r)
+ assert.NotContains(t, out, "forged entry\"}",
+ "forged JSON entry must not appear as a separate parseable line")
+ })
+ }
+}
+
+// TestCWE117_ZapFieldValueInjection verifies field values with newlines
+// are escaped by zap's JSON encoder.
+func TestCWE117_ZapFieldValueInjection(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
- t.Run("ZapWithTraceLogger WithFields", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ maliciousValue := "user123\n{\"level\":\"error\",\"msg\":\"ADMIN ACCESS GRANTED\"}"
+ logger.Info("login", String("user_id", maliciousValue))
+ require.NoError(t, logger.Sync(context.Background()))
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.WithFields("", "")
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1,
+ "CWE-117: field value injection must not create extra JSON lines")
+}
+
+// TestCWE117_ZapFieldNameInjection verifies that field names with control
+// characters are escaped by zap's JSON encoder.
+func TestCWE117_ZapFieldNameInjection(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+
+ // Field name with embedded newline
+ logger.Info("event", zap.String("key\ninjected", "value"))
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1,
+ "CWE-117: field name injection must not create extra JSON lines")
+}
+
+// TestCWE117_ZapNullByteInMessage verifies null bytes in messages are handled.
+func TestCWE117_ZapNullByteInMessage(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+ logger.Info("before\x00after")
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1, "null byte must not split log output")
+}
+
+// TestCWE117_ZapANSIEscapeInMessage verifies ANSI escapes don't break output.
+func TestCWE117_ZapANSIEscapeInMessage(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+ logger.Info("normal \x1b[31mRED\x1b[0m normal")
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1, "ANSI escape must not split log output")
+}
+
+// TestCWE117_ZapTabInMessage verifies tab characters are handled in JSON output.
+func TestCWE117_ZapTabInMessage(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+ logger.Info("col1\tcol2\tcol3")
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1, "tabs must not split log output")
+ // JSON encoder escapes tabs as \t in the JSON string
+ assert.Contains(t, out, "col1")
+ assert.Contains(t, out, "col2")
+}
+
+// TestCWE117_ZapWithPreservesSanitization verifies that child loggers created
+// via With() still properly handle injection attempts.
+func TestCWE117_ZapWithPreservesSanitization(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+ child := logger.WithZapFields(String("session", "sess\n{\"forged\":true}"))
+ child.Info("child message")
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1,
+ "CWE-117: With() must not allow field injection to split lines")
+}
+
+// TestCWE117_ZapMultipleVectorsSimultaneously combines multiple attack vectors.
+func TestCWE117_ZapMultipleVectorsSimultaneously(t *testing.T) {
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+
+ // Message with injection
+ msg := "event\n{\"level\":\"error\",\"msg\":\"forged\"}\ttab\r\nmore"
+ // Fields with injection
+ logger.Info(msg,
+ zap.String("user\nfake", "val\nfake"),
+ zap.String("safe_key", "safe_val"))
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1,
+ "CWE-117: combined attack vectors must not create multiple JSON lines")
+}
+
+// ===========================================================================
+// Zap Level Filtering Tests
+// ===========================================================================
+
+// TestZapLevelFiltering verifies that the observed logger correctly filters
+// by log level.
+func TestZapLevelFiltering(t *testing.T) {
+ t.Run("info level suppresses debug", func(t *testing.T) {
+ logger, observed := newObservedLogger(zapcore.InfoLevel)
+ logger.Debug("should be suppressed")
+ logger.Info("should appear")
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ assert.Equal(t, "should appear", entries[0].Message)
})
- t.Run("ZapWithTraceLogger Sync)", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ t.Run("error level suppresses warn and below", func(t *testing.T) {
+ logger, observed := newObservedLogger(zapcore.ErrorLevel)
+ logger.Debug("suppressed")
+ logger.Info("suppressed")
+ logger.Warn("suppressed")
+ logger.Error("visible")
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.Sync()
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ assert.Equal(t, "visible", entries[0].Message)
})
+}
+
+// TestZapRawReturnsUnderlyingLogger verifies Raw() returns the inner zap.Logger.
+func TestZapRawReturnsUnderlyingLogger(t *testing.T) {
+ logger, _ := newObservedLogger(zapcore.DebugLevel)
+ raw := logger.Raw()
+ assert.NotNil(t, raw)
+}
+
+// TestZapRawOnNilReturnsNop verifies Raw() on nil returns nop logger.
+func TestZapRawOnNilReturnsNop(t *testing.T) {
+ var logger *Logger
+ raw := logger.Raw()
+ assert.NotNil(t, raw, "Raw() on nil logger should return nop, not nil")
+}
+
+// TestZapErrorFieldHelper verifies the ErrorField helper.
+func TestZapErrorFieldHelper(t *testing.T) {
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+ testErr := errors.New("test error")
+ logger.Error("failed", ErrorField(testErr))
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ assert.Equal(t, "test error", entries[0].ContextMap()["error"].(string))
+}
+
+// TestZapAnyFieldHelper verifies the Any helper with various types.
+func TestZapAnyFieldHelper(t *testing.T) {
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+ logger.Info("test",
+ Any("slice", []string{"a", "b"}),
+ Any("map", map[string]int{"x": 1}))
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ // Verify fields exist (exact format depends on zap encoding)
+ ctx := entries[0].ContextMap()
+ assert.NotNil(t, ctx["slice"])
+ assert.NotNil(t, ctx["map"])
+}
+
+// ===========================================================================
+// log.Logger interface coverage
+// ===========================================================================
+
+func TestLogAllLevels(t *testing.T) {
+ t.Parallel()
- t.Run("ZapWithTraceLogger WithDefaultMessageTemplate)", func(t *testing.T) {
- logger, _ := zap.NewDevelopment()
- sugar := logger.Sugar()
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
- zapLogger := &ZapWithTraceLogger{
- Logger: sugar,
- defaultMessageTemplate: "default template: ",
- }
- zapLogger.WithDefaultMessageTemplate("")
+ logger.Log(context.Background(), logpkg.LevelDebug, "debug via Log")
+ logger.Log(context.Background(), logpkg.LevelInfo, "info via Log")
+ logger.Log(context.Background(), logpkg.LevelWarn, "warn via Log")
+ logger.Log(context.Background(), logpkg.LevelError, "error via Log")
+
+ entries := observed.All()
+ require.Len(t, entries, 4)
+
+ assert.Equal(t, zapcore.DebugLevel, entries[0].Level)
+ assert.Equal(t, zapcore.InfoLevel, entries[1].Level)
+ assert.Equal(t, zapcore.WarnLevel, entries[2].Level)
+ assert.Equal(t, zapcore.ErrorLevel, entries[3].Level)
+}
+
+func TestLogDefaultLevel(t *testing.T) {
+ t.Parallel()
+
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+
+ // Use an undefined level value to hit the default case
+ logger.Log(context.Background(), logpkg.Level(99), "default level")
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ assert.Equal(t, zapcore.InfoLevel, entries[0].Level, "unknown level should default to Info")
+}
+
+func TestLogWithNilContext(t *testing.T) {
+ t.Parallel()
+
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+
+ assert.NotPanics(t, func() {
+ //nolint:staticcheck // intentionally passing nil context
+ logger.Log(nil, logpkg.LevelInfo, "nil ctx message")
})
- t.Run("ZapWithTraceLogger WithDefaultMessageTemplate)", func(t *testing.T) {
- hydrateArgs("", []any{})
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ assert.Equal(t, "nil ctx message", entries[0].Message)
+ // No trace_id/span_id should be present
+ _, hasTrace := entries[0].ContextMap()["trace_id"]
+ assert.False(t, hasTrace)
+}
+
+func TestLogWithOTelSpanInjectsTraceFields(t *testing.T) {
+ t.Parallel()
+
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+
+ // Create a span context with valid trace ID and span ID
+ traceID, _ := trace.TraceIDFromHex("0af7651916cd43dd8448eb211c80319c")
+ spanID, _ := trace.SpanIDFromHex("b7ad6b7169203331")
+ sc := trace.NewSpanContext(trace.SpanContextConfig{
+ TraceID: traceID,
+ SpanID: spanID,
+ TraceFlags: trace.FlagsSampled,
})
+ ctx := trace.ContextWithSpanContext(context.Background(), sc)
+
+ logger.Log(ctx, logpkg.LevelInfo, "traced message", logpkg.String("request_id", "req-42"))
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+
+ cm := entries[0].ContextMap()
+ assert.Equal(t, traceID.String(), cm["trace_id"])
+ assert.Equal(t, spanID.String(), cm["span_id"])
+ assert.Equal(t, "req-42", cm["request_id"])
+}
+
+func TestLogWithInvalidSpanDoesNotInjectTraceFields(t *testing.T) {
+ t.Parallel()
+
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+
+ // Background context has no active span — SpanContext is invalid
+ logger.Log(context.Background(), logpkg.LevelInfo, "no span")
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+
+ _, hasTrace := entries[0].ContextMap()["trace_id"]
+ assert.False(t, hasTrace)
+}
+
+func TestWithReturnsChildLogger(t *testing.T) {
+ t.Parallel()
+
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+
+ child := logger.With(logpkg.String("component", "auth"))
+ child.Log(context.Background(), logpkg.LevelInfo, "child msg")
+
+ // Parent should not have the field
+ logger.Log(context.Background(), logpkg.LevelInfo, "parent msg")
+
+ entries := observed.All()
+ require.Len(t, entries, 2)
+
+ assert.Equal(t, "auth", entries[0].ContextMap()["component"])
+ _, parentHas := entries[1].ContextMap()["component"]
+ assert.False(t, parentHas)
+}
+
+func TestWithGroupNamespacesFields(t *testing.T) {
+ t.Parallel()
+
+ // Use a buffered JSON logger so we can inspect the serialized output
+ // and verify the namespaced field structure.
+ logger, buf := newBufferedLogger(zapcore.DebugLevel)
+
+ grouped := logger.WithGroup("http")
+ grouped.Log(context.Background(), logpkg.LevelInfo, "grouped msg", logpkg.String("method", "GET"))
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ // The JSON output should contain the "http" namespace wrapping "method"
+ assert.Contains(t, out, `"http"`)
+ assert.Contains(t, out, `"method"`)
+ assert.Contains(t, out, `"GET"`)
+ assert.Contains(t, out, "grouped msg")
+}
+
+func TestEnabledReportsCorrectly(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ coreLevel zapcore.Level
+ checkLvl logpkg.Level
+ expected bool
+ }{
+ {"debug enabled at debug", zapcore.DebugLevel, logpkg.LevelDebug, true},
+ {"info enabled at debug", zapcore.DebugLevel, logpkg.LevelInfo, true},
+ {"warn enabled at debug", zapcore.DebugLevel, logpkg.LevelWarn, true},
+ {"error enabled at debug", zapcore.DebugLevel, logpkg.LevelError, true},
+ {"debug disabled at info", zapcore.InfoLevel, logpkg.LevelDebug, false},
+ {"info enabled at info", zapcore.InfoLevel, logpkg.LevelInfo, true},
+ {"debug disabled at error", zapcore.ErrorLevel, logpkg.LevelDebug, false},
+ {"info disabled at error", zapcore.ErrorLevel, logpkg.LevelInfo, false},
+ {"warn disabled at error", zapcore.ErrorLevel, logpkg.LevelWarn, false},
+ {"error enabled at error", zapcore.ErrorLevel, logpkg.LevelError, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ logger, _ := newObservedLogger(tt.coreLevel)
+ assert.Equal(t, tt.expected, logger.Enabled(tt.checkLvl))
+ })
+ }
+}
+
+func TestSyncWithCancelledContext(t *testing.T) {
+ t.Parallel()
+
+ logger, _ := newObservedLogger(zapcore.DebugLevel)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel() // cancel immediately
+
+ err := logger.Sync(ctx)
+ require.Error(t, err)
+ assert.ErrorIs(t, err, context.Canceled)
+}
+
+func TestLevelReturnsAtomicLevel(t *testing.T) {
+ t.Parallel()
+
+ al := zap.NewAtomicLevelAt(zapcore.WarnLevel)
+ logger := &Logger{
+ logger: zap.NewNop(),
+ atomicLevel: al,
+ }
+
+ assert.Equal(t, zapcore.WarnLevel, logger.Level().Level())
+}
+
+func TestLevelOnNilReceiverReturnsDefault(t *testing.T) {
+ t.Parallel()
+
+ var logger *Logger
+ // Should not panic and should return a usable default level.
+ level := logger.Level()
+ assert.Equal(t, zapcore.InfoLevel, level.Level(),
+ "nil receiver should return default AtomicLevel (info)")
+}
+
+func TestWithGroupEmptyNameReturnsReceiver(t *testing.T) {
+ t.Parallel()
+
+ logger, _ := newObservedLogger(zapcore.DebugLevel)
+
+ // Empty name should return the same logger instance (no namespace created).
+ same := logger.WithGroup("")
+ assert.Equal(t, logger, same, "WithGroup(\"\") should return the same logger")
+}
+
+func TestSensitiveFieldRedaction(t *testing.T) {
+ t.Parallel()
+
+ logger, observed := newObservedLogger(zapcore.DebugLevel)
+ logger.Log(context.Background(), logpkg.LevelInfo, "login",
+ logpkg.String("password", "super_secret"),
+ logpkg.String("api_key", "key-12345"),
+ logpkg.String("user_id", "u-42"),
+ )
+
+ entries := observed.All()
+ require.Len(t, entries, 1)
+ ctx := entries[0].ContextMap()
+
+ assert.Equal(t, "[REDACTED]", ctx["password"],
+ "password field must be redacted")
+ assert.Equal(t, "[REDACTED]", ctx["api_key"],
+ "api_key field must be redacted")
+ assert.Equal(t, "u-42", ctx["user_id"],
+ "non-sensitive fields must pass through")
+}
+
+func TestConsoleEncodingSanitizesMessages(t *testing.T) {
+ // Create a console-encoded logger and verify newlines are sanitized
+ buf := &strings.Builder{}
+ ws := zapcore.AddSync(buf)
+
+ encoderCfg := zap.NewDevelopmentEncoderConfig()
+ encoderCfg.TimeKey = ""
+ core := zapcore.NewCore(
+ zapcore.NewConsoleEncoder(encoderCfg),
+ ws,
+ zapcore.DebugLevel,
+ )
+
+ logger := &Logger{
+ logger: zap.New(core),
+ consoleEncoding: true,
+ }
+
+ logger.Log(context.Background(), logpkg.LevelInfo, "line1\nline2\rline3")
+ require.NoError(t, logger.Sync(context.Background()))
+
+ out := buf.String()
+ lines := strings.Split(strings.TrimSpace(out), "\n")
+ assert.Len(t, lines, 1,
+ "console output with injection attempt must remain a single line, got: %q", out)
+}
+
+func TestLogLevelToZapConversions(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ input logpkg.Level
+ expected zapcore.Level
+ }{
+ {logpkg.LevelDebug, zapcore.DebugLevel},
+ {logpkg.LevelInfo, zapcore.InfoLevel},
+ {logpkg.LevelWarn, zapcore.WarnLevel},
+ {logpkg.LevelError, zapcore.ErrorLevel},
+ {logpkg.Level(42), zapcore.InfoLevel}, // default
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input.String(), func(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, tt.expected, logLevelToZap(tt.input))
+ })
+ }
}
diff --git a/docs/PROJECT_RULES.md b/docs/PROJECT_RULES.md
index f62679a2..9884f81a 100644
--- a/docs/PROJECT_RULES.md
+++ b/docs/PROJECT_RULES.md
@@ -1,6 +1,6 @@
# Project Rules - lib-commons
-This document defines the coding standards, architecture patterns, and development guidelines for the `lib-commons` library.
+This document defines the coding standards, architecture patterns, and development guidelines for the unified `lib-commons` library.
## Table of Contents
@@ -23,17 +23,43 @@ This document defines the coding standards, architecture patterns, and developme
```text
lib-commons/
-├── commons/ # All library packages
-│ ├── {package}/ # Feature package
-│ │ ├── {package}.go # Main implementation
-│ │ ├── {package}_test.go # Unit tests
-│ │ └── doc.go # Package documentation (optional)
-│ ├── context.go # Root-level utilities
-│ ├── errors.go # Error definitions
-│ └── utils.go # Utility functions
-├── docs/ # Documentation
-├── scripts/ # Build/test scripts
-└── go.mod # Module definition
+├── commons/ # All library packages
+│ ├── assert/ # Production-safe assertions with telemetry
+│ ├── backoff/ # Exponential backoff with jitter
+│ ├── circuitbreaker/ # Circuit breaker manager and health checker
+│ ├── constants/ # Shared constants (headers, errors, pagination)
+│ ├── cron/ # Cron expression parsing and scheduling
+│ ├── crypto/ # Hashing and symmetric encryption
+│ ├── errgroup/ # Goroutine coordination with panic recovery
+│ ├── jwt/ # HMAC-based JWT signing and verification
+│ ├── license/ # License validation and enforcement
+│ ├── log/ # Logging abstraction (Logger interface)
+│ ├── mongo/ # MongoDB connector
+│ ├── net/http/ # Fiber-oriented HTTP helpers and middleware
+│ │ └── ratelimit/ # Redis-backed rate limit storage
+│ ├── opentelemetry/ # Telemetry bootstrap, propagation, redaction
+│ │ └── metrics/ # Metric factory and fluent builders
+│ ├── pointers/ # Pointer conversion helpers
+│ ├── postgres/ # PostgreSQL connector with migrations
+│ ├── rabbitmq/ # RabbitMQ connector
+│ ├── redis/ # Redis connector (standalone/sentinel/cluster)
+│ ├── runtime/ # Panic recovery, metrics, safe goroutine wrappers
+│ ├── safe/ # Panic-free math/regex/slice operations
+│ ├── security/ # Sensitive field detection and handling
+│ ├── server/ # Graceful shutdown and lifecycle (ServerManager)
+│ ├── shell/ # Makefile includes and shell utilities
+│ ├── transaction/ # Typed transaction validation/posting primitives
+│ ├── zap/ # Zap logging adapter
+│ ├── app.go # Application bootstrap helpers
+│ ├── context.go # Context utilities
+│ ├── errors.go # Error definitions
+│ ├── os.go # OS utilities
+│ ├── stringUtils.go # String utilities
+│ ├── time.go # Time utilities
+│ └── utils.go # General utility functions
+├── docs/ # Documentation
+├── reports/ # Test and coverage reports
+└── go.mod # Module definition (v4)
```
### Package Design Principles
@@ -42,17 +68,19 @@ lib-commons/
2. **Minimal Dependencies**: Packages should minimize external dependencies
3. **Interface-Driven**: Define interfaces for testability and flexibility
4. **Zero Business Logic**: This is a utility library - no domain/business logic
+5. **Nil-Safe and Concurrency-Safe**: Keep behavior safe by default
+6. **Explicit Error Returns**: Prefer error returns over panic paths
### Naming Conventions
| Type | Convention | Example |
|------|------------|---------|
-| Package | lowercase, single word preferred | `postgres`, `redis`, `poolmanager` |
-| Files | snake_case matching content | `pool_manager_pg.go` |
-| Public Functions | PascalCase, descriptive | `NewPostgresConnection` |
+| Package | lowercase, single word preferred | `postgres`, `redis`, `circuitbreaker` |
+| Files | snake_case or camelCase matching content | `pool_manager_pg.go`, `stringUtils.go` |
+| Public Functions | PascalCase, descriptive | `NewClient`, `ServeReverseProxy` |
| Private Functions | camelCase | `validateConfig` |
-| Interfaces | -er suffix or descriptive | `Logger`, `ConnectionPool` |
-| Constants | PascalCase or UPPER_SNAKE_CASE | `DefaultTimeout`, `MAX_RETRIES` |
+| Interfaces | -er suffix or descriptive | `Logger`, `Manager`, `LockManager` |
+| Constants | PascalCase | `DefaultTimeout`, `LevelInfo` |
---
@@ -60,8 +88,24 @@ lib-commons/
### Go Version
-- **Minimum**: Go 1.24.0
+- **Minimum**: Go 1.25.7
- Keep `go.mod` updated with latest stable Go version
+- Module path: `github.com/LerianStudio/lib-commons/v4`
+
+### Build Tags
+
+- Unit test files **MUST** have `//go:build unit` as the first line
+- Integration test files **MUST** have `//go:build integration` as the first line
+
+```go
+//go:build unit
+
+package mypackage
+
+import "testing"
+
+func TestMyFunc(t *testing.T) { ... }
+```
### Imports Organization
@@ -77,7 +121,7 @@ import (
"go.uber.org/zap"
// Internal packages
- "github.com/LerianStudio/lib-commons/v2/commons/log"
+ "github.com/LerianStudio/lib-commons/v4/commons/log"
)
```
@@ -188,7 +232,17 @@ result, _ := doSomething()
- **Minimum Coverage**: 80% for new packages
- **Critical Paths**: 100% coverage for error handling paths
-- **Run Coverage**: `make cover`
+- **Run Coverage**: `make coverage-unit` or `make coverage-integration`
+- **Coverage Exclusions**: Defined in `.ignorecoverunit` (e.g., `*_mock.go`)
+
+### Build Tags
+
+All test files **MUST** include the appropriate build tag as the first line:
+
+| Type | Build Tag | Example |
+|------|-----------|---------|
+| Unit Tests | `//go:build unit` | All `_test.go` files |
+| Integration Tests | `//go:build integration` | All `_integration_test.go` files |
### Test File Naming
@@ -196,7 +250,15 @@ result, _ := doSomething()
|------|---------|---------|
| Unit Tests | `{file}_test.go` | `config_test.go` |
| Integration | `{file}_integration_test.go` | `postgres_integration_test.go` |
-| Benchmarks | In `_test.go` files | `BenchmarkXxx` |
+| Examples | `{feature}_example_test.go` | `cursor_example_test.go` |
+| Benchmarks | In `_test.go` or `benchmark_test.go` | `BenchmarkXxx` |
+
+### Integration Test Conventions
+
+- Test function names **MUST** start with `TestIntegration_` (e.g., `TestIntegration_MyFeature_Works`)
+- Integration tests use `testcontainers-go` to spin up ephemeral containers
+- Docker is required to run integration tests
+- Integration tests run sequentially (`-p=1`) to avoid Docker container conflicts
### Test Patterns
@@ -241,6 +303,7 @@ func TestConfig_Validate(t *testing.T) {
- Use `go.uber.org/mock` for interface mocking
- Define interfaces at point of use for testability
- Prefer dependency injection over global state
+- Mock files follow the `{type}_mock.go` pattern
---
@@ -275,6 +338,11 @@ func (c *Client) Connect(ctx context.Context) error {
- Update `README.md` API Reference when adding public APIs
- Include usage examples for new packages
+### Migration Awareness
+
+- If a task touches renamed/removed v1 symbols, update `MIGRATION_MAP.md`
+- If a task changes package-level behavior or API expectations, update `README.md`
+
---
## Dependencies
@@ -283,18 +351,22 @@ func (c *Client) Connect(ctx context.Context) error {
| Category | Allowed Packages |
|----------|-----------------|
-| Database | `pgx/v5`, `mongo-driver`, `go-redis/v9` |
+| Database | `pgx/v5`, `mongo-driver`, `go-redis/v9`, `dbresolver/v2`, `golang-migrate/v4` |
| Messaging | `amqp091-go` |
+| HTTP | `gofiber/fiber/v2` |
| Logging | `zap`, internal `log` package |
-| Testing | `testify`, `gomock`, `miniredis` |
-| Observability | `opentelemetry/*` |
-| Utilities | `google/uuid`, `shopspring/decimal` |
+| Testing | `testify`, `go.uber.org/mock`, `miniredis/v2` |
+| Observability | `opentelemetry/*`, `otelzap` |
+| Utilities | `google/uuid`, `shopspring/decimal`, `go-playground/validator/v10` |
+| Resilience | `sony/gobreaker`, `go-redsync/v4` |
+| Security | `golang.org/x/oauth2`, `google.golang.org/api` |
+| System | `shirou/gopsutil`, `joho/godotenv` |
### Forbidden Dependencies
-- `io/ioutil` - Deprecated, use `io` and `os`
+- `io/ioutil` - Deprecated, use `io` and `os` (enforced by `depguard` linter)
- Direct database drivers without connection pooling
-- Logging packages other than `zap` (use internal wrapper)
+- Logging packages other than `zap` (use internal `log` wrapper)
### Adding Dependencies
@@ -310,24 +382,32 @@ func (c *Client) Connect(ctx context.Context) error {
### Credential Handling
1. **Never hardcode credentials** - Use environment variables
-2. **Never log credentials** - Use obfuscation for sensitive fields
+2. **Never log credentials** - Use the `Redactor` for sensitive fields
3. **Mask in errors** - Never include credentials in error messages
```go
-// Good - mask DSN
-func MaskDSN(dsn string) string {
- return regexp.MustCompile(`password=[^\s]+`).ReplaceAllString(dsn, "password=***")
-}
-
-// Bad - exposes password
-log.Errorf("failed to connect: %s", dsn)
+// Use the built-in Redactor for sensitive data
+redactor := opentelemetry.NewDefaultRedactor()
+safeValue := redactor.Redact(sensitiveField)
```
+### Sensitive Field Detection
+
+- Use `commons/security` for sensitive field detection and handling
+- Use `commons/opentelemetry.Redactor` with `RedactionRule` patterns
+- Constructors: `NewDefaultRedactor()` and `NewRedactor(rules, mask)`
+
### Input Validation
1. Validate all external inputs
2. Use parameterized queries - never string concatenation
3. Sanitize user-provided identifiers
+4. Use `go-playground/validator/v10` for struct validation
+
+### Log Injection Prevention
+
+- Use `commons/log/sanitizer.go` for log-injection prevention
+- Never interpolate untrusted input into log messages without sanitization
### Environment Variables
@@ -343,7 +423,12 @@ log.Errorf("failed to connect: %s", dsn)
- **Tool**: `golangci-lint` v2
- **Config**: `.golangci.yml`
-- **Run**: `make lint`
+- **Run**: `make lint` (read-only check) or `make lint-fix` (auto-fix)
+- **Performance**: Optional `perfsprint` checks (install separately)
+
+### Enabled Linters
+
+`bodyclose`, `depguard`, `dogsled`, `dupword`, `errchkjson`, `gocognit`, `gocyclo`, `loggercheck`, `misspell`, `nakedret`, `nilerr`, `nolintlint`, `prealloc`, `predeclared`, `reassign`, `revive`, `staticcheck`, `thelper`, `tparallel`, `unconvert`, `unparam`, `usestdlibvars`, `wastedassign`, `wsl_v5`
### Formatting
@@ -354,12 +439,38 @@ log.Errorf("failed to connect: %s", dsn)
### Testing Commands
```bash
-make test # Run all tests
-make cover # Generate coverage report
-make lint # Run linters
-make format # Format code
-make sec # Security scan with gosec
-make tidy # Clean up go.mod
+make ci # Local fix + verify pipeline
+make test # Run unit tests (with -tags=unit)
+make test-unit # Run unit tests (excluding integration)
+make test-integration # Run integration tests with testcontainers (requires Docker)
+make test-all # Run all tests (unit + integration)
+make coverage-unit # Unit tests with coverage report
+make coverage-integration # Integration tests with coverage report
+make coverage # All coverage targets
+```
+
+### Testing Options
+
+| Option | Description | Example |
+|--------|-------------|---------|
+| `RUN` | Specific test name pattern | `make test-integration RUN=TestIntegration_MyFeature` |
+| `PKG` | Specific package to test | `make test-integration PKG=./commons/postgres/...` |
+| `LOW_RESOURCE` | Low-resource mode (no race, -p=1) | `make test LOW_RESOURCE=1` |
+| `RETRY_ON_FAIL` | Retry failed tests once | `make test RETRY_ON_FAIL=1` |
+
+### Code Quality Commands
+
+```bash
+make lint # Run linters (read-only)
+make lint-fix # Run linters with auto-fix
+make format # Format code
+make tidy # Clean dependencies
+make check-tests # Verify test coverage for packages
+make vet # Run go vet on all packages
+make sec # Security scan with gosec
+make sec SARIF=1 # Security scan with SARIF output
+make build # Build all packages
+make clean # Clean all build artifacts
```
### Git Hooks
@@ -367,6 +478,7 @@ make tidy # Clean up go.mod
- Pre-commit hooks available in `.githooks/`
- Setup: `make setup-git-hooks`
- Verify: `make check-hooks`
+- Environment check: `make check-envs`
### CI/CD
@@ -377,6 +489,35 @@ make tidy # Clean up go.mod
---
+## API Invariants
+
+Key v2 API contracts that must be preserved:
+
+| Package | Invariant |
+|---------|-----------|
+| `opentelemetry` | `NewTelemetry(...)` for init; `ApplyGlobals()` opt-in for global providers |
+| `log` | `Logger` 5-method interface: `Log`, `With`, `WithGroup`, `Enabled`, `Sync` |
+| `log` | Level constants: `LevelError`, `LevelWarn`, `LevelInfo`, `LevelDebug` |
+| `log` | Field constructors: `String()`, `Int()`, `Bool()`, `Err()` |
+| `zap` | `zap.New(cfg Config)` constructor; `Logger.Raw()` for underlying access |
+| `net/http` | `Respond`, `RespondStatus`, `RespondError`, `RenderError`, `FiberErrorHandler` |
+| `net/http` | `ServeReverseProxy(target, policy, res, req)` with `ReverseProxyPolicy` |
+| `server` | `ServerManager` exclusively (no `GracefulShutdown`) |
+| `circuitbreaker` | `NewManager(logger) (Manager, error)`; `GetOrCreate` returns `(CircuitBreaker, error)` |
+| `assert` | `assert.New(ctx, logger, component, operation)` returns errors, no panics |
+| `safe` | Explicit error returns for division, slice access, regex operations |
+| `jwt` | `jwt.Parse()` / `jwt.Sign()` with `AlgHS256`, `AlgHS384`, `AlgHS512` |
+| `backoff` | `ExponentialWithJitter()` and `WaitContext()` |
+| `redis` | `New(ctx, cfg)` with topology-based `Config` (standalone/sentinel/cluster) |
+| `redis` | `NewRedisLockManager()` and `LockManager` interface |
+| `postgres` | `New(cfg Config)`; `Resolver(ctx)` (not `GetDB()`); `NewMigrator(cfg)` |
+| `mongo` | `NewClient(ctx, cfg, opts...)` constructor |
+| `transaction` | `BuildIntentPlan()` + `ValidateBalanceEligibility()` + `ApplyPosting()` |
+| `rabbitmq` | `*Context()` variants for lifecycle; `HealthCheck()` returns `(bool, error)` |
+| `opentelemetry` | `Redactor` with `RedactionRule`; `NewDefaultRedactor()` / `NewRedactor(rules, mask)` |
+
+---
+
## Checklist
Before submitting code:
@@ -384,8 +525,12 @@ Before submitting code:
- [ ] Code follows naming conventions
- [ ] All public APIs are documented
- [ ] Tests achieve 80%+ coverage
+- [ ] Test files have correct build tag (`//go:build unit` or `//go:build integration`)
- [ ] No panics - all errors handled
- [ ] No hardcoded credentials
- [ ] `make lint` passes
- [ ] `make test` passes
+- [ ] `make build` passes
- [ ] Dependencies are justified
+- [ ] `MIGRATION_MAP.md` updated if v1 symbols changed
+- [ ] `README.md` updated if public API changed
diff --git a/go.mod b/go.mod
index 09c9b65d..17554367 100644
--- a/go.mod
+++ b/go.mod
@@ -1,16 +1,18 @@
-module github.com/LerianStudio/lib-commons/v2
+module github.com/LerianStudio/lib-commons/v4
-go 1.25.0
-
-toolchain go1.25.7
+go 1.25.7
require (
cloud.google.com/go/iam v1.5.3
- github.com/Masterminds/squirrel v1.5.4
- github.com/alicebob/miniredis/v2 v2.36.1
+ github.com/alicebob/miniredis/v2 v2.37.0
+ github.com/aws/aws-sdk-go-v2 v1.41.3
+ github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.3
+ github.com/aws/smithy-go v1.24.2
github.com/bxcodec/dbresolver/v2 v2.2.1
- github.com/go-redsync/redsync/v4 v4.15.0
- github.com/gofiber/fiber/v2 v2.52.11
+ github.com/go-playground/validator/v10 v10.30.1
+ github.com/go-redsync/redsync/v4 v4.16.0
+ github.com/gofiber/fiber/v2 v2.52.12
+ github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.8.0
@@ -21,25 +23,32 @@ require (
github.com/shopspring/decimal v1.4.0
github.com/sony/gobreaker v1.0.0
github.com/stretchr/testify v1.11.1
+ github.com/testcontainers/testcontainers-go v0.41.0
+ github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0
+ github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0
+ github.com/testcontainers/testcontainers-go/modules/rabbitmq v0.40.0
+ github.com/testcontainers/testcontainers-go/modules/redis v0.41.0
go.mongodb.org/mongo-driver v1.17.9
- go.opentelemetry.io/contrib/bridges/otelzap v0.15.0
- go.opentelemetry.io/otel v1.40.0
- go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.16.0
- go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0
- go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0
- go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0
- go.opentelemetry.io/otel/log v0.16.0
- go.opentelemetry.io/otel/metric v1.40.0
- go.opentelemetry.io/otel/sdk v1.40.0
- go.opentelemetry.io/otel/sdk/log v0.16.0
- go.opentelemetry.io/otel/sdk/metric v1.40.0
- go.opentelemetry.io/otel/trace v1.40.0
+ go.opentelemetry.io/contrib/bridges/otelzap v0.17.0
+ go.opentelemetry.io/otel v1.42.0
+ go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.18.0
+ go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0
+ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0
+ go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0
+ go.opentelemetry.io/otel/log v0.18.0
+ go.opentelemetry.io/otel/metric v1.42.0
+ go.opentelemetry.io/otel/sdk v1.42.0
+ go.opentelemetry.io/otel/sdk/log v0.18.0
+ go.opentelemetry.io/otel/sdk/metric v1.42.0
+ go.opentelemetry.io/otel/trace v1.42.0
+ go.uber.org/goleak v1.3.0
go.uber.org/mock v0.6.0
go.uber.org/zap v1.27.1
- golang.org/x/oauth2 v0.35.0
+ golang.org/x/oauth2 v0.36.0
+ golang.org/x/sync v0.20.0
golang.org/x/text v0.34.0
- google.golang.org/api v0.267.0
- google.golang.org/grpc v1.79.1
+ google.golang.org/api v0.271.0
+ google.golang.org/grpc v1.79.2
google.golang.org/protobuf v1.36.11
)
@@ -47,19 +56,38 @@ require (
cloud.google.com/go/auth v0.18.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
+ dario.cat/mergo v1.0.2 // indirect
+ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
+ github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect
+ github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
+ github.com/containerd/errdefs v1.0.0 // indirect
+ github.com/containerd/errdefs/pkg v0.3.0 // indirect
+ github.com/containerd/log v0.1.0 // indirect
+ github.com/containerd/platforms v0.2.1 // indirect
+ github.com/cpuguy83/dockercfg v0.3.2 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
+ github.com/distribution/reference v0.6.0 // indirect
+ github.com/docker/docker v28.5.2+incompatible // indirect
+ github.com/docker/go-connections v0.6.0 // indirect
+ github.com/docker/go-units v0.5.0 // indirect
+ github.com/ebitengine/purego v0.10.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
+ github.com/gabriel-vasile/mimetype v1.4.13 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
+ github.com/go-playground/locales v0.14.1 // indirect
+ github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect
- github.com/googleapis/enterprise-certificate-proxy v0.3.12 // indirect
+ github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.17.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
@@ -68,14 +96,30 @@ require (
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/klauspost/compress v1.18.4 // indirect
- github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
- github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
+ github.com/leodido/go-urn v1.4.0 // indirect
github.com/lib/pq v1.11.2 // indirect
+ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
+ github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
- github.com/mattn/go-runewidth v0.0.20 // indirect
+ github.com/mattn/go-runewidth v0.0.21 // indirect
+ github.com/mdelapenya/tlscert v0.2.0 // indirect
+ github.com/moby/docker-image-spec v1.3.1 // indirect
+ github.com/moby/go-archive v0.2.0 // indirect
+ github.com/moby/patternmatcher v0.6.0 // indirect
+ github.com/moby/sys/sequential v0.6.0 // indirect
+ github.com/moby/sys/user v0.4.0 // indirect
+ github.com/moby/sys/userns v0.1.0 // indirect
+ github.com/moby/term v0.5.2 // indirect
github.com/montanaflynn/stats v0.7.1 // indirect
+ github.com/morikuni/aec v1.0.0 // indirect
+ github.com/opencontainers/go-digest v1.0.0 // indirect
+ github.com/opencontainers/image-spec v1.1.1 // indirect
+ github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
+ github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
+ github.com/shirou/gopsutil/v4 v4.26.2 // indirect
+ github.com/sirupsen/logrus v1.9.3 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
@@ -87,17 +131,16 @@ require (
github.com/yuin/gopher-lua v1.1.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
- go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 // indirect
- go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect
+ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
+ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
- golang.org/x/net v0.50.0 // indirect
- golang.org/x/sync v0.19.0 // indirect
+ golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.41.0 // indirect
- golang.org/x/time v0.14.0 // indirect
- google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d // indirect
- google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d // indirect
+ golang.org/x/time v0.15.0 // indirect
+ google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 // indirect
+ google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 0d5bba62..b5e244d0 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,3 @@
-cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
-cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
@@ -8,38 +6,58 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc=
cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU=
-github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
-github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
+dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8=
+dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA=
+github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
+github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
+github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
+github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
-github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
-github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
-github.com/alicebob/miniredis/v2 v2.36.1 h1:Dvc5oAnNOr7BIfPn7tF269U8DvRW1dBG2D5n0WrfYMI=
-github.com/alicebob/miniredis/v2 v2.36.1/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
+github.com/alicebob/miniredis/v2 v2.37.0 h1:RheObYW32G1aiJIj81XVt78ZHJpHonHLHW7OLIshq68=
+github.com/alicebob/miniredis/v2 v2.37.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
+github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA=
+github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw=
+github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.3 h1:9bb0dEq1WzA0ZxIGG2EmwEgxfMAJpHyusxwbVN7f6iM=
+github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.3/go.mod h1:2z9eg35jfuRtdPE4Ci0ousrOU9PBhDBilXA1cwq9Ptk=
+github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
+github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/bxcodec/dbresolver/v2 v2.2.1 h1:bjIZm3YXK40dX36qHHj6Vhitj6C1XF88X4d3P3k8Jtw=
github.com/bxcodec/dbresolver/v2 v2.2.1/go.mod h1:xWb3HT8vrWUnoLVA7KQ+IcD9RvnzfRBqOkO9rKsg1rQ=
+github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
+github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/clipperhouse/uax29/v2 v2.6.0 h1:z0cDbUV+aPASdFb2/ndFnS9ts/WNXgTNNGFoKXuhpos=
-github.com/clipperhouse/uax29/v2 v2.6.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
-github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f h1:Y8xYupdHxryycyPlc9Y+bSQAYZnetRJ70VMVKm5CKI0=
-github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f/go.mod h1:HlzOvOjVBOfTGSRXRyY0OiCS/3J1akRGQQpRO/7zyF4=
+github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w=
+github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
+github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
+github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
+github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
+github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
+github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
+github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
+github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
+github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
@@ -50,19 +68,23 @@ github.com/dhui/dktest v0.4.6 h1:+DPKyScKSEp3VLtbMDHcUq6V5Lm5zfZZVb0Sk7Ahom4=
github.com/dhui/dktest v0.4.6/go.mod h1:JHTSYDtKkvFNFHJKqCzVzqXecyv+tKt8EzceOmQOgbU=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
-github.com/docker/docker v28.3.3+incompatible h1:Dypm25kh4rmk49v1eiVbsAtpAsYURjYkaKubwuBdxEI=
-github.com/docker/docker v28.3.3+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
-github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
-github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
+github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
+github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
+github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
+github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
-github.com/envoyproxy/go-control-plane v0.13.5-0.20251024222203-75eaa193e329 h1:K+fnvUM0VZ7ZFJf0n4L/BRlnsb9pL/GuDG6FqaH+PwM=
-github.com/envoyproxy/go-control-plane/envoy v1.35.0 h1:ixjkELDE+ru6idPxcHLj8LBVc2bFP7iBytj353BoHUo=
-github.com/envoyproxy/go-control-plane/envoy v1.35.0/go.mod h1:09qwbGVuSWWAyN5t/b3iyVfz5+z8QWGrzkoqm/8SbEs=
-github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8=
-github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU=
+github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU=
+github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
+github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
+github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
+github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98=
+github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4=
+github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
+github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
+github.com/gabriel-vasile/mimetype v1.4.13/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -71,18 +93,26 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
+github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
+github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
+github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
+github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
+github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
+github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
+github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
+github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg=
github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA=
github.com/go-redis/redis/v7 v7.4.1 h1:PASvf36gyUpr2zdOUS/9Zqc80GbM+9BDyiJSJDDOrTI=
github.com/go-redis/redis/v7 v7.4.1/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
-github.com/go-redsync/redsync/v4 v4.15.0 h1:KH/XymuxSV7vyKs6z1Cxxj+N+N18JlPxgXeP6x4JY54=
-github.com/go-redsync/redsync/v4 v4.15.0/go.mod h1:qNp+lLs3vkfZbtA/aM/OjlZHfEr5YTAYhRktFPKHC7s=
-github.com/gofiber/fiber/v2 v2.52.11 h1:5f4yzKLcBcF8ha1GQTWB+mpblWz3Vz6nSAbTL31HkWs=
-github.com/gofiber/fiber/v2 v2.52.11/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
-github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
-github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
+github.com/go-redsync/redsync/v4 v4.16.0 h1:bNcOzeHH9d3s6pghU9NJFMPrQa41f5Nx3L4YKr3BdEU=
+github.com/go-redsync/redsync/v4 v4.16.0/go.mod h1:V4gagqgyASWBZuwx4xGzu72aZNb/6Mo05byUa3mVmKQ=
+github.com/gofiber/fiber/v2 v2.52.12 h1:0LdToKclcPOj8PktUdIKo9BUohjjwfnQl42Dhw8/WUw=
+github.com/gofiber/fiber/v2 v2.52.12/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
+github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
+github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
@@ -91,6 +121,7 @@ github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gomodule/redigo v1.9.3 h1:dNPSXeXv6HCq2jdyWfjgmhBdqnR6PRO3m/G05nvpPC8=
github.com/gomodule/redigo v1.9.3/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw=
+github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
@@ -99,14 +130,10 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao=
-github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8=
-github.com/googleapis/enterprise-certificate-proxy v0.3.12 h1:Fg+zsqzYEs1ZnvmcztTYxhgCBsx3eEhEwQ1W/lHq/sQ=
-github.com/googleapis/enterprise-certificate-proxy v0.3.12/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
+github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
+github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc=
github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY=
-github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak=
-github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
@@ -126,38 +153,52 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
+github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
+github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw=
-github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o=
-github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk=
-github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw=
-github.com/lib/pq v1.11.1 h1:wuChtj2hfsGmmx3nf1m7xC2XpK6OtelS2shMY+bGMtI=
-github.com/lib/pq v1.11.1/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
+github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
+github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs=
github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
+github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
+github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
+github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
+github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
-github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
-github.com/mattn/go-runewidth v0.0.20 h1:WcT52H91ZUAwy8+HUkdM3THM6gXqXuLJi9O3rjcQQaQ=
-github.com/mattn/go-runewidth v0.0.20/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
+github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
+github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
+github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
+github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
-github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
-github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
+github.com/moby/go-archive v0.2.0 h1:zg5QDUM2mi0JIM9fdQZWC7U8+2ZfixfTYoHL7rWUcP8=
+github.com/moby/go-archive v0.2.0/go.mod h1:mNeivT14o8xU+5q1YnNrkQVpK+dnNe/K6fHqnTg4qPU=
+github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
+github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
+github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
+github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs=
+github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
+github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
+github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
+github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
+github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
+github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
+github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
+github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
-github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
-github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
+github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
+github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
@@ -165,32 +206,47 @@ github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
+github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw=
github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o=
-github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4=
-github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
-github.com/redis/rueidis v1.0.69 h1:WlUefRhuDekji5LsD387ys3UCJtSFeBVf0e5yI0B8b4=
-github.com/redis/rueidis v1.0.69/go.mod h1:Lkhr2QTgcoYBhxARU7kJRO8SyVlgUuEkcJO1Y8MCluA=
-github.com/redis/rueidis/rueidiscompat v1.0.69 h1:IWVYY9lXdjNO3do2VpJT7aDFi8zbCUuQxZB6E2Grahs=
-github.com/redis/rueidis/rueidiscompat v1.0.69/go.mod h1:iC4Y8DoN0Uth0Uezg9e2trvNRC7QAgGeuP2OPLb5ccI=
+github.com/redis/rueidis v1.0.71 h1:pODtnAR5GAB7j4ekhldZ29HKOxe4Hph0GTDGk1ayEQY=
+github.com/redis/rueidis v1.0.71/go.mod h1:lfdcZzJ1oKGKL37vh9fO3ymwt+0TdjkkUCJxbgpmcgQ=
+github.com/redis/rueidis/rueidiscompat v1.0.71 h1:wNZ//kEjMZgBM0KCk7ncOX8KmAgROU2kDdDNpwheG4w=
+github.com/redis/rueidis/rueidiscompat v1.0.71/go.mod h1:esmCLJvaRzZoKlgB82G1bY7Iky5TnO9Rz+NlhbEccFI=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
+github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI=
+github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
+github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
+github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ=
github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
+github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
+github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08YuiTGPZLls0Wq99X9bWd0Q5ZSBesM=
github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8=
+github.com/testcontainers/testcontainers-go v0.41.0 h1:mfpsD0D36YgkxGj2LrIyxuwQ9i2wCKAD+ESsYM1wais=
+github.com/testcontainers/testcontainers-go v0.41.0/go.mod h1:pdFrEIfaPl24zmBjerWTTYaY0M6UHsqA1YSvsoU40MI=
+github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0 h1:z/1qHeliTLDKNaJ7uOHOx1FjwghbcbYfga4dTFkF0hU=
+github.com/testcontainers/testcontainers-go/modules/mongodb v0.40.0/go.mod h1:GaunAWwMXLtsMKG3xn2HYIBDbKddGArfcGsF2Aog81E=
+github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
+github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ=
+github.com/testcontainers/testcontainers-go/modules/rabbitmq v0.40.0 h1:wGznWj8ZlEoqWfMN2L+EWjQBbjZ99vhoy/S61h+cED0=
+github.com/testcontainers/testcontainers-go/modules/rabbitmq v0.40.0/go.mod h1:Y+9/8YMZo3ElEZmHZOgFnjKrxE4+H2OFrjWdYzm/jtU=
+github.com/testcontainers/testcontainers-go/modules/redis v0.41.0 h1:QlTSe4JGOnjr/37MXx0GqNLGa+8sKQst7lsn7uLjg8E=
+github.com/testcontainers/testcontainers-go/modules/redis v0.41.0/go.mod h1:5mDOIWrS/a+z8gBesXBQAAQtrqJrW2tUi9Tf46+/Luo=
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
@@ -214,42 +270,48 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
+github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
+github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU=
go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ=
+go.mongodb.org/mongo-driver/v2 v2.3.0 h1:sh55yOXA2vUjW1QYw/2tRlHSQViwDyPnW61AwpZ4rtU=
+go.mongodb.org/mongo-driver/v2 v2.3.0/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
-go.opentelemetry.io/contrib/bridges/otelzap v0.15.0 h1:x4qzjKkTl2hXmLl+IviSXvzaTyCJSYvpFZL5SRVLBxs=
-go.opentelemetry.io/contrib/bridges/otelzap v0.15.0/go.mod h1:h7dZHJgqkzUiKFXCTJBrPWH0LEZaZXBFzKWstjWBRxw=
-go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0 h1:XmiuHzgJt067+a6kwyAzkhXooYVv3/TOw9cM2VfJgUM=
-go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.65.0/go.mod h1:KDgtbWKTQs4bM+VPUr6WlL9m/WXcmkCcBlIzqxPGzmI=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0=
-go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms=
-go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g=
-go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.16.0 h1:ZVg+kCXxd9LtAaQNKBxAvJ5NpMf7LpvEr4MIZqb0TMQ=
-go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.16.0/go.mod h1:hh0tMeZ75CCXrHd9OXRYxTlCAdxcXioWHFIpYw2rZu8=
-go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0 h1:NOyNnS19BF2SUDApbOKbDtWZ0IK7b8FJ2uAGdIWOGb0=
-go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.40.0/go.mod h1:VL6EgVikRLcJa9ftukrHu/ZkkhFBSo1lzvdBC9CF1ss=
-go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs=
-go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI=
-go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0 h1:DvJDOPmSWQHWywQS6lKL+pb8s3gBLOZUtw4N+mavW1I=
-go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.40.0/go.mod h1:EtekO9DEJb4/jRyN4v4Qjc2yA7AtfCBuz2FynRUWTXs=
-go.opentelemetry.io/otel/log v0.16.0 h1:DeuBPqCi6pQwtCK0pO4fvMB5eBq6sNxEnuTs88pjsN4=
-go.opentelemetry.io/otel/log v0.16.0/go.mod h1:rWsmqNVTLIA8UnwYVOItjyEZDbKIkMxdQunsIhpUMes=
-go.opentelemetry.io/otel/log/logtest v0.16.0 h1:jr1CG3Z6FD9pwUaL/D0s0X4lY2ZVm1jP3JfCtzGxUmE=
-go.opentelemetry.io/otel/log/logtest v0.16.0/go.mod h1:qeeZw+cI/rAtCzZ03Kq1ozq6C4z/PCa+K+bb0eJfKNs=
-go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g=
-go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc=
-go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8=
-go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE=
-go.opentelemetry.io/otel/sdk/log v0.16.0 h1:e/b4bdlQwC5fnGtG3dlXUrNOnP7c8YLVSpSfEBIkTnI=
-go.opentelemetry.io/otel/sdk/log v0.16.0/go.mod h1:JKfP3T6ycy7QEuv3Hj8oKDy7KItrEkus8XJE6EoSzw4=
-go.opentelemetry.io/otel/sdk/log/logtest v0.16.0 h1:/XVkpZ41rVRTP4DfMgYv1nEtNmf65XPPyAdqV90TMy4=
-go.opentelemetry.io/otel/sdk/log/logtest v0.16.0/go.mod h1:iOOPgQr5MY9oac/F5W86mXdeyWZGleIx3uXO98X2R6Y=
-go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw=
-go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg=
-go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw=
-go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA=
+go.opentelemetry.io/contrib/bridges/otelzap v0.17.0 h1:oCltVHJcblcth2z9B9dRTeZIZTe2Sf9Ad9h8bcc+s8M=
+go.opentelemetry.io/contrib/bridges/otelzap v0.17.0/go.mod h1:G/VE1A/hRn6mEWdfC8rMvSdQVGM64KUPi4XilLkwcQw=
+go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
+go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
+go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
+go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
+go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.18.0 h1:deI9UQMoGFgrg5iLPgzueqFPHevDl+28YKfSpPTI6rY=
+go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.18.0/go.mod h1:PFx9NgpNUKXdf7J4Q3agRxMs3Y07QhTCVipKmLsMKnU=
+go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 h1:MdKucPl/HbzckWWEisiNqMPhRrAOQX8r4jTuGr636gk=
+go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0/go.mod h1:RolT8tWtfHcjajEH5wFIZ4Dgh5jpPdFXYV9pTAk/qjc=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 h1:THuZiwpQZuHPul65w4WcwEnkX2QIuMT+UFoOrygtoJw=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0/go.mod h1:J2pvYM5NGHofZ2/Ru6zw/TNWnEQp5crgyDeSrYpXkAw=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 h1:zWWrB1U6nqhS/k6zYB74CjRpuiitRtLLi68VcgmOEto=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0/go.mod h1:2qXPNBX1OVRC0IwOnfo1ljoid+RD0QK3443EaqVlsOU=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.41.0 h1:inYW9ZhgqiDqh6BioM7DVHHzEGVq76Db5897WLGZ5Go=
+go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.41.0/go.mod h1:Izur+Wt8gClgMJqO/cZ8wdeeMryJ/xxiOVgFSSfpDTY=
+go.opentelemetry.io/otel/log v0.18.0 h1:XgeQIIBjZZrliksMEbcwMZefoOSMI1hdjiLEiiB0bAg=
+go.opentelemetry.io/otel/log v0.18.0/go.mod h1:KEV1kad0NofR3ycsiDH4Yjcoj0+8206I6Ox2QYFSNgI=
+go.opentelemetry.io/otel/log/logtest v0.18.0 h1:2QeyoKJdIgK2LJhG1yn78o/zmpXx1EditeyRDREqVS8=
+go.opentelemetry.io/otel/log/logtest v0.18.0/go.mod h1:v1vh3PYR9zIa5MK6HwkH2lMrLBg/Y9Of6Qc+krlesX0=
+go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
+go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
+go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
+go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
+go.opentelemetry.io/otel/sdk/log v0.18.0 h1:n8OyZr7t7otkeTnPTbDNom6rW16TBYGtvyy2Gk6buQw=
+go.opentelemetry.io/otel/sdk/log v0.18.0/go.mod h1:C0+wxkTwKpOCZLrlJ3pewPiiQwpzycPI/u6W0Z9fuYk=
+go.opentelemetry.io/otel/sdk/log/logtest v0.18.0 h1:l3mYuPsuBx6UKE47BVcPrZoZ0q/KER57vbj2qkgDLXA=
+go.opentelemetry.io/otel/sdk/log/logtest v0.18.0/go.mod h1:7cHtiVJpZebB3wybTa4NG+FUo5NPe3PROz1FqB0+qdw=
+go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
+go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
+go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
+go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A=
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
@@ -264,29 +326,28 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
-golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
-golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
-golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
-golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
-golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
-golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
-golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
+golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
+golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
+golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
+golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
-golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
+golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
+golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -294,38 +355,33 @@ golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
+golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
-golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
-golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
+golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
+golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
-google.golang.org/api v0.265.0 h1:FZvfUdI8nfmuNrE34aOWFPmLC+qRBEiNm3JdivTvAAU=
-google.golang.org/api v0.265.0/go.mod h1:uAvfEl3SLUj/7n6k+lJutcswVojHPp2Sp08jWCu8hLY=
-google.golang.org/api v0.267.0 h1:w+vfWPMPYeRs8qH1aYYsFX68jMls5acWl/jocfLomwE=
-google.golang.org/api v0.267.0/go.mod h1:Jzc0+ZfLnyvXma3UtaTl023TdhZu6OMBP9tJ+0EmFD0=
+google.golang.org/api v0.271.0 h1:cIPN4qcUc61jlh7oXu6pwOQqbJW2GqYh5PS6rB2C/JY=
+google.golang.org/api v0.271.0/go.mod h1:CGT29bhwkbF+i11qkRUJb2KMKqcJ1hdFceEIRd9u64Q=
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM=
google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM=
-google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 h1:7ei4lp52gK1uSejlA8AZl5AJjeLUOHBQscRQZUgAcu0=
-google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20/go.mod h1:ZdbssH/1SOVnjnDlXzxDHK2MCidiqXtbYccJNzNYPEE=
-google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d h1:EocjzKLywydp5uZ5tJ79iP6Q0UjDnyiHkGRWxuPBP8s=
-google.golang.org/genproto/googleapis/api v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:48U2I+QQUYhsFrg2SY6r+nJzeOtjey7j//WBESw+qyQ=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 h1:Jr5R2J6F6qWyzINc+4AM8t5pfUz6beZpHp678GNrMbE=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d h1:t/LOSXPJ9R0B6fnZNyALBRfZBH0Uy0gT+uR+SJ6syqQ=
-google.golang.org/genproto/googleapis/rpc v0.0.0-20260217215200-42d3e9bedb6d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
-google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
-google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
-google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY=
-google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
+google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 h1:tu/dtnW1o3wfaxCOjSLn5IRX4YDcJrtlpzYkhHhGaC4=
+google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171/go.mod h1:M5krXqk4GhBKvB596udGL3UyjL4I1+cTbK0orROM9ng=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ=
+google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
+google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU=
+google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -334,3 +390,5 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
+gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
diff --git a/mk/tests.mk b/mk/tests.mk
deleted file mode 100644
index c9999100..00000000
--- a/mk/tests.mk
+++ /dev/null
@@ -1,297 +0,0 @@
-# ------------------------------------------------------
-# Test configuration for lib-commons
-# ------------------------------------------------------
-
-# Native fuzz test controls
-# FUZZ: specific fuzz target name (e.g., FuzzValidateEmail)
-# FUZZTIME: duration per fuzz target (default: 10s)
-FUZZ ?=
-FUZZTIME ?= 10s
-
-# Integration test filter
-# RUN: specific test name pattern (e.g., TestIntegration_FeatureName)
-# PKG: specific package to test (e.g., ./commons/...)
-# Usage: make test-integration RUN=TestIntegration_FeatureName
-# make test-integration PKG=./commons/...
-RUN ?=
-PKG ?=
-
-# Computed run pattern: uses RUN if set, otherwise defaults to '^TestIntegration'
-ifeq ($(RUN),)
- RUN_PATTERN := ^TestIntegration
-else
- RUN_PATTERN := $(RUN)
-endif
-
-# Low-resource mode for limited machines (sets -p=1 -parallel=1, disables -race)
-# Usage: make test-integration LOW_RESOURCE=1
-# make coverage-integration LOW_RESOURCE=1
-LOW_RESOURCE ?= 0
-
-# Computed flags for low-resource mode
-ifeq ($(LOW_RESOURCE),1)
- LOW_RES_P_FLAG := -p 1
- LOW_RES_PARALLEL_FLAG := -parallel 1
- LOW_RES_RACE_FLAG :=
-else
- LOW_RES_P_FLAG :=
- LOW_RES_PARALLEL_FLAG :=
- LOW_RES_RACE_FLAG := -race
-endif
-
-# macOS ld64 workaround: newer ld emits noisy LC_DYSYMTAB warnings when linking test binaries with -race.
-# If available, prefer Apple's classic linker to silence them.
-UNAME_S := $(shell uname -s)
-ifeq ($(UNAME_S),Darwin)
- # Prefer classic mode to suppress LC_DYSYMTAB warnings on macOS.
- # Set DISABLE_OSX_LINKER_WORKAROUND=1 to disable this behavior.
- ifneq ($(DISABLE_OSX_LINKER_WORKAROUND),1)
- GO_TEST_LDFLAGS := -ldflags="-linkmode=external -extldflags=-ld_classic"
- else
- GO_TEST_LDFLAGS :=
- endif
-else
- GO_TEST_LDFLAGS :=
-endif
-
-# ------------------------------------------------------
-# Test tooling configuration
-# ------------------------------------------------------
-
-TEST_REPORTS_DIR ?= ./reports
-GOTESTSUM := $(shell command -v gotestsum 2>/dev/null)
-RETRY_ON_FAIL ?= 0
-
-.PHONY: tools tools-gotestsum
-tools: tools-gotestsum ## Install helpful dev/test tools
-
-tools-gotestsum:
- @if [ -z "$(GOTESTSUM)" ]; then \
- echo "Installing gotestsum..."; \
- GO111MODULE=on go install gotest.tools/gotestsum@latest; \
- else \
- echo "gotestsum already installed: $(GOTESTSUM)"; \
- fi
-
-#-------------------------------------------------------
-# Core Test Commands
-#-------------------------------------------------------
-
-.PHONY: test
-test:
- $(call print_title,Running all tests)
- $(call check_command,go,"Install Go from https://golang.org/doc/install")
- @set -e; mkdir -p $(TEST_REPORTS_DIR); \
- if [ -n "$(GOTESTSUM)" ]; then \
- echo "Running tests with gotestsum"; \
- gotestsum --format testname -- -v -race -count=1 $(GO_TEST_LDFLAGS) ./...; \
- else \
- go test -v -race -count=1 $(GO_TEST_LDFLAGS) ./...; \
- fi
- @echo "$(GREEN)$(BOLD)[ok]$(NC) All tests passed$(GREEN) ✔️$(NC)"
-
-#-------------------------------------------------------
-# Test Suite Aliases
-#-------------------------------------------------------
-
-# Unit tests (excluding integration tests)
-.PHONY: test-unit
-test-unit:
- $(call print_title,Running Go unit tests)
- $(call check_command,go,"Install Go from https://golang.org/doc/install")
- @set -e; mkdir -p $(TEST_REPORTS_DIR); \
- pkgs=$$(go list ./... | grep -v '/tests'); \
- if [ -z "$$pkgs" ]; then \
- echo "No unit test packages found"; \
- else \
- if [ -n "$(GOTESTSUM)" ]; then \
- echo "Running unit tests with gotestsum"; \
- gotestsum --format testname -- -v -race -count=1 $(GO_TEST_LDFLAGS) $$pkgs || { \
- if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
- echo "Retrying unit tests once..."; \
- gotestsum --format testname -- -v -race -count=1 $(GO_TEST_LDFLAGS) $$pkgs; \
- else \
- exit 1; \
- fi; \
- }; \
- else \
- go test -v -race -count=1 $(GO_TEST_LDFLAGS) $$pkgs; \
- fi; \
- fi
- @echo "$(GREEN)$(BOLD)[ok]$(NC) Unit tests passed$(GREEN) ✔️$(NC)"
-
-# Integration tests with testcontainers (no coverage)
-# These tests use the `integration` build tag and testcontainers-go to spin up
-# ephemeral containers. No external Docker stack is required.
-#
-# Requirements:
-# - Test files must follow the naming convention: *_integration_test.go
-# - Test functions must start with TestIntegration_ (e.g., TestIntegration_MyFeature_Works)
-.PHONY: test-integration
-test-integration:
- $(call print_title,Running integration tests with testcontainers)
- $(call check_command,go,"Install Go from https://golang.org/doc/install")
- $(call check_command,docker,"Install Docker from https://docs.docker.com/get-docker/")
- @set -e; mkdir -p $(TEST_REPORTS_DIR); \
- if [ -n "$(PKG)" ]; then \
- echo "Using specified package: $(PKG)"; \
- pkgs=$$(go list $(PKG) 2>/dev/null | tr '\n' ' '); \
- else \
- echo "Finding packages with *_integration_test.go files..."; \
- dirs=$$(find . -name '*_integration_test.go' -not -path './vendor/*' 2>/dev/null | xargs -n1 dirname 2>/dev/null | sort -u | tr '\n' ' '); \
- pkgs=$$(if [ -n "$$dirs" ]; then go list $$dirs 2>/dev/null | tr '\n' ' '; fi); \
- fi; \
- if [ -z "$$pkgs" ]; then \
- echo "No integration test packages found"; \
- else \
- echo "Packages: $$pkgs"; \
- echo "Running packages sequentially (-p=1) to avoid Docker container conflicts"; \
- if [ "$(LOW_RESOURCE)" = "1" ]; then \
- echo "LOW_RESOURCE mode: -parallel=1, race detector disabled"; \
- fi; \
- if [ -n "$(GOTESTSUM)" ]; then \
- echo "Running testcontainers integration tests with gotestsum"; \
- gotestsum --format testname -- \
- -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
- -p 1 $(LOW_RES_PARALLEL_FLAG) \
- -run '$(RUN_PATTERN)' $$pkgs || { \
- if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
- echo "Retrying integration tests once..."; \
- gotestsum --format testname -- \
- -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
- -p 1 $(LOW_RES_PARALLEL_FLAG) \
- -run '$(RUN_PATTERN)' $$pkgs; \
- else \
- exit 1; \
- fi; \
- }; \
- else \
- go test -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
- -p 1 $(LOW_RES_PARALLEL_FLAG) \
- -run '$(RUN_PATTERN)' $$pkgs; \
- fi; \
- fi
- @echo "$(GREEN)$(BOLD)[ok]$(NC) Integration tests passed$(GREEN) ✔️$(NC)"
-
-# Run all tests (unit + integration)
-.PHONY: test-all
-test-all:
- $(call print_title,Running all tests (unit + integration))
- $(call print_title,Running unit tests)
- $(MAKE) test-unit
- $(call print_title,Running integration tests)
- $(MAKE) test-integration
- @echo "$(GREEN)$(BOLD)[ok]$(NC) All tests passed$(GREEN) ✔️$(NC)"
-
-#-------------------------------------------------------
-# Coverage Commands
-#-------------------------------------------------------
-
-# Unit tests with coverage (uses covermode=atomic)
-# Supports PKG parameter to filter packages (e.g., PKG=./commons/...)
-# Supports .ignorecoverunit file to exclude patterns from coverage stats
-.PHONY: coverage-unit
-coverage-unit:
- $(call print_title,Running Go unit tests with coverage)
- $(call check_command,go,"Install Go from https://golang.org/doc/install")
- @set -e; mkdir -p $(TEST_REPORTS_DIR); \
- if [ -n "$(PKG)" ]; then \
- echo "Using specified package: $(PKG)"; \
- pkgs=$$(go list $(PKG) 2>/dev/null | grep -v '/tests' | tr '\n' ' '); \
- else \
- pkgs=$$(go list ./... | grep -v '/tests'); \
- fi; \
- if [ -z "$$pkgs" ]; then \
- echo "No unit test packages found"; \
- else \
- echo "Packages: $$pkgs"; \
- if [ -n "$(GOTESTSUM)" ]; then \
- echo "Running unit tests with gotestsum (coverage enabled)"; \
- gotestsum --format testname -- -v -race -count=1 $(GO_TEST_LDFLAGS) -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/unit_coverage.out $$pkgs || { \
- if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
- echo "Retrying unit tests once..."; \
- gotestsum --format testname -- -v -race -count=1 $(GO_TEST_LDFLAGS) -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/unit_coverage.out $$pkgs; \
- else \
- exit 1; \
- fi; \
- }; \
- else \
- go test -v -race -count=1 $(GO_TEST_LDFLAGS) -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/unit_coverage.out $$pkgs; \
- fi; \
- if [ -f .ignorecoverunit ]; then \
- echo "Filtering coverage with .ignorecoverunit patterns..."; \
- patterns=$$(grep -v '^#' .ignorecoverunit | grep -v '^$$' | tr '\n' '|' | sed 's/|$$//'); \
- if [ -n "$$patterns" ]; then \
- regex_patterns=$$(echo "$$patterns" | sed 's/\./\\./g' | sed 's/\*/.*/g'); \
- head -1 $(TEST_REPORTS_DIR)/unit_coverage.out > $(TEST_REPORTS_DIR)/unit_coverage_filtered.out; \
- tail -n +2 $(TEST_REPORTS_DIR)/unit_coverage.out | grep -vE "$$regex_patterns" >> $(TEST_REPORTS_DIR)/unit_coverage_filtered.out || true; \
- mv $(TEST_REPORTS_DIR)/unit_coverage_filtered.out $(TEST_REPORTS_DIR)/unit_coverage.out; \
- echo "Excluded patterns: $$patterns"; \
- fi; \
- fi; \
- echo "----------------------------------------"; \
- go tool cover -func=$(TEST_REPORTS_DIR)/unit_coverage.out | grep total | awk '{print "Total coverage: " $$3}'; \
- echo "----------------------------------------"; \
- fi
- @echo "$(GREEN)$(BOLD)[ok]$(NC) Unit coverage report generated$(GREEN) ✔️$(NC)"
-
-# Integration tests with testcontainers (with coverage, uses covermode=atomic)
-.PHONY: coverage-integration
-coverage-integration:
- $(call print_title,Running integration tests with testcontainers (coverage enabled))
- $(call check_command,go,"Install Go from https://golang.org/doc/install")
- $(call check_command,docker,"Install Docker from https://docs.docker.com/get-docker/")
- @set -e; mkdir -p $(TEST_REPORTS_DIR); \
- if [ -n "$(PKG)" ]; then \
- echo "Using specified package: $(PKG)"; \
- pkgs=$$(go list $(PKG) 2>/dev/null | tr '\n' ' '); \
- else \
- echo "Finding packages with *_integration_test.go files..."; \
- dirs=$$(find . -name '*_integration_test.go' -not -path './vendor/*' 2>/dev/null | xargs -n1 dirname 2>/dev/null | sort -u | tr '\n' ' '); \
- pkgs=$$(if [ -n "$$dirs" ]; then go list $$dirs 2>/dev/null | tr '\n' ' '; fi); \
- fi; \
- if [ -z "$$pkgs" ]; then \
- echo "No integration test packages found"; \
- else \
- echo "Packages: $$pkgs"; \
- echo "Running packages sequentially (-p=1) to avoid Docker container conflicts"; \
- if [ "$(LOW_RESOURCE)" = "1" ]; then \
- echo "LOW_RESOURCE mode: -parallel=1, race detector disabled"; \
- fi; \
- if [ -n "$(GOTESTSUM)" ]; then \
- echo "Running testcontainers integration tests with gotestsum (coverage enabled)"; \
- gotestsum --format testname -- \
- -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
- -p 1 $(LOW_RES_PARALLEL_FLAG) \
- -run '$(RUN_PATTERN)' -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/integration_coverage.out \
- $$pkgs || { \
- if [ "$(RETRY_ON_FAIL)" = "1" ]; then \
- echo "Retrying integration tests once..."; \
- gotestsum --format testname -- \
- -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
- -p 1 $(LOW_RES_PARALLEL_FLAG) \
- -run '$(RUN_PATTERN)' -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/integration_coverage.out \
- $$pkgs; \
- else \
- exit 1; \
- fi; \
- }; \
- else \
- go test -tags=integration -v $(LOW_RES_RACE_FLAG) -count=1 -timeout 600s $(GO_TEST_LDFLAGS) \
- -p 1 $(LOW_RES_PARALLEL_FLAG) \
- -run '$(RUN_PATTERN)' -covermode=atomic -coverprofile=$(TEST_REPORTS_DIR)/integration_coverage.out \
- $$pkgs; \
- fi; \
- echo "----------------------------------------"; \
- go tool cover -func=$(TEST_REPORTS_DIR)/integration_coverage.out | grep total | awk '{print "Total coverage: " $$3}'; \
- echo "----------------------------------------"; \
- fi
- @echo "$(GREEN)$(BOLD)[ok]$(NC) Integration coverage report generated$(GREEN) ✔️$(NC)"
-
-# Run all coverage targets
-.PHONY: coverage
-coverage:
- $(call print_title,Running all coverage targets)
- $(MAKE) coverage-unit
- $(MAKE) coverage-integration
- @echo "$(GREEN)$(BOLD)[ok]$(NC) All coverage reports generated$(GREEN) ✔️$(NC)"
diff --git a/scripts/check-license-header.sh b/scripts/check-license-header.sh
deleted file mode 100755
index 995a57a8..00000000
--- a/scripts/check-license-header.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/bin/bash
-
-# Copyright (c) 2026 Lerian Studio. All rights reserved.
-# Use of this source code is governed by the Elastic License 2.0
-# that can be found in the LICENSE file.
-
-# Check if staged files have the required license header
-# Returns 0 if all files have headers, 1 otherwise
-
-REPO_ROOT=$(git rev-parse --show-toplevel)
-source "$REPO_ROOT"/commons/shell/colors.sh 2>/dev/null || true
-
-# Get staged files by type (excluding generated files)
-# Excludes: mock_*.go, *_mock.go, *_mocks.go (mockgen)
-STAGED_FILES=$(git diff --cached --name-only --diff-filter=d | grep -E '\.(go|sh)$' | grep -v -E '(^|/)mock_.*\.go$' | grep -v -E '_mocks?\.go$' || true)
-
-if [ -z "$STAGED_FILES" ]; then
- exit 0
-fi
-
-MISSING_HEADER=""
-
-for file in $STAGED_FILES; do
- # Read STAGED content (not working directory) using git show
- FIRST_LINES=$(git show ":$file" 2>/dev/null | head -10)
- if [ -n "$FIRST_LINES" ]; then
- # Check if line STARTS with comment + Copyright (regex anchored to line start)
- # This avoids matching patterns inside string literals
- if ! echo "$FIRST_LINES" | grep -qE '^(//|#) Copyright \(c\) 2026 Lerian Studio'; then
- MISSING_HEADER="${MISSING_HEADER}${file}\n"
- fi
- fi
-done
-
-if [ -n "$MISSING_HEADER" ]; then
- echo "${red:-}Missing license header in files:${normal:-}"
- echo -e "$MISSING_HEADER"
- echo ""
- echo "Add this header to the top of each file:"
- echo ""
- echo " // Copyright (c) 2026 Lerian Studio. All rights reserved."
- echo " // Use of this source code is governed by the Elastic License 2.0"
- echo " // that can be found in the LICENSE file."
- echo ""
- echo "For shell scripts, use # instead of //"
- exit 1
-fi
-
-exit 0