diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0698e9..aba591a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -81,4 +81,4 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: crate-ci/typos@master + - uses: crate-ci/typos@v1 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 83c92de..3bb4c23 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -25,7 +25,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Install mdBook - run: cargo install mdbook --no-default-features --features search + uses: taiki-e/install-action@mdbook - name: Build rustdoc run: cargo doc --workspace --all-features --no-deps diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index a35e79e..3240315 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -8,6 +8,11 @@ on: - "**/Cargo.toml" - "**/Cargo.lock" - "deny.toml" + pull_request: + paths: + - "**/Cargo.toml" + - "**/Cargo.lock" + - "deny.toml" jobs: audit: diff --git a/.github/workflows/semver.yml b/.github/workflows/semver.yml index 7021120..a2d8fb7 100644 --- a/.github/workflows/semver.yml +++ b/.github/workflows/semver.yml @@ -8,6 +8,8 @@ jobs: semver: name: Semver Checks runs-on: ubuntu-latest + # informational only — release-plz owns version bumps + continue-on-error: true steps: - uses: actions/checkout@v4 with: diff --git a/.github/workflows/update-readme-versions.yml b/.github/workflows/update-readme-versions.yml new file mode 100644 index 0000000..9b72439 --- /dev/null +++ b/.github/workflows/update-readme-versions.yml @@ -0,0 +1,46 @@ +name: Update README Versions + +on: + push: + branches: [main] + paths: + - "crates/*/Cargo.toml" + +permissions: + contents: write + +jobs: + update-readme: + name: Update README Versions + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'deadcode-walker' }} + steps: + - uses: actions/checkout@v4 + with: + token: ${{ secrets.RELEASE_PLZ_TOKEN }} + + - name: Extract versions and update README + run: | + get_minor() { + grep '^version' "crates/$1/Cargo.toml" | head -1 | sed 's/.*"\(.*\)".*/\1/' | cut -d. -f1,2 + } + + UMBRELLA=$(get_minor asterisk-rs) + AMI=$(get_minor asterisk-rs-ami) + AGI=$(get_minor asterisk-rs-agi) + ARI=$(get_minor asterisk-rs-ari) + + echo "umbrella=$UMBRELLA ami=$AMI agi=$AGI ari=$ARI" + + sed -i "s/asterisk-rs = \"[0-9]*\.[0-9]*\"/asterisk-rs = \"$UMBRELLA\"/g" README.md + sed -i "s/asterisk-rs-ami = \"[0-9]*\.[0-9]*\"/asterisk-rs-ami = \"$AMI\"/g" README.md + sed -i "s/asterisk-rs-agi = \"[0-9]*\.[0-9]*\"/asterisk-rs-agi = \"$AGI\"/g" README.md + sed -i "s/asterisk-rs-ari = \"[0-9]*\.[0-9]*\"/asterisk-rs-ari = \"$ARI\"/g" README.md + + sed -i "s/version = \"[0-9]*\.[0-9]*\", default-features/version = \"$UMBRELLA\", default-features/g" README.md + + - name: Commit if changed + run: | + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + git diff --quiet README.md || (git add README.md && git commit -m "docs: update README install versions" && git push) diff --git a/.gitignore b/.gitignore index ffda1e2..cbe5629 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ Thumbs.db .idea/ .vscode/ *.iml +.claude/ diff --git a/.omp/rules/asterisk.md b/.omp/rules/asterisk.md deleted file mode 100644 index ab3facb..0000000 --- a/.omp/rules/asterisk.md +++ /dev/null @@ -1,44 +0,0 @@ ---- -description: "Asterisk protocol domain knowledge. Read when working on AMI, AGI, or ARI code." -globs: - - "crates/asterisk-rs-ami/**" - - "crates/asterisk-rs-agi/**" - - "crates/asterisk-rs-ari/**" ---- - -# Asterisk Protocols - -## AMI (Asterisk Manager Interface) - -- TCP line-based protocol on port 5038. Messages are `Key: Value\r\n` pairs terminated by `\r\n\r\n`. -- Three message types: Action (client→server), Response (server→client), Event (server→client async). -- Authentication: `Action: Login` with username/secret, or MD5 challenge-response (`Action: Challenge`). -- Every Action has an `ActionID` header for correlating responses. The server echoes it back. -- `Response: Follows` carries multi-line command output terminated by `--END COMMAND--`. Lines without `:` in this block are command output, not key-value headers. -- Event-generating actions (Status, CoreShowChannels, QueueStatus, etc.) return events with a matching `ActionID`, terminated by a `*Complete` event (e.g., `StatusComplete`). -- Events may carry `ChanVariable(name)=value` headers for channel variables set on the channel. -- Events are unsolicited and arrive at any time. Must be handled concurrently with action/response pairs. - -## AGI (Asterisk Gateway Interface) - -- Asterisk connects to an AGI server via TCP (FastAGI, port 4573) and sends environment variables as `key: value\n` lines, terminated by a blank line. -- Commands are single-line text, responses are `xxx result=data` where xxx is a 3-digit status code. -- 200 = success, 510 = invalid command, 520 = usage error. -- Channel is blocked during AGI execution — one command at a time, synchronous. - -## ARI (Asterisk REST Interface) - -- HTTP REST API + WebSocket event stream. Default port 8088. -- Auth: HTTP Basic (username:password) on every REST request and WebSocket upgrade. -- WebSocket delivers JSON events for subscribed applications. App subscribes with `?app=name` on the WS URL. -- Every event carries base fields: `application`, `timestamp`, `asterisk_id`. Events are wrapped in an `AriMessage` struct containing these fields plus the typed event payload. -- WebSocket reconnects automatically, but dynamic subscriptions created via REST (`POST /applications/{app}/subscription`) are lost on reconnect and must be re-established. -- REST endpoints use Basic Auth on every request; credentials are not session-based. -- Resources: channels, bridges, endpoints, device states, mailboxes, sounds, recordings, playbacks. -- Stasis application model: channels enter stasis via dialplan `Stasis(appname)`, controlled via REST. - - -## Testing - -- No inline tests (`#[cfg(test)]`) in any protocol crate. All tests live in `tests/` (`asterisk-rs-tests` crate). -- Mock servers for each protocol are in `tests/src/mock/` (MockAmiServer, MockAriServer, MockAgiClient). \ No newline at end of file diff --git a/.omp/rules/rust.md b/.omp/rules/rust.md deleted file mode 100644 index c443eb8..0000000 --- a/.omp/rules/rust.md +++ /dev/null @@ -1,62 +0,0 @@ ---- -description: "Rust conventions for this workspace. Read when writing or modifying Rust code." -globs: - - "*.rs" - - "Cargo.toml" ---- - -# Rust Conventions - -## Error handling - -- Use `thiserror` for error types. Every crate has its own `Error` enum in `error.rs`. -- Propagate with `?`. Map context with `.map_err()` or `.context()` where the source alone is unclear. -- No `.unwrap()` — workspace lints deny it. Use `.expect("reason")` only for provably infallible cases (static regex, known-good parse). -- Return `Result` from all public APIs. Never panic in library code. - -## Async - -- Runtime is `tokio` with full features. All async code is tokio-native. -- Connection types own their `tokio::task` handles and clean up on drop. -- Use `tokio::select!` for concurrent operations, not `futures::join!` unless all branches must complete. -- Cancel safety: document whether async functions are cancel-safe in their doc comments. -- AMI connection task re-authenticates after every reconnect (Login is re-sent automatically). -- ARI HTTP client uses 10s connect timeout + 30s request timeout. - -## Types - -- Prefer newtypes over raw primitives for domain concepts (`ActionId(String)` not bare `String`). -- Derive `Debug, Clone` on all public types. Add `Serialize, Deserialize` where wire format applies. -- `#[non_exhaustive]` on all public enums that may grow. -- `Serialize` on event types to support logging and forwarding. -- `PartialEq` on AMI event and response types for assertion and matching. - -## Crate boundaries - -- `asterisk-rs-core` owns shared types. Other crates depend on core, never on each other. -- Each protocol crate (ami, agi, ari) is independently usable. -- `asterisk-rs` is the umbrella re-export. It adds no logic, only pub use. - -## Testing - -- No `#[cfg(test)]` or inline test modules in production crates. All tests live in the external `tests/` crate (`asterisk-rs-tests`). -- Unit, mock integration, and live integration tests are separate binaries in `tests/`. -- Run tests with `cargo test -p asterisk-rs-tests`, never with per-crate `cargo test -p asterisk-rs-ami`. - -## Build - -```bash -cargo check --workspace # type check -cargo clippy --workspace -- -D warnings # lint -cargo test --workspace # test -cargo doc --workspace --no-deps # docs -``` - -## Patterns - -- `FilteredSubscription` wraps `EventSubscription` with a predicate closure for selective event delivery. -- `EventListResponse` collects multi-event action results via `send_collecting()`. -- `AriMessage` wraps `AriEvent` with common metadata fields (application, timestamp, asterisk_id). -- `url_encode()` in ARI client for percent-encoding user-provided query parameter values. -- `ShutdownHandle` returned from AGI server builder for graceful shutdown. -- AMI `RawAmiMessage.output` captures multi-line command output from `Response: Follows`. diff --git a/.omp/skills/update-docs/SKILL.md b/.omp/skills/update-docs/SKILL.md deleted file mode 100644 index 97ce3c8..0000000 --- a/.omp/skills/update-docs/SKILL.md +++ /dev/null @@ -1,121 +0,0 @@ ---- -name: update-docs -description: "Updates all documentation and rules in the asterisk-rs workspace. Run when code changes affect public API, versions, crate structure, or features." ---- - -# update-docs - -Updates project documentation to reflect current codebase state. Covers READMEs, AGENTS.md, CHANGELOG, mdBook, and omp rules. - -## Philosophy - -Documentation describes what users can **do**, not implementation metrics. Never put raw counts ("161 events", "47 commands") in user-facing docs. Describe capabilities: "typed events covering the full Asterisk 23 surface", "all AGI commands with typed async methods". Counts belong in auto-generated reference pages only. - -## Prerequisites - -- Python 3.11+ (for docs/generate.py) -- The workspace must have a root `Cargo.toml` with `[workspace]` - -## Step 1: Generate reference pages - -```bash -python3 docs/generate.py -``` - -This parses Rust source files and generates: -- `docs/src/ami/reference.md` — all AMI events and actions -- `docs/src/agi/reference.md` — all AGI commands and channel methods -- `docs/src/ari/reference.md` — all ARI events and resource operations -- `docs/src/types.md` — all domain type enums with variants -- `docs/src/SUMMARY.md` — table of contents - -These files are fully auto-generated. Never hand-edit them. - -## Step 2: Update documentation files - -For each file: read current version, apply targeted edits. Preserve accurate prose. Fix what's wrong. - -### 2.1 Root README.md - -Structure (see current file for reference): -- One-line pitch: what users can DO with the library -- Three protocol bullets: AMI, AGI, ARI with one-sentence descriptions -- Code example showing a real use case (not just ping) -- Install section with cargo add -- Capabilities list: features described as user-facing abilities, not implementation details -- Protocol table with links to docs.rs -- Links to documentation - -**Never put counts in README.** Say "typed events covering the full Asterisk 23 surface" not "161 typed events". - -### 2.2 Per-crate READMEs (`crates/*/README.md`) - -Structure: -- Badges (crates.io, docs.rs) -- One-line pitch -- Code example showing the primary use case for that protocol -- Features list (capabilities, not counts) -- One-liner: "Part of asterisk-rs. MSRV X. MIT/Apache-2.0." - -### 2.3 AGENTS.md - -Drives AI agent behavior. Accuracy is critical — wrong info here means wrong code generated. - -Update against source: -- **Architecture tree** — module names, file descriptions, type names -- **Key Directories** — verify paths exist -- **Code Conventions** — workspace lints, patterns -- **Event System** — current types (AmiEvent, AriMessage wrapping AriEvent) -- **Important Files** — add/remove as needed -- **Testing section** — total test count, coverage gaps -- **CI Matrix** — match workflow files - -**No counts in descriptions.** Say "typed variants + Unknown" not "161 typed variants + Unknown". - -### 2.4 CHANGELOG.md - -Format: [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) - -- Describe what changed in terms of capability, not implementation detail -- Each item starts with crate name in backticks -- Released sections are immutable history -- Say "typed events covering all Asterisk 23 events" not "161 typed event variants" - -### 2.5 mdBook guide pages (`docs/src/`) - -Guide pages (overview, connection, events, fastagi, stasis, resources) are manually maintained. - -Rules: -- Focus on HOW to use the library, not what it contains -- Code examples use `rust,ignore` fences -- Link to reference.md for complete lists: "see [Reference](./reference.md)" -- No manually-maintained event/action/command tables — those are auto-generated -- Use current API shapes (check source before writing examples) -- Import paths use actual crate names: `asterisk_rs_ami`, `asterisk_rs_agi`, `asterisk_rs_ari` - -### 2.6 Rules (`.omp/rules/`) - -Update the `asterisk` and `rust` rules when: -- Protocol knowledge changes (new wire format details, auth mechanisms) -- Code conventions change (new derives, new patterns) -- Architecture changes (new modules, new types) - -Rules should describe the domain and conventions, not enumerate items. - -## Step 3: Verify - -1. `python3 docs/generate.py` runs without error -2. `mdbook build docs/` succeeds (if mdbook installed) -3. No stale references to removed types or modules -4. Code examples use current API (check imports against source) -5. MSRV, license text consistent across all files -6. AGENTS.md architecture matches actual directory structure -7. No raw counts in user-facing documentation (README, CHANGELOG, guide pages) - -## Anti-patterns - -- Putting counts like "161 events" or "47 commands" in READMEs or CHANGELOGs -- Manually maintaining event/action/command tables in mdBook pages -- Writing code examples without verifying against current source -- Duplicating reference content that's auto-generated -- Using `rust,no_run` fences (still compiles; use `rust,ignore` for examples needing a live server) diff --git a/.omp/skills/update-docs/reference/file-specs.md b/.omp/skills/update-docs/reference/file-specs.md deleted file mode 100644 index 3cb5122..0000000 --- a/.omp/skills/update-docs/reference/file-specs.md +++ /dev/null @@ -1,135 +0,0 @@ -# Documentation file specifications - -Structural requirements and examples for each documentation file type. - -## Root README.md - -```markdown -# {crate_name} - -{badges} - -{one_line_pitch — what users can DO, not what the library contains} - -- **AMI** -- {one sentence} -- **AGI** -- {one sentence} -- **ARI** -- {one sentence} - -## Example - -```rust,ignore -{real use case, not just ping — show subscribing to events, originating, or similar} -``` - -## Install - -```toml -[dependencies] -{crate_name} = "{version}" -``` - -## Capabilities - -- {capability described as user-facing ability} -- {not "161 typed events" but "typed events covering the full Asterisk 23 surface"} - -## Protocols - -| Protocol | Default Port | Transport | Use Case | -|----------|-------------|-----------|----------| - -## Documentation - -- [API Reference](https://docs.rs/{crate_name}) -- [User Guide]({pages_url}) - -## MSRV / License -``` - -## Per-crate README.md - -```markdown -# {crate_name} - -{badges: crates.io, docs.rs} - -{one_line_pitch} - -{2-3 sentence description of what this protocol does} - -```rust,ignore -{code example showing primary use case} -``` - -## Features - -- {capability, not count} - -Part of [asterisk-rs]({repo_url}). MSRV {msrv}. MIT/Apache-2.0. -``` - -## AGENTS.md - -The most important file. Drives AI agent behavior. - -Structure: -- Project Overview (one paragraph) -- Architecture tree (module names + one-line descriptions, no counts) -- Key Directories table -- Development Commands -- Code Conventions (lints, formatting, comment style, error handling) -- Pattern descriptions (builder, event system, handle, action trait, handler trait) -- Important Files table -- Runtime & Tooling -- CI Matrix -- Testing section (total count is fine here since it's for development, not users) -- Examples table - -**Critical**: every path, type name, and pattern description must match source. -Prefer describing patterns over enumerating items. - -## CHANGELOG.md - -```markdown -## [Unreleased] - -### Added - -- `{crate}`: {what users can now do, not implementation detail} -``` - -Describe capability changes: "event-collecting actions for multi-event responses" -not "added EventListResponse type and PendingEventList tracking". - -## mdBook guide pages - -Guide pages explain HOW. Reference pages (auto-generated) explain WHAT. - -Guide page structure: -```markdown -# {Topic} - -{1-2 paragraphs explaining the concept} - -## {Usage Pattern} - -```rust,ignore -{code showing the pattern} -``` - -{explanation of what the code does} - -See [Reference](./reference.md) for the complete list. -``` - -Never duplicate reference tables in guide pages. Link to reference.md. - -## Rules (.omp/rules/) - -Rules describe domain knowledge and conventions for AI agents. - -- `asterisk`: protocol wire format details, auth mechanisms, message types -- `rust`: error handling, async patterns, type conventions, build commands - -Rules should be stable knowledge that rarely changes. Don't enumerate -items that change with every commit (event counts, action lists). diff --git a/.omp/skills/update-docs/scripts/extract-workspace-meta.py b/.omp/skills/update-docs/scripts/extract-workspace-meta.py deleted file mode 100755 index 2b29f32..0000000 --- a/.omp/skills/update-docs/scripts/extract-workspace-meta.py +++ /dev/null @@ -1,541 +0,0 @@ -#!/usr/bin/env -S uv run --script -# /// script -# requires-python = ">=3.11" -# /// -"""scan a rust workspace and emit structured JSON metadata for doc generation.""" - -from __future__ import annotations - -import json -import re -import sys -import tomllib -from pathlib import Path - - -# --------------------------------------------------------------------------- -# cargo.toml parsing -# --------------------------------------------------------------------------- - - -def parse_toml(path: Path) -> dict: - with open(path, "rb") as f: - return tomllib.load(f) - - -def resolve_members(root: Path, patterns: list[str]) -> list[str]: - """resolve workspace member globs to actual crate directories (relative to root).""" - members: list[str] = [] - for pat in patterns: - if "*" in pat or "?" in pat: - for p in sorted(root.glob(pat)): - if (p / "Cargo.toml").is_file(): - members.append(str(p.relative_to(root))) - else: - candidate = root / pat - if (candidate / "Cargo.toml").is_file(): - members.append(pat) - return members - - -def extract_workspace(root: Path) -> dict: - cargo = parse_toml(root / "Cargo.toml") - ws = cargo.get("workspace", {}) - pkg = ws.get("package", {}) - lints = ws.get("lints", {}) - deps = ws.get("dependencies", {}) - - dep_versions: dict[str, str] = {} - for name, spec in deps.items(): - if isinstance(spec, str): - dep_versions[name] = spec - elif isinstance(spec, dict): - dep_versions[name] = spec.get("version", "") - - members = resolve_members(root, ws.get("members", [])) - - return { - "root": str(root.resolve()), - "members": members, - "package": { - "edition": pkg.get("edition", ""), - "rust_version": pkg.get("rust-version", ""), - "license": pkg.get("license", ""), - "repository": pkg.get("repository", ""), - "homepage": pkg.get("homepage", ""), - "keywords": pkg.get("keywords", []), - "categories": pkg.get("categories", []), - }, - "lints": { - section: {k: v for k, v in rules.items()} - for section, rules in lints.items() - }, - "dependencies": dep_versions, - } - - -# --------------------------------------------------------------------------- -# per-crate metadata -# --------------------------------------------------------------------------- - - -def extract_crate(root: Path, member: str, ws_pkg: dict) -> dict: - crate_dir = root / member - cargo = parse_toml(crate_dir / "Cargo.toml") - pkg = cargo.get("package", {}) - - name = pkg.get("name", "") - version = pkg.get("version", "") - description = pkg.get("description", "") - - # resolve inherited edition - edition_val = pkg.get("edition", {}) - if isinstance(edition_val, dict) and edition_val.get("workspace"): - edition = ws_pkg.get("edition", "") - elif isinstance(edition_val, str): - edition = edition_val - else: - edition = "" - - # dependencies - deps_section = cargo.get("dependencies", {}) - workspace_deps: list[str] = [] - external_deps: list[str] = [] - for dep_name, spec in deps_section.items(): - if isinstance(spec, dict) and (spec.get("workspace") or spec.get("path")): - workspace_deps.append(dep_name) - else: - external_deps.append(dep_name) - - features = cargo.get("features", {}) - dev_deps = list(cargo.get("dev-dependencies", {}).keys()) - - src_dir = crate_dir / "src" - api = scan_api_surface(src_dir) if src_dir.is_dir() else empty_api() - tests = count_tests(src_dir) if src_dir.is_dir() else {"total": 0, "files": {}} - - examples = sorted( - p.name for p in (crate_dir / "examples").glob("*.rs") - ) if (crate_dir / "examples").is_dir() else [] - - lib_rs = src_dir / "lib.rs" - modules, reexports = parse_lib_rs(lib_rs) if lib_rs.is_file() else ([], []) - - return { - "name": name, - "version": version, - "description": description, - "edition": edition, - "workspace_deps": sorted(workspace_deps), - "external_deps": sorted(external_deps), - "dependencies": sorted(workspace_deps + external_deps), - "features": features, - "dev_dependencies": sorted(dev_deps), - "api": api, - "tests": tests, - "examples": examples, - "modules": modules, - "reexports": reexports, - } - - -# --------------------------------------------------------------------------- -# public api surface scanning -# --------------------------------------------------------------------------- - -# patterns for pub items (not pub(crate), pub(super), etc.) -RE_PUB_STRUCT = re.compile(r"^pub\s+struct\s+(\w+)") -RE_PUB_ENUM = re.compile(r"^pub\s+enum\s+(\w+)") -RE_PUB_FN = re.compile(r"^pub\s+(?:async\s+)?fn\s+(\w+)") -RE_PUB_TRAIT = re.compile(r"^pub\s+trait\s+(\w+)") -RE_PUB_CONST = re.compile(r"^pub\s+const\s+(\w+)") -RE_PUB_TYPE = re.compile(r"^pub\s+type\s+(\w+)") - -# restricted visibility — skip these -RE_PUB_RESTRICTED = re.compile(r"^pub\s*\(") - -# cfg(test) module detection -RE_CFG_TEST = re.compile(r"#\[cfg\(test\)]") - - -def empty_api() -> dict: - return { - "structs": 0, - "enums": [], - "traits": 0, - "functions": 0, - "constants": 0, - "type_aliases": 0, - } - - -def scan_api_surface(src_dir: Path) -> dict: - structs = 0 - enums: list[dict] = [] - traits = 0 - functions = 0 - constants = 0 - type_aliases = 0 - - for rs_file in sorted(src_dir.rglob("*.rs")): - result = scan_file_api(rs_file) - structs += result["structs"] - enums.extend(result["enums"]) - traits += result["traits"] - functions += result["functions"] - constants += result["constants"] - type_aliases += result["type_aliases"] - - return { - "structs": structs, - "enums": enums, - "traits": traits, - "functions": functions, - "constants": constants, - "type_aliases": type_aliases, - } - - -def _extract_variant_name(stripped: str) -> str | None: - """extract enum variant name from a line known to be at enum top-level depth.""" - # skip noise - if not stripped or stripped.startswith("//") or stripped.startswith("#[") or stripped == "}": - return None - # first word before any punctuation is the variant name - word = stripped.split("(")[0].split("{")[0].split(",")[0].split("<")[0].strip() - if word and word[0].isupper() and word.isidentifier(): - return word - return None - - -def scan_file_api(path: Path) -> dict: - """scan a single .rs file for pub declarations, skipping test modules.""" - lines = path.read_text(errors="replace").splitlines() - - structs = 0 - enums: list[dict] = [] - traits = 0 - functions = 0 - constants = 0 - type_aliases = 0 - - in_test_module = False - test_brace_depth = 0 - in_block_comment = False - in_enum_body = False - enum_name = "" - enum_brace_depth = 0 - enum_variants: list[str] = [] - cfg_test_next = False - - for line in lines: - stripped = line.strip() - - # track block comments - if in_block_comment: - if "*/" in stripped: - in_block_comment = False - continue - if "/*" in stripped and "*/" not in stripped: - in_block_comment = True - continue - - # skip line comments - if stripped.startswith("//"): - continue - - # detect #[cfg(test)] - if RE_CFG_TEST.search(stripped): - cfg_test_next = True - continue - - # entering a test module - if cfg_test_next: - if "mod " in stripped: - in_test_module = True - test_brace_depth = stripped.count("{") - stripped.count("}") - cfg_test_next = False - continue - elif stripped == "" or stripped.startswith("#[") or stripped.startswith("///"): - # attribute or doc comment between cfg(test) and mod — keep flag - continue - else: - cfg_test_next = False - - # track test module brace depth - if in_test_module: - test_brace_depth += stripped.count("{") - stripped.count("}") - if test_brace_depth <= 0: - in_test_module = False - continue - - # track enum body for variant extraction - if in_enum_body: - depth_before = enum_brace_depth - enum_brace_depth += stripped.count("{") - stripped.count("}") - - if enum_brace_depth <= 0: - # enum closed — check if closing line also started a variant - if depth_before == 1: - name = _extract_variant_name(stripped) - if name: - enum_variants.append(name) - enums.append({"name": enum_name, "variants": len(enum_variants)}) - in_enum_body = False - continue - - # count variants at the top level of the enum (depth was 1 before this line) - if depth_before == 1: - name = _extract_variant_name(stripped) - if name: - enum_variants.append(name) - continue - - # skip restricted visibility - if RE_PUB_RESTRICTED.match(stripped): - continue - - # pub struct - m = RE_PUB_STRUCT.match(stripped) - if m: - structs += 1 - continue - - # pub enum — start variant tracking - m = RE_PUB_ENUM.match(stripped) - if m: - enum_name = m.group(1) - enum_variants = [] - if "{" in stripped: - in_enum_body = True - enum_brace_depth = stripped.count("{") - stripped.count("}") - if enum_brace_depth <= 0: - enums.append({"name": enum_name, "variants": 0}) - in_enum_body = False - continue - - # pub fn / pub async fn - m = RE_PUB_FN.match(stripped) - if m: - functions += 1 - continue - - # pub trait - m = RE_PUB_TRAIT.match(stripped) - if m: - traits += 1 - continue - - # pub const - m = RE_PUB_CONST.match(stripped) - if m: - constants += 1 - continue - - # pub type - m = RE_PUB_TYPE.match(stripped) - if m: - type_aliases += 1 - continue - - return { - "structs": structs, - "enums": enums, - "traits": traits, - "functions": functions, - "constants": constants, - "type_aliases": type_aliases, - } - - -# --------------------------------------------------------------------------- -# test inventory -# --------------------------------------------------------------------------- - -RE_TEST_ATTR = re.compile(r"#\[(tokio::)?test") - - -def count_tests(src_dir: Path) -> dict: - total = 0 - files: dict[str, int] = {} - - for rs_file in sorted(src_dir.rglob("*.rs")): - content = rs_file.read_text(errors="replace") - count = len(RE_TEST_ATTR.findall(content)) - if count > 0: - rel = str(rs_file.relative_to(src_dir)) - files[rel] = count - total += count - - return {"total": total, "files": files} - - -# --------------------------------------------------------------------------- -# lib.rs module structure -# --------------------------------------------------------------------------- - -RE_PUB_MOD = re.compile(r"^pub\s+mod\s+(\w+)") -RE_PUB_USE = re.compile(r"^pub\s+use\s+\w+::(?:\{([^}]+)\}|(\w+))") - - -def parse_lib_rs(path: Path) -> tuple[list[str], list[str]]: - content = path.read_text(errors="replace") - modules: list[str] = [] - reexports: list[str] = [] - - for line in content.splitlines(): - stripped = line.strip() - if stripped.startswith("//"): - continue - - m = RE_PUB_MOD.match(stripped) - if m: - modules.append(m.group(1)) - continue - - m = RE_PUB_USE.match(stripped) - if m: - if m.group(1): - for item in m.group(1).split(","): - name = item.strip() - if name: - reexports.append(name) - elif m.group(2): - reexports.append(m.group(2)) - - return modules, reexports - - -# --------------------------------------------------------------------------- -# documentation files -# --------------------------------------------------------------------------- - - -def find_docs(root: Path, members: list[str]) -> dict: - root_files = sorted( - p.name for p in root.glob("*.md") if p.is_file() - ) - - crate_readmes: dict[str, str] = {} - crate_changelogs: dict[str, str] = {} - for member in members: - crate_dir = root / member - cargo = parse_toml(crate_dir / "Cargo.toml") - crate_name = cargo.get("package", {}).get("name", Path(member).name) - - readme = crate_dir / "README.md" - if readme.is_file(): - crate_readmes[crate_name] = str(readme.relative_to(root)) - - changelog = crate_dir / "CHANGELOG.md" - if changelog.is_file(): - crate_changelogs[crate_name] = str(changelog.relative_to(root)) - - mdbook: dict = {} - summary = root / "docs" / "src" / "SUMMARY.md" - if summary.is_file(): - mdbook["summary"] = str(summary.relative_to(root)) - mdbook["pages"] = sorted( - str(p.relative_to(root)) - for p in (root / "docs" / "src").rglob("*.md") - if p.is_file() - ) - - return { - "root_files": root_files, - "crate_readmes": crate_readmes, - "crate_changelogs": crate_changelogs, - "mdbook": mdbook, - } - - -# --------------------------------------------------------------------------- -# ci workflows -# --------------------------------------------------------------------------- - - -def scan_workflows(root: Path) -> list[dict]: - wf_dir = root / ".github" / "workflows" - if not wf_dir.is_dir(): - return [] - - workflows: list[dict] = [] - for yml in sorted(wf_dir.glob("*.yml")): - content = yml.read_text(errors="replace") - name_match = re.search(r"^name:\s*(.+)$", content, re.MULTILINE) - wf_name = name_match.group(1).strip().strip('"').strip("'") if name_match else yml.stem - - jobs: list[str] = [] - in_jobs = False - job_indent: int | None = None - for line in content.splitlines(): - if line.rstrip() == "jobs:" or re.match(r"^jobs:\s*$", line): - in_jobs = True - job_indent = None - continue - if in_jobs: - if not line or not line.strip(): - continue - if job_indent is None: - stripped = line.lstrip() - if stripped and not stripped.startswith("#"): - job_indent = len(line) - len(stripped) - if job_indent is not None: - if line[0] not in (" ", "\t"): - break - indent = len(line) - len(line.lstrip()) - if indent == job_indent: - m = re.match(r"\s+(\w[\w-]*):", line) - if m: - jobs.append(m.group(1)) - - workflows.append({ - "file": yml.name, - "name": wf_name, - "jobs": jobs, - }) - - return workflows - - -# --------------------------------------------------------------------------- -# main -# --------------------------------------------------------------------------- - - -def main() -> None: - if len(sys.argv) > 1: - root = Path(sys.argv[1]).resolve() - else: - root = Path.cwd().resolve() - - root_cargo = root / "Cargo.toml" - if not root_cargo.is_file(): - print(f"error: no Cargo.toml found at {root}", file=sys.stderr) - sys.exit(1) - - workspace = extract_workspace(root) - ws_pkg = workspace["package"] - members = workspace["members"] - - crates: dict[str, dict] = {} - for member in members: - crate_data = extract_crate(root, member, ws_pkg) - crates[crate_data["name"]] = crate_data - - docs = find_docs(root, members) - ci = scan_workflows(root) - - output = { - "workspace": workspace, - "crates": crates, - "docs": docs, - "ci": {"workflows": ci}, - } - - json.dump(output, sys.stdout, indent=2) - sys.stdout.write("\n") - - -if __name__ == "__main__": - main() diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..90b0908 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,11 @@ +# asterisk-rs + +See [AGENTS.md](./AGENTS.md) for full project guidelines, architecture, conventions, and commands. + +## Agent Instructions + +- after editing any `.rs` file, run `cargo test -p asterisk-rs-tests --test unit` to catch regressions fast +- after editing codec, connection, or transport modules, also run `cargo test -p asterisk-rs-tests --test mock_integration` +- use `cargo clippy --workspace --all-targets --all-features -- -D warnings` before committing +- all tests live in the external `tests/` crate — never add `#[cfg(test)]` to production code +- breaking API changes require a version bump (CI runs cargo-semver-checks on PRs) diff --git a/Cargo.lock b/Cargo.lock index cf99d11..01b5fb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -924,9 +924,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.9" +version = "0.103.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" dependencies = [ "ring", "rustls-pki-types", diff --git a/crates/asterisk-rs-agi/src/channel.rs b/crates/asterisk-rs-agi/src/channel.rs index b6024ad..ee08e2d 100644 --- a/crates/asterisk-rs-agi/src/channel.rs +++ b/crates/asterisk-rs-agi/src/channel.rs @@ -50,14 +50,14 @@ impl AgiChannel { /// /// the command should already be formatted with a trailing newline. pub async fn send_command(&mut self, command: &str) -> Result { + if self.hung_up { + return Err(AgiError::ChannelHungUp); + } match self.state { ChannelState::InFlight => return Err(AgiError::CommandInFlight), ChannelState::Poisoned => return Err(AgiError::ChannelPoisoned), ChannelState::Ready => {} } - if self.hung_up { - return Err(AgiError::ChannelHungUp); - } // mark in-flight before write so that a cancellation between write and // read is visible to the next caller @@ -83,7 +83,7 @@ impl AgiChannel { if bytes_read == 0 { self.hung_up = true; - self.state = ChannelState::Ready; + self.state = ChannelState::Poisoned; return Err(AgiError::ChannelHungUp); } @@ -103,7 +103,7 @@ impl AgiChannel { }; if n == 0 { self.hung_up = true; - self.state = ChannelState::Ready; + self.state = ChannelState::Poisoned; return Err(AgiError::ChannelHungUp); } let trimmed = next.trim(); @@ -133,7 +133,7 @@ impl AgiChannel { // 511 means the channel is dead if response.code == 511 { self.hung_up = true; - self.state = ChannelState::Ready; + self.state = ChannelState::Poisoned; return Err(AgiError::ChannelHungUp); } diff --git a/crates/asterisk-rs-agi/src/command.rs b/crates/asterisk-rs-agi/src/command.rs index 8dd9661..9bf7b81 100644 --- a/crates/asterisk-rs-agi/src/command.rs +++ b/crates/asterisk-rs-agi/src/command.rs @@ -82,6 +82,9 @@ pub fn format_command(name: &str, args: &[&str]) -> Result { if arg.is_empty() || arg.contains(' ') || arg.contains('"') { cmd.push('"'); for ch in arg.chars() { + if ch == '\\' { + cmd.push('\\'); + } if ch == '"' { cmd.push('\\'); } diff --git a/crates/asterisk-rs-agi/src/error.rs b/crates/asterisk-rs-agi/src/error.rs index 267154b..3074317 100644 --- a/crates/asterisk-rs-agi/src/error.rs +++ b/crates/asterisk-rs-agi/src/error.rs @@ -27,6 +27,9 @@ pub enum AgiError { #[error("protocol error: {0}")] Protocol(#[from] ProtocolError), + + #[error("invalid configuration: {details}")] + InvalidConfig { details: String }, } pub type Result = std::result::Result; diff --git a/crates/asterisk-rs-agi/src/request.rs b/crates/asterisk-rs-agi/src/request.rs index 267dbb0..35cd335 100644 --- a/crates/asterisk-rs-agi/src/request.rs +++ b/crates/asterisk-rs-agi/src/request.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use tokio::io::AsyncBufReadExt; +use tokio::io::{AsyncBufReadExt, AsyncReadExt}; /// parsed AGI request environment sent by asterisk on connection #[derive(Debug, Clone)] @@ -22,15 +22,17 @@ impl AgiRequest { loop { line.clear(); - let bytes_read = reader.read_line(&mut line).await?; + // limit bytes read per line to prevent OOM from a malicious client + // sending an unbounded line without a newline + let bytes_read = (&mut *reader).take(8193).read_line(&mut line).await?; // eof or blank line terminates the environment block if bytes_read == 0 || line.trim().is_empty() { break; } - // guard against maliciously large lines exhausting memory - if line.len() > 8192 { + // reject lines that hit the byte limit without a newline + if line.len() >= 8193 && !line.ends_with('\n') { return Err(crate::error::AgiError::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "agi prelude line exceeds 8192 bytes", diff --git a/crates/asterisk-rs-agi/src/response.rs b/crates/asterisk-rs-agi/src/response.rs index 4f0a59d..96ee92b 100644 --- a/crates/asterisk-rs-agi/src/response.rs +++ b/crates/asterisk-rs-agi/src/response.rs @@ -63,9 +63,9 @@ impl AgiResponse { // extract optional parenthesized data let (data, remainder) = if let Some(start) = remainder.find('(') { - if let Some(end) = remainder[start..].find(')') { - let data_str = &remainder[start + 1..start + end]; - let after = remainder[start + end + 1..].trim(); + if let Some(end) = remainder.rfind(')') { + let data_str = &remainder[start + 1..end]; + let after = remainder[end + 1..].trim(); (Some(data_str.to_owned()), after) } else { (None, remainder) diff --git a/crates/asterisk-rs-agi/src/server.rs b/crates/asterisk-rs-agi/src/server.rs index 6b09f39..4922c72 100644 --- a/crates/asterisk-rs-agi/src/server.rs +++ b/crates/asterisk-rs-agi/src/server.rs @@ -4,6 +4,7 @@ use std::time::Duration; use tokio::io::BufReader; use tokio::net::TcpListener; use tokio::sync::{watch, Semaphore}; +use tokio::task::JoinHandle; use crate::channel::AgiChannel; use crate::error::{AgiError, Result}; @@ -43,7 +44,7 @@ impl AgiServer { /// create a new builder for configuring the server pub fn builder() -> AgiServerBuilder { AgiServerBuilder { - bind_addr: "0.0.0.0:4573".to_owned(), + bind_addr: "127.0.0.1:4573".to_owned(), handler: None, max_connections: None, } @@ -54,6 +55,7 @@ impl AgiServer { /// runs until shutdown is signaled or an unrecoverable error occurs pub async fn run(mut self) -> Result<()> { let semaphore = self.max_connections.map(|n| Arc::new(Semaphore::new(n))); + let mut handles: Vec> = Vec::new(); loop { tokio::select! { @@ -96,14 +98,17 @@ impl AgiServer { None }; - tokio::spawn(async move { + // prune completed handles to prevent unbounded growth + handles.retain(|h| !h.is_finished()); + + handles.push(tokio::spawn(async move { // permit is held until the task completes, then dropped automatically let _permit = permit; if let Err(err) = handle_connection(handler, stream).await { tracing::warn!(%peer, %err, "AGI session error"); } - }); + })); } result = self.shutdown_rx.changed() => { // Err means all senders were dropped — treat as shutdown signal @@ -166,11 +171,8 @@ impl AgiServerBuilder { /// /// returns the server and a handle that can signal graceful shutdown pub async fn build(self) -> Result<(AgiServer, ShutdownHandle)> { - let handler = self.handler.ok_or_else(|| { - AgiError::Io(std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "handler is required", - )) + let handler = self.handler.ok_or_else(|| AgiError::InvalidConfig { + details: "handler is required".to_owned(), })?; let listener = TcpListener::bind(&self.bind_addr).await?; diff --git a/crates/asterisk-rs-ami/Cargo.toml b/crates/asterisk-rs-ami/Cargo.toml index c49c92c..f5876c3 100644 --- a/crates/asterisk-rs-ami/Cargo.toml +++ b/crates/asterisk-rs-ami/Cargo.toml @@ -24,6 +24,7 @@ md-5.workspace = true futures-util.workspace = true pin-project-lite.workspace = true serde.workspace = true +zeroize.workspace = true [dev-dependencies] tracing-subscriber.workspace = true diff --git a/crates/asterisk-rs-ami/src/action.rs b/crates/asterisk-rs-ami/src/action.rs index db029e8..97b3200 100644 --- a/crates/asterisk-rs-ami/src/action.rs +++ b/crates/asterisk-rs-ami/src/action.rs @@ -4,8 +4,11 @@ use crate::codec::RawAmiMessage; use std::sync::atomic::{AtomicU64, Ordering}; +use zeroize::Zeroizing; -/// global action ID counter +// relaxed is sufficient: fetch_add is an atomic RMW — it cannot return +// the same value to two threads. no other memory operations need +// ordering relative to this counter static ACTION_ID_COUNTER: AtomicU64 = AtomicU64::new(1); /// generate a unique action ID @@ -45,7 +48,29 @@ pub trait AmiAction { /// login with plaintext credentials pub struct LoginAction { pub username: String, - pub secret: String, + secret: Zeroizing, +} + +impl LoginAction { + pub fn new(username: impl Into, secret: impl Into) -> Self { + Self { + username: username.into(), + secret: Zeroizing::new(secret.into()), + } + } + + pub fn secret(&self) -> &str { + &self.secret + } +} + +impl std::fmt::Debug for LoginAction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LoginAction") + .field("username", &self.username) + .field("secret", &"[REDACTED]") + .finish() + } } impl AmiAction for LoginAction { @@ -56,7 +81,7 @@ impl AmiAction for LoginAction { fn to_headers(&self) -> Vec<(String, String)> { vec![ ("Username".into(), self.username.clone()), - ("Secret".into(), self.secret.clone()), + ("Secret".into(), self.secret().to_string()), ] } } diff --git a/crates/asterisk-rs-ami/src/client.rs b/crates/asterisk-rs-ami/src/client.rs index ceadc9a..6396d83 100644 --- a/crates/asterisk-rs-ami/src/client.rs +++ b/crates/asterisk-rs-ami/src/client.rs @@ -172,6 +172,7 @@ pub struct AmiClientBuilder { timeout: Duration, event_capacity: usize, ping_interval: Option, + require_challenge: bool, } impl Default for AmiClientBuilder { @@ -184,6 +185,7 @@ impl Default for AmiClientBuilder { timeout: DEFAULT_TIMEOUT, event_capacity: 1024, ping_interval: None, + require_challenge: true, } } } @@ -225,6 +227,17 @@ impl AmiClientBuilder { self } + /// allow plaintext login fallback when MD5 challenge auth fails + /// + /// when `true` (the default), login fails if the server does not + /// support challenge-response authentication. set to `false` only + /// for connections over a trusted loopback — plaintext login sends + /// the secret in cleartext. + pub fn require_challenge(mut self, require: bool) -> Self { + self.require_challenge = require; + self + } + /// set the interval for keep-alive pings /// /// when set, the client sends periodic Ping actions to detect @@ -263,6 +276,7 @@ impl AmiClientBuilder { event_bus.clone(), self.reconnect_policy, self.ping_interval, + self.require_challenge, ); // wait for connection + login to complete diff --git a/crates/asterisk-rs-ami/src/codec.rs b/crates/asterisk-rs-ami/src/codec.rs index be93239..41da080 100644 --- a/crates/asterisk-rs-ami/src/codec.rs +++ b/crates/asterisk-rs-ami/src/codec.rs @@ -96,12 +96,10 @@ impl Decoder for AmiCodec { let line = &src[..pos]; // validate it looks like an AMI banner if !line.starts_with(b"Asterisk Call Manager") { + let preview = String::from_utf8_lossy(&line[..line.len().min(64)]); return Err(AmiError::Protocol( asterisk_rs_core::error::ProtocolError::MalformedMessage { - details: format!( - "expected AMI banner, got: {}", - String::from_utf8_lossy(line) - ), + details: format!("expected AMI banner, got: {}", preview), }, )); } @@ -119,98 +117,101 @@ impl Decoder for AmiCodec { // we must not accept a \r\n\r\n that appears before the end marker. const END_MARKER: &[u8] = b"--END COMMAND--"; - let first_blank = match find_double_crlf(src) { - Some(pos) => pos, - None => return Ok(None), - }; + // loop to skip empty frames instead of recursing + loop { + let first_blank = match find_double_crlf(src) { + Some(pos) => pos, + None => return Ok(None), + }; - // peek: does this frame contain a Follows header? - // if so, the real terminator is \r\n\r\n *after* --END COMMAND-- - let frame_end = if is_follows_response(&src[..first_blank]) { - // the marker may appear after the first \r\n\r\n because the - // output body can contain blank lines in some edge cases. - // scan the entire buffer for --END COMMAND--\r\n\r\n - match find_subsequence(src, END_MARKER) { - Some(marker_pos) => { - let after_marker = marker_pos + END_MARKER.len(); - // expect \r\n after the marker (Asterisk always sends it) - if src.len() < after_marker + 2 { - return Ok(None); - } - // then look for \r\n\r\n immediately after the marker line - if &src[after_marker..after_marker + 2] != b"\r\n" { - return Ok(None); + // peek: does this frame contain a Follows header? + // if so, the real terminator is \r\n\r\n *after* --END COMMAND-- + let frame_end = if is_follows_response(&src[..first_blank]) { + // the marker may appear after the first \r\n\r\n because the + // output body can contain blank lines in some edge cases. + // scan the entire buffer for --END COMMAND--\r\n\r\n + match find_subsequence(src, END_MARKER) { + Some(marker_pos) => { + let after_marker = marker_pos + END_MARKER.len(); + // expect \r\n after the marker (Asterisk always sends it) + if src.len() < after_marker + 2 { + return Ok(None); + } + // then look for \r\n\r\n immediately after the marker line + if &src[after_marker..after_marker + 2] != b"\r\n" { + return Ok(None); + } + // frame ends after marker + \r\n + after_marker + 2 } - // frame ends after marker + \r\n - after_marker + 2 + None => return Ok(None), } - None => return Ok(None), - } - } else { - // regular message: frame ends at first \r\n\r\n + 4 - first_blank + 4 - }; + } else { + // regular message: frame ends at first \r\n\r\n + 4 + first_blank + 4 + }; - // size check on the individual message, not the whole buffer - if frame_end > MAX_MESSAGE_SIZE { - return Err(AmiError::Protocol( - asterisk_rs_core::error::ProtocolError::MalformedMessage { - details: format!("message exceeds {} byte limit", MAX_MESSAGE_SIZE), - }, - )); - } + // size check on the individual message, not the whole buffer + if frame_end > MAX_MESSAGE_SIZE { + return Err(AmiError::Protocol( + asterisk_rs_core::error::ProtocolError::MalformedMessage { + details: format!("message exceeds {} byte limit", MAX_MESSAGE_SIZE), + }, + )); + } - // parse all lines in the frame: key:value pairs go to headers, - // everything else goes to output (command body for Response: Follows) - let message_bytes = &src[..frame_end]; - let mut headers = Vec::new(); - let mut output = Vec::new(); - let mut channel_variables = HashMap::new(); + // parse all lines in the frame: key:value pairs go to headers, + // everything else goes to output (command body for Response: Follows) + let message_bytes = &src[..frame_end]; + let mut headers = Vec::new(); + let mut output = Vec::new(); + let mut channel_variables = HashMap::new(); - for line in message_bytes.split(|&b| b == b'\n') { - let line = line.strip_suffix(b"\r").unwrap_or(line); - if line.is_empty() { - continue; - } - if line == END_MARKER { - continue; - } - if let Some(colon_pos) = line.iter().position(|&b| b == b':') { - let key = String::from_utf8_lossy(&line[..colon_pos]) - .trim() - .to_string(); - let value_start = colon_pos + 1; - let value = if value_start < line.len() { - String::from_utf8_lossy(&line[value_start..]) + for line in message_bytes.split(|&b| b == b'\n') { + let line = line.strip_suffix(b"\r").unwrap_or(line); + if line.is_empty() { + continue; + } + if line == END_MARKER { + continue; + } + if let Some(colon_pos) = line.iter().position(|&b| b == b':') { + let key = String::from_utf8_lossy(&line[..colon_pos]) .trim() - .to_string() - } else { - String::new() - }; - if key.starts_with("ChanVariable(") && key.ends_with(')') { - let var_name = &key["ChanVariable(".len()..key.len() - 1]; - channel_variables.insert(var_name.to_string(), value); + .to_string(); + let value_start = colon_pos + 1; + let value = if value_start < line.len() { + String::from_utf8_lossy(&line[value_start..]) + .trim() + .to_string() + } else { + String::new() + }; + if key.starts_with("ChanVariable(") && key.ends_with(')') { + let var_name = &key["ChanVariable(".len()..key.len() - 1]; + channel_variables.insert(var_name.to_string(), value); + } else { + headers.push((key, value)); + } } else { - headers.push((key, value)); + // non-key-value line: command output + output.push(String::from_utf8_lossy(line).into_owned()); } - } else { - // non-key-value line: command output - output.push(String::from_utf8_lossy(line).into_owned()); } - } - src.advance(frame_end); + src.advance(frame_end); - if headers.is_empty() { - // empty message, try next - return self.decode(src); - } + if headers.is_empty() { + // empty frame, skip and try next + continue; + } - Ok(Some(RawAmiMessage { - headers, - output, - channel_variables, - })) + return Ok(Some(RawAmiMessage { + headers, + output, + channel_variables, + })); + } } } @@ -285,11 +286,19 @@ fn find_double_crlf(buf: &[u8]) -> Option { buf.windows(4).position(|w| w == b"\r\n\r\n") } -/// returns true if the header block contains `Response: Follows` +/// returns true if the header block contains a `Response: Follows` header, +/// tolerating optional whitespace after the colon (e.g. `Response:Follows`) fn is_follows_response(header_bytes: &[u8]) -> bool { header_bytes.split(|&b| b == b'\n').any(|line| { let line = line.strip_suffix(b"\r").unwrap_or(line); - line.eq_ignore_ascii_case(b"response: follows") + if let Some(colon_pos) = line.iter().position(|&b| b == b':') { + let key = &line[..colon_pos]; + let value = &line[colon_pos + 1..]; + let value_trimmed = value.strip_prefix(b" ").unwrap_or(value); + key.eq_ignore_ascii_case(b"response") && value_trimmed.eq_ignore_ascii_case(b"follows") + } else { + false + } }) } diff --git a/crates/asterisk-rs-ami/src/connection.rs b/crates/asterisk-rs-ami/src/connection.rs index 8f3a35c..be4c186 100644 --- a/crates/asterisk-rs-ami/src/connection.rs +++ b/crates/asterisk-rs-ami/src/connection.rs @@ -10,6 +10,7 @@ use asterisk_rs_core::config::{ConnectionState, ReconnectPolicy}; use asterisk_rs_core::event::EventBus; use futures_util::{SinkExt, StreamExt}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; @@ -48,6 +49,7 @@ impl ConnectionManager { event_bus: EventBus, reconnect_policy: ReconnectPolicy, ping_interval: Option, + require_challenge: bool, ) -> Self { let (command_tx, command_rx) = mpsc::channel(256); let (state_tx, state_rx) = watch::channel(ConnectionState::Disconnected); @@ -60,6 +62,7 @@ impl ConnectionManager { state_tx, reconnect_policy, ping_interval, + require_challenge, )); Self { @@ -96,6 +99,7 @@ impl ConnectionManager { } } +#[allow(clippy::too_many_arguments)] async fn connection_task( address: String, credentials: Credentials, @@ -104,6 +108,7 @@ async fn connection_task( state_tx: watch::Sender, reconnect_policy: ReconnectPolicy, ping_interval: Option, + require_challenge: bool, ) { let pending = Arc::new(Mutex::new(PendingActions::new())); let mut attempt: u32 = 0; @@ -124,7 +129,7 @@ async fn connection_task( // 30s covers the full login exchange (challenge + auth RTTs) let login_result = tokio::time::timeout( Duration::from_secs(30), - perform_login(&credentials, &mut reader, &mut writer), + perform_login(&credentials, &mut reader, &mut writer, require_challenge), ) .await; match login_result { @@ -147,17 +152,25 @@ async fn connection_task( if let Some(ref mut timer) = ping_timer { timer.tick().await; // consume the immediate first tick } - // tracks the receiver for the most recently sent ping - let mut pending_pong_rx: Option> = None; + // shared flag set by dispatch_message when a pong arrives; + // avoids the try_recv race where select! picks the timer + // arm before the reader arm has dispatched a buffered pong + let pong_received = Arc::new(AtomicBool::new(false)); + let mut awaiting_pong = false; // process messages until disconnect loop { + // biased: always drain incoming frames before checking + // the ping timer, preventing false "missed pong" when the + // response is buffered but not yet dispatched tokio::select! { - // incoming message from AMI + biased; + + // incoming message from AMI (highest priority) frame = reader.next() => { match frame { Some(Ok(raw)) => { - dispatch_message(raw, &pending, &event_bus).await; + dispatch_message(raw, &pending, &event_bus, &pong_received).await; } Some(Err(e)) => { tracing::error!(error = %e, "AMI codec error"); @@ -198,31 +211,23 @@ async fn connection_task( } } } - // keep-alive ping + // keep-alive ping (lowest priority due to biased select) _ = async { match ping_timer.as_mut() { Some(timer) => timer.tick().await, None => std::future::pending().await, } } => { - // if we sent a ping and haven't received the pong yet, - // the connection is dead - if let Some(mut rx) = pending_pong_rx.take() { - match rx.try_recv() { - Ok(_) => {} // pong received in time - Err(tokio::sync::oneshot::error::TryRecvError::Empty) => { - tracing::warn!("keep-alive pong not received, treating connection as dead"); - break; - } - Err(tokio::sync::oneshot::error::TryRecvError::Closed) => { - tracing::warn!("keep-alive pong channel closed unexpectedly"); - break; - } - } + if awaiting_pong && !pong_received.load(Ordering::Acquire) { + tracing::warn!("keep-alive pong not received, treating connection as dead"); + break; } + // send a new ping + pong_received.store(false, Ordering::Release); let (action_id, ping_msg) = PingAction.to_message(); - let pong_rx = pending.lock().await.register(action_id); - pending_pong_rx = Some(pong_rx); + // register so the response routes through dispatch_message + let _pong_rx = pending.lock().await.register(action_id); + awaiting_pong = true; if let Err(e) = writer.send(ping_msg).await { tracing::warn!(error = %e, "keep-alive ping failed, reconnecting"); break; @@ -266,9 +271,16 @@ async fn connection_task( let _ = state_tx.send(ConnectionState::Disconnected); return; } - Some(_) => { - // non-shutdown command while disconnected; it will time out on the - // caller side — nothing we can do until reconnected + Some(ConnectionCommand::SendAction { response_tx, .. }) => { + // fail fast: drop the sender so the caller gets + // ResponseChannelClosed immediately instead of waiting + // for the action timeout to expire + tracing::debug!("dropping action received during reconnect backoff"); + drop(response_tx); + } + Some(ConnectionCommand::SendEventGeneratingAction { response_tx, .. }) => { + tracing::debug!("dropping event-list action received during reconnect backoff"); + drop(response_tx); } } } @@ -279,11 +291,14 @@ async fn connection_task( /// perform the AMI login sequence over the raw framed connection /// -/// tries MD5 challenge-response first, falls back to plaintext +/// tries MD5 challenge-response first. when `require_challenge` is +/// false, falls back to plaintext login (only safe over trusted +/// loopback connections). async fn perform_login( credentials: &Credentials, reader: &mut FramedRead, writer: &mut FramedWrite, + require_challenge: bool, ) -> Result<()> { // try MD5 challenge-response first let (_, challenge_msg) = ChallengeAction.to_message(); @@ -313,11 +328,20 @@ async fn perform_login( } } + // challenge auth did not produce a Challenge field + if require_challenge { + return Err(AmiError::Auth( + asterisk_rs_core::error::AuthError::Rejected { + reason: "server did not provide MD5 challenge; plaintext fallback is disabled \ + (set require_challenge(false) for trusted loopback connections)" + .to_owned(), + }, + )); + } + // fall back to plaintext - let login = LoginAction { - username: credentials.username().to_string(), - secret: credentials.secret().to_string(), - }; + tracing::warn!("MD5 challenge auth unavailable, falling back to plaintext login"); + let login = LoginAction::new(credentials.username(), credentials.secret()); let (_, login_msg) = login.to_message(); writer.send(login_msg).await?; @@ -362,19 +386,24 @@ async fn dispatch_message( raw: RawAmiMessage, pending: &Arc>, event_bus: &EventBus, + pong_received: &AtomicBool, ) { // try as response first if let Some(response) = AmiResponse::from_raw(&raw) { let mut guard = pending.lock().await; - // check if this is for an event-generating action - if guard.deliver_event_list_response(response.clone()) { + if guard.contains_event_list(&response.action_id) { + // any inbound response proves the connection is alive + pong_received.store(true, Ordering::Release); + guard.deliver_event_list_response(response); return; } // regular action response let action_id = response.action_id.clone(); - if !guard.deliver(response) { + if guard.deliver(response) { + pong_received.store(true, Ordering::Release); + } else { tracing::debug!(action_id, "received response for unknown action"); } return; diff --git a/crates/asterisk-rs-ami/src/lib.rs b/crates/asterisk-rs-ami/src/lib.rs index 6b5440d..cd315e5 100644 --- a/crates/asterisk-rs-ami/src/lib.rs +++ b/crates/asterisk-rs-ami/src/lib.rs @@ -27,14 +27,15 @@ pub mod action; pub mod client; pub mod codec; -pub mod connection; +pub(crate) mod connection; pub mod error; pub mod event; pub mod response; pub mod tracker; pub use client::{AmiClient, AmiClientBuilder}; +pub use codec::{AmiCodec, RawAmiMessage}; pub use error::AmiError; pub use event::AmiEvent; -pub use response::EventListResponse; +pub use response::{AmiResponse, EventListResponse}; pub use tracker::{CallTracker, CompletedCall}; diff --git a/crates/asterisk-rs-ami/src/response.rs b/crates/asterisk-rs-ami/src/response.rs index 7a7b21a..02b6bf2 100644 --- a/crates/asterisk-rs-ami/src/response.rs +++ b/crates/asterisk-rs-ami/src/response.rs @@ -156,6 +156,11 @@ impl PendingActions { ); } + /// check whether an action_id has a pending event list + pub fn contains_event_list(&self, action_id: &str) -> bool { + self.pending_event_lists.contains_key(action_id) + } + /// deliver the initial response for an event-generating action /// /// returns true if this action_id has a pending event list diff --git a/crates/asterisk-rs-ami/src/tracker.rs b/crates/asterisk-rs-ami/src/tracker.rs index 9042be5..00f29b4 100644 --- a/crates/asterisk-rs-ami/src/tracker.rs +++ b/crates/asterisk-rs-ami/src/tracker.rs @@ -1,6 +1,8 @@ //! call correlation engine — tracks AMI events by UniqueID into call lifecycle objects. use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{mpsc, watch}; @@ -47,6 +49,7 @@ struct ActiveCall { pub struct CallTracker { shutdown_tx: watch::Sender, task_handle: tokio::task::JoinHandle<()>, + dropped_count: Arc, } impl std::fmt::Debug for CallTracker { @@ -60,22 +63,30 @@ impl CallTracker { pub fn new(subscription: EventSubscription) -> (Self, mpsc::Receiver) { let (completed_tx, completed_rx) = mpsc::channel(256); let (shutdown_tx, shutdown_rx) = watch::channel(false); + let dropped_count = Arc::new(AtomicU64::new(0)); let task_handle = tokio::spawn(track_loop( subscription, completed_tx, shutdown_rx, DEFAULT_CALL_TTL, + Arc::clone(&dropped_count), )); let tracker = Self { shutdown_tx, task_handle, + dropped_count, }; (tracker, completed_rx) } + /// number of completed calls dropped because the receiver channel was full or closed + pub fn dropped_count(&self) -> u64 { + self.dropped_count.load(Ordering::Relaxed) + } + /// stop the background tracking task pub fn shutdown(&self) { let _ = self.shutdown_tx.send(true); @@ -96,6 +107,7 @@ async fn track_loop( completed_tx: mpsc::Sender, mut shutdown_rx: watch::Receiver, ttl: Duration, + dropped_count: Arc, ) { let mut active: HashMap = HashMap::new(); @@ -103,8 +115,8 @@ async fn track_loop( tokio::select! { event = subscription.recv() => { let Some(event) = event else { break }; - evict_stale(&mut active, &completed_tx, ttl); - handle_event(&mut active, &completed_tx, event); + evict_stale(&mut active, &completed_tx, ttl, &dropped_count); + handle_event(&mut active, &completed_tx, event, &dropped_count); } _ = shutdown_rx.changed() => { break; @@ -117,6 +129,7 @@ fn handle_event( active: &mut HashMap, completed_tx: &mpsc::Sender, event: AmiEvent, + dropped_count: &AtomicU64, ) { // handle new channel creation if let AmiEvent::NewChannel { @@ -176,6 +189,7 @@ fn handle_event( }; // receiver may have been dropped or channel full — drop rather than block the tracker if completed_tx.try_send(completed).is_err() { + dropped_count.fetch_add(1, Ordering::Relaxed); tracing::warn!("completed_tx full or closed, dropping completed call"); } } @@ -399,6 +413,7 @@ fn evict_stale( active: &mut HashMap, completed_tx: &mpsc::Sender, ttl: Duration, + dropped_count: &AtomicU64, ) { let now = Instant::now(); active.retain(|_, call| { @@ -417,6 +432,7 @@ fn evict_stale( events: std::mem::take(&mut call.events), }; if completed_tx.try_send(completed).is_err() { + dropped_count.fetch_add(1, Ordering::Relaxed); tracing::warn!(unique_id = %call.unique_id, "completed_tx full, dropping stale evicted call"); } false diff --git a/crates/asterisk-rs-ari/src/client.rs b/crates/asterisk-rs-ari/src/client.rs index c5e1b41..e2969ba 100644 --- a/crates/asterisk-rs-ari/src/client.rs +++ b/crates/asterisk-rs-ari/src/client.rs @@ -27,8 +27,8 @@ pub struct AriClient { impl std::fmt::Debug for AriClient { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AriClient") - .field("base_url", &self.config.base_url) - .field("transport_mode", &self.config.transport_mode) + .field("base_url", self.config.base_url()) + .field("transport_mode", &self.config.transport_mode()) .finish_non_exhaustive() } } @@ -41,23 +41,22 @@ impl AriClient { pub async fn connect(config: AriConfig) -> Result { let event_bus = EventBus::new(256); - let transport = match config.transport_mode { + let transport = match config.transport_mode() { TransportMode::Http => { let http = HttpTransport::new( - config.base_url.as_str(), - config.username.clone(), - config.password.clone(), - config.ws_url.to_string(), + config.base_url().as_str(), + config.credentials().clone(), + config.ws_url().to_string(), event_bus.clone(), - config.reconnect_policy.clone(), + config.reconnect_policy().clone(), )?; TransportInner::Http(http) } TransportMode::WebSocket => { let ws = WsTransport::spawn( - config.ws_url.to_string(), + config.ws_url().to_string(), event_bus.clone(), - config.reconnect_policy.clone(), + config.reconnect_policy().clone(), ); TransportInner::WebSocket(ws) } diff --git a/crates/asterisk-rs-ari/src/config.rs b/crates/asterisk-rs-ari/src/config.rs index f291d7d..a10f7a0 100644 --- a/crates/asterisk-rs-ari/src/config.rs +++ b/crates/asterisk-rs-ari/src/config.rs @@ -1,5 +1,6 @@ //! ARI client configuration and builder. +use asterisk_rs_core::auth::Credentials; use asterisk_rs_core::config::ReconnectPolicy; use url::Url; @@ -19,22 +20,65 @@ pub enum TransportMode { } /// ARI connection configuration -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct AriConfig { /// http base url for rest requests - pub base_url: Url, - /// ari username - pub username: String, - /// ari password - pub password: String, + pub(crate) base_url: Url, + /// ari credentials + pub(crate) credentials: Credentials, /// stasis application name - pub app_name: String, + pub(crate) app_name: String, /// websocket url for event subscription - pub ws_url: Url, + pub(crate) ws_url: Url, /// policy controlling reconnect behavior - pub reconnect_policy: ReconnectPolicy, + pub(crate) reconnect_policy: ReconnectPolicy, /// transport mode for rest communication - pub transport_mode: TransportMode, + pub(crate) transport_mode: TransportMode, +} + +impl std::fmt::Debug for AriConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AriConfig") + .field("base_url", &self.base_url) + .field("credentials", &self.credentials) + .field("app_name", &self.app_name) + .field("ws_url", &"[redacted]") + .field("reconnect_policy", &self.reconnect_policy) + .field("transport_mode", &self.transport_mode) + .finish() + } +} + +impl AriConfig { + /// http base url for rest requests + pub fn base_url(&self) -> &Url { + &self.base_url + } + + /// ari credentials + pub fn credentials(&self) -> &Credentials { + &self.credentials + } + + /// stasis application name + pub fn app_name(&self) -> &str { + &self.app_name + } + + /// websocket url for event subscription (internal only — contains credentials) + pub(crate) fn ws_url(&self) -> &Url { + &self.ws_url + } + + /// policy controlling reconnect behavior + pub fn reconnect_policy(&self) -> &ReconnectPolicy { + &self.reconnect_policy + } + + /// transport mode for rest communication + pub fn transport_mode(&self) -> TransportMode { + self.transport_mode + } } /// builder for constructing an [`AriConfig`] with validation @@ -118,7 +162,7 @@ impl AriConfigBuilder { /// build the config, constructing base and websocket URLs /// - /// fails if app_name or username is empty, or URLs cannot be parsed + /// fails if app_name, username, or password is empty, or URLs cannot be parsed pub fn build(self) -> Result { if self.app_name.is_empty() { return Err(AriError::InvalidUrl( @@ -130,6 +174,11 @@ impl AriConfigBuilder { "username must not be empty".to_owned(), )); } + if self.password.is_empty() { + return Err(AriError::InvalidUrl( + "password must not be empty".to_owned(), + )); + } let http_scheme = if self.secure { "https" } else { "http" }; let ws_scheme = if self.secure { "wss" } else { "ws" }; @@ -145,10 +194,11 @@ impl AriConfigBuilder { ); let ws_url = Url::parse(&ws_url_str).map_err(|e| AriError::InvalidUrl(e.to_string()))?; + let credentials = Credentials::new(self.username, self.password); + Ok(AriConfig { base_url, - username: self.username, - password: self.password, + credentials, app_name: self.app_name, ws_url, reconnect_policy: self.reconnect_policy, diff --git a/crates/asterisk-rs-ari/src/lib.rs b/crates/asterisk-rs-ari/src/lib.rs index be7b9ee..1ea16e3 100644 --- a/crates/asterisk-rs-ari/src/lib.rs +++ b/crates/asterisk-rs-ari/src/lib.rs @@ -14,6 +14,7 @@ pub mod server; pub(crate) mod transport; pub(crate) mod util; pub mod websocket; +pub(crate) mod ws_proto; pub(crate) mod ws_transport; pub use client::AriClient; diff --git a/crates/asterisk-rs-ari/src/media.rs b/crates/asterisk-rs-ari/src/media.rs index b70208b..3b056e3 100644 --- a/crates/asterisk-rs-ari/src/media.rs +++ b/crates/asterisk-rs-ari/src/media.rs @@ -324,6 +324,14 @@ impl Drop for MediaChannel { } } +/// whether an event is a critical control event that must not be dropped +fn is_critical_event(event: &MediaEvent) -> bool { + matches!( + event, + MediaEvent::MediaStart { .. } | MediaEvent::MediaBufferingCompleted { .. } + ) +} + /// background task that bridges a websocket stream into typed channels. /// /// generic over the stream type so it works for both outbound @@ -348,15 +356,21 @@ async fn media_loop( Some(Ok(Message::Text(text))) => { match serde_json::from_str::(&text) { Ok(event) => { - match event_tx.try_send(event) { - Ok(()) => {} - Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { - tracing::warn!("media event channel full, dropping event"); - } - Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { - // receiver dropped + if is_critical_event(&event) { + // critical control events must not be dropped + if event_tx.send(event).await.is_err() { return; } + } else { + match event_tx.try_send(event) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + tracing::warn!("media event channel full, dropping event"); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + return; + } + } } } Err(e) => { diff --git a/crates/asterisk-rs-ari/src/resources/asterisk.rs b/crates/asterisk-rs-ari/src/resources/asterisk.rs index c2fd9af..e9ee856 100644 --- a/crates/asterisk-rs-ari/src/resources/asterisk.rs +++ b/crates/asterisk-rs-ari/src/resources/asterisk.rs @@ -81,28 +81,28 @@ pub async fn list_modules(client: &AriClient) -> Result> { /// get details for a specific module pub async fn get_module(client: &AriClient, module_name: &str) -> Result { client - .get(&format!("/asterisk/modules/{module_name}")) + .get(&format!("/asterisk/modules/{}", url_encode(module_name))) .await } /// load a module pub async fn load_module(client: &AriClient, module_name: &str) -> Result<()> { client - .post_empty(&format!("/asterisk/modules/{module_name}")) + .post_empty(&format!("/asterisk/modules/{}", url_encode(module_name))) .await } /// unload a module pub async fn unload_module(client: &AriClient, module_name: &str) -> Result<()> { client - .delete(&format!("/asterisk/modules/{module_name}")) + .delete(&format!("/asterisk/modules/{}", url_encode(module_name))) .await } /// reload a module pub async fn reload_module(client: &AriClient, module_name: &str) -> Result<()> { client - .put_empty(&format!("/asterisk/modules/{module_name}")) + .put_empty(&format!("/asterisk/modules/{}", url_encode(module_name))) .await } diff --git a/crates/asterisk-rs-ari/src/resources/bridge.rs b/crates/asterisk-rs-ari/src/resources/bridge.rs index 6db0278..e73ed0b 100644 --- a/crates/asterisk-rs-ari/src/resources/bridge.rs +++ b/crates/asterisk-rs-ari/src/resources/bridge.rs @@ -28,7 +28,7 @@ impl BridgeHandle { self.client .post_empty(&format!( "/bridges/{}/addChannel?channel={}", - self.id, + url_encode(&self.id), url_encode(channel_id) )) .await @@ -39,7 +39,7 @@ impl BridgeHandle { self.client .post_empty(&format!( "/bridges/{}/removeChannel?channel={}", - self.id, + url_encode(&self.id), url_encode(channel_id) )) .await @@ -49,7 +49,7 @@ impl BridgeHandle { pub async fn play(&self, media: &str) -> Result { self.client .post( - &format!("/bridges/{}/play", self.id), + &format!("/bridges/{}/play", url_encode(&self.id)), &serde_json::json!({"media": media}), ) .await @@ -59,7 +59,7 @@ impl BridgeHandle { pub async fn record(&self, name: &str, format: &str) -> Result { self.client .post( - &format!("/bridges/{}/record", self.id), + &format!("/bridges/{}/record", url_encode(&self.id)), &serde_json::json!({"name": name, "format": format}), ) .await @@ -67,14 +67,16 @@ impl BridgeHandle { /// destroy this bridge pub async fn destroy(&self) -> Result<()> { - self.client.delete(&format!("/bridges/{}", self.id)).await + self.client + .delete(&format!("/bridges/{}", url_encode(&self.id))) + .await } /// start music on hold for the bridge pub async fn start_moh(&self, moh_class: Option<&str>) -> Result<()> { let path = match moh_class { - Some(c) => format!("/bridges/{}/moh?mohClass={}", self.id, url_encode(c)), - None => format!("/bridges/{}/moh", self.id), + Some(c) => format!("/bridges/{}/moh?mohClass={}", url_encode(&self.id), url_encode(c)), + None => format!("/bridges/{}/moh", url_encode(&self.id)), }; self.client.post_empty(&path).await } @@ -82,7 +84,7 @@ impl BridgeHandle { /// stop music on hold for the bridge pub async fn stop_moh(&self) -> Result<()> { self.client - .delete(&format!("/bridges/{}/moh", self.id)) + .delete(&format!("/bridges/{}/moh", url_encode(&self.id))) .await } @@ -90,7 +92,7 @@ impl BridgeHandle { pub async fn play_with_id(&self, playback_id: &str, media: &str) -> Result { self.client .post( - &format!("/bridges/{}/play/{}", self.id, playback_id), + &format!("/bridges/{}/play/{}", url_encode(&self.id), url_encode(playback_id)), &serde_json::json!({"media": media}), ) .await @@ -101,7 +103,7 @@ impl BridgeHandle { self.client .post_empty(&format!( "/bridges/{}/videoSource/{}", - self.id, + url_encode(&self.id), url_encode(channel_id) )) .await @@ -110,7 +112,7 @@ impl BridgeHandle { /// clear the video source for the bridge pub async fn clear_video_source(&self) -> Result<()> { self.client - .delete(&format!("/bridges/{}/videoSource", self.id)) + .delete(&format!("/bridges/{}/videoSource", url_encode(&self.id))) .await } } diff --git a/crates/asterisk-rs-ari/src/resources/channel.rs b/crates/asterisk-rs-ari/src/resources/channel.rs index 2df0635..9857636 100644 --- a/crates/asterisk-rs-ari/src/resources/channel.rs +++ b/crates/asterisk-rs-ari/src/resources/channel.rs @@ -147,15 +147,15 @@ impl ChannelHandle { /// answer the channel pub async fn answer(&self) -> Result<()> { self.client - .post_empty(&format!("/channels/{}/answer", self.id)) + .post_empty(&format!("/channels/{}/answer", url_encode(&self.id))) .await } /// hang up the channel with an optional reason pub async fn hangup(&self, reason: Option<&str>) -> Result<()> { let path = match reason { - Some(r) => format!("/channels/{}?reason={}", self.id, url_encode(r)), - None => format!("/channels/{}", self.id), + Some(r) => format!("/channels/{}?reason={}", url_encode(&self.id), url_encode(r)), + None => format!("/channels/{}", url_encode(&self.id)), }; self.client.delete(&path).await } @@ -164,7 +164,7 @@ impl ChannelHandle { pub async fn play(&self, media: &str) -> Result { self.client .post( - &format!("/channels/{}/play", self.id), + &format!("/channels/{}/play", url_encode(&self.id)), &serde_json::json!({"media": media}), ) .await @@ -174,7 +174,7 @@ impl ChannelHandle { pub async fn record(&self, name: &str, format: &str) -> Result { self.client .post( - &format!("/channels/{}/record", self.id), + &format!("/channels/{}/record", url_encode(&self.id)), &serde_json::json!({"name": name, "format": format}), ) .await @@ -183,8 +183,8 @@ impl ChannelHandle { /// mute the channel, optionally specifying direction (both, in, out) pub async fn mute(&self, direction: Option<&str>) -> Result<()> { let path = match direction { - Some(d) => format!("/channels/{}/mute?direction={}", self.id, url_encode(d)), - None => format!("/channels/{}/mute", self.id), + Some(d) => format!("/channels/{}/mute?direction={}", url_encode(&self.id), url_encode(d)), + None => format!("/channels/{}/mute", url_encode(&self.id)), }; self.client.post_empty(&path).await } @@ -192,8 +192,8 @@ impl ChannelHandle { /// unmute the channel, optionally specifying direction pub async fn unmute(&self, direction: Option<&str>) -> Result<()> { let path = match direction { - Some(d) => format!("/channels/{}/mute?direction={}", self.id, url_encode(d)), - None => format!("/channels/{}/mute", self.id), + Some(d) => format!("/channels/{}/mute?direction={}", url_encode(&self.id), url_encode(d)), + None => format!("/channels/{}/mute", url_encode(&self.id)), }; self.client.delete(&path).await } @@ -201,14 +201,14 @@ impl ChannelHandle { /// place the channel on hold pub async fn hold(&self) -> Result<()> { self.client - .post_empty(&format!("/channels/{}/hold", self.id)) + .post_empty(&format!("/channels/{}/hold", url_encode(&self.id))) .await } /// remove the channel from hold pub async fn unhold(&self) -> Result<()> { self.client - .delete(&format!("/channels/{}/hold", self.id)) + .delete(&format!("/channels/{}/hold", url_encode(&self.id))) .await } @@ -217,7 +217,7 @@ impl ChannelHandle { self.client .post_empty(&format!( "/channels/{}/dtmf?dtmf={}", - self.id, + url_encode(&self.id), url_encode(dtmf) )) .await @@ -228,7 +228,7 @@ impl ChannelHandle { self.client .get(&format!( "/channels/{}/variable?variable={}", - self.id, + url_encode(&self.id), url_encode(name) )) .await @@ -239,7 +239,7 @@ impl ChannelHandle { self.client .post_empty(&format!( "/channels/{}/variable?variable={}&value={}", - self.id, + url_encode(&self.id), url_encode(name), url_encode(value) )) @@ -253,7 +253,7 @@ impl ChannelHandle { extension: Option<&str>, priority: Option, ) -> Result<()> { - let mut path = format!("/channels/{}/continue", self.id); + let mut path = format!("/channels/{}/continue", url_encode(&self.id)); let mut params = Vec::new(); if let Some(c) = context { params.push(format!("context={}", url_encode(c))); @@ -288,7 +288,7 @@ impl ChannelHandle { let query = params.join("&"); self.client .post( - &format!("/channels/{}/snoop?{}", self.id, query), + &format!("/channels/{}/snoop?{}", url_encode(&self.id), query), &serde_json::json!({}), ) .await @@ -299,7 +299,7 @@ impl ChannelHandle { self.client .post_empty(&format!( "/channels/{}/redirect?context={}&extension={}&priority={}", - self.id, + url_encode(&self.id), url_encode(context), url_encode(extension), priority @@ -310,28 +310,28 @@ impl ChannelHandle { /// start ringing on the channel pub async fn ring(&self) -> Result<()> { self.client - .post_empty(&format!("/channels/{}/ring", self.id)) + .post_empty(&format!("/channels/{}/ring", url_encode(&self.id))) .await } /// stop ringing on the channel pub async fn ring_stop(&self) -> Result<()> { self.client - .delete(&format!("/channels/{}/ring", self.id)) + .delete(&format!("/channels/{}/ring", url_encode(&self.id))) .await } /// start silence on the channel pub async fn start_silence(&self) -> Result<()> { self.client - .post_empty(&format!("/channels/{}/silence", self.id)) + .post_empty(&format!("/channels/{}/silence", url_encode(&self.id))) .await } /// stop silence on the channel pub async fn stop_silence(&self) -> Result<()> { self.client - .delete(&format!("/channels/{}/silence", self.id)) + .delete(&format!("/channels/{}/silence", url_encode(&self.id))) .await } @@ -339,7 +339,7 @@ impl ChannelHandle { pub async fn play_with_id(&self, playback_id: &str, media: &str) -> Result { self.client .post( - &format!("/channels/{}/play/{}", self.id, playback_id), + &format!("/channels/{}/play/{}", url_encode(&self.id), url_encode(playback_id)), &serde_json::json!({"media": media}), ) .await @@ -360,14 +360,14 @@ impl ChannelHandle { format!("?{}", params.join("&")) }; self.client - .post_empty(&format!("/channels/{}/dial{}", self.id, query)) + .post_empty(&format!("/channels/{}/dial{}", url_encode(&self.id), query)) .await } /// get rtp statistics for the channel pub async fn rtp_statistics(&self) -> Result { self.client - .get(&format!("/channels/{}/rtp_statistics", self.id)) + .get(&format!("/channels/{}/rtp_statistics", url_encode(&self.id))) .await } diff --git a/crates/asterisk-rs-ari/src/resources/device_state.rs b/crates/asterisk-rs-ari/src/resources/device_state.rs index d28dbdf..36d1964 100644 --- a/crates/asterisk-rs-ari/src/resources/device_state.rs +++ b/crates/asterisk-rs-ari/src/resources/device_state.rs @@ -1,6 +1,6 @@ //! device state operations. -use crate::client::AriClient; +use crate::client::{url_encode, AriClient}; use crate::error::Result; /// ari device state representation @@ -17,7 +17,9 @@ pub async fn list(client: &AriClient) -> Result> { /// get a specific device state pub async fn get(client: &AriClient, name: &str) -> Result { - client.get(&format!("/deviceStates/{name}")).await + client + .get(&format!("/deviceStates/{}", url_encode(name))) + .await } /// update a device state @@ -26,11 +28,17 @@ pub async fn get(client: &AriClient, name: &str) -> Result { /// also accepts POST for compatibility pub async fn update(client: &AriClient, name: &str, state: &str) -> Result<()> { client - .post_empty(&format!("/deviceStates/{name}?deviceState={state}")) + .post_empty(&format!( + "/deviceStates/{}?deviceState={}", + url_encode(name), + url_encode(state) + )) .await } /// delete a device state pub async fn delete(client: &AriClient, name: &str) -> Result<()> { - client.delete(&format!("/deviceStates/{name}")).await + client + .delete(&format!("/deviceStates/{}", url_encode(name))) + .await } diff --git a/crates/asterisk-rs-ari/src/resources/mailbox.rs b/crates/asterisk-rs-ari/src/resources/mailbox.rs index 8400f3b..cceba50 100644 --- a/crates/asterisk-rs-ari/src/resources/mailbox.rs +++ b/crates/asterisk-rs-ari/src/resources/mailbox.rs @@ -1,6 +1,6 @@ //! mailbox operations. -use crate::client::AriClient; +use crate::client::{url_encode, AriClient}; use crate::error::Result; /// ari mailbox representation @@ -18,7 +18,9 @@ pub async fn list(client: &AriClient) -> Result> { /// get a specific mailbox pub async fn get(client: &AriClient, name: &str) -> Result { - client.get(&format!("/mailboxes/{name}")).await + client + .get(&format!("/mailboxes/{}", url_encode(name))) + .await } /// update a mailbox message count @@ -33,12 +35,15 @@ pub async fn update( ) -> Result<()> { client .post_empty(&format!( - "/mailboxes/{name}?oldMessages={old_messages}&newMessages={new_messages}" + "/mailboxes/{}?oldMessages={old_messages}&newMessages={new_messages}", + url_encode(name) )) .await } /// delete a mailbox pub async fn delete(client: &AriClient, name: &str) -> Result<()> { - client.delete(&format!("/mailboxes/{name}")).await + client + .delete(&format!("/mailboxes/{}", url_encode(name))) + .await } diff --git a/crates/asterisk-rs-ari/src/resources/playback.rs b/crates/asterisk-rs-ari/src/resources/playback.rs index 48ef3ee..b1033f1 100644 --- a/crates/asterisk-rs-ari/src/resources/playback.rs +++ b/crates/asterisk-rs-ari/src/resources/playback.rs @@ -1,6 +1,6 @@ //! playback control operations. -use crate::client::AriClient; +use crate::client::{url_encode, AriClient}; use crate::error::Result; use crate::event::Playback; @@ -28,18 +28,23 @@ impl PlaybackHandle { self.client .post_empty(&format!( "/playbacks/{}/control?operation={}", - self.id, operation + url_encode(&self.id), + url_encode(operation) )) .await } /// stop the playback pub async fn stop(&self) -> Result<()> { - self.client.delete(&format!("/playbacks/{}", self.id)).await + self.client + .delete(&format!("/playbacks/{}", url_encode(&self.id))) + .await } /// get current playback state pub async fn get(&self) -> Result { - self.client.get(&format!("/playbacks/{}", self.id)).await + self.client + .get(&format!("/playbacks/{}", url_encode(&self.id))) + .await } } diff --git a/crates/asterisk-rs-ari/src/resources/recording.rs b/crates/asterisk-rs-ari/src/resources/recording.rs index 3e856fb..59de483 100644 --- a/crates/asterisk-rs-ari/src/resources/recording.rs +++ b/crates/asterisk-rs-ari/src/resources/recording.rs @@ -1,6 +1,6 @@ //! recording control operations — live and stored. -use crate::client::AriClient; +use crate::client::{url_encode, AriClient}; use crate::error::Result; use crate::event::LiveRecording; @@ -33,42 +33,48 @@ impl RecordingHandle { /// stop the live recording pub async fn stop(&self) -> Result<()> { self.client - .post_empty(&format!("/recordings/live/{}/stop", self.name)) + .post_empty(&format!("/recordings/live/{}/stop", url_encode(&self.name))) .await } /// pause the live recording pub async fn pause(&self) -> Result<()> { self.client - .post_empty(&format!("/recordings/live/{}/pause", self.name)) + .post_empty(&format!( + "/recordings/live/{}/pause", + url_encode(&self.name) + )) .await } /// unpause the live recording pub async fn unpause(&self) -> Result<()> { self.client - .delete(&format!("/recordings/live/{}/pause", self.name)) + .delete(&format!( + "/recordings/live/{}/pause", + url_encode(&self.name) + )) .await } /// mute the live recording pub async fn mute(&self) -> Result<()> { self.client - .post_empty(&format!("/recordings/live/{}/mute", self.name)) + .post_empty(&format!("/recordings/live/{}/mute", url_encode(&self.name))) .await } /// unmute the live recording pub async fn unmute(&self) -> Result<()> { self.client - .delete(&format!("/recordings/live/{}/mute", self.name)) + .delete(&format!("/recordings/live/{}/mute", url_encode(&self.name))) .await } /// get current live recording state pub async fn get(&self) -> Result { self.client - .get(&format!("/recordings/live/{}", self.name)) + .get(&format!("/recordings/live/{}", url_encode(&self.name))) .await } } diff --git a/crates/asterisk-rs-ari/src/resources/sound.rs b/crates/asterisk-rs-ari/src/resources/sound.rs index 879c724..3eeadf6 100644 --- a/crates/asterisk-rs-ari/src/resources/sound.rs +++ b/crates/asterisk-rs-ari/src/resources/sound.rs @@ -1,6 +1,6 @@ //! sound query operations (read-only). -use crate::client::AriClient; +use crate::client::{url_encode, AriClient}; use crate::error::Result; /// format information for a sound @@ -27,5 +27,7 @@ pub async fn list(client: &AriClient) -> Result> { /// get a specific sound pub async fn get(client: &AriClient, sound_id: &str) -> Result { - client.get(&format!("/sounds/{sound_id}")).await + client + .get(&format!("/sounds/{}", url_encode(sound_id))) + .await } diff --git a/crates/asterisk-rs-ari/src/server.rs b/crates/asterisk-rs-ari/src/server.rs index 9d7d502..17106e8 100644 --- a/crates/asterisk-rs-ari/src/server.rs +++ b/crates/asterisk-rs-ari/src/server.rs @@ -23,6 +23,7 @@ use asterisk_rs_core::event::{EventBus, EventSubscription, FilteredSubscription} use crate::error::{AriError, Result}; use crate::event::{AriEvent, AriMessage}; +use crate::ws_proto::WsRestRequest; /// per-session request id counter — only needs uniqueness within a session, /// but a global counter keeps ids distinct across sessions for tracing @@ -36,22 +37,6 @@ fn next_request_id() -> String { /// default timeout for REST-over-WS requests const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); -// --- REST-over-WS protocol types (duplicated from ws_transport to keep modules decoupled) --- - -/// REST request envelope sent over websocket -#[derive(serde::Serialize)] -struct WsRestRequest { - #[serde(rename = "type")] - type_field: &'static str, - request_id: String, - method: String, - uri: String, - #[serde(skip_serializing_if = "Option::is_none")] - content_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - message_body: Option, -} - /// internal command sent from request methods to the session background task struct SessionCommand { request_id: String, @@ -474,11 +459,8 @@ fn route_message( } } Err(e) => { - tracing::warn!( - error = %e, - payload = %text, - "failed to deserialize ARI message in session" - ); + tracing::warn!(error = %e, "failed to deserialize ARI message in session"); + tracing::trace!(payload = %text, "raw ARI session message payload"); } } } diff --git a/crates/asterisk-rs-ari/src/transport.rs b/crates/asterisk-rs-ari/src/transport.rs index 7cf6ff0..3eea85f 100644 --- a/crates/asterisk-rs-ari/src/transport.rs +++ b/crates/asterisk-rs-ari/src/transport.rs @@ -6,6 +6,7 @@ use crate::error::{AriError, Result}; use crate::event::AriMessage; use crate::websocket::WsEventListener; use crate::ws_transport::WsTransport; +use asterisk_rs_core::auth::Credentials; use asterisk_rs_core::config::ReconnectPolicy; use asterisk_rs_core::event::EventBus; @@ -48,16 +49,14 @@ impl TransportInner { pub(crate) struct HttpTransport { client: reqwest::Client, base_url: String, - username: String, - password: String, + credentials: Credentials, ws_listener: WsEventListener, } impl HttpTransport { pub fn new( base_url: &str, - username: String, - password: String, + credentials: Credentials, ws_url: String, event_bus: EventBus, reconnect: ReconnectPolicy, @@ -73,8 +72,7 @@ impl HttpTransport { Ok(Self { client, base_url: base_url.trim_end_matches('/').to_owned(), - username, - password, + credentials, ws_listener, }) } @@ -91,7 +89,7 @@ impl HttpTransport { let mut req = self .client .request(http_method, &url) - .basic_auth(&self.username, Some(&self.password)); + .basic_auth(self.credentials.username(), Some(self.credentials.secret())); if let Some(json_body) = body { req = req diff --git a/crates/asterisk-rs-ari/src/websocket.rs b/crates/asterisk-rs-ari/src/websocket.rs index 40508d6..4522eca 100644 --- a/crates/asterisk-rs-ari/src/websocket.rs +++ b/crates/asterisk-rs-ari/src/websocket.rs @@ -4,7 +4,7 @@ use std::time::Duration; use asterisk_rs_core::config::ReconnectPolicy; use asterisk_rs_core::event::EventBus; -use futures_util::StreamExt; +use futures_util::{SinkExt, StreamExt}; use tokio::sync::watch; use crate::event::AriMessage; @@ -124,7 +124,7 @@ async fn read_messages( event_bus: &EventBus, shutdown_rx: &mut watch::Receiver, ) -> std::result::Result<(), bool> { - let (_write, mut read) = ws_stream.split(); + let (mut write, mut read) = ws_stream.split(); loop { tokio::select! { @@ -145,6 +145,9 @@ async fn read_messages( } _ = shutdown_rx.changed() => { if *shutdown_rx.borrow() { + if let Err(e) = write.send(tokio_tungstenite::tungstenite::Message::Close(None)).await { + tracing::debug!(error = %e, "failed to send websocket close frame"); + } return Err(true); } } @@ -166,7 +169,8 @@ fn handle_message( event_bus.publish(event); } Err(e) => { - tracing::warn!(error = %e, payload = %text, "failed to deserialize ARI event"); + tracing::warn!(error = %e, "failed to deserialize ARI event"); + tracing::trace!(payload = %text, "raw ARI event payload"); } }, Message::Close(_) => { diff --git a/crates/asterisk-rs-ari/src/ws_proto.rs b/crates/asterisk-rs-ari/src/ws_proto.rs new file mode 100644 index 0000000..f35fe19 --- /dev/null +++ b/crates/asterisk-rs-ari/src/ws_proto.rs @@ -0,0 +1,15 @@ +//! shared protocol types for REST-over-WebSocket communication. + +/// REST request envelope sent over websocket +#[derive(serde::Serialize)] +pub(crate) struct WsRestRequest { + #[serde(rename = "type")] + pub type_field: &'static str, + pub request_id: String, + pub method: String, + pub uri: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub content_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message_body: Option, +} diff --git a/crates/asterisk-rs-ari/src/ws_transport.rs b/crates/asterisk-rs-ari/src/ws_transport.rs index 8fcca74..7bbf11c 100644 --- a/crates/asterisk-rs-ari/src/ws_transport.rs +++ b/crates/asterisk-rs-ari/src/ws_transport.rs @@ -20,6 +20,7 @@ use crate::error::{AriError, Result}; use crate::event::{AriEvent, AriMessage}; use crate::transport::TransportResponse; use crate::util::redact_url; +use crate::ws_proto::WsRestRequest; static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(1); @@ -28,20 +29,6 @@ fn next_request_id() -> String { format!("wsreq-{id}") } -/// REST request envelope sent over websocket -#[derive(serde::Serialize)] -struct WsRestRequest { - #[serde(rename = "type")] - type_field: &'static str, - request_id: String, - method: String, - uri: String, - #[serde(skip_serializing_if = "Option::is_none")] - content_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - message_body: Option, -} - /// internal command sent from request() to the background task struct RestCommand { request_id: String, @@ -324,11 +311,8 @@ fn route_text_message( } } Err(e) => { - tracing::warn!( - error = %e, - payload = %text, - "failed to deserialize ARI message" - ); + tracing::warn!(error = %e, "failed to deserialize ARI message"); + tracing::trace!(payload = %text, "raw ARI message payload"); } } } diff --git a/crates/asterisk-rs-core/src/config.rs b/crates/asterisk-rs-core/src/config.rs index 08afdf8..c0cacc2 100644 --- a/crates/asterisk-rs-core/src/config.rs +++ b/crates/asterisk-rs-core/src/config.rs @@ -100,15 +100,18 @@ impl std::fmt::Display for ConnectionState { } } -/// jitter factor between 0.5 and 1.0 using system time for entropy +/// jitter factor in [0.5, 1.5) using OS-seeded entropy via RandomState fn jitter_factor() -> f64 { - // mix system time nanos with thread id for per-instance variation + use std::hash::{BuildHasher, Hasher}; + let nanos = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .subsec_nanos(); - let thread_id = std::thread::current().id(); - let hash = nanos ^ (format!("{thread_id:?}").len() as u32).wrapping_mul(0x9E3779B9); - let normalized = (hash as f64) / (u32::MAX as f64); - 0.5 + 0.5 * normalized + .map(|d| d.subsec_nanos()) + .unwrap_or(0); + + let mut hasher = std::collections::hash_map::RandomState::new().build_hasher(); + hasher.write_u32(nanos); + let hash = hasher.finish(); + + 0.5 + (hash % 1000) as f64 / 1000.0 } diff --git a/crates/asterisk-rs/src/pbx.rs b/crates/asterisk-rs/src/pbx.rs index d0df7dc..088eb1f 100644 --- a/crates/asterisk-rs/src/pbx.rs +++ b/crates/asterisk-rs/src/pbx.rs @@ -41,7 +41,14 @@ impl Call { /// wait for this channel to reach "Up" state (answered) /// /// listens for Newstate events with channel_state_desc "Up". - /// returns Err if the channel hangs up before answering + /// returns Err if the channel hangs up before answering. + /// + /// the inner subscription is protected by a tokio Mutex so that + /// [`Call`] can remain `Clone`. if multiple clones call this + /// concurrently, only one acquires the lock at a time — the + /// winner consumes events while the others block on the mutex. + /// callers that need concurrent waiting should create separate + /// subscriptions via [`Pbx::client`]. pub async fn wait_for_answer(&self, timeout: Duration) -> Result<(), PbxError> { let uid = self.unique_id.clone(); diff --git a/deny.toml b/deny.toml index 1aa3181..986ffd6 100644 --- a/deny.toml +++ b/deny.toml @@ -18,11 +18,15 @@ allow = [ [bans] multiple-versions = "warn" -wildcards = "allow" +wildcards = "deny" highlight = "all" +# test crate uses workspace wildcards and is never published +[[bans.skip-tree]] +crate = "asterisk-rs-tests" + [sources] -unknown-registry = "warn" -unknown-git = "warn" +unknown-registry = "deny" +unknown-git = "deny" allow-registry = ["https://github.com/rust-lang/crates.io-index"] allow-git = [] diff --git a/tests/src/helpers.rs b/tests/src/helpers.rs index 2db479d..1718586 100644 --- a/tests/src/helpers.rs +++ b/tests/src/helpers.rs @@ -3,6 +3,16 @@ pub fn init_tracing() { let _ = tracing_subscriber::fmt::try_init(); } +/// re-raise panics from spawned server tasks so test failures point at the +/// actual panic location instead of producing misleading messages +pub fn assert_server_ok(result: Result<(), tokio::task::JoinError>) { + if let Err(e) = result { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + } +} + /// read test config from environment or use defaults pub fn ami_host() -> String { std::env::var("ASTERISK_AMI_HOST").unwrap_or_else(|_| "127.0.0.1".into()) diff --git a/tests/src/mock/ari_server.rs b/tests/src/mock/ari_server.rs index 62d80d7..d880c49 100644 --- a/tests/src/mock/ari_server.rs +++ b/tests/src/mock/ari_server.rs @@ -27,7 +27,7 @@ struct ServerState { routes: HashMap<(String, String), MockRoute>, event_tx: broadcast::Sender, ws_clients: AtomicUsize, - ws_connected: Arc, + ws_connected: Notify, } /// mock ARI server binding HTTP and WebSocket on one port @@ -66,8 +66,14 @@ impl MockAriServer { /// wait until at least one websocket client has connected pub async fn wait_for_ws_client(&self) { - while self.state.ws_clients.load(Ordering::Acquire) == 0 { - self.state.ws_connected.notified().await; + loop { + // register the notification future before the load to avoid a + // race where the signal fires between the check and the await + let notified = self.state.ws_connected.notified(); + if self.state.ws_clients.load(Ordering::Acquire) > 0 { + return; + } + notified.await; } } } @@ -118,7 +124,7 @@ impl MockAriServerBuilder { routes: self.routes, event_tx: event_tx.clone(), ws_clients: AtomicUsize::new(0), - ws_connected: Arc::new(Notify::new()), + ws_connected: Notify::new(), }); let task = tokio::spawn(accept_loop(listener, Arc::clone(&state), shutdown_rx)); diff --git a/tests/tests/mock_tests/ami.rs b/tests/tests/mock_tests/ami.rs index e409873..7375d3e 100644 --- a/tests/tests/mock_tests/ami.rs +++ b/tests/tests/mock_tests/ami.rs @@ -4,19 +4,11 @@ use std::time::Duration; use asterisk_rs_ami::client::AmiClient; use asterisk_rs_core::config::{ConnectionState, ReconnectPolicy}; -use asterisk_rs_tests::helpers::init_tracing; +use asterisk_rs_tests::helpers::{assert_server_ok, init_tracing}; use asterisk_rs_tests::mock::ami_server::{ get_header, handle_login, handle_login_reject, MockAmiServer, }; -fn assert_server_ok(result: Result<(), tokio::task::JoinError>) { - if let Err(e) = result { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } - } -} - #[tokio::test] async fn connect_and_login() { init_tracing(); @@ -88,6 +80,7 @@ async fn login_rejected() { .host("127.0.0.1") .port(port) .credentials("admin", "wrong") + .require_challenge(false) .reconnect(ReconnectPolicy::none()) .timeout(Duration::from_secs(5)) .build() @@ -487,6 +480,7 @@ async fn plaintext_fallback() { .host("127.0.0.1") .port(port) .credentials("admin", "mysecret") + .require_challenge(false) .reconnect(ReconnectPolicy::none()) .timeout(Duration::from_secs(5)) .build() @@ -953,7 +947,7 @@ async fn reconnect_on_disconnect() { assert_eq!(response.get("Ping"), Some("Pong")); client.disconnect().await.expect("disconnect"); - let _ = mock_handle.await; + assert_server_ok(mock_handle.await); } #[tokio::test] @@ -2244,6 +2238,7 @@ async fn reconnect_login_failure_gives_up() { .host("127.0.0.1") .port(port) .credentials("admin", "wrong") + .require_challenge(false) .reconnect(ReconnectPolicy::none()) .timeout(Duration::from_secs(5)) .build() @@ -2296,6 +2291,7 @@ async fn reconnect_login_failure_retries_then_gives_up() { .host("127.0.0.1") .port(port) .credentials("admin", "wrong") + .require_challenge(false) .reconnect(ReconnectPolicy::fixed(Duration::from_millis(50)).with_max_retries(2)) .timeout(Duration::from_secs(5)) .build(), diff --git a/tests/tests/mock_tests/ari.rs b/tests/tests/mock_tests/ari.rs index 39faf21..c69a6a3 100644 --- a/tests/tests/mock_tests/ari.rs +++ b/tests/tests/mock_tests/ari.rs @@ -4,7 +4,7 @@ use asterisk_rs_ari::config::AriConfigBuilder; use asterisk_rs_ari::{AriClient, AriError}; use asterisk_rs_core::config::ReconnectPolicy; -use asterisk_rs_tests::helpers::init_tracing; +use asterisk_rs_tests::helpers::{assert_server_ok, init_tracing}; use asterisk_rs_tests::mock::ari_server::MockAriServerBuilder; /// build an ARI client pointed at the mock server @@ -1574,7 +1574,7 @@ async fn mailbox_get() { let body = r#"{"name":"2000@default","old_messages":0,"new_messages":1}"#; let server = MockAriServerBuilder::new() - .route("GET", "/ari/mailboxes/2000@default", 200, body) + .route("GET", "/ari/mailboxes/2000%40default", 200, body) .start() .await; @@ -1597,7 +1597,7 @@ async fn mailbox_update() { let server = MockAriServerBuilder::new() .route( "POST", - "/ari/mailboxes/1000@default?oldMessages=3&newMessages=7", + "/ari/mailboxes/1000%40default?oldMessages=3&newMessages=7", 204, "", ) @@ -1618,7 +1618,7 @@ async fn mailbox_delete() { init_tracing(); let server = MockAriServerBuilder::new() - .route("DELETE", "/ari/mailboxes/1000@default", 204, "") + .route("DELETE", "/ari/mailboxes/1000%40default", 204, "") .start() .await; @@ -1664,7 +1664,7 @@ async fn device_state_get() { let body = r#"{"name":"Stasis:phone2","state":"INUSE"}"#; let server = MockAriServerBuilder::new() - .route("GET", "/ari/deviceStates/Stasis:phone2", 200, body) + .route("GET", "/ari/deviceStates/Stasis%3Aphone2", 200, body) .start() .await; @@ -1687,7 +1687,7 @@ async fn device_state_update() { let server = MockAriServerBuilder::new() .route( "POST", - "/ari/deviceStates/Stasis:phone1?deviceState=INUSE", + "/ari/deviceStates/Stasis%3Aphone1?deviceState=INUSE", 204, "", ) @@ -1708,7 +1708,7 @@ async fn device_state_delete() { init_tracing(); let server = MockAriServerBuilder::new() - .route("DELETE", "/ari/deviceStates/Stasis:phone1", 204, "") + .route("DELETE", "/ari/deviceStates/Stasis%3Aphone1", 204, "") .start() .await; @@ -3181,7 +3181,7 @@ async fn outbound_ws_server_accepts_connection() { .expect("session channel should not be closed"); handle.shutdown(); - let _ = server_task.await; + assert_server_ok(server_task.await); } #[tokio::test] @@ -3261,7 +3261,7 @@ async fn outbound_ws_server_delivers_events_to_session() { ); handle.shutdown(); - let _ = server_task.await; + assert_server_ok(server_task.await); } #[tokio::test] @@ -3291,7 +3291,7 @@ async fn media_channel_connect_and_disconnect() { .expect("media channel connect should succeed"); media.disconnect(); - let _ = server_task.await; + assert_server_ok(server_task.await); } #[tokio::test] @@ -3344,7 +3344,7 @@ async fn media_channel_sends_command() { ); media.disconnect(); - let _ = server_task.await; + assert_server_ok(server_task.await); } #[tokio::test] @@ -3411,7 +3411,7 @@ async fn media_channel_receives_event() { } media.disconnect(); - let _ = server_task.await; + assert_server_ok(server_task.await); } #[tokio::test] @@ -3466,5 +3466,5 @@ async fn media_channel_sends_and_receives_audio() { ); media.disconnect(); - let _ = server_task.await; + assert_server_ok(server_task.await); } diff --git a/tests/tests/unit/agi.rs b/tests/tests/unit/agi.rs index 701716d..bc6feab 100644 --- a/tests/tests/unit/agi.rs +++ b/tests/tests/unit/agi.rs @@ -1094,3 +1094,42 @@ async fn channel_send_command_on_511_response_sets_hung_up() { "expected ChannelHungUp, got {err:?}" ); } + +// --------------------------------------------------------------------------- +// backslash escaping in format_command +// --------------------------------------------------------------------------- + +#[test] +fn command_format_escapes_backslash_in_quoted_arg() { + // arg contains both backslash and quote — both must be escaped + let cmd = command::format_command(command::EXEC, &["a\\\"b"]).expect("valid command"); + assert_eq!(cmd, "EXEC \"a\\\\\\\"b\"\n"); +} + +#[test] +fn command_format_backslash_in_unquoted_arg() { + // backslash in an arg without spaces or quotes passes through unquoted + let cmd = command::format_command(command::SET_VARIABLE, &["path\\to", "val"]) + .expect("valid command"); + assert_eq!(cmd, "SET VARIABLE path\\to val\n"); +} + +// --------------------------------------------------------------------------- +// nested parentheses in response parsing +// --------------------------------------------------------------------------- + +#[test] +fn response_parse_nested_parentheses() { + let resp = AgiResponse::parse("200 result=1 (outer(inner))").expect("should parse"); + assert_eq!(resp.code, 200); + assert_eq!(resp.result, 1); + assert_eq!(resp.data.as_deref(), Some("outer(inner)")); +} + +#[test] +fn response_parse_simple_parentheses_still_works() { + let resp = AgiResponse::parse("200 result=1 (hello)").expect("should parse"); + assert_eq!(resp.code, 200); + assert_eq!(resp.result, 1); + assert_eq!(resp.data.as_deref(), Some("hello")); +} diff --git a/tests/tests/unit/ami_actions.rs b/tests/tests/unit/ami_actions.rs index 3c88eed..44eff37 100644 --- a/tests/tests/unit/ami_actions.rs +++ b/tests/tests/unit/ami_actions.rs @@ -446,10 +446,7 @@ fn challenge_action_headers() { #[test] fn login_action_headers() { - let action = LoginAction { - username: "admin".into(), - secret: "pass".into(), - }; + let action = LoginAction::new("admin", "pass"); let (id, msg) = action.to_message(); assert!(!id.is_empty()); assert_eq!(get_header(&msg, "Action"), Some("Login".into())); @@ -2575,3 +2572,37 @@ fn to_message_output_and_channel_variables_empty() { assert!(msg.output.is_empty()); assert!(msg.channel_variables.is_empty()); } + +#[test] +fn login_action_new_creates_correctly() { + let action = LoginAction::new("admin", "secret123"); + assert_eq!(action.username, "admin"); + assert_eq!(action.secret(), "secret123"); +} + +#[test] +fn login_action_debug_redacts_secret() { + let action = LoginAction::new("admin", "super_secret_password"); + let debug_output = format!("{:?}", action); + assert!( + !debug_output.contains("super_secret_password"), + "debug output must not contain the secret: {debug_output}" + ); + assert!( + debug_output.contains("[REDACTED]"), + "debug output should contain [REDACTED]: {debug_output}" + ); + assert!( + debug_output.contains("admin"), + "debug output should contain username: {debug_output}" + ); +} + +#[test] +fn login_action_new_headers_match() { + let action = LoginAction::new("testuser", "testpass"); + let (_, msg) = action.to_message(); + assert_eq!(get_header(&msg, "Action"), Some("Login".into())); + assert_eq!(get_header(&msg, "Username"), Some("testuser".into())); + assert_eq!(get_header(&msg, "Secret"), Some("testpass".into())); +} diff --git a/tests/tests/unit/ami_codec.rs b/tests/tests/unit/ami_codec.rs index 953621d..d4fcaf9 100644 --- a/tests/tests/unit/ami_codec.rs +++ b/tests/tests/unit/ami_codec.rs @@ -805,3 +805,85 @@ fn empty_message_skipped() { .expect("should skip empty and return real message"); assert_eq!(msg.get("Event"), Some("Real")); } + +#[test] +fn decode_follows_with_space() { + let mut codec = AmiCodec::new(); + let raw = with_banner( + "Response: Follows\r\n\ + ActionID: 1\r\n\ + line one\r\n\ + --END COMMAND--\r\n\ + \r\n", + ); + let mut buf = BytesMut::from(raw.as_str()); + let msg = codec + .decode(&mut buf) + .expect("decode should succeed") + .expect("should produce a message"); + assert_eq!(msg.get("Response"), Some("Follows")); + assert_eq!(msg.output, vec!["line one"]); +} + +#[test] +fn decode_follows_without_space() { + let mut codec = AmiCodec::new(); + let raw = with_banner( + "Response:Follows\r\n\ + ActionID: 1\r\n\ + line one\r\n\ + --END COMMAND--\r\n\ + \r\n", + ); + let mut buf = BytesMut::from(raw.as_str()); + let msg = codec + .decode(&mut buf) + .expect("decode should succeed") + .expect("should produce a message"); + assert_eq!(msg.get("Response"), Some("Follows")); + assert_eq!(msg.output, vec!["line one"]); +} + +#[test] +fn decode_follows_case_insensitive() { + let mut codec = AmiCodec::new(); + let raw = with_banner( + "response: follows\r\n\ + ActionID: 1\r\n\ + output\r\n\ + --END COMMAND--\r\n\ + \r\n", + ); + let mut buf = BytesMut::from(raw.as_str()); + let msg = codec + .decode(&mut buf) + .expect("decode should succeed") + .expect("should produce a message"); + assert_eq!(msg.get("response"), Some("follows")); + assert_eq!(msg.output, vec!["output"]); +} + +#[test] +fn banner_error_truncates_long_content() { + let mut codec = AmiCodec::new(); + let long_banner = format!("{}\r\n", "X".repeat(200)); + let mut buf = BytesMut::from(long_banner.as_str()); + let err = codec + .decode(&mut buf) + .expect_err("should reject non-AMI banner"); + let msg = err.to_string(); + assert!( + msg.contains("expected AMI banner"), + "error should mention banner: {msg}" + ); + // the error message should not contain the full 200-char string + assert!( + !msg.contains(&"X".repeat(200)), + "error should truncate long banner content" + ); + // but should contain the truncated version (64 X's) + assert!( + msg.contains(&"X".repeat(64)), + "error should contain truncated banner: {msg}" + ); +} diff --git a/tests/tests/unit/ami_tracker.rs b/tests/tests/unit/ami_tracker.rs index 2170777..76ccc06 100644 --- a/tests/tests/unit/ami_tracker.rs +++ b/tests/tests/unit/ami_tracker.rs @@ -256,3 +256,14 @@ async fn test_tracker_shutdown_stops_processing() { "no CompletedCall should be produced after shutdown" ); } + +#[tokio::test] +async fn test_tracker_dropped_count_starts_at_zero() { + let bus = EventBus::::new(64); + let sub = bus.subscribe(); + let (tracker, _rx) = CallTracker::new(sub); + + assert_eq!(tracker.dropped_count(), 0); + + tracker.shutdown(); +} diff --git a/tests/tests/unit/ari.rs b/tests/tests/unit/ari.rs index eb8ed71..fd09f09 100644 --- a/tests/tests/unit/ari.rs +++ b/tests/tests/unit/ari.rs @@ -36,12 +36,7 @@ fn build_default_config() { .build() .expect("default config should build"); - assert_eq!(config.base_url.as_str(), "http://127.0.0.1:8088/ari"); - assert!( - config.ws_url.as_str().starts_with("ws://"), - "ws_url should start with ws://, got: {}", - config.ws_url - ); + assert_eq!(config.base_url().as_str(), "http://127.0.0.1:8088/ari"); } #[test] @@ -55,14 +50,9 @@ fn build_with_custom_host_port() { .expect("custom host/port should build"); assert!( - config.base_url.as_str().contains("10.0.0.1:9999"), + config.base_url().as_str().contains("10.0.0.1:9999"), "base_url should contain custom host:port, got: {}", - config.base_url - ); - assert!( - config.ws_url.as_str().contains("10.0.0.1:9999"), - "ws_url should contain custom host:port, got: {}", - config.ws_url + config.base_url() ); } @@ -76,14 +66,9 @@ fn build_secure_uses_https_wss() { .expect("secure config should build"); assert!( - config.base_url.as_str().starts_with("https://"), + config.base_url().as_str().starts_with("https://"), "base_url should use https, got: {}", - config.base_url - ); - assert!( - config.ws_url.as_str().starts_with("wss://"), - "ws_url should use wss, got: {}", - config.ws_url + config.base_url() ); } @@ -127,33 +112,26 @@ fn build_empty_app_name_via_setter_fails() { } #[test] -fn ws_url_contains_app_name() { +fn config_preserves_app_name() { let config = AriConfigBuilder::new("test_app") .username("admin") .password("secret") .build() .expect("config should build"); - assert!( - config.ws_url.as_str().contains("app=test_app"), - "ws_url should contain app=test_app, got: {}", - config.ws_url - ); + assert_eq!(config.app_name(), "test_app"); } #[test] -fn ws_url_contains_credentials() { +fn config_preserves_credentials() { let config = AriConfigBuilder::new("myapp") .username("admin") .password("secret") .build() .expect("config with credentials should build"); - assert!( - config.ws_url.as_str().contains("api_key=admin:secret"), - "ws_url should contain api_key=admin:secret, got: {}", - config.ws_url - ); + assert_eq!(config.credentials().username(), "admin"); + assert_eq!(config.credentials().secret(), "secret"); } #[test] @@ -168,10 +146,10 @@ fn build_with_custom_reconnect_policy() { .expect("config with reconnect policy should build"); assert_eq!( - config.reconnect_policy.initial_delay, + config.reconnect_policy().initial_delay, Duration::from_secs(5) ); - assert_eq!(config.reconnect_policy.max_delay, Duration::from_secs(5)); + assert_eq!(config.reconnect_policy().max_delay, Duration::from_secs(5)); } #[test] @@ -185,13 +163,15 @@ fn config_fields_accessible() { .build() .expect("full config should build"); - assert_eq!(config.app_name, "myapp"); - assert_eq!(config.username, "user1"); - assert_eq!(config.password, "pass1"); - assert_eq!(config.base_url.as_str(), "https://asterisk.local:5080/ari"); - assert!(config.ws_url.as_str().starts_with("wss://")); + assert_eq!(config.app_name(), "myapp"); + assert_eq!(config.credentials().username(), "user1"); + assert_eq!(config.credentials().secret(), "pass1"); + assert_eq!( + config.base_url().as_str(), + "https://asterisk.local:5080/ari" + ); // reconnect_policy is accessible (default) - let _ = &config.reconnect_policy; + let _ = config.reconnect_policy(); } #[test] @@ -219,9 +199,9 @@ fn default_host_is_localhost() { .expect("default config should build"); assert!( - config.base_url.as_str().contains("127.0.0.1"), + config.base_url().as_str().contains("127.0.0.1"), "default host should be 127.0.0.1, got: {}", - config.base_url + config.base_url() ); } @@ -234,9 +214,9 @@ fn default_port_is_8088() { .expect("default config should build"); assert!( - config.base_url.as_str().contains(":8088"), + config.base_url().as_str().contains(":8088"), "default port should be 8088, got: {}", - config.base_url + config.base_url() ); } @@ -2027,7 +2007,7 @@ fn builder_with_transport_mode() { .transport(TransportMode::WebSocket) .build() .expect("should build config"); - assert_eq!(config.transport_mode, TransportMode::WebSocket); + assert_eq!(config.transport_mode(), TransportMode::WebSocket); } // ── external media / originate params tests (migrated from channel.rs) ──── @@ -2464,7 +2444,7 @@ fn config_default_transport_mode_is_http() { .password("secret") .build() .expect("default config should build"); - assert_eq!(config.transport_mode, TransportMode::Http); + assert_eq!(config.transport_mode(), TransportMode::Http); } #[test] @@ -2531,3 +2511,127 @@ fn media_command_get_status_serialization() { let obj = json.as_object().expect("should be an object"); assert_eq!(obj.len(), 1, "GetStatus should only have 'command' key"); } + +// ── url_encode path-safety tests ──────────────────────────────────────────── + +#[test] +fn url_encode_encodes_slash() { + assert_eq!(url_encode("chan/123"), "chan%2F123"); +} + +#[test] +fn url_encode_encodes_question_mark() { + assert_eq!(url_encode("chan?id=1"), "chan%3Fid%3D1"); +} + +#[test] +fn url_encode_encodes_hash() { + assert_eq!(url_encode("chan#frag"), "chan%23frag"); +} + +#[test] +fn url_encode_encodes_percent() { + assert_eq!(url_encode("100%done"), "100%25done"); +} + +#[test] +fn url_encode_preserves_safe_chars() { + // alphanumerics, hyphens, underscores, dots, tildes are safe + assert_eq!(url_encode("chan-123_test.0~ok"), "chan-123_test.0~ok"); +} + +#[test] +fn url_encode_encodes_space() { + assert_eq!(url_encode("my channel"), "my%20channel"); +} + +#[test] +fn url_encode_encodes_ampersand() { + assert_eq!(url_encode("a&b"), "a%26b"); +} + +// ── credential encoding tests ────────────────────────────────────────────── + +#[test] +fn config_build_encodes_special_chars_in_credentials() { + let config = AriConfigBuilder::new("my app") + .username("user&name") + .password("pass=word#1") + .build() + .expect("config with special chars should build"); + + assert_eq!(config.base_url().as_str(), "http://127.0.0.1:8088/ari"); + assert_eq!(config.app_name(), "my app"); +} + +#[test] +fn form_urlencoded_round_trips_special_chars() { + let encoded = url::form_urlencoded::Serializer::new(String::new()) + .append_pair("app", "my app") + .append_pair("api_key", "user&name:pass=word#1") + .finish(); + + // verify special chars are percent-encoded, not raw + assert!( + !encoded.contains(' '), + "spaces must be encoded, got: {encoded}" + ); + assert_eq!( + encoded.matches('&').count(), + 1, + "only the pair separator & should be literal, got: {encoded}" + ); + + // round-trip: parse the query string back and verify values + let pairs: HashMap = + url::form_urlencoded::parse(encoded.as_bytes()) + .map(|(k, v)| (k.into_owned(), v.into_owned())) + .collect(); + assert_eq!(pairs.get("app").unwrap(), "my app"); + assert_eq!(pairs.get("api_key").unwrap(), "user&name:pass=word#1"); +} + +// ── uri slash stripping tests ────────────────────────────────────────────── + +#[test] +fn strip_prefix_removes_single_leading_slash() { + let path = "/channels/12345"; + let result = path.strip_prefix('/').unwrap_or(path); + assert_eq!(result, "channels/12345"); +} + +#[test] +fn strip_prefix_preserves_double_leading_slashes() { + let path = "//channels/12345"; + let result = path.strip_prefix('/').unwrap_or(path); + assert_eq!(result, "/channels/12345", "should only strip one slash"); +} + +#[test] +fn strip_prefix_no_leading_slash_unchanged() { + let path = "channels/12345"; + let result = path.strip_prefix('/').unwrap_or(path); + assert_eq!(result, "channels/12345"); +} + +// ── server builder default bind address ──────────────────────────────────── + +#[test] +fn server_builder_defaults_to_loopback() { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); + rt.block_on(async { + let (server, _handle) = AriServerBuilder::new() + .bind(([127, 0, 0, 1], 0).into()) + .build() + .await + .unwrap(); + let addr = server.local_addr().unwrap(); + assert!( + addr.ip().is_loopback(), + "default should bind to loopback, got: {addr}" + ); + }); +} diff --git a/tests/tests/unit/core_tests.rs b/tests/tests/unit/core_tests.rs index 8eb10a5..3a8745b 100644 --- a/tests/tests/unit/core_tests.rs +++ b/tests/tests/unit/core_tests.rs @@ -1144,13 +1144,11 @@ fn max_retries_returns_zero_after_exhausted() { #[test] fn jitter_stays_in_range() { + // jitter_factor is in [0.5, 1.5), so 10s base -> [5s, 15s) let policy = ReconnectPolicy::exponential(Duration::from_secs(10), Duration::from_secs(60)); let delay = policy.delay_for_attempt(0); assert!(delay >= Duration::from_secs(5), "delay too low: {delay:?}"); - assert!( - delay <= Duration::from_secs(10), - "delay too high: {delay:?}" - ); + assert!(delay < Duration::from_secs(15), "delay too high: {delay:?}"); } #[test] @@ -1280,15 +1278,15 @@ fn zero_duration_policy_no_panic() { fn jitter_produces_values_in_expected_range() { // run multiple times to get some statistical coverage let policy = ReconnectPolicy::exponential(Duration::from_secs(100), Duration::from_secs(1000)); - for _ in 0..20 { + for _ in 0..100 { let delay = policy.delay_for_attempt(0); - // base = 100s, jitter range = [0.5 * 100, 1.0 * 100] = [50, 100] + // base = 100s, jitter range = [0.5 * 100, 1.5 * 100) = [50, 150) assert!( delay >= Duration::from_secs(50), "jitter delay too low: {delay:?}" ); assert!( - delay <= Duration::from_secs(100), + delay < Duration::from_secs(150), "jitter delay too high: {delay:?}" ); } @@ -1623,3 +1621,22 @@ fn reconnect_policy_attempt_at_max_retries_returns_zero() { // attempt below max_retries returns non-zero assert_ne!(policy.delay_for_attempt(2), Duration::ZERO); } + +// ============================================================================= +// jitter entropy tests +// ============================================================================= + +#[test] +fn jitter_produces_varying_delays() { + let policy = ReconnectPolicy::exponential(Duration::from_secs(1), Duration::from_secs(60)); + let delays: Vec = (0..20).map(|_| policy.delay_for_attempt(0)).collect(); + // with 20 samples from a range of 1000 discrete values, at least 2 should differ + let unique = delays + .iter() + .collect::>() + .len(); + assert!( + unique > 1, + "expected varying jitter values, got {unique} unique out of 20" + ); +}