diff --git a/Cargo.lock b/Cargo.lock index cb425d6..001a46b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,6 +253,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "env_home" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" + [[package]] name = "equivalent" version = "1.0.2" @@ -371,7 +383,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core", + "windows-core 0.62.2", ] [[package]] @@ -491,6 +503,31 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c2599ce0ec54857b29ce62166b0ed9b4f6f1a70ccc9a71165b6154caca8c05" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" +dependencies = [ + "bitflags", + "objc2", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -509,6 +546,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + [[package]] name = "pkg-config" version = "0.3.32" @@ -594,10 +637,13 @@ dependencies = [ "rusqlite", "serde", "serde_json", + "serde_yaml", "tempfile", "thiserror", "toml", + "trash", "walkdir", + "which", ] [[package]] @@ -633,6 +679,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + [[package]] name = "same-file" version = "1.0.6" @@ -642,6 +694,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.228" @@ -695,6 +753,19 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "shlex" version = "1.3.0" @@ -798,12 +869,42 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "trash" +version = "5.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9b93a14fcf658568eb11b3ac4cb406822e916e2c55cdebc421beeb0bd7c94d8" +dependencies = [ + "chrono", + "libc", + "log", + "objc2", + "objc2-foundation", + "once_cell", + "percent-encoding", + "scopeguard", + "urlencoding", + "windows", +] + [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8parse" version = "0.2.2" @@ -892,6 +993,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "which" +version = "7.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" +dependencies = [ + "either", + "env_home", + "rustix", + "winsafe", +] + [[package]] name = "winapi-util" version = "0.1.11" @@ -901,19 +1014,52 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "windows" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de69df01bdf1ead2f4ac895dc77c9351aefff65b2f3db429a343f9cbf05e132" +dependencies = [ + "windows-core 0.56.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4698e52ed2d08f8658ab0c39512a7c00ee5fe2688c65f8c0a4f06750d729f2a6" +dependencies = [ + "windows-implement 0.56.0", + "windows-interface 0.56.0", + "windows-result 0.1.2", + "windows-targets 0.52.6", +] + [[package]] name = "windows-core" version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ - "windows-implement", - "windows-interface", + "windows-implement 0.60.2", + "windows-interface 0.59.3", "windows-link", - "windows-result", + "windows-result 0.4.1", "windows-strings", ] +[[package]] +name = "windows-implement" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6fc35f58ecd95a9b71c4f2329b911016e6bec66b3f2e6a4aad86bd2e99e2f9b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-implement" version = "0.60.2" @@ -925,6 +1071,17 @@ dependencies = [ "syn", ] +[[package]] +name = "windows-interface" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08990546bf4edef8f431fa6326e032865f27138718c587dc21bc0265bbcb57cc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-interface" version = "0.59.3" @@ -942,6 +1099,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-result" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-result" version = "0.4.1" @@ -1117,6 +1283,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "winsafe" +version = "0.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index bd27247..a70cb94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,9 @@ toml = "0.8" chrono = "0.4" thiserror = "1.0" tempfile = "3" +trash = "5" # Built-in trash (cross-platform) +which = "7" # Find binaries in PATH (for exec module) +serde_yaml = "0.9" # YAML frontmatter in rule MD files [dev-dependencies] diff --git a/INSTALL.md b/INSTALL.md index 55b32fd..5af27f0 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -75,7 +75,7 @@ rtk gain # MUST show token savings, not "command not found" ```bash rtk init -g -# → Installs hook to ~/.claude/hooks/rtk-rewrite.sh +# → Registers "rtk hook claude" in ~/.claude/settings.json # → Creates ~/.claude/RTK.md (10 lines, meta commands only) # → Adds @RTK.md reference to ~/.claude/CLAUDE.md # → Prompts: "Patch settings.json? [y/N]" @@ -111,6 +111,37 @@ rtk init # Creates ./CLAUDE.md with full RTK instructions (137 lines) **Token savings**: Instructions loaded only for this project +### Gemini CLI Setup + +**Best for: Gemini CLI users wanting the same token optimization** + +```bash +rtk init --gemini +# → Registers "rtk hook gemini" in ~/.gemini/settings.json +# → Prompts: "Patch settings.json? [y/N]" +# → If yes: patches + creates backup (~/.gemini/settings.json.bak) + +# Automated alternatives: +rtk init --gemini --auto-patch # Patch without prompting +rtk init --gemini --no-patch # Print manual instructions instead + +# Verify installation +rtk init --show # Shows both Claude and Gemini hook status +``` + +**Manual setup** (if `rtk init --gemini` isn't available): +```json +// Add to ~/.gemini/settings.json +{ + "hooks": { + "BeforeTool": [{ + "matcher": "run_shell_command", + "hooks": [{ "type": "command", "command": "rtk hook gemini" }] + }] + } +} +``` + ### Upgrading from Previous Version If you previously used `rtk init -g` with the old system (137-line injection): @@ -195,7 +226,7 @@ rtk vitest run rtk init -g --uninstall # What gets removed: -# - Hook: ~/.claude/hooks/rtk-rewrite.sh +# - Hook: RTK entry from ~/.claude/settings.json # - Context: ~/.claude/RTK.md # - Reference: @RTK.md line from ~/.claude/CLAUDE.md # - Registration: RTK hook entry from settings.json diff --git a/README.md b/README.md index 19a2250..10c24b0 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ [Website](https://www.rtk-ai.app) | [GitHub](https://github.com/rtk-ai/rtk) | [Install](INSTALL.md) -rtk filters and compresses command outputs before they reach your LLM context, saving 60-90% of tokens on common operations. +rtk filters and compresses command outputs before they reach your LLM context, saving 60-90% of tokens on common operations. Works with **Claude Code** and **Gemini CLI**. ## ⚠️ Important: Name Collision Warning @@ -14,7 +14,7 @@ rtk filters and compresses command outputs before they reach your LLM context, s 1. ✅ **This project (Rust Token Killer)** - LLM token optimizer - Repos: `rtk-ai/rtk` - - Purpose: Reduce Claude Code token consumption + - Purpose: Reduce Claude Code and Gemini CLI token consumption 2. ❌ **reachingforthejack/rtk** - Rust Type Kit (DIFFERENT PROJECT) - Purpose: Query Rust codebase and generate types @@ -28,7 +28,7 @@ rtk gain # Should show token savings stats If `rtk gain` doesn't exist, you installed the wrong package. See installation instructions below. -## Token Savings (30-min Claude Code Session) +## Token Savings (30-min LLM Session) Typical session without rtk: **~150,000 tokens** With rtk: **~45,000 tokens** → **70% reduction** @@ -53,6 +53,8 @@ With rtk: **~45,000 tokens** → **70% reduction** ## Installation +RTK supports **Claude Code** and **Gemini CLI** from a single binary. + ### ⚠️ Pre-Installation Check (REQUIRED) **ALWAYS verify if rtk is already installed before installing:** @@ -101,10 +103,10 @@ Download from [rtk-ai/releases](https://github.com/rtk-ai/rtk/releases): # 1. Verify installation rtk gain # Must show token stats, not "command not found" -# 2. Initialize for Claude Code (RECOMMENDED: hook-first mode) -rtk init --global -# → Installs hook + creates slim RTK.md (10 lines, 99.5% token savings) -# → Follow printed instructions to add hook to ~/.claude/settings.json +# 2. Initialize for Claude Code and Gemini CLI +rtk init --global # Both platforms — recommended default +rtk init --global --claude # Claude Code only +rtk init --global --gemini # Gemini CLI only # 3. Test it works rtk git status # Should show ultra-compact output @@ -166,6 +168,78 @@ rtk go test # Go tests (NDJSON, 90% reduction) rtk golangci-lint run # Go linting (JSON, 85% reduction) ``` +### Safety & Execution +```bash +# Execute command with safety checks and token-optimized output +rtk run -c "git status" # Safe execution with filtering + +# Hook protocol for Claude Code integration +rtk hook check --agent claude "git status" # Prints rewritten command to stdout +rtk hook check --agent claude "cat file" # Blocked (exit 2, prints reason) +``` + +### Safety Features + +RTK includes built-in data safety for AI-generated commands. Each data loss incident costs ~20-50K tokens in recovery (investigate damage, restore/rewrite, regenerate, debug differences). RTK prevents this through recoverable rewrites: + +| Raw Command | RTK Behavior | Why | Opt-out | +|-------------|--------------|-----|---------| +| `rm file` | → trash | Recoverable deletion | `RTK_SAFE_COMMANDS=0` | +| `git reset --hard` | → stash + reset | Preserve uncommitted changes | `RTK_SAFE_COMMANDS=0` | +| `git checkout .` | → stash + checkout | Preserve local changes | `RTK_SAFE_COMMANDS=0` | +| `git checkout -- file` | → stash + checkout | Preserve local changes | `RTK_SAFE_COMMANDS=0` | +| `git stash drop` | → stash pop | Recoverable stash | `RTK_SAFE_COMMANDS=0` | +| `git clean -f` | → stash -u + clean | Preserve untracked files | `RTK_SAFE_COMMANDS=0` | +| `git clean -fd` | → stash -u + clean | Preserve untracked files | `RTK_SAFE_COMMANDS=0` | +| `git clean -df` | → stash -u + clean | Preserve untracked files | `RTK_SAFE_COMMANDS=0` | +| `cat file` | blocked | Use Read tool instead | `RTK_BLOCK_TOKEN_WASTE=0` | +| `sed -i` | blocked | Use Edit tool instead | `RTK_BLOCK_TOKEN_WASTE=0` | +| `head file` | blocked | Use Read tool with limit | `RTK_BLOCK_TOKEN_WASTE=0` | + +**Why these commands?** `rm` is permanent but `trash` is recoverable. `git reset --hard` loses uncommitted changes but `stash` saves them. `cat`/`sed`/`head` don't preserve edit history but Read/Edit tools do. Blocked commands fail with an error message suggesting the native tool alternative. + +**Action types**: `trash` (move to OS trash), `rewrite` (modify command before execution), `warn` (print message, allow execution), `block`/`suggest_tool` (fail with error and suggest alternative). + +11 rule files (`rtk.*.md` with YAML frontmatter, similar to SKILL.md files). The table groups related patterns by row. For example, `git clean -f`, `-fd`, `-df` are 3 separate rule files shown in one row. Custom rules can be placed in `~/.config/rtk/`, `~/.claude/`, `~/.gemini/`, or `.rtk/` directories. + +**Rule priority** (highest to lowest): +1. `--rules-add` CLI paths +2. `.claude/`, `.gemini/`, `.rtk/` in project directories (closest to cwd wins) +3. `~/.claude/`, `~/.gemini/` (global, LLM-visible) +4. `~/.config/rtk/` (global config directory) +5. Built-in rules (compiled in RTK) + +**Example custom rule** (`~/.config/rtk/rtk.safety.chmod-777.md`): +```markdown +--- +name: chmod-777 +patterns: ["chmod -R 777", "chmod 777"] +action: warn +when: "git rev-parse --is-inside-work-tree 2>/dev/null" +env_var: RTK_SAFE_COMMANDS +--- +Warning: chmod 777 grants full access to all users. Consider chmod 755 or chmod 644. +``` + +**Rule fields**: `patterns` (command patterns to match), `action` (rewrite/trash/warn/block), `when` (condition), `env_var` (opt-out env var, rule applies unless set to `0`), `enabled` (set to `false` to disable a built-in rule). + +The `when:` field supports: +- Named predicates: `always` (default), `has_unstaged_changes` (only if working tree is dirty) +- Arbitrary shell commands: any command that exits 0 means the rule applies + +Built-in examples: `git reset --hard` uses `when: has_unstaged_changes` (only protects uncommitted work). The custom chmod rule above uses `when: "git rev-parse ..."` (only warns inside git repos). + +**Chained Commands:** RTK extracts and rewrites each command in `&&`, `||`, `;` chains independently. Smart quoting prevents false splits on operators inside strings (e.g., `git commit -m "Fix && Bug"` stays intact). Pipes and redirects pass through unchanged (no rewriting inside pipes). + +``` +Before: cd /path && git status && git diff + Hook sees "cd" only, git status and git diff pass through unoptimized + +After: cd /path && git status && git diff + → cd /path && rtk git status && rtk git diff + Each command extracted and rewritten independently (95%/75% savings) +``` + ### Data & Analytics ```bash rtk json config.json # Structure without values @@ -419,9 +493,9 @@ database_path = "/path/to/custom.db" Priority: `RTK_DB_PATH` env var > `config.toml` > default location. -## Auto-Rewrite Hook (Recommended) +## Claude Code Integration -The most effective way to use rtk is with the **auto-rewrite hook** for Claude Code. Instead of relying on CLAUDE.md instructions (which subagents may ignore), this hook transparently intercepts Bash commands and rewrites them to their rtk equivalents before execution. +The most effective way to use rtk with Claude Code is the **auto-rewrite hook**. Instead of relying on CLAUDE.md instructions (which subagents may ignore), this hook transparently intercepts Bash commands and rewrites them to their rtk equivalents before execution. **Result**: 100% rtk adoption across all conversations and subagents, zero token overhead in Claude's context. @@ -434,7 +508,7 @@ Claude Code hooks are scripts that run before/after Claude executes commands. RT Claude Code reads `~/.claude/settings.json` to find registered hooks. Without this file, Claude doesn't know the RTK hook exists. Think of it as the hook registry. **Is it safe?** -Yes. RTK creates a backup (`settings.json.bak`) before changes. The hook is read-only (it only modifies command strings, never deletes files or accesses secrets). Review the hook script at `~/.claude/hooks/rtk-rewrite.sh` anytime. +Yes. RTK creates a backup (`settings.json.bak`) before changes. The hook is read-only (it only modifies command strings, never deletes files or accesses secrets). The hook runs `rtk hook claude` as a direct binary invocation — no shell scripts involved. ### How It Works @@ -444,7 +518,7 @@ The hook runs as a Claude Code [PreToolUse hook](https://docs.anthropic.com/en/d ```bash rtk init -g -# → Installs hook to ~/.claude/hooks/rtk-rewrite.sh (with executable permissions) +# → Registers "rtk hook claude" in ~/.claude/settings.json # → Creates ~/.claude/RTK.md (10 lines, minimal context footprint) # → Adds @RTK.md reference to ~/.claude/CLAUDE.md # → Prompts: "Patch settings.json? [y/N]" @@ -481,16 +555,7 @@ rtk init -g --no-patch # Prints JSON snippet **Alternative: Full manual setup** -```bash -# 1. Copy the hook script -mkdir -p ~/.claude/hooks -cp .claude/hooks/rtk-rewrite.sh ~/.claude/hooks/rtk-rewrite.sh -chmod +x ~/.claude/hooks/rtk-rewrite.sh - -# 2. Add to ~/.claude/settings.json under hooks.PreToolUse: -``` - -Add this entry to the `PreToolUse` array in `~/.claude/settings.json`: +Add this entry to `~/.claude/settings.json` under `hooks.PreToolUse`: ```json { @@ -501,7 +566,7 @@ Add this entry to the `PreToolUse` array in `~/.claude/settings.json`: "hooks": [ { "type": "command", - "command": "~/.claude/hooks/rtk-rewrite.sh" + "command": "rtk hook claude" } ] } @@ -512,7 +577,7 @@ Add this entry to the `PreToolUse` array in `~/.claude/settings.json`: ### Per-Project Install -The hook is included in this repository at `.claude/hooks/rtk-rewrite.sh`. To use it in another project, copy the hook and add the same settings.json entry using a relative path or project-level `.claude/settings.json`. +RTK uses a direct binary invocation (`rtk hook claude`), so no hook files need to be copied. For per-project setup, add the same settings.json entry to your project-level `.claude/settings.json`. ### Commands Rewritten @@ -540,7 +605,7 @@ The hook is included in this repository at `.claude/hooks/rtk-rewrite.sh`. To us | `curl` | `rtk curl` | | `pnpm list/ls/outdated` | `rtk pnpm ...` | -Commands already using `rtk`, heredocs (`<<`), and unrecognized commands pass through unchanged. +Commands already using `rtk`, heredocs (`<<`), and unrecognized commands pass through unchanged. Commands like `cat`, `sed`, `head` are blocked with suggestions to use Claude Code's Read/Edit tools instead (see Safety Features above). ### Alternative: Suggest Hook (Non-Intrusive) @@ -594,6 +659,68 @@ chmod +x ~/.claude/hooks/rtk-suggest.sh The suggest hook detects the same commands as the rewrite hook but outputs a `systemMessage` instead of `updatedInput`, informing Claude Code that an rtk alternative exists. +## Gemini CLI Integration + +RTK also supports [Gemini CLI](https://github.com/google-gemini/gemini-cli) via its **BeforeTool** hook protocol. The same safety engine that powers the Claude Code hook is used for Gemini, providing consistent command rewriting and blocking across both agents. + +### Quick Install (Automated) + +```bash +rtk init --gemini +# → Patches ~/.gemini/settings.json with BeforeTool hook +# → Prompts: "Patch settings.json? [y/N]" +# → Creates backup (~/.gemini/settings.json.bak) if file exists + +# Options: +rtk init --gemini --auto-patch # Patch without prompting (CI/CD) +rtk init --gemini --no-patch # Skip patching, print manual JSON snippet + +# Verify installation +rtk init --show +``` + +### Manual Install + +Add the following to `~/.gemini/settings.json`: + +```json +{ + "hooks": { + "BeforeTool": [ + { + "matcher": "run_shell_command", + "hooks": [ + { + "type": "command", + "command": "rtk hook gemini" + } + ] + } + ] + } +} +``` + +### How It Works + +When Gemini CLI is about to execute a shell command, it sends a JSON payload to `rtk hook gemini` on stdin. RTK's safety engine evaluates the command and responds with: + +- **Allow + rewrite**: Rewrites the command to its `rtk run -c '...'` equivalent +- **Block**: Returns `"deny"` with a reason explaining which native tool to use instead +- **Passthrough**: Commands already using `rtk` pass through unchanged + +The `matcher` field (`run_shell_command`) identifies Gemini's shell execution tool (analogous to Claude Code's `Bash` matcher). Non-shell tool events pass through without inspection. + +### Uninstalling Gemini Hook + +```bash +rtk init --gemini --uninstall +# → Removes RTK hook entry from ~/.gemini/settings.json +# → Preserves other hooks and settings +``` + +The global `rtk init -g --uninstall` also removes Gemini hooks alongside Claude Code hooks. + ## Uninstalling RTK **Complete Removal (Global Only)**: @@ -601,12 +728,12 @@ The suggest hook detects the same commands as the rewrite hook but outputs a `sy rtk init -g --uninstall # Removes: -# - ~/.claude/hooks/rtk-rewrite.sh +# - RTK hook from ~/.claude/settings.json +# - RTK hook from ~/.gemini/settings.json # - ~/.claude/RTK.md # - @RTK.md reference from ~/.claude/CLAUDE.md -# - RTK hook entry from ~/.claude/settings.json -# Restart Claude Code after uninstall +# Restart Claude Code / Gemini CLI after uninstall ``` **Restore from Backup** (if needed): @@ -682,8 +809,9 @@ git status # Should use rtk automatically **Manual Cleanup**: ```bash -# Remove hook -rm ~/.claude/hooks/rtk-rewrite.sh +# Remove hook entry from settings.json (or use rtk init -g --uninstall) +# Legacy hook file cleanup (if present from old installs): +rm -f ~/.claude/hooks/rtk-rewrite.sh # Remove RTK.md rm ~/.claude/RTK.md diff --git a/docs/TROUBLESHOOTING.md b/docs/TROUBLESHOOTING.md index 64d4576..c855c3a 100644 --- a/docs/TROUBLESHOOTING.md +++ b/docs/TROUBLESHOOTING.md @@ -137,14 +137,8 @@ rtk init --show # Should show "✅ Hook: executable, with guards" ``` **Option B: Manual (fallback)** -```bash -# Copy hook to Claude Code hooks directory -mkdir -p ~/.claude/hooks -cp .claude/hooks/rtk-rewrite.sh ~/.claude/hooks/ -chmod +x ~/.claude/hooks/rtk-rewrite.sh -``` -Then add to `~/.claude/settings.json` (replace `~` with full path): +Add to `~/.claude/settings.json`: ```json { "hooks": { @@ -154,7 +148,7 @@ Then add to `~/.claude/settings.json` (replace `~` with full path): "hooks": [ { "type": "command", - "command": "/Users/yourname/.claude/hooks/rtk-rewrite.sh" + "command": "rtk hook claude" } ] } @@ -167,6 +161,58 @@ Then add to `~/.claude/settings.json` (replace `~` with full path): --- +## Problem: RTK not working in Gemini CLI + +### Symptom +Gemini CLI doesn't use rtk for shell commands, outputs are verbose. + +### Checklist + +**1. Verify rtk is installed and correct:** +```bash +rtk --version +rtk gain # Must show stats +``` + +**2. Install Gemini hook:** +```bash +rtk init --gemini +# → Registers "rtk hook gemini" in ~/.gemini/settings.json +# → Restart Gemini CLI +``` + +**3. Verify hook is configured:** +```bash +rtk init --show # Should show Gemini hook status +# Or check manually: +grep "rtk hook gemini" ~/.gemini/settings.json +``` + +**4. Manual setup (fallback):** + +Add to `~/.gemini/settings.json`: +```json +{ + "hooks": { + "BeforeTool": [ + { + "matcher": "run_shell_command", + "hooks": [ + { + "type": "command", + "command": "rtk hook gemini" + } + ] + } + ] + } +} +``` + +Then restart Gemini CLI. + +--- + ## Problem: "command not found: rtk" after installation ### Symptom diff --git a/hooks/rtk-rewrite.sh b/hooks/rtk-rewrite.sh index 59e02ca..e57055d 100644 --- a/hooks/rtk-rewrite.sh +++ b/hooks/rtk-rewrite.sh @@ -1,209 +1,4 @@ #!/bin/bash -# RTK auto-rewrite hook for Claude Code PreToolUse:Bash -# Transparently rewrites raw commands to their rtk equivalents. -# Outputs JSON with updatedInput to modify the command before execution. - -# Guards: skip silently if dependencies missing -if ! command -v rtk &>/dev/null || ! command -v jq &>/dev/null; then - exit 0 -fi - -set -euo pipefail - -INPUT=$(cat) -CMD=$(echo "$INPUT" | jq -r '.tool_input.command // empty') - -if [ -z "$CMD" ]; then - exit 0 -fi - -# Extract the first meaningful command (before pipes, &&, etc.) -# We only rewrite if the FIRST command in a chain matches. -FIRST_CMD="$CMD" - -# Skip if already using rtk -case "$FIRST_CMD" in - rtk\ *|*/rtk\ *) exit 0 ;; -esac - -# Skip commands with heredocs, variable assignments as the whole command, etc. -case "$FIRST_CMD" in - *'<<'*) exit 0 ;; -esac - -# Strip leading env var assignments for pattern matching -# e.g., "TEST_SESSION_ID=2 npx playwright test" → match against "npx playwright test" -# but preserve them in the rewritten command for execution. -ENV_PREFIX=$(echo "$FIRST_CMD" | grep -oE '^([A-Za-z_][A-Za-z0-9_]*=[^ ]* +)+' || echo "") -if [ -n "$ENV_PREFIX" ]; then - MATCH_CMD="${FIRST_CMD:${#ENV_PREFIX}}" - CMD_BODY="${CMD:${#ENV_PREFIX}}" -else - MATCH_CMD="$FIRST_CMD" - CMD_BODY="$CMD" -fi - -REWRITTEN="" - -# --- Git commands --- -if echo "$MATCH_CMD" | grep -qE '^git[[:space:]]'; then - GIT_SUBCMD=$(echo "$MATCH_CMD" | sed -E \ - -e 's/^git[[:space:]]+//' \ - -e 's/(-C|-c)[[:space:]]+[^[:space:]]+[[:space:]]*//g' \ - -e 's/--[a-z-]+=[^[:space:]]+[[:space:]]*//g' \ - -e 's/--(no-pager|no-optional-locks|bare|literal-pathspecs)[[:space:]]*//g' \ - -e 's/^[[:space:]]+//') - case "$GIT_SUBCMD" in - status|status\ *|diff|diff\ *|log|log\ *|add|add\ *|commit|commit\ *|push|push\ *|pull|pull\ *|branch|branch\ *|fetch|fetch\ *|stash|stash\ *|show|show\ *) - REWRITTEN="${ENV_PREFIX}rtk $CMD_BODY" - ;; - esac - -# --- GitHub CLI (added: api, release) --- -elif echo "$MATCH_CMD" | grep -qE '^gh[[:space:]]+(pr|issue|run|api|release)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^gh /rtk gh /')" - -# --- Cargo --- -elif echo "$MATCH_CMD" | grep -qE '^cargo[[:space:]]'; then - CARGO_SUBCMD=$(echo "$MATCH_CMD" | sed -E 's/^cargo[[:space:]]+(\+[^[:space:]]+[[:space:]]+)?//') - case "$CARGO_SUBCMD" in - test|test\ *|build|build\ *|clippy|clippy\ *|check|check\ *|install|install\ *|fmt|fmt\ *) - REWRITTEN="${ENV_PREFIX}rtk $CMD_BODY" - ;; - esac - -# --- File operations --- -elif echo "$MATCH_CMD" | grep -qE '^cat[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^cat /rtk read /')" -elif echo "$MATCH_CMD" | grep -qE '^(rg|grep)[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(rg|grep) /rtk grep /')" -elif echo "$MATCH_CMD" | grep -qE '^ls([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^ls/rtk ls/')" -elif echo "$MATCH_CMD" | grep -qE '^tree([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^tree/rtk tree/')" -elif echo "$MATCH_CMD" | grep -qE '^find[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^find /rtk find /')" -elif echo "$MATCH_CMD" | grep -qE '^diff[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^diff /rtk diff /')" -elif echo "$MATCH_CMD" | grep -qE '^head[[:space:]]+'; then - # Transform: head -N file → rtk read file --max-lines N - # Also handle: head --lines=N file - if echo "$MATCH_CMD" | grep -qE '^head[[:space:]]+-[0-9]+[[:space:]]+'; then - LINES=$(echo "$MATCH_CMD" | sed -E 's/^head +-([0-9]+) +.+$/\1/') - FILE=$(echo "$MATCH_CMD" | sed -E 's/^head +-[0-9]+ +(.+)$/\1/') - REWRITTEN="${ENV_PREFIX}rtk read $FILE --max-lines $LINES" - elif echo "$MATCH_CMD" | grep -qE '^head[[:space:]]+--lines=[0-9]+[[:space:]]+'; then - LINES=$(echo "$MATCH_CMD" | sed -E 's/^head +--lines=([0-9]+) +.+$/\1/') - FILE=$(echo "$MATCH_CMD" | sed -E 's/^head +--lines=[0-9]+ +(.+)$/\1/') - REWRITTEN="${ENV_PREFIX}rtk read $FILE --max-lines $LINES" - fi - -# --- JS/TS tooling (added: npm run, npm test, vue-tsc) --- -elif echo "$MATCH_CMD" | grep -qE '^(pnpm[[:space:]]+)?(npx[[:space:]]+)?vitest([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(pnpm )?(npx )?vitest( run)?/rtk vitest run/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+test([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm test/rtk vitest run/')" -elif echo "$MATCH_CMD" | grep -qE '^npm[[:space:]]+test([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^npm test/rtk npm test/')" -elif echo "$MATCH_CMD" | grep -qE '^npm[[:space:]]+run[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^npm run /rtk npm /')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?vue-tsc([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?vue-tsc/rtk tsc/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+tsc([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm tsc/rtk tsc/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?tsc([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?tsc/rtk tsc/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+lint([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm lint/rtk lint/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?eslint([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?eslint/rtk lint/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?prettier([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?prettier/rtk prettier/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?playwright([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?playwright/rtk playwright/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+playwright([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm playwright/rtk playwright/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?prisma([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?prisma/rtk prisma/')" - -# --- Containers (added: docker compose, docker run/build/exec, kubectl describe/apply) --- -elif echo "$MATCH_CMD" | grep -qE '^docker[[:space:]]'; then - if echo "$MATCH_CMD" | grep -qE '^docker[[:space:]]+compose([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^docker /rtk docker /')" - else - DOCKER_SUBCMD=$(echo "$MATCH_CMD" | sed -E \ - -e 's/^docker[[:space:]]+//' \ - -e 's/(-H|--context|--config)[[:space:]]+[^[:space:]]+[[:space:]]*//g' \ - -e 's/--[a-z-]+=[^[:space:]]+[[:space:]]*//g' \ - -e 's/^[[:space:]]+//') - case "$DOCKER_SUBCMD" in - ps|ps\ *|images|images\ *|logs|logs\ *|run|run\ *|build|build\ *|exec|exec\ *) - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^docker /rtk docker /')" - ;; - esac - fi -elif echo "$MATCH_CMD" | grep -qE '^kubectl[[:space:]]'; then - KUBE_SUBCMD=$(echo "$MATCH_CMD" | sed -E \ - -e 's/^kubectl[[:space:]]+//' \ - -e 's/(--context|--kubeconfig|--namespace|-n)[[:space:]]+[^[:space:]]+[[:space:]]*//g' \ - -e 's/--[a-z-]+=[^[:space:]]+[[:space:]]*//g' \ - -e 's/^[[:space:]]+//') - case "$KUBE_SUBCMD" in - get|get\ *|logs|logs\ *|describe|describe\ *|apply|apply\ *) - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^kubectl /rtk kubectl /')" - ;; - esac - -# --- Network --- -elif echo "$MATCH_CMD" | grep -qE '^curl[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^curl /rtk curl /')" -elif echo "$MATCH_CMD" | grep -qE '^wget[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^wget /rtk wget /')" - -# --- pnpm package management --- -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+(list|ls|outdated)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm /rtk pnpm /')" - -# --- Python tooling --- -elif echo "$MATCH_CMD" | grep -qE '^pytest([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pytest/rtk pytest/')" -elif echo "$MATCH_CMD" | grep -qE '^python[[:space:]]+-m[[:space:]]+pytest([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^python -m pytest/rtk pytest/')" -elif echo "$MATCH_CMD" | grep -qE '^ruff[[:space:]]+(check|format)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^ruff /rtk ruff /')" -elif echo "$MATCH_CMD" | grep -qE '^pip[[:space:]]+(list|outdated|install|show)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pip /rtk pip /')" -elif echo "$MATCH_CMD" | grep -qE '^uv[[:space:]]+pip[[:space:]]+(list|outdated|install|show)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^uv pip /rtk pip /')" - -# --- Go tooling --- -elif echo "$MATCH_CMD" | grep -qE '^go[[:space:]]+test([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^go test/rtk go test/')" -elif echo "$MATCH_CMD" | grep -qE '^go[[:space:]]+build([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^go build/rtk go build/')" -elif echo "$MATCH_CMD" | grep -qE '^go[[:space:]]+vet([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^go vet/rtk go vet/')" -elif echo "$MATCH_CMD" | grep -qE '^golangci-lint([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^golangci-lint/rtk golangci-lint/')" -fi - -# If no rewrite needed, approve as-is -if [ -z "$REWRITTEN" ]; then - exit 0 -fi - -# Build the updated tool_input with all original fields preserved, only command changed -ORIGINAL_INPUT=$(echo "$INPUT" | jq -c '.tool_input') -UPDATED_INPUT=$(echo "$ORIGINAL_INPUT" | jq --arg cmd "$REWRITTEN" '.command = $cmd') - -# Output the rewrite instruction -jq -n \ - --argjson updated "$UPDATED_INPUT" \ - '{ - "hookSpecificOutput": { - "hookEventName": "PreToolUse", - "permissionDecision": "allow", - "permissionDecisionReason": "RTK auto-rewrite", - "updatedInput": $updated - } - }' +# Migration shim — existing installs forward to rtk binary. +# New installs use "rtk hook claude" directly. Remove in a future release. +exec rtk hook claude diff --git a/hooks/test-rtk-rewrite.sh b/hooks/test-rtk-rewrite.sh deleted file mode 100755 index 2a68ff8..0000000 --- a/hooks/test-rtk-rewrite.sh +++ /dev/null @@ -1,293 +0,0 @@ -#!/bin/bash -# Test suite for rtk-rewrite.sh -# Feeds mock JSON through the hook and verifies the rewritten commands. -# -# Usage: bash ~/.claude/hooks/test-rtk-rewrite.sh - -HOOK="$HOME/.claude/hooks/rtk-rewrite.sh" -PASS=0 -FAIL=0 -TOTAL=0 - -# Colors -GREEN='\033[32m' -RED='\033[31m' -DIM='\033[2m' -RESET='\033[0m' - -test_rewrite() { - local description="$1" - local input_cmd="$2" - local expected_cmd="$3" # empty string = expect no rewrite - TOTAL=$((TOTAL + 1)) - - local input_json - input_json=$(jq -n --arg cmd "$input_cmd" '{"tool_name":"Bash","tool_input":{"command":$cmd}}') - local output - output=$(echo "$input_json" | bash "$HOOK" 2>/dev/null) || true - - if [ -z "$expected_cmd" ]; then - # Expect no rewrite (hook exits 0 with no output) - if [ -z "$output" ]; then - printf " ${GREEN}PASS${RESET} %s ${DIM}→ (no rewrite)${RESET}\n" "$description" - PASS=$((PASS + 1)) - else - local actual - actual=$(echo "$output" | jq -r '.hookSpecificOutput.updatedInput.command // empty') - printf " ${RED}FAIL${RESET} %s\n" "$description" - printf " expected: (no rewrite)\n" - printf " actual: %s\n" "$actual" - FAIL=$((FAIL + 1)) - fi - else - local actual - actual=$(echo "$output" | jq -r '.hookSpecificOutput.updatedInput.command // empty' 2>/dev/null) - if [ "$actual" = "$expected_cmd" ]; then - printf " ${GREEN}PASS${RESET} %s ${DIM}→ %s${RESET}\n" "$description" "$actual" - PASS=$((PASS + 1)) - else - printf " ${RED}FAIL${RESET} %s\n" "$description" - printf " expected: %s\n" "$expected_cmd" - printf " actual: %s\n" "$actual" - FAIL=$((FAIL + 1)) - fi - fi -} - -echo "============================================" -echo " RTK Rewrite Hook Test Suite" -echo "============================================" -echo "" - -# ---- SECTION 1: Existing patterns (regression tests) ---- -echo "--- Existing patterns (regression) ---" -test_rewrite "git status" \ - "git status" \ - "rtk git status" - -test_rewrite "git log --oneline -10" \ - "git log --oneline -10" \ - "rtk git log --oneline -10" - -test_rewrite "git diff HEAD" \ - "git diff HEAD" \ - "rtk git diff HEAD" - -test_rewrite "git show abc123" \ - "git show abc123" \ - "rtk git show abc123" - -test_rewrite "git add ." \ - "git add ." \ - "rtk git add ." - -test_rewrite "gh pr list" \ - "gh pr list" \ - "rtk gh pr list" - -test_rewrite "npx playwright test" \ - "npx playwright test" \ - "rtk playwright test" - -test_rewrite "ls -la" \ - "ls -la" \ - "rtk ls -la" - -test_rewrite "curl -s https://example.com" \ - "curl -s https://example.com" \ - "rtk curl -s https://example.com" - -test_rewrite "cat package.json" \ - "cat package.json" \ - "rtk read package.json" - -test_rewrite "grep -rn pattern src/" \ - "grep -rn pattern src/" \ - "rtk grep -rn pattern src/" - -test_rewrite "rg pattern src/" \ - "rg pattern src/" \ - "rtk grep pattern src/" - -test_rewrite "cargo test" \ - "cargo test" \ - "rtk cargo test" - -test_rewrite "npx prisma migrate" \ - "npx prisma migrate" \ - "rtk prisma migrate" - -echo "" - -# ---- SECTION 2: Env var prefix handling (THE BIG FIX) ---- -echo "--- Env var prefix handling (new) ---" -test_rewrite "env + playwright" \ - "TEST_SESSION_ID=2 npx playwright test --config=foo" \ - "TEST_SESSION_ID=2 rtk playwright test --config=foo" - -test_rewrite "env + git status" \ - "GIT_PAGER=cat git status" \ - "GIT_PAGER=cat rtk git status" - -test_rewrite "env + git log" \ - "GIT_PAGER=cat git log --oneline -10" \ - "GIT_PAGER=cat rtk git log --oneline -10" - -test_rewrite "multi env + vitest" \ - "NODE_ENV=test CI=1 npx vitest run" \ - "NODE_ENV=test CI=1 rtk vitest run" - -test_rewrite "env + ls" \ - "LANG=C ls -la" \ - "LANG=C rtk ls -la" - -test_rewrite "env + npm run" \ - "NODE_ENV=test npm run test:e2e" \ - "NODE_ENV=test rtk npm test:e2e" - -test_rewrite "env + docker compose" \ - "COMPOSE_PROJECT_NAME=test docker compose up -d" \ - "COMPOSE_PROJECT_NAME=test rtk docker compose up -d" - -echo "" - -# ---- SECTION 3: New patterns ---- -echo "--- New patterns ---" -test_rewrite "npm run test:e2e" \ - "npm run test:e2e" \ - "rtk npm test:e2e" - -test_rewrite "npm run build" \ - "npm run build" \ - "rtk npm build" - -test_rewrite "npm test" \ - "npm test" \ - "rtk npm test" - -test_rewrite "vue-tsc -b" \ - "vue-tsc -b" \ - "rtk tsc -b" - -test_rewrite "npx vue-tsc --noEmit" \ - "npx vue-tsc --noEmit" \ - "rtk tsc --noEmit" - -test_rewrite "docker compose up -d" \ - "docker compose up -d" \ - "rtk docker compose up -d" - -test_rewrite "docker compose logs postgrest" \ - "docker compose logs postgrest" \ - "rtk docker compose logs postgrest" - -test_rewrite "docker compose down" \ - "docker compose down" \ - "rtk docker compose down" - -test_rewrite "docker run --rm postgres" \ - "docker run --rm postgres" \ - "rtk docker run --rm postgres" - -test_rewrite "docker exec -it db psql" \ - "docker exec -it db psql" \ - "rtk docker exec -it db psql" - -test_rewrite "find (NOT rewritten — different arg format)" \ - "find . -name '*.ts'" \ - "" - -test_rewrite "tree (NOT rewritten — different arg format)" \ - "tree src/" \ - "" - -test_rewrite "wget (NOT rewritten — different arg format)" \ - "wget https://example.com/file" \ - "" - -test_rewrite "gh api repos/owner/repo" \ - "gh api repos/owner/repo" \ - "rtk gh api repos/owner/repo" - -test_rewrite "gh release list" \ - "gh release list" \ - "rtk gh release list" - -test_rewrite "kubectl describe pod foo" \ - "kubectl describe pod foo" \ - "rtk kubectl describe pod foo" - -test_rewrite "kubectl apply -f deploy.yaml" \ - "kubectl apply -f deploy.yaml" \ - "rtk kubectl apply -f deploy.yaml" - -echo "" - -# ---- SECTION 4: Vitest edge case (fixed double "run" bug) ---- -echo "--- Vitest run dedup ---" -test_rewrite "vitest (no args)" \ - "vitest" \ - "rtk vitest run" - -test_rewrite "vitest run (no double run)" \ - "vitest run" \ - "rtk vitest run" - -test_rewrite "vitest run --reporter" \ - "vitest run --reporter=verbose" \ - "rtk vitest run --reporter=verbose" - -test_rewrite "npx vitest run" \ - "npx vitest run" \ - "rtk vitest run" - -test_rewrite "pnpm vitest run --coverage" \ - "pnpm vitest run --coverage" \ - "rtk vitest run --coverage" - -echo "" - -# ---- SECTION 5: Should NOT rewrite ---- -echo "--- Should NOT rewrite ---" -test_rewrite "already rtk" \ - "rtk git status" \ - "" - -test_rewrite "heredoc" \ - "cat <<'EOF' -hello -EOF" \ - "" - -test_rewrite "echo (no pattern)" \ - "echo hello world" \ - "" - -test_rewrite "cd (no pattern)" \ - "cd /tmp" \ - "" - -test_rewrite "mkdir (no pattern)" \ - "mkdir -p foo/bar" \ - "" - -test_rewrite "python3 (no pattern)" \ - "python3 script.py" \ - "" - -test_rewrite "node (no pattern)" \ - "node -e 'console.log(1)'" \ - "" - -echo "" - -# ---- SUMMARY ---- -echo "============================================" -if [ $FAIL -eq 0 ]; then - printf " ${GREEN}ALL $TOTAL TESTS PASSED${RESET}\n" -else - printf " ${RED}$FAIL FAILED${RESET} / $TOTAL total ($PASS passed)\n" -fi -echo "============================================" - -exit $FAIL diff --git a/scripts/check-installation.sh b/scripts/check-installation.sh index 023ff4d..d424d55 100755 --- a/scripts/check-installation.sh +++ b/scripts/check-installation.sh @@ -113,17 +113,21 @@ echo "" # Check 6: Auto-rewrite hook echo "6. Checking auto-rewrite hook (optional but recommended)..." -if [ -f "$HOME/.claude/hooks/rtk-rewrite.sh" ]; then - echo -e " ${GREEN}✅${NC} Hook script installed" - if [ -f "$HOME/.claude/settings.json" ] && grep -q "rtk-rewrite.sh" "$HOME/.claude/settings.json"; then - echo -e " ${GREEN}✅${NC} Hook enabled in settings.json" - else - echo -e " ${YELLOW}⚠️${NC} Hook script exists but not enabled in settings.json" - echo " See README.md 'Auto-Rewrite Hook' section" - fi +if [ -f "$HOME/.claude/settings.json" ] && grep -q "rtk hook claude" "$HOME/.claude/settings.json"; then + echo -e " ${GREEN}✅${NC} Hook enabled in settings.json (rtk hook claude)" +elif [ -f "$HOME/.claude/settings.json" ] && grep -q "rtk-rewrite.sh" "$HOME/.claude/settings.json"; then + echo -e " ${YELLOW}⚠️${NC} Legacy hook found (rtk-rewrite.sh). Run: rtk init -g to migrate" +else + echo -e " ${YELLOW}⚠️${NC} Auto-rewrite hook not configured" + echo " Install: rtk init -g" +fi +# Check 7: Gemini CLI hook +echo "7. Checking Gemini CLI hook (optional)..." +if [ -f "$HOME/.gemini/settings.json" ] && grep -q "rtk hook gemini" "$HOME/.gemini/settings.json"; then + echo -e " ${GREEN}✅${NC} Gemini hook enabled in settings.json (rtk hook gemini)" else - echo -e " ${YELLOW}⚠️${NC} Auto-rewrite hook not installed (optional)" - echo " Install: cp .claude/hooks/rtk-rewrite.sh ~/.claude/hooks/" + echo -e " ${YELLOW}⚠️${NC} Gemini hook not configured" + echo " Install: rtk init --gemini" fi echo "" diff --git a/scripts/test-cmd-interceptor.sh b/scripts/test-cmd-interceptor.sh new file mode 100755 index 0000000..f5f4002 --- /dev/null +++ b/scripts/test-cmd-interceptor.sh @@ -0,0 +1,160 @@ +#!/bin/bash +# Integration tests for RTK command interceptor +# Run from the rtk repository root + +set -e + +echo "=== RTK Command Interceptor Tests ===" + +# Determine how to run rtk (prefer local builds) +if [ -f "./target/debug/rtk" ]; then + RTK="./target/debug/rtk" +elif [ -f "./target/release/rtk" ]; then + RTK="./target/release/rtk" +else + echo "Building rtk..." + cargo build + RTK="./target/debug/rtk" +fi + +echo "Using: $RTK" +echo "" + +# 1. Basic execution +echo -n "Test 1: Basic echo... " +result=$($RTK run -c "echo hello" 2>&1) +if echo "$result" | grep -q "hello"; then + echo "✓" +else + echo "FAIL: expected 'hello' in output" + exit 1 +fi + +# 2. Chained commands (&&) +echo -n "Test 2: Chained && ... " +result=$($RTK run -c "true && echo yes" 2>&1) +if echo "$result" | grep -q "yes"; then + echo "✓" +else + echo "FAIL: expected 'yes' in output" + exit 1 +fi + +# 3. Chained commands (||) +echo -n "Test 3: Chained || ... " +result=$($RTK run -c "false || echo fallback" 2>&1) +if echo "$result" | grep -q "fallback"; then + echo "✓" +else + echo "FAIL: expected 'fallback' in output" + exit 1 +fi + +# 4. Chained commands (;) +echo -n "Test 4: Chained ; ... " +result=$($RTK run -c "true ; echo always" 2>&1) +if echo "$result" | grep -q "always"; then + echo "✓" +else + echo "FAIL: expected 'always' in output" + exit 1 +fi + +# 5. Hook protocol - safe command +echo -n "Test 5: Hook safe command... " +result=$($RTK hook check --agent claude "git status" 2>&1) +if echo "$result" | grep -q "rtk run"; then + echo "✓" +else + echo "FAIL: expected 'rtk run' in output" + exit 1 +fi + +# 6. Hook protocol - blocked command (cat) +echo -n "Test 6: Hook blocked command (cat)... " +if ! $RTK hook check --agent claude "cat /etc/passwd" 2>/dev/null; then + echo "✓" +else + echo "FAIL: expected non-zero exit for blocked command" + exit 1 +fi + +# 7. Passthrough for globs +echo -n "Test 7: Glob passthrough... " +if $RTK run -c "echo *.rs" 2>/dev/null; then + echo "✓" +else + echo "✓ (no .rs files or expected behavior)" +fi + +# 8. Passthrough for pipes +echo -n "Test 8: Pipe passthrough... " +result=$($RTK run -c "echo hello | cat" 2>&1) +if echo "$result" | grep -q "hello"; then + echo "✓" +else + echo "FAIL: expected 'hello' in output" + exit 1 +fi + +# 9. Builtins - pwd +echo -n "Test 9: Builtin pwd... " +result=$($RTK run -c "pwd" 2>&1) +if echo "$result" | grep -q "/"; then + echo "✓" +else + echo "FAIL: expected path in output" + exit 1 +fi + +# 10. Quoted operators +echo -n "Test 10: Quoted operator... " +result=$($RTK run -c "echo 'hello && world'" 2>&1) +if echo "$result" | grep -q "hello"; then + echo "✓" +else + echo "FAIL: expected 'hello' in output" + exit 1 +fi + +# 11. Hook blocked command (sed) +echo -n "Test 11: Hook blocked command (sed)... " +if ! $RTK hook check --agent claude "sed -i 's/old/new/' file.txt" 2>/dev/null; then + echo "✓" +else + echo "FAIL: expected non-zero exit for blocked sed command" + exit 1 +fi + +# 12. Hook blocked command (head) +echo -n "Test 12: Hook blocked command (head)... " +if ! $RTK hook check --agent claude "head -n 10 file.txt" 2>/dev/null; then + echo "✓" +else + echo "FAIL: expected non-zero exit for blocked head command" + exit 1 +fi + +# 13. Hook exit code for rewrite is 0 +echo -n "Test 13: Hook rewrite exit code 0... " +$RTK hook check --agent claude "git status" > /dev/null 2>&1 +exit_code=$? +if [ $exit_code -eq 0 ]; then + echo "✓" +else + echo "FAIL: expected exit code 0, got $exit_code" + exit 1 +fi + +# 14. Hook exit code for blocked is 2 +echo -n "Test 14: Hook blocked exit code 2... " +$RTK hook check --agent claude "cat file.txt" > /dev/null 2>&1 || exit_code=$? +if [ "$exit_code" -eq 2 ]; then + echo "✓" +else + echo "FAIL: expected exit code 2, got ${exit_code:-0}" + exit 1 +fi + +echo "" +echo "=== All 14 tests passed ===" diff --git a/src/cc_economics.rs b/src/cc_economics.rs index b38bba2..a6e600f 100644 --- a/src/cc_economics.rs +++ b/src/cc_economics.rs @@ -14,8 +14,6 @@ use crate::utils::{format_cpt, format_tokens, format_usd}; // ── Constants ── -const BILLION: f64 = 1e9; - // API pricing ratios (verified Feb 2026, consistent across Claude models <=200K context) // Source: https://docs.anthropic.com/en/docs/about-claude/models const WEIGHT_OUTPUT: f64 = 5.0; // Output = 5x input diff --git a/src/ccusage.rs b/src/ccusage.rs index 822cca1..db692b4 100644 --- a/src/ccusage.rs +++ b/src/ccusage.rs @@ -114,11 +114,6 @@ fn build_command() -> Option { None } -/// Check if ccusage CLI is available (binary or via npx) -pub fn is_available() -> bool { - build_command().is_some() -} - /// Fetch usage data from ccusage for the last 90 days /// /// Returns `Ok(None)` if ccusage is unavailable (graceful degradation) @@ -330,11 +325,4 @@ mod tests { assert_eq!(periods[0].metrics.cache_creation_tokens, 0); // default assert_eq!(periods[0].metrics.cache_read_tokens, 0); } - - #[test] - fn test_is_available() { - // Just smoke test - actual availability depends on system - let _available = is_available(); - // No assertion - just ensure it doesn't panic - } } diff --git a/src/cmd/analysis.rs b/src/cmd/analysis.rs new file mode 100644 index 0000000..1d816e2 --- /dev/null +++ b/src/cmd/analysis.rs @@ -0,0 +1,249 @@ +//! Analyzes tokens to decide: Native execution or Passthrough? + +use super::lexer::{strip_quotes, ParsedToken, TokenKind}; + +/// Represents a single command in a chain +#[derive(Debug, Clone, PartialEq)] +pub struct NativeCommand { + pub binary: String, + pub args: Vec, + pub operator: Option, // &&, ||, ;, or None for last command +} + +/// Check if command needs real shell (has shellisms, pipes, redirects) +pub fn needs_shell(tokens: &[ParsedToken]) -> bool { + tokens.iter().any(|t| { + matches!( + t.kind, + TokenKind::Shellism | TokenKind::Pipe | TokenKind::Redirect + ) + }) +} + +/// Parse tokens into native command chain +/// Returns error if syntax is invalid (e.g., operator with no preceding command) +pub fn parse_chain(tokens: Vec) -> Result, String> { + let mut commands = Vec::new(); + let mut current_args = Vec::new(); + + for token in tokens { + match token.kind { + TokenKind::Arg => { + // Strip quotes from the argument + current_args.push(strip_quotes(&token.value)); + } + TokenKind::Operator => { + if current_args.is_empty() { + return Err(format!( + "Syntax error: operator {} with no command", + token.value + )); + } + // First arg is the binary, rest are args + let binary = current_args.remove(0); + commands.push(NativeCommand { + binary, + args: current_args.clone(), + operator: Some(token.value.clone()), + }); + current_args.clear(); + } + TokenKind::Pipe | TokenKind::Redirect | TokenKind::Shellism => { + // Should not reach here if needs_shell() was checked first + // But handle gracefully + return Err(format!( + "Unexpected {:?} in native mode - use passthrough", + token.kind + )); + } + } + } + + // Handle last command (no trailing operator) + if !current_args.is_empty() { + let binary = current_args.remove(0); + commands.push(NativeCommand { + binary, + args: current_args, + operator: None, + }); + } + + Ok(commands) +} + +/// Should the next command run based on operator and last result? +pub fn should_run(operator: Option<&str>, last_success: bool) -> bool { + match operator { + Some("&&") => last_success, + Some("||") => !last_success, + Some(";") | None => true, + _ => true, // Unknown operator, just run + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cmd::lexer::tokenize; + + // === NEEDS_SHELL TESTS === + + #[test] + fn test_needs_shell_simple() { + let tokens = tokenize("git status"); + assert!(!needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_glob() { + let tokens = tokenize("ls *.rs"); + assert!(needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_pipe() { + let tokens = tokenize("cat file | grep x"); + assert!(needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_redirect() { + let tokens = tokenize("cmd > file"); + assert!(needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_chain() { + let tokens = tokenize("cd dir && git status"); + // && is an Operator, not a Shellism - should NOT need shell + assert!(!needs_shell(&tokens)); + } + + // === PARSE_CHAIN TESTS === + + #[test] + fn test_parse_simple_command() { + let tokens = tokenize("git status"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].binary, "git"); + assert_eq!(cmds[0].args, vec!["status"]); + assert_eq!(cmds[0].operator, None); + } + + #[test] + fn test_parse_command_with_multiple_args() { + let tokens = tokenize("git commit -m message"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].binary, "git"); + assert_eq!(cmds[0].args, vec!["commit", "-m", "message"]); + } + + #[test] + fn test_parse_chained_and() { + let tokens = tokenize("cd dir && git status"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 2); + assert_eq!(cmds[0].binary, "cd"); + assert_eq!(cmds[0].args, vec!["dir"]); + assert_eq!(cmds[0].operator, Some("&&".to_string())); + assert_eq!(cmds[1].binary, "git"); + assert_eq!(cmds[1].args, vec!["status"]); + assert_eq!(cmds[1].operator, None); + } + + #[test] + fn test_parse_chained_or() { + let tokens = tokenize("cmd1 || cmd2"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 2); + assert_eq!(cmds[0].operator, Some("||".to_string())); + } + + #[test] + fn test_parse_chained_semicolon() { + let tokens = tokenize("cmd1 ; cmd2 ; cmd3"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 3); + assert_eq!(cmds[0].operator, Some(";".to_string())); + assert_eq!(cmds[1].operator, Some(";".to_string())); + assert_eq!(cmds[2].operator, None); + } + + #[test] + fn test_parse_triple_chain() { + let tokens = tokenize("a && b && c"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 3); + } + + #[test] + fn test_parse_operator_at_start() { + let tokens = tokenize("&& cmd"); + let result = parse_chain(tokens); + assert!(result.is_err()); + } + + #[test] + fn test_parse_operator_at_end() { + let tokens = tokenize("cmd &&"); + let cmds = parse_chain(tokens).unwrap(); + // cmd is parsed, && triggers flush but no second command + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].operator, Some("&&".to_string())); + } + + #[test] + fn test_parse_quoted_arg() { + let tokens = tokenize("git commit -m \"Fix && Bug\""); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 1); + // The && inside quotes should be in the arg, not an operator + // args are: commit, -m, "Fix && Bug" + assert_eq!(cmds[0].args.len(), 3); + assert_eq!(cmds[0].args[2], "Fix && Bug"); + } + + #[test] + fn test_parse_empty() { + let tokens = tokenize(""); + let cmds = parse_chain(tokens).unwrap(); + assert!(cmds.is_empty()); + } + + // === SHOULD_RUN TESTS === + + #[test] + fn test_should_run_and_success() { + assert!(should_run(Some("&&"), true)); + } + + #[test] + fn test_should_run_and_failure() { + assert!(!should_run(Some("&&"), false)); + } + + #[test] + fn test_should_run_or_success() { + assert!(!should_run(Some("||"), true)); + } + + #[test] + fn test_should_run_or_failure() { + assert!(should_run(Some("||"), false)); + } + + #[test] + fn test_should_run_semicolon() { + assert!(should_run(Some(";"), true)); + assert!(should_run(Some(";"), false)); + } + + #[test] + fn test_should_run_none() { + assert!(should_run(None, true)); + assert!(should_run(None, false)); + } +} diff --git a/src/cmd/builtins.rs b/src/cmd/builtins.rs new file mode 100644 index 0000000..fa11be6 --- /dev/null +++ b/src/cmd/builtins.rs @@ -0,0 +1,246 @@ +//! Built-in commands that RTK handles natively. +//! These maintain session state across hook calls. + +use super::predicates::{expand_tilde, get_home}; +use anyhow::{Context, Result}; + +/// Change directory (persists in RTK process) +pub fn builtin_cd(args: &[String]) -> Result { + let target = args + .first() + .map(|s| expand_tilde(s)) + .unwrap_or_else(get_home); + + std::env::set_current_dir(&target) + .with_context(|| format!("cd: {}: No such file or directory", target))?; + + Ok(true) +} + +/// Export environment variable +pub fn builtin_export(args: &[String]) -> Result { + for arg in args { + if let Some((key, value)) = arg.split_once('=') { + // Handle quoted values: export FOO="bar baz" + let clean_value = value + .strip_prefix('"') + .and_then(|v| v.strip_suffix('"')) + .or_else(|| value.strip_prefix('\'').and_then(|v| v.strip_suffix('\''))) + .unwrap_or(value); + std::env::set_var(key, clean_value); + } + } + Ok(true) +} + +/// Check if a binary is a builtin +pub fn is_builtin(binary: &str) -> bool { + matches!( + binary, + "cd" | "export" | "pwd" | "echo" | "true" | "false" | ":" + ) +} + +/// Execute a builtin command +pub fn execute(binary: &str, args: &[String]) -> Result { + match binary { + "cd" => builtin_cd(args), + "export" => builtin_export(args), + "pwd" => { + println!("{}", std::env::current_dir()?.display()); + Ok(true) + } + "echo" => { + let (print_args, no_newline) = if args.first().map(|s| s.as_str()) == Some("-n") { + (&args[1..], true) + } else { + (args, false) + }; + print!("{}", print_args.join(" ")); + if !no_newline { + println!(); + } + Ok(true) + } + "true" | ":" => Ok(true), + "false" => Ok(false), + _ => anyhow::bail!("Unknown builtin: {}", binary), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + // === CD TESTS === + // Consolidated into one test: cwd is process-global, so parallel tests race. + + #[test] + fn test_cd_all_cases() { + let original = env::current_dir().unwrap(); + let home = get_home(); + + // 1. cd to existing dir + let result = builtin_cd(&["/tmp".to_string()]).unwrap(); + assert!(result); + let new_dir = env::current_dir().unwrap(); + // On macOS /tmp symlinks to /private/tmp — canonicalize both sides + let canon_tmp = std::fs::canonicalize("/tmp").unwrap(); + let canon_new = std::fs::canonicalize(&new_dir).unwrap(); + assert_eq!(canon_new, canon_tmp, "cd /tmp should land in /tmp"); + + // 2. cd to nonexistent dir + let result = builtin_cd(&["/nonexistent/path/xyz".to_string()]); + assert!(result.is_err()); + // cwd unchanged after failed cd + assert_eq!( + std::fs::canonicalize(env::current_dir().unwrap()).unwrap(), + canon_tmp + ); + + // 3. cd with no args → home + let result = builtin_cd(&[]).unwrap(); + assert!(result); + let cwd = env::current_dir().unwrap(); + let canon_home = std::fs::canonicalize(&home).unwrap(); + let canon_cwd = std::fs::canonicalize(&cwd).unwrap(); + assert_eq!(canon_cwd, canon_home, "cd with no args should go home"); + + // 4. cd ~ → home + let _ = env::set_current_dir("/tmp"); + let result = builtin_cd(&["~".to_string()]).unwrap(); + assert!(result); + let cwd = std::fs::canonicalize(env::current_dir().unwrap()).unwrap(); + assert_eq!(cwd, canon_home, "cd ~ should go home"); + + // 5. cd ~/nonexistent-subpath — may fail, just verify no panic + let _ = builtin_cd(&["~/nonexistent_rtk_test_subpath_xyz".to_string()]); + + // Restore original cwd + let _ = env::set_current_dir(&original); + } + + // === EXPORT TESTS === + + #[test] + fn test_export_simple() { + builtin_export(&["RTK_TEST_SIMPLE=value".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_SIMPLE").unwrap(), "value"); + env::remove_var("RTK_TEST_SIMPLE"); + } + + #[test] + fn test_export_with_equals_in_value() { + builtin_export(&["RTK_TEST_EQUALS=key=value".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_EQUALS").unwrap(), "key=value"); + env::remove_var("RTK_TEST_EQUALS"); + } + + #[test] + fn test_export_quoted_value() { + builtin_export(&["RTK_TEST_QUOTED=\"hello world\"".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_QUOTED").unwrap(), "hello world"); + env::remove_var("RTK_TEST_QUOTED"); + } + + #[test] + fn test_export_multiple() { + builtin_export(&["RTK_TEST_A=1".to_string(), "RTK_TEST_B=2".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_A").unwrap(), "1"); + assert_eq!(env::var("RTK_TEST_B").unwrap(), "2"); + env::remove_var("RTK_TEST_A"); + env::remove_var("RTK_TEST_B"); + } + + #[test] + fn test_export_no_equals() { + // Should be silently ignored (like bash) + let result = builtin_export(&["NO_EQUALS_HERE".to_string()]).unwrap(); + assert!(result); + } + + // === IS_BUILTIN TESTS === + + #[test] + fn test_is_builtin_cd() { + assert!(is_builtin("cd")); + } + + #[test] + fn test_is_builtin_export() { + assert!(is_builtin("export")); + } + + #[test] + fn test_is_builtin_pwd() { + assert!(is_builtin("pwd")); + } + + #[test] + fn test_is_builtin_echo() { + assert!(is_builtin("echo")); + } + + #[test] + fn test_is_builtin_true() { + assert!(is_builtin("true")); + } + + #[test] + fn test_is_builtin_false() { + assert!(is_builtin("false")); + } + + #[test] + fn test_is_builtin_external() { + assert!(!is_builtin("git")); + assert!(!is_builtin("ls")); + assert!(!is_builtin("cargo")); + } + + // === EXECUTE TESTS === + + #[test] + fn test_execute_pwd() { + let result = execute("pwd", &[]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_echo() { + let result = execute("echo", &["hello".to_string(), "world".to_string()]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_true() { + let result = execute("true", &[]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_false() { + let result = execute("false", &[]).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_unknown_builtin() { + let result = execute("notabuiltin", &[]); + assert!(result.is_err()); + } + + #[test] + fn test_execute_echo_n_flag() { + // echo -n should succeed (prints without newline) + let result = execute("echo", &["-n".to_string(), "hello".to_string()]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_echo_empty_args() { + let result = execute("echo", &[]).unwrap(); + assert!(result); + } +} diff --git a/src/cmd/claude_hook.rs b/src/cmd/claude_hook.rs new file mode 100644 index 0000000..1918c16 --- /dev/null +++ b/src/cmd/claude_hook.rs @@ -0,0 +1,506 @@ +//! Claude Code PreToolUse hook protocol handler. +//! +//! Reads JSON from stdin, applies safety checks and rewrites, +//! outputs JSON to stdout. +//! +//! Protocol: https://docs.anthropic.com/en/docs/claude-code/hooks +//! +//! ## Exit Code Behavior +//! +//! - Exit 0 = success (allow/rewrite) — tool proceeds +//! - Exit 2 = blocking error (deny) — tool rejected +//! +//! ## Claude Code Stderr Rule (CRITICAL) +//! +//! **Source:** See `/Users/athundt/.claude/clautorun/.worktrees/claude-stable-pre-v0.8.0/notes/hooks_api_reference.md:720-728` +//! +//! ```text +//! CRITICAL: ANY stderr output at exit 0 = hook error = fail-open +//! ``` +//! +//! **Implication:** +//! - Exit 0 + ANY stderr → Claude Code treats hook as FAILED → tool executes anyway (fail-open) +//! - Exit 2 + stderr → Claude Code treats stderr as the block reason → tool blocked, AI sees reason +//! +//! **This module's stderr usage:** +//! - ✅ Exit 0 paths (NoOpinion, Allow): **NEVER write to stderr** +//! - ✅ Exit 2 path (Deny): **stderr ONLY** for bug #4669 workaround (see below) +//! +//! ## Bug #4669 Workaround (Dual-Path Deny) +//! +//! **Issue:** https://github.com/anthropics/claude-code/issues/4669 +//! **Versions:** v1.0.62+ through current (not fixed) +//! **Problem:** `permissionDecision: "deny"` at exit 0 is IGNORED — tool executes anyway +//! +//! **Workaround:** +//! ```text +//! stdout: JSON with permissionDecision "deny" (documented main path, but broken) +//! stderr: plain text reason (fallback path that actually works) +//! exit code: 2 (triggers Claude Code to read stderr as error) +//! ``` +//! +//! This ensures deny works regardless of which path Claude Code processes. +//! +//! ## I/O Enforcement (Module-Specific) +//! +//! **This restriction applies ONLY to claude_hook.rs and gemini_hook.rs.** +//! All other RTK modules (main.rs, git.rs, etc.) use `println!`/`eprintln!` normally. +//! +//! **Why restricted here:** +//! - Hook protocol requires JSON-only stdout +//! - Claude Code's "ANY stderr = hook error" rule (see above) +//! - Accidental prints corrupt the JSON protocol +//! +//! **Enforcement mechanism:** +//! - `#![deny(clippy::print_stdout, clippy::print_stderr)]` at module level (line 52) +//! - `run_inner()` returns `HookResponse` enum — pure logic, no I/O +//! - `run()` is the ONLY function that writes output — single I/O point +//! - Uses `write!`/`writeln!` which are NOT caught by the clippy lint +//! +//! **Pathway:** main.rs → Commands::Hook → claude_hook::run() [DENY ENFORCED HERE] +//! +//! Fail-open: Any parse error or unexpected input → exit 0, no output. + +// Compile-time I/O enforcement for THIS MODULE ONLY. +// Other RTK modules (main.rs, git.rs, etc.) use println!/eprintln! normally. +// +// Why restrict here: +// - Claude Code hook protocol requires JSON-only stdout +// - Claude Code rule: "ANY stderr at exit 0 = hook error = fail-open" +// (Source: clautorun hooks_api_reference.md:720-728) +// - Accidental prints would corrupt the JSON response +// +// Mechanism: +// - Denies println!/eprintln! at compile-time +// - Allows write!/writeln! (used only in run() for controlled output) +// - run_inner() returns HookResponse (no I/O) +// - run() is the single I/O point +#![deny(clippy::print_stdout, clippy::print_stderr)] + +use super::hook::{ + check_for_hook, is_hook_disabled, should_passthrough, update_command_in_tool_input, + HookResponse, HookResult, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::io::{self, Read, Write}; + +// --- Wire format structs (field names must match Claude Code spec exactly) --- + +#[derive(Deserialize)] +pub(crate) struct ClaudePayload { + tool_input: Option, + // Claude Code also sends: tool_name, session_id, session_cwd, + // transcript_path — serde silently ignores unknown fields. + // The settings.json matcher already filters to Bash-only events. +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ClaudeResponse { + hook_specific_output: HookOutput, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct HookOutput { + hook_event_name: &'static str, + permission_decision: &'static str, + permission_decision_reason: String, + #[serde(skip_serializing_if = "Option::is_none")] + updated_input: Option, +} + +// --- Guard logic (extracted for testability) --- + +/// Extract the command string from a parsed payload. +/// Returns None if payload has no tool_input or no command field. +pub(crate) fn extract_command(payload: &ClaudePayload) -> Option<&str> { + payload + .tool_input + .as_ref()? + .get("command")? + .as_str() + .filter(|s| !s.is_empty()) +} + +// Guard functions `is_hook_disabled()` and `should_passthrough()` are shared +// with gemini_hook.rs via hook.rs to avoid duplication (DRY). + +/// Build a ClaudeResponse for an allowed/rewritten command. +pub(crate) fn allow_response(reason: String, updated_input: Option) -> ClaudeResponse { + ClaudeResponse { + hook_specific_output: HookOutput { + hook_event_name: "PreToolUse", + permission_decision: "allow", + permission_decision_reason: reason, + updated_input, + }, + } +} + +/// Build a ClaudeResponse for a blocked command. +pub(crate) fn deny_response(reason: String) -> ClaudeResponse { + ClaudeResponse { + hook_specific_output: HookOutput { + hook_event_name: "PreToolUse", + permission_decision: "deny", + permission_decision_reason: reason, + updated_input: None, + }, + } +} + +// --- Entry point --- + +/// Run the Claude Code hook handler. +/// +/// This is the ONLY function that performs I/O (stdout/stderr). +/// `run_inner()` returns a `HookResponse` enum — pure logic, no I/O. +/// Combined with `#![deny(clippy::print_stdout, clippy::print_stderr)]`, +/// this ensures no stray output corrupts the JSON hook protocol. +/// +/// Fail-open design: malformed input → exit 0, no output. +/// Claude Code interprets this as "no opinion" and proceeds normally. +pub fn run() -> anyhow::Result<()> { + // Fail-open: wrap entire handler so ANY error → exit 0 (no opinion). + let response = match run_inner() { + Ok(r) => r, + Err(_) => HookResponse::NoOpinion, // Fail-open: swallow errors + }; + + // ┌────────────────────────────────────────────────────────────────┐ + // │ SINGLE I/O POINT - All stdout/stderr output happens here only │ + // │ │ + // │ Why: Claude Code rule "ANY stderr at exit 0 = hook error" │ + // │ (Source: hooks_api_reference.md:720-728) │ + // │ │ + // │ Enforcement: #![deny(...)] at line 52 prevents println!/eprintln! │ + // │ write!/writeln! are not caught by lint (allowed) │ + // └────────────────────────────────────────────────────────────────┘ + match response { + HookResponse::NoOpinion => { + // Exit 0, NO stdout, NO stderr + // Claude Code sees no output → proceeds with original command + } + HookResponse::Allow(json) => { + // Exit 0, JSON to stdout, NO stderr + // CRITICAL: No stderr at exit 0 (would cause fail-open) + writeln!(io::stdout(), "{json}")?; + } + HookResponse::Deny(json, reason) => { + // Exit 2, JSON to stdout, reason to stderr + // This is the ONLY path that writes to stderr (valid at exit 2 only) + // + // Dual-path deny for bug #4669 workaround: + // - stdout: JSON with permissionDecision "deny" (documented path, but ignored) + // - stderr: plain text reason (actual blocking mechanism via exit 2) + // - exit 2: Triggers Claude Code to read stderr and block tool + writeln!(io::stdout(), "{json}")?; + writeln!(io::stderr(), "{reason}")?; + std::process::exit(2); + } + } + Ok(()) +} + +/// Inner handler: pure decision logic, no I/O. +/// Returns `HookResponse` for `run()` to output. +fn run_inner() -> anyhow::Result { + let mut buffer = String::new(); + io::stdin().read_to_string(&mut buffer)?; + + let payload: ClaudePayload = match serde_json::from_str(&buffer) { + Ok(p) => p, + Err(_) => return Ok(HookResponse::NoOpinion), + }; + + let cmd = match extract_command(&payload) { + Some(c) => c, + None => return Ok(HookResponse::NoOpinion), + }; + + if is_hook_disabled() || should_passthrough(cmd) { + return Ok(HookResponse::NoOpinion); + } + + let result = check_for_hook(cmd, "claude"); + + match result { + HookResult::Rewrite(new_cmd) => { + // Preserve all original tool_input fields, only replace "command" + // Shared helper (DRY with gemini_hook.rs via hook.rs) + let updated = update_command_in_tool_input(payload.tool_input, new_cmd); + + let response = allow_response("RTK safety rewrite applied".into(), Some(updated)); + let json = serde_json::to_string(&response)?; + Ok(HookResponse::Allow(json)) + } + HookResult::Blocked(msg) => { + let response = deny_response(msg.clone()); + let json = serde_json::to_string(&response)?; + Ok(HookResponse::Deny(json, msg)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ========================================================================= + // CLAUDE CODE WIRE FORMAT CONFORMANCE + // https://docs.anthropic.com/en/docs/claude-code/hooks + // + // These tests verify exact JSON field names per the Claude Code spec. + // A wrong field name means Claude Code silently ignores the response. + // ========================================================================= + + // --- Output: field name conformance --- + + #[test] + fn test_output_uses_hook_specific_output() { + // Claude expects "hookSpecificOutput" (camelCase), NOT "hook_specific_output" + let response = allow_response("test".into(), None); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert!( + parsed.get("hookSpecificOutput").is_some(), + "must have 'hookSpecificOutput' field" + ); + assert!( + parsed.get("hook_specific_output").is_none(), + "must NOT have snake_case field" + ); + } + + #[test] + fn test_output_uses_permission_decision() { + // Claude expects "permissionDecision", NOT "decision" + let response = allow_response("test".into(), None); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + let output = &parsed["hookSpecificOutput"]; + + assert!( + output.get("permissionDecision").is_some(), + "must have 'permissionDecision' field" + ); + assert!( + output.get("decision").is_none(), + "must NOT have Gemini-style 'decision' field" + ); + } + + #[test] + fn test_output_uses_permission_decision_reason() { + let response = deny_response("blocked".into()); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + let output = &parsed["hookSpecificOutput"]; + + assert!( + output.get("permissionDecisionReason").is_some(), + "must have 'permissionDecisionReason'" + ); + } + + #[test] + fn test_output_uses_hook_event_name() { + let response = allow_response("test".into(), None); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed["hookSpecificOutput"]["hookEventName"], "PreToolUse"); + } + + #[test] + fn test_output_uses_updated_input_for_rewrite() { + let input = serde_json::json!({"command": "rtk run -c 'git status'"}); + let response = allow_response("rewrite".into(), Some(input)); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert!( + parsed["hookSpecificOutput"].get("updatedInput").is_some(), + "must have 'updatedInput' for rewrites" + ); + } + + #[test] + fn test_allow_omits_updated_input_when_none() { + let response = allow_response("passthrough".into(), None); + let json = serde_json::to_string(&response).unwrap(); + + assert!( + !json.contains("updatedInput"), + "updatedInput must be omitted when None" + ); + } + + #[test] + fn test_rewrite_preserves_other_tool_input_fields() { + let original = serde_json::json!({ + "command": "git status", + "timeout": 30, + "description": "check repo" + }); + + let mut updated = original.clone(); + if let Some(obj) = updated.as_object_mut() { + obj.insert( + "command".into(), + Value::String("rtk run -c 'git status'".into()), + ); + } + + assert_eq!(updated["timeout"], 30); + assert_eq!(updated["description"], "check repo"); + assert_eq!(updated["command"], "rtk run -c 'git status'"); + } + + #[test] + fn test_output_decision_values() { + let allow = allow_response("test".into(), None); + let deny = deny_response("blocked".into()); + + let allow_json: Value = + serde_json::from_str(&serde_json::to_string(&allow).unwrap()).unwrap(); + let deny_json: Value = + serde_json::from_str(&serde_json::to_string(&deny).unwrap()).unwrap(); + + assert_eq!( + allow_json["hookSpecificOutput"]["permissionDecision"], + "allow" + ); + assert_eq!( + deny_json["hookSpecificOutput"]["permissionDecision"], + "deny" + ); + } + + // --- Input: payload parsing --- + + #[test] + fn test_input_extra_fields_ignored() { + // Claude sends session_id, tool_name, transcript_path, etc. + let json = r#"{"tool_input": {"command": "ls"}, "tool_name": "Bash", "session_id": "abc-123", "session_cwd": "/tmp", "transcript_path": "/path/to/transcript.jsonl"}"#; + let payload: ClaudePayload = serde_json::from_str(json).unwrap(); + assert_eq!(extract_command(&payload), Some("ls")); + } + + #[test] + fn test_input_tool_input_is_object() { + let json = r#"{"tool_input": {"command": "git status", "timeout": 30}}"#; + let payload: ClaudePayload = serde_json::from_str(json).unwrap(); + let input = payload.tool_input.unwrap(); + assert_eq!(input["command"].as_str().unwrap(), "git status"); + assert_eq!(input["timeout"].as_i64().unwrap(), 30); + } + + // --- Guard function tests --- + + #[test] + fn test_extract_command_basic() { + let payload: ClaudePayload = + serde_json::from_str(r#"{"tool_input": {"command": "git status"}}"#).unwrap(); + assert_eq!(extract_command(&payload), Some("git status")); + } + + #[test] + fn test_extract_command_missing_tool_input() { + let payload: ClaudePayload = serde_json::from_str(r#"{}"#).unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_extract_command_missing_command_field() { + let payload: ClaudePayload = + serde_json::from_str(r#"{"tool_input": {"cwd": "/tmp"}}"#).unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_extract_command_empty_string() { + let payload: ClaudePayload = + serde_json::from_str(r#"{"tool_input": {"command": ""}}"#).unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_shared_should_passthrough_rtk_prefix() { + assert!(should_passthrough("rtk run -c 'ls'")); + assert!(should_passthrough("rtk cargo test")); + assert!(should_passthrough("/usr/local/bin/rtk run -c 'ls'")); + } + + #[test] + fn test_shared_should_passthrough_heredoc() { + assert!(should_passthrough("cat <(input); + } + } + + // --- Fail-open behavior --- + + #[test] + fn test_run_inner_returns_no_opinion_for_empty_payload() { + // "{}" has no tool_input → no command → NoOpinion + let payload: ClaudePayload = serde_json::from_str("{}").unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_shared_is_hook_disabled_hook_enabled_zero() { + std::env::set_var("RTK_HOOK_ENABLED", "0"); + assert!(is_hook_disabled()); + std::env::remove_var("RTK_HOOK_ENABLED"); + } + + #[test] + fn test_shared_is_hook_disabled_rtk_active() { + std::env::set_var("RTK_ACTIVE", "1"); + assert!(is_hook_disabled()); + std::env::remove_var("RTK_ACTIVE"); + } + + // --- Integration: Bug #4669 workaround verification --- + + #[test] + fn test_deny_response_includes_reason_for_stderr() { + // Bug #4669 workaround: deny must provide plain text reason + // that can be output to stderr alongside the JSON stdout. + // The msg is cloned for both paths in run_inner(). + let msg = "RTK: cat is blocked (use rtk read instead)"; + let response = deny_response(msg.to_string()); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + // JSON stdout path + assert_eq!(parsed["hookSpecificOutput"]["permissionDecision"], "deny"); + assert_eq!( + parsed["hookSpecificOutput"]["permissionDecisionReason"], + msg + ); + // The same msg string is used for stderr in run() via HookResponse::Deny + } + + // Note: Integration tests for check_for_hook() safety decisions are in + // src/cmd/hook.rs (test_safe_commands_rewrite, test_blocked_commands, etc.) + // to avoid duplication. This module focuses on Claude Code wire format. +} diff --git a/src/cmd/exec.rs b/src/cmd/exec.rs new file mode 100644 index 0000000..59be99b --- /dev/null +++ b/src/cmd/exec.rs @@ -0,0 +1,459 @@ +//! Command executor: runs simple chains natively, delegates complex shell to /bin/sh. + +use anyhow::{Context, Result}; +use std::process::{Command, Stdio}; + +use super::{analysis, builtins, filters, lexer, safety, trash_cmd}; +use crate::tracking; + +/// Check if RTK is already active (recursion guard) +fn is_rtk_active() -> bool { + std::env::var("RTK_ACTIVE").is_ok() +} + +/// RAII guard: sets RTK_ACTIVE on creation, removes on drop (even on panic). +struct RtkActiveGuard; + +impl RtkActiveGuard { + fn new() -> Self { + std::env::set_var("RTK_ACTIVE", "1"); + RtkActiveGuard + } +} + +impl Drop for RtkActiveGuard { + fn drop(&mut self) { + std::env::remove_var("RTK_ACTIVE"); + } +} + +/// Execute a raw command string +pub fn execute(raw: &str, verbose: u8) -> Result { + // Recursion guard + if is_rtk_active() { + if verbose > 0 { + eprintln!("rtk: Recursion detected, passing through"); + } + return run_passthrough(raw, verbose); + } + + // Handle empty input + if raw.trim().is_empty() { + return Ok(true); + } + + let _guard = RtkActiveGuard::new(); + execute_inner(raw, verbose) +} + +fn execute_inner(raw: &str, verbose: u8) -> Result { + // === STEP 0: Remap expansion (aliases like "t" → "cargo test") === + if let Some(expanded) = crate::config::rules::try_remap(raw) { + if verbose > 0 { + eprintln!( + "rtk remap: {} → {}", + raw.split_whitespace().next().unwrap_or(raw), + expanded + ); + } + return execute_inner(&expanded, verbose); + } + + let tokens = lexer::tokenize(raw); + + // === STEP 1: Decide Native vs Passthrough === + if analysis::needs_shell(&tokens) { + // Even in passthrough, check safety on raw string + if let safety::SafetyResult::Blocked(msg) = safety::check_raw(raw) { + eprintln!("{}", msg); + return Ok(false); + } + return run_passthrough(raw, verbose); + } + + // === STEP 2: Parse into native command chain === + let commands = + analysis::parse_chain(tokens).map_err(|e| anyhow::anyhow!("Parse error: {}", e))?; + + // === STEP 3: Execute native chain === + run_native(&commands, verbose) +} + +/// Run commands in native mode (iterate, check safety, filter output) +fn run_native(commands: &[analysis::NativeCommand], verbose: u8) -> Result { + let mut last_success = true; + let mut prev_operator: Option<&str> = None; + + for cmd in commands { + // === SHORT-CIRCUIT LOGIC === + // Check if we should run based on PREVIOUS operator and result + // The operator stored in cmd is the one AFTER it, so we use prev_operator + if !analysis::should_run(prev_operator, last_success) { + // For && with failure or || with success, skip this command + prev_operator = cmd.operator.as_deref(); + continue; + } + + // === RECURSION PREVENTION === + // Handle "rtk run" or "rtk" binary specially + if cmd.binary == "rtk" && cmd.args.first().map(|s| s.as_str()) == Some("run") { + // Flatten: execute the inner command directly + // rtk run -c "git status" → args = ["run", "-c", "git status"] + let inner = if cmd.args.get(1).map(|s| s.as_str()) == Some("-c") { + cmd.args.get(2).cloned().unwrap_or_default() + } else { + cmd.args.get(1).cloned().unwrap_or_default() + }; + if verbose > 0 { + eprintln!("rtk: Flattening nested rtk run"); + } + return execute(&inner, verbose); + } + // Other rtk commands: spawn as external (they have their own filters) + + // === SAFETY CHECK === + match safety::check(&cmd.binary, &cmd.args) { + safety::SafetyResult::Blocked(msg) => { + eprintln!("{}", msg); + return Ok(false); + } + safety::SafetyResult::Rewritten(new_cmd) => { + // Re-execute the rewritten command + if verbose > 0 { + eprintln!("rtk safety: Rewrote command"); + } + return execute(&new_cmd, verbose); + } + safety::SafetyResult::TrashRequested(paths) => { + last_success = trash_cmd::execute(&paths)?; + prev_operator = cmd.operator.as_deref(); + continue; + } + safety::SafetyResult::Safe => {} + } + + // === BUILTINS === + if builtins::is_builtin(&cmd.binary) { + last_success = builtins::execute(&cmd.binary, &cmd.args)?; + prev_operator = cmd.operator.as_deref(); + continue; + } + + // === EXTERNAL COMMAND WITH FILTERING === + last_success = spawn_with_filter(&cmd.binary, &cmd.args, verbose)?; + prev_operator = cmd.operator.as_deref(); + } + + Ok(last_success) +} + +/// Spawn external command and apply appropriate filter +fn spawn_with_filter(binary: &str, args: &[String], _verbose: u8) -> Result { + let timer = tracking::TimedExecution::start(); + + // Try to find the binary in PATH + let binary_path = match which::which(binary) { + Ok(path) => path, + Err(_) => { + // Binary not found + eprintln!("rtk: {}: command not found", binary); + return Ok(false); + } + }; + + // Use wait_with_output() to avoid deadlock when child output exceeds + // pipe buffer (~64KB Linux, ~16KB macOS). This reads stdout/stderr in + // separate threads internally before calling wait(). + let output = Command::new(&binary_path) + .args(args) + .stdin(Stdio::inherit()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .with_context(|| format!("Failed to execute: {}", binary))?; + + let raw_out = String::from_utf8_lossy(&output.stdout); + let raw_err = String::from_utf8_lossy(&output.stderr); + + // Determine filter type and apply + let filter_type = filters::get_filter_type(binary); + let filtered_out = filters::apply_to_string(filter_type, &raw_out); + let filtered_err = crate::utils::strip_ansi(&raw_err); + + // Print filtered output + print!("{}", filtered_out); + eprint!("{}", filtered_err); + + // Track usage with raw vs filtered for accurate savings + let raw_output = format!("{}{}", raw_out, raw_err); + let filtered_output = format!("{}{}", filtered_out, filtered_err); + timer.track( + &format!("{} {}", binary, args.join(" ")), + &format!("rtk run {} {}", binary, args.join(" ")), + &raw_output, + &filtered_output, + ); + + Ok(output.status.success()) +} + +/// Run command via system shell (passthrough mode) +pub fn run_passthrough(raw: &str, verbose: u8) -> Result { + if verbose > 0 { + eprintln!("rtk: Passthrough mode for complex command"); + } + + let timer = tracking::TimedExecution::start(); + + let shell = if cfg!(windows) { "cmd" } else { "sh" }; + let flag = if cfg!(windows) { "/C" } else { "-c" }; + + let output = Command::new(shell) + .arg(flag) + .arg(raw) + .stdin(Stdio::inherit()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .context("Failed to execute passthrough")?; + + let raw_out = String::from_utf8_lossy(&output.stdout); + let raw_err = String::from_utf8_lossy(&output.stderr); + + // Basic filtering even in passthrough (strip ANSI) + let filtered_out = crate::utils::strip_ansi(&raw_out); + let filtered_err = crate::utils::strip_ansi(&raw_err); + print!("{}", filtered_out); + eprint!("{}", filtered_err); + + let raw_output = format!("{}{}", raw_out, raw_err); + let filtered_output = format!("{}{}", filtered_out, filtered_err); + timer.track( + raw, + &format!("rtk passthrough {}", raw), + &raw_output, + &filtered_output, + ); + + Ok(output.status.success()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cmd::test_helpers::EnvGuard; + + // === RAII GUARD TESTS === + + #[test] + fn test_is_rtk_active_default() { + let _env = EnvGuard::new(); + assert!(!is_rtk_active()); + } + + #[test] + fn test_raii_guard_sets_and_clears() { + let _env = EnvGuard::new(); + { + let _guard = RtkActiveGuard::new(); + assert!(is_rtk_active()); + } + assert!( + !is_rtk_active(), + "RTK_ACTIVE must be cleared when guard drops" + ); + } + + #[test] + fn test_raii_guard_clears_on_panic() { + let _env = EnvGuard::new(); + let result = std::panic::catch_unwind(|| { + let _guard = RtkActiveGuard::new(); + assert!(is_rtk_active()); + panic!("simulated panic"); + }); + assert!(result.is_err()); + assert!( + !is_rtk_active(), + "RTK_ACTIVE must be cleared even after panic" + ); + } + + // === EXECUTE TESTS === + + #[test] + fn test_execute_empty() { + let result = execute("", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_whitespace_only() { + let result = execute(" ", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_simple_command() { + let result = execute("echo hello", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_builtin_cd() { + let original = std::env::current_dir().unwrap(); + let result = execute("cd /tmp", 0).unwrap(); + assert!(result); + // On macOS, /tmp might be a symlink to /private/tmp + // Just verify the command succeeded (the cd happened) + let _ = std::env::set_current_dir(&original); + } + + #[test] + fn test_execute_builtin_pwd() { + let result = execute("pwd", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_builtin_true() { + let result = execute("true", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_builtin_false() { + let result = execute("false", 0).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_chain_and_success() { + let result = execute("true && echo success", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_chain_and_failure() { + let result = execute("false && echo should_not_run", 0).unwrap(); + // Chain stops at false, so result is false + assert!(!result); + } + + #[test] + fn test_execute_chain_or_success() { + let result = execute("true || echo should_not_run", 0).unwrap(); + // true succeeds, || doesn't run second command + assert!(result); + } + + #[test] + fn test_execute_chain_or_failure() { + let result = execute("false || echo fallback", 0).unwrap(); + // false fails, || runs fallback + assert!(result); + } + + #[test] + fn test_execute_chain_semicolon() { + let result = execute("true ; false", 0).unwrap(); + // Both run, last result is false + assert!(!result); + } + + #[test] + fn test_execute_passthrough_for_glob() { + let result = execute("echo *", 0).unwrap(); + // Should work via passthrough + assert!(result); + } + + #[test] + fn test_execute_passthrough_for_pipe() { + let result = execute("echo hello | cat", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_quoted_operator() { + let result = execute(r#"echo "hello && world""#, 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_binary_not_found() { + let result = execute("nonexistent_command_xyz_123", 0).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_chain_and_three_commands() { + // 3-command chain: true succeeds, false fails, stops before third + let result = execute("true && false && true", 0).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_chain_semicolon_last_wins() { + // Semicolon runs all; last result (true) determines outcome + let result = execute("false ; true", 0).unwrap(); + assert!(result); + } + + // === INTEGRATION TESTS (moved from edge_cases.rs) === + + #[test] + fn test_chain_mixed_operators() { + // false -> || runs true -> true && runs echo + let result = execute("false || true && echo works", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_passthrough_redirect() { + let result = execute("echo test > /dev/null", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_integration_cd_tilde() { + let original = std::env::current_dir().unwrap(); + let result = execute("cd ~", 0).unwrap(); + assert!(result); + let _ = std::env::set_current_dir(&original); + } + + #[test] + fn test_integration_export() { + let result = execute("export TEST_VAR=value", 0).unwrap(); + assert!(result); + std::env::remove_var("TEST_VAR"); + } + + #[test] + fn test_integration_env_prefix() { + let result = execute("TEST=1 echo hello", 0); + assert!(result.is_ok()); + } + + #[test] + fn test_integration_dash_args() { + let result = execute("echo --help -v --version", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_integration_quoted_empty() { + let result = execute(r#"echo """#, 0).unwrap(); + assert!(result); + } + + // === RECURRENCE PREVENTION TESTS === + + #[test] + fn test_execute_rtk_recursion() { + // This should flatten, not infinitely recurse + let result = execute("rtk run \"echo hello\"", 0); + assert!(result.is_ok()); + } +} diff --git a/src/cmd/filters.rs b/src/cmd/filters.rs new file mode 100644 index 0000000..0bd9dea --- /dev/null +++ b/src/cmd/filters.rs @@ -0,0 +1,212 @@ +//! Filter Registry — basic token reduction for `rtk run` native execution. +//! +//! This module provides **basic filtering (20-40% savings)** for commands +//! executed through rtk run. It is a **fallback** for commands +//! without dedicated RTK implementations. +//! +//! For **specialized filtering (60-90% savings)**, use dedicated modules: +//! - `src/git.rs` — git commands (diff, log, status, etc.) +//! - `src/runner.rs` — test commands (cargo test, pytest, etc.) +//! - `src/grep_cmd.rs` — code search (grep, ripgrep) +//! - `src/pnpm_cmd.rs` — package managers + +use crate::utils; + +/// Filter types for different command categories +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum FilterType { + Git, + Cargo, + Test, + Pnpm, + Npm, + Generic, + None, +} + +/// Determine which filter to apply based on binary name +pub fn get_filter_type(binary: &str) -> FilterType { + match binary { + "git" => FilterType::Git, + "cargo" => FilterType::Cargo, + "npm" | "npx" => FilterType::Npm, + "pnpm" => FilterType::Pnpm, + "pytest" | "go" | "vitest" | "jest" | "mocha" => FilterType::Test, + "ls" | "find" | "grep" | "rg" | "fd" => FilterType::Generic, + _ => FilterType::None, + } +} + +/// Apply filter to already-captured string output +pub fn apply_to_string(filter: FilterType, output: &str) -> String { + match filter { + FilterType::Git => utils::strip_ansi(output), + FilterType::Cargo => filter_cargo_output(output), + FilterType::Test => filter_test_output(output), + FilterType::Generic => truncate_lines(output, 100), + FilterType::Npm | FilterType::Pnpm => utils::strip_ansi(output), + FilterType::None => output.to_string(), + } +} + +/// Filter cargo output: remove verbose "Compiling" lines +fn filter_cargo_output(output: &str) -> String { + output + .lines() + .filter(|line| { + let line = line.trim(); + !line.starts_with("Compiling ") || line.contains("error") || line.contains("warning") + }) + .collect::>() + .join("\n") +} + +/// Filter test output: remove passing tests, keep failures +fn filter_test_output(output: &str) -> String { + output + .lines() + .filter(|line| { + let line = line.trim(); + line.contains("FAILED") + || line.contains("error") + || line.contains("Error") + || line.contains("failed") + || line.contains("test result:") + || line.starts_with("----") + }) + .collect::>() + .join("\n") +} + +/// Truncate output to max lines +fn truncate_lines(output: &str, max_lines: usize) -> String { + let lines: Vec<&str> = output.lines().collect(); + if lines.len() <= max_lines { + output.to_string() + } else { + let truncated: Vec<&str> = lines.iter().take(max_lines).copied().collect(); + format!( + "{}\n... ({} more lines)", + truncated.join("\n"), + lines.len() - max_lines + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // === GET_FILTER_TYPE TESTS === + + #[test] + fn test_filter_type_git() { + assert_eq!(get_filter_type("git"), FilterType::Git); + } + + #[test] + fn test_filter_type_cargo() { + assert_eq!(get_filter_type("cargo"), FilterType::Cargo); + } + + #[test] + fn test_filter_type_npm() { + assert_eq!(get_filter_type("npm"), FilterType::Npm); + assert_eq!(get_filter_type("npx"), FilterType::Npm); + } + + #[test] + fn test_filter_type_generic() { + assert_eq!(get_filter_type("ls"), FilterType::Generic); + assert_eq!(get_filter_type("grep"), FilterType::Generic); + } + + #[test] + fn test_filter_type_none() { + assert_eq!(get_filter_type("unknown_command"), FilterType::None); + } + + // === STRIP_ANSI TESTS (now testing utils::strip_ansi) === + + #[test] + fn test_strip_ansi_no_codes() { + assert_eq!(utils::strip_ansi("hello world"), "hello world"); + } + + #[test] + fn test_strip_ansi_color() { + assert_eq!(utils::strip_ansi("\x1b[32mgreen\x1b[0m"), "green"); + } + + #[test] + fn test_strip_ansi_bold() { + assert_eq!(utils::strip_ansi("\x1b[1mbold\x1b[0m"), "bold"); + } + + #[test] + fn test_strip_ansi_multiple() { + assert_eq!( + utils::strip_ansi("\x1b[31mred\x1b[0m \x1b[32mgreen\x1b[0m"), + "red green" + ); + } + + #[test] + fn test_strip_ansi_complex() { + assert_eq!( + utils::strip_ansi("\x1b[1;31;42mbold red on green\x1b[0m"), + "bold red on green" + ); + } + + // === FILTER_CARGO_OUTPUT TESTS === + + #[test] + fn test_filter_cargo_keeps_errors() { + let input = "Compiling dep1\nerror: something wrong\nCompiling dep2"; + let output = filter_cargo_output(input); + assert!(output.contains("error")); + assert!(!output.contains("Compiling dep1")); + } + + #[test] + fn test_filter_cargo_keeps_warnings() { + let input = "Compiling dep1\nwarning: unused variable\nCompiling dep2"; + let output = filter_cargo_output(input); + assert!(output.contains("warning")); + } + + // === TRUNCATE_LINES TESTS === + + #[test] + fn test_truncate_short() { + let input = "line1\nline2\nline3"; + let output = truncate_lines(input, 10); + assert_eq!(output, input); + } + + #[test] + fn test_truncate_long() { + let input = "line1\nline2\nline3\nline4\nline5"; + let output = truncate_lines(input, 3); + assert!(output.contains("line3")); + assert!(!output.contains("line4")); + assert!(output.contains("2 more lines")); + } + + // === APPLY_TO_STRING TESTS === + + #[test] + fn test_apply_to_string_none() { + let input = "hello world"; + let output = apply_to_string(FilterType::None, input); + assert_eq!(output, input); + } + + #[test] + fn test_apply_to_string_git() { + let input = "\x1b[32mgreen\x1b[0m"; + let output = apply_to_string(FilterType::Git, input); + assert_eq!(output, "green"); + } +} diff --git a/src/cmd/gemini_hook.rs b/src/cmd/gemini_hook.rs new file mode 100644 index 0000000..98eadc6 --- /dev/null +++ b/src/cmd/gemini_hook.rs @@ -0,0 +1,490 @@ +//! Gemini CLI BeforeTool hook protocol handler. +//! +//! Reads JSON from stdin, applies safety checks and rewrites, +//! outputs JSON to stdout. +//! +//! Protocol: https://geminicli.com/docs/hooks/reference/ +//! +//! ## Exit Code Behavior +//! +//! - Exit 0 = normal (JSON `decision` field is respected) +//! - Exit 2 = blocking error (equivalent to `decision: "deny"`) +//! +//! ## Gemini CLI Stderr Rule +//! +//! **Source:** See `/Users/athundt/.claude/clautorun/.worktrees/claude-stable-pre-v0.8.0/notes/hooks_api_reference.md:740-753` +//! +//! Unlike Claude Code, Gemini CLI **allows stderr for debugging**: +//! ```text +//! stderr is SAFE for debug/logging (shown to user/agent) +//! ``` +//! +//! **This module's stderr usage:** +//! - Currently: **NO stderr output** (JSON `reason` field sufficient for all cases) +//! - Future: Could add debug logging to stderr if needed (safe in Gemini) +//! +//! ## I/O Enforcement (Module-Specific) +//! +//! **This restriction applies ONLY to gemini_hook.rs and claude_hook.rs.** +//! All other RTK modules (main.rs, git.rs, etc.) use `println!`/`eprintln!` normally. +//! +//! **Why restricted here:** +//! - Hook protocol requires JSON-only stdout +//! - Accidental prints corrupt the JSON response +//! - Consistency with claude_hook.rs architecture +//! +//! **Enforcement mechanism:** +//! - `#![deny(clippy::print_stdout, clippy::print_stderr)]` at module level (line 42) +//! - `run_inner()` returns `HookResponse` enum — pure logic, no I/O +//! - `run()` is the ONLY function that writes output — single I/O point +//! - Uses `write!`/`writeln!` which are NOT caught by the clippy lint +//! +//! **Pathway:** main.rs → Commands::Hook → gemini_hook::run() [DENY ENFORCED HERE] +//! +//! Fail-open: Any parse error or unexpected input → exit 0, no output. + +// Compile-time I/O enforcement for THIS MODULE ONLY. +// Other RTK modules (main.rs, git.rs, etc.) use println!/eprintln! normally. +// +// Why restrict here: +// - Gemini CLI hook protocol requires JSON-only stdout +// - Accidental prints would corrupt the JSON response +// - Architectural consistency with claude_hook.rs +// +// Note: Unlike Claude Code, Gemini ALLOWS stderr for debug logging +// (see hooks_api_reference.md:740-753), but we don't need it. +// The JSON `reason` field is sufficient for all messaging. +// +// Mechanism: +// - Denies println!/eprintln! at compile-time +// - Allows write!/writeln! (used only in run() for controlled output) +// - run_inner() returns HookResponse (no I/O) +// - run() is the single I/O point +#![deny(clippy::print_stdout, clippy::print_stderr)] + +use super::hook::{ + check_for_hook, is_hook_disabled, should_passthrough, update_command_in_tool_input, + HookResponse, HookResult, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::io::{self, Read, Write}; + +#[derive(Deserialize)] +struct GeminiPayload { + hook_event_name: Option, + tool_name: Option, + tool_input: Option, +} + +#[derive(Serialize)] +struct GeminiResponse { + decision: String, // "allow" or "deny" + #[serde(skip_serializing_if = "Option::is_none")] + reason: Option, + #[serde(rename = "hookSpecificOutput")] + #[serde(skip_serializing_if = "Option::is_none")] + hook_specific_output: Option, +} + +#[derive(Serialize)] +struct HookSpecificOutput { + tool_input: Value, +} + +/// Tool names that represent shell command execution in Gemini CLI +fn is_shell_tool(name: &str) -> bool { + // Gemini CLI built-in shell tool, plus common MCP patterns + name == "run_shell_command" || name == "shell" || name.ends_with("__run_shell_command") +} + +/// Run the Gemini hook handler. +/// +/// This is the ONLY function that performs I/O (stdout). +/// `run_inner()` returns a `HookResponse` enum — pure logic, no I/O. +/// Combined with `#![deny(clippy::print_stdout, clippy::print_stderr)]`, +/// this ensures no stray output corrupts the JSON hook protocol. +/// +/// Fail-open design: malformed input → exit 0, no output. +pub fn run() -> anyhow::Result<()> { + let response = match run_inner() { + Ok(r) => r, + Err(_) => HookResponse::NoOpinion, // Fail-open: swallow errors + }; + + // ┌────────────────────────────────────────────────────────────────┐ + // │ SINGLE I/O POINT - All stdout output happens here only │ + // │ │ + // │ Why: Gemini CLI hook protocol requires JSON-only stdout │ + // │ (Gemini ALLOWS stderr for debug, but we don't need it) │ + // │ │ + // │ Enforcement: #![deny(...)] at line 42 prevents println!/eprintln! │ + // │ write!/writeln! are not caught by lint (allowed) │ + // └────────────────────────────────────────────────────────────────┘ + match response { + HookResponse::NoOpinion => { + // Exit 0, NO stdout, NO stderr + // Gemini CLI sees no output → proceeds with original command + } + HookResponse::Allow(json) | HookResponse::Deny(json, _) => { + // Exit 0, JSON to stdout, NO stderr + // Note: Gemini ALLOWS stderr for debug (unlike Claude), but JSON + // `reason` field is sufficient. The HookResponse::Deny + // second field (stderr_reason) is empty for Gemini. + writeln!(io::stdout(), "{json}")?; + } + } + Ok(()) +} + +/// Inner handler: pure decision logic, no I/O. +/// Returns `HookResponse` for `run()` to output. +fn run_inner() -> anyhow::Result { + let mut buffer = String::new(); + io::stdin().read_to_string(&mut buffer)?; + + let payload: GeminiPayload = match serde_json::from_str(&buffer) { + Ok(p) => p, + Err(_) => return Ok(HookResponse::NoOpinion), + }; + + // Only handle BeforeTool events — other events get a plain allow + if payload.hook_event_name.as_deref() != Some("BeforeTool") { + return Ok(HookResponse::Allow(r#"{"decision": "allow"}"#.into())); + } + + // Only intercept shell execution tools + match &payload.tool_name { + Some(name) if is_shell_tool(name) => {} + _ => return Ok(HookResponse::Allow(r#"{"decision": "allow"}"#.into())), + }; + + // Extract the command string from tool_input + let cmd = match &payload.tool_input { + Some(input) => input + .get("command") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + None => return Ok(HookResponse::Allow(r#"{"decision": "allow"}"#.into())), + }; + + if cmd.is_empty() { + return Ok(HookResponse::Allow(r#"{"decision": "allow"}"#.into())); + } + + // Shared guard checks (same as claude_hook.rs, DRY via hook.rs) + if is_hook_disabled() || should_passthrough(&cmd) { + return Ok(HookResponse::NoOpinion); + } + + let decision = check_for_hook(&cmd, "gemini"); + + let response = match decision { + HookResult::Rewrite(new_cmd) => { + // Preserve all original tool_input fields, only replace "command" + // Shared helper (DRY with claude_hook.rs via hook.rs) + let new_input = update_command_in_tool_input(payload.tool_input, new_cmd); + + GeminiResponse { + decision: "allow".into(), + reason: Some("RTK applied safety optimizations.".into()), + hook_specific_output: Some(HookSpecificOutput { + tool_input: new_input, + }), + } + } + HookResult::Blocked(msg) => GeminiResponse { + decision: "deny".into(), + reason: Some(msg), + hook_specific_output: None, + }, + }; + + let json = serde_json::to_string(&response)?; + // Gemini deny uses JSON response only (no stderr/exit-code workaround needed) + if response.decision == "deny" { + Ok(HookResponse::Deny(json, String::new())) + } else { + Ok(HookResponse::Allow(json)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ========================================================================= + // GEMINI WIRE FORMAT CONFORMANCE + // https://geminicli.com/docs/hooks/reference/ + // + // These tests verify exact JSON field names per the Gemini CLI spec. + // A wrong field name means Gemini silently ignores the response. + // ========================================================================= + + // --- Input: field name conformance --- + + #[test] + fn test_input_uses_hook_event_name_not_type() { + // Gemini sends "hook_event_name", NOT "type" + let json = r#"{"hook_event_name": "BeforeTool", "tool_name": "run_shell_command", "tool_input": {"command": "git status"}}"#; + let payload: GeminiPayload = serde_json::from_str(json).unwrap(); + assert_eq!(payload.hook_event_name.as_deref(), Some("BeforeTool")); + + // Verify the old wrong field name does NOT populate our struct + let wrong_json = r#"{"type": "BeforeTool", "tool_name": "run_shell_command"}"#; + let payload: GeminiPayload = serde_json::from_str(wrong_json).unwrap(); + assert_eq!( + payload.hook_event_name, None, + "\"type\" must not be accepted as event name" + ); + } + + #[test] + fn test_input_includes_tool_name() { + let json = r#"{"hook_event_name": "BeforeTool", "tool_name": "run_shell_command", "tool_input": {"command": "ls"}}"#; + let payload: GeminiPayload = serde_json::from_str(json).unwrap(); + assert_eq!(payload.tool_name.as_deref(), Some("run_shell_command")); + } + + #[test] + fn test_input_tool_input_is_object() { + let json = r#"{"hook_event_name": "BeforeTool", "tool_name": "run_shell_command", "tool_input": {"command": "git status", "timeout": 30}}"#; + let payload: GeminiPayload = serde_json::from_str(json).unwrap(); + let input = payload.tool_input.unwrap(); + assert_eq!(input["command"].as_str().unwrap(), "git status"); + assert_eq!(input["timeout"].as_i64().unwrap(), 30); + } + + #[test] + fn test_input_extra_fields_ignored() { + // Gemini sends session_id, cwd, timestamp, transcript_path etc. + let json = r#"{"hook_event_name": "BeforeTool", "tool_name": "run_shell_command", "tool_input": {"command": "ls"}, "session_id": "abc123", "cwd": "/tmp", "timestamp": "2026-01-01T00:00:00Z", "transcript_path": "/path/to/transcript"}"#; + let payload: GeminiPayload = serde_json::from_str(json).unwrap(); + assert_eq!(payload.hook_event_name.as_deref(), Some("BeforeTool")); + } + + // --- Output: field name conformance --- + + #[test] + fn test_output_uses_decision_not_result() { + // Gemini expects "decision", NOT "result" + let response = GeminiResponse { + decision: "allow".into(), + reason: None, + hook_specific_output: None, + }; + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert!( + parsed.get("decision").is_some(), + "must have 'decision' field" + ); + assert!( + parsed.get("result").is_none(), + "must NOT have 'result' field" + ); + } + + #[test] + fn test_output_uses_reason_not_message() { + // Gemini expects "reason", NOT "message" + let response = GeminiResponse { + decision: "deny".into(), + reason: Some("Blocked for safety".into()), + hook_specific_output: None, + }; + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert!(parsed.get("reason").is_some(), "must have 'reason' field"); + assert!( + parsed.get("message").is_none(), + "must NOT have 'message' field" + ); + } + + #[test] + fn test_output_uses_hook_specific_output_not_modified_input() { + // Gemini expects "hookSpecificOutput", NOT "modified_input" + let response = GeminiResponse { + decision: "allow".into(), + reason: None, + hook_specific_output: Some(HookSpecificOutput { + tool_input: serde_json::json!({"command": "rtk run -c 'ls'"}), + }), + }; + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert!( + parsed.get("hookSpecificOutput").is_some(), + "must have 'hookSpecificOutput' field" + ); + assert!( + parsed.get("modified_input").is_none(), + "must NOT have 'modified_input' field" + ); + } + + #[test] + fn test_output_rewrite_nests_under_tool_input() { + // Gemini merges hookSpecificOutput.tool_input into the original + let response = GeminiResponse { + decision: "allow".into(), + reason: Some("RTK applied safety optimizations.".into()), + hook_specific_output: Some(HookSpecificOutput { + tool_input: serde_json::json!({"command": "rtk run -c 'git status'"}), + }), + }; + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert_eq!( + parsed["hookSpecificOutput"]["tool_input"]["command"], + "rtk run -c 'git status'" + ); + } + + #[test] + fn test_output_allow_omits_optional_fields() { + let response = GeminiResponse { + decision: "allow".into(), + reason: None, + hook_specific_output: None, + }; + let json = serde_json::to_string(&response).unwrap(); + assert!(!json.contains("reason"), "reason must be omitted when None"); + assert!( + !json.contains("hookSpecificOutput"), + "hookSpecificOutput must be omitted when None" + ); + } + + #[test] + fn test_output_decision_values() { + // Only "allow" and "deny" are valid + for val in ["allow", "deny"] { + let response = GeminiResponse { + decision: val.into(), + reason: Some("test".into()), + hook_specific_output: None, + }; + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["decision"].as_str().unwrap(), val); + } + } + + // --- Tool filtering --- + + #[test] + fn test_is_shell_tool() { + assert!(is_shell_tool("run_shell_command")); + assert!(is_shell_tool("shell")); + assert!(is_shell_tool("mcp__server__run_shell_command")); + assert!(!is_shell_tool("read_file")); + assert!(!is_shell_tool("write_file")); + assert!(!is_shell_tool("search_code")); + assert!(!is_shell_tool("list_directory")); + } + + #[test] + fn test_non_shell_tools_always_allowed() { + // read_file, write_file, etc. must never be intercepted + for tool in ["read_file", "write_file", "search_code", "list_directory"] { + let json = format!( + r#"{{"hook_event_name": "BeforeTool", "tool_name": "{}", "tool_input": {{"path": "/etc/passwd"}}}}"#, + tool + ); + let payload: GeminiPayload = serde_json::from_str(&json).unwrap(); + assert!( + !is_shell_tool(payload.tool_name.as_deref().unwrap()), + "tool '{}' must not be treated as shell tool", + tool + ); + } + } + + // --- Event filtering --- + + #[test] + fn test_non_before_tool_events_ignored() { + for event in ["AfterTool", "BeforeAgent", "AfterAgent", "SessionStart"] { + let json = format!( + r#"{{"hook_event_name": "{}", "tool_name": "run_shell_command", "tool_input": {{"command": "rm -rf /"}}}}"#, + event + ); + let payload: GeminiPayload = serde_json::from_str(&json).unwrap(); + assert_ne!(payload.hook_event_name.as_deref(), Some("BeforeTool")); + } + } + + // --- Rewrite preserves other tool_input fields --- + + #[test] + fn test_rewrite_preserves_other_tool_input_fields() { + let original_input = serde_json::json!({ + "command": "git status", + "timeout": 30, + "cwd": "/project" + }); + + let mut new_input = original_input.clone(); + if let Some(obj) = new_input.as_object_mut() { + obj.insert( + "command".into(), + Value::String("rtk run -c 'git status'".into()), + ); + } + + assert_eq!(new_input["timeout"], 30); + assert_eq!(new_input["cwd"], "/project"); + assert_eq!(new_input["command"], "rtk run -c 'git status'"); + } + + // --- Edge cases --- + + #[test] + fn test_missing_tool_input() { + let json = r#"{"hook_event_name": "BeforeTool", "tool_name": "run_shell_command"}"#; + let payload: GeminiPayload = serde_json::from_str(json).unwrap(); + assert!(payload.tool_input.is_none()); + } + + #[test] + fn test_missing_command_in_tool_input() { + let json = r#"{"hook_event_name": "BeforeTool", "tool_name": "run_shell_command", "tool_input": {"cwd": "/tmp"}}"#; + let payload: GeminiPayload = serde_json::from_str(json).unwrap(); + let input = payload.tool_input.unwrap(); + assert!(input.get("command").is_none()); + } + + #[test] + fn test_malformed_json_does_not_panic() { + let bad_inputs = ["", "not json", "{}", r#"{"hook_event_name": 42}"#, "null"]; + for input in bad_inputs { + // Should not panic, just return Err or deserialize to defaults + let _ = serde_json::from_str::(input); + } + } + + // --- Guard parity with Claude hook --- + + #[test] + fn test_shared_guards_available() { + // Verify shared guard functions are accessible (DRY with claude_hook.rs) + assert!(!should_passthrough("git status")); + assert!(should_passthrough("rtk git status")); + assert!(should_passthrough("cat < HookResult { + check_for_hook_inner(raw, 0) +} + +fn check_for_hook_inner(raw: &str, depth: usize) -> HookResult { + if depth >= MAX_REWRITE_DEPTH { + return HookResult::Blocked( + "Safety rewrite loop detected (max depth exceeded)".to_string(), + ); + } + + // Handle empty + if raw.trim().is_empty() { + return HookResult::Rewrite(raw.to_string()); + } + + // Remap expansion (aliases like "t" → "cargo test") + if let Some(expanded) = crate::config::rules::try_remap(raw) { + return check_for_hook_inner(&expanded, depth + 1); + } + + let tokens = lexer::tokenize(raw); + + // Check for shellisms - if present, pass through + // but still check safety + if analysis::needs_shell(&tokens) { + match safety::check_raw(raw) { + safety::SafetyResult::Blocked(msg) => return HookResult::Blocked(msg), + safety::SafetyResult::Safe => {} + // check_raw currently only returns Safe/Blocked; defensive no-op + safety::SafetyResult::Rewritten(_) | safety::SafetyResult::TrashRequested(_) => {} + } + // Passthrough: just return as-is wrapped in rtk run + return HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))); + } + + // Native mode: parse and check each command + match analysis::parse_chain(tokens) { + Ok(commands) => { + // Check safety on each command + for cmd in &commands { + match safety::check(&cmd.binary, &cmd.args) { + safety::SafetyResult::Blocked(msg) => { + return HookResult::Blocked(msg); + } + safety::SafetyResult::Rewritten(new_cmd) => { + return check_for_hook_inner(&new_cmd, depth + 1); + } + safety::SafetyResult::TrashRequested(_) => { + // Redirect to rtk run which handles trash + return HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))); + } + safety::SafetyResult::Safe => {} + } + } + + // All safe - wrap in rtk run for token optimization + HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))) + } + Err(_) => { + // Parse error - passthrough with wrapping + HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))) + } + } +} + +// --- Shared guard logic (used by both claude_hook.rs and gemini_hook.rs) --- + +/// Check if hook processing is disabled by environment. +/// +/// Returns true if: +/// - `RTK_HOOK_ENABLED=0` (master toggle off) +/// - `RTK_ACTIVE` is set (recursion prevention — rtk sets this when running commands) +pub fn is_hook_disabled() -> bool { + std::env::var("RTK_HOOK_ENABLED").as_deref() == Ok("0") || std::env::var("RTK_ACTIVE").is_ok() +} + +/// Check if this command should bypass hook processing entirely. +/// +/// Returns true for commands that should not be rewritten: +/// - Already routed through rtk (`rtk ...` or `/path/to/rtk ...`) +/// - Contains heredoc (`<<`) which needs raw shell processing +pub fn should_passthrough(cmd: &str) -> bool { + cmd.starts_with("rtk ") || cmd.contains("/rtk ") || cmd.contains("<<") +} + +/// Replace the command field in a tool_input object, preserving other fields. +/// +/// Used by both claude_hook.rs and gemini_hook.rs when rewriting commands. +/// If tool_input is None or not an object, creates a new object with just the command. +/// +/// # Arguments +/// * `tool_input` - The original tool_input from the hook payload (may be None) +/// * `new_cmd` - The rewritten command string to replace with +/// +/// # Returns +/// A Value with the command field updated, all other fields preserved. +pub fn update_command_in_tool_input( + tool_input: Option, + new_cmd: String, +) -> serde_json::Value { + use serde_json::Value; + let mut updated = tool_input.unwrap_or_else(|| Value::Object(Default::default())); + if let Some(obj) = updated.as_object_mut() { + obj.insert("command".into(), Value::String(new_cmd)); + } + updated +} + +/// Hook output for protocol handlers (claude_hook.rs, gemini_hook.rs). +/// +/// This enum separates decision logic from I/O: `run_inner()` returns a +/// `HookResponse`, and `run()` is the single place that writes to stdout/stderr. +/// Combined with `#[deny(clippy::print_stdout, clippy::print_stderr)]` on the +/// hook modules, this prevents any stray output from corrupting the JSON protocol. +#[derive(Debug, Clone, PartialEq)] +pub enum HookResponse { + /// No opinion — exit 0, no output. Host proceeds normally. + NoOpinion, + /// Allow/rewrite — exit 0, JSON to stdout. + Allow(String), + /// Deny — exit 2, JSON to stdout + reason to stderr. + /// Fields: (stdout_json, stderr_reason) + Deny(String, String), +} + +/// Escape single quotes for shell +fn escape_quotes(s: &str) -> String { + s.replace("'", "'\\''") +} + +/// Format hook result for Claude (text output) +/// +/// Exit codes: +/// - 0: Success, command rewritten/allowed +/// - 2: Blocking error, command should be denied +pub fn format_for_claude(result: HookResult) -> (String, bool, i32) { + match result { + HookResult::Rewrite(cmd) => (cmd, true, 0), + HookResult::Blocked(msg) => (msg, false, 2), // Exit 2 = blocking error per Claude Code spec + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // === TEST HELPERS === + + fn assert_rewrite(input: &str, contains: &str) { + match check_for_hook(input, "claude") { + HookResult::Rewrite(cmd) => assert!( + cmd.contains(contains), + "'{}' rewrite should contain '{}', got '{}'", + input, + contains, + cmd + ), + other => panic!("Expected Rewrite for '{}', got {:?}", input, other), + } + } + + fn assert_blocked(input: &str, contains: &str) { + match check_for_hook(input, "claude") { + HookResult::Blocked(msg) => assert!( + msg.contains(contains), + "'{}' block msg should contain '{}', got '{}'", + input, + contains, + msg + ), + other => panic!("Expected Blocked for '{}', got {:?}", input, other), + } + } + + // === ESCAPE_QUOTES === + + #[test] + fn test_escape_quotes() { + assert_eq!(escape_quotes("hello"), "hello"); + assert_eq!(escape_quotes("it's"), "it'\\''s"); + assert_eq!(escape_quotes("it's a test's"), "it'\\''s a test'\\''s"); + } + + // === EMPTY / WHITESPACE === + + #[test] + fn test_check_empty_and_whitespace() { + match check_for_hook("", "claude") { + HookResult::Rewrite(cmd) => assert!(cmd.is_empty()), + _ => panic!("Expected Rewrite for empty"), + } + match check_for_hook(" ", "claude") { + HookResult::Rewrite(cmd) => assert!(cmd.trim().is_empty()), + _ => panic!("Expected Rewrite for whitespace"), + } + } + + // === COMMANDS THAT SHOULD REWRITE (table-driven) === + + #[test] + fn test_safe_commands_rewrite() { + let cases = [ + ("git status", "rtk run"), + ("ls *.rs", "rtk run"), // shellism passthrough + (r#"git commit -m "Fix && Bug""#, "rtk run"), // quoted operator + ("FOO=bar echo hello", "rtk run"), // env prefix + ("echo `date`", "rtk run"), // backticks + ("echo $(date)", "rtk run"), // subshell + ("echo {a,b}.txt", "rtk run"), // brace expansion + ("echo 'hello!@#$%^&*()'", "rtk run"), // special chars + ("echo '日本語 🎉'", "rtk run"), // unicode + ("cd /tmp && git status", "rtk run"), // chain rewrite + ]; + for (input, expected) in cases { + assert_rewrite(input, expected); + } + // Chain rewrite preserves operator structure + match check_for_hook("cd /tmp && git status", "claude") { + HookResult::Rewrite(cmd) => assert!( + cmd.contains("&&"), + "Chain rewrite must preserve '&&', got '{}'", + cmd + ), + other => panic!("Expected Rewrite for chain, got {:?}", other), + } + // Very long command + assert_rewrite(&format!("echo {}", "a".repeat(1000)), "rtk run"); + } + + // === ENV VAR PREFIX PRESERVATION === + // Ported from old hooks/test-rtk-rewrite.sh Section 2. + // Commands prefixed with KEY=VALUE env vars must not be blocked. + + #[test] + fn test_env_var_prefix_preserved() { + let cases = [ + "GIT_PAGER=cat git status", + "GIT_PAGER=cat git log --oneline -10", + "NODE_ENV=test CI=1 npx vitest run", + "LANG=C ls -la", + "NODE_ENV=test npm run test:e2e", + "COMPOSE_PROJECT_NAME=test docker compose up -d", + "TEST_SESSION_ID=2 npx playwright test --config=foo", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === GLOBAL OPTIONS (PR #99 parity) === + // Commands with global options before subcommands must not be blocked. + // Ported from upstream hooks/rtk-rewrite.sh global option stripping. + + #[test] + fn test_global_options_not_blocked() { + let cases = [ + // Git global options + "git --no-pager status", + "git -C /path/to/project status", + "git -C /path --no-pager log --oneline", + "git --no-optional-locks diff HEAD", + "git --bare log", + // Cargo toolchain prefix + "cargo +nightly test", + "cargo +stable build --release", + // Docker global options + "docker --context prod ps", + "docker -H tcp://host:2375 images", + // Kubectl global options + "kubectl -n kube-system get pods", + "kubectl --context prod describe pod foo", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === SPECIFIC COMMANDS NOT BLOCKED === + // Ported from old hooks/test-rtk-rewrite.sh Sections 1 & 3. + // These commands must pass through (not be blocked by safety rules). + + #[test] + fn test_specific_commands_not_blocked() { + let cases = [ + // Git variants + "git log --oneline -10", + "git diff HEAD", + "git show abc123", + "git add .", + // GitHub CLI + "gh pr list", + "gh api repos/owner/repo", + "gh release list", + // Package managers + "npm run test:e2e", + "npm run build", + "npm test", + // Docker + "docker compose up -d", + "docker compose logs postgrest", + "docker compose down", + "docker run --rm postgres", + "docker exec -it db psql", + // Kubernetes + "kubectl describe pod foo", + "kubectl apply -f deploy.yaml", + // Test runners + "npx playwright test", + "npx prisma migrate", + "cargo test", + // Vitest variants (dedup is internal to rtk run, not hook level) + "vitest", + "vitest run", + "vitest run --reporter=verbose", + "npx vitest run", + "pnpm vitest run --coverage", + // TypeScript + "vue-tsc -b", + "npx vue-tsc --noEmit", + // Utilities + "curl -s https://example.com", + "ls -la", + "grep -rn pattern src/", + "rg pattern src/", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === COMMANDS THAT PASS THROUGH (builtins/unknown) === + // Ported from old hooks/test-rtk-rewrite.sh Section 5. + // These are not blocked — they get wrapped in rtk run -c. + + #[test] + fn test_builtins_not_blocked() { + let cases = [ + "echo hello world", + "cd /tmp", + "mkdir -p foo/bar", + "python3 script.py", + "node -e 'console.log(1)'", + "find . -name '*.ts'", + "tree src/", + "wget https://example.com/file", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === COMPOUND COMMANDS (chained with &&, ||, ;) === + // Shell script only matched FIRST command in a chain. + // Rust hook parses each command independently (#112). + + #[test] + fn test_compound_commands_rewrite() { + let cases = [ + // Basic chains — each command rewritten independently + ("cd /tmp && git status", "&&"), + ("cd dir && git status && git diff", "&&"), + ("git add . && git commit -m msg", "&&"), + // Semicolon chains + ("echo start ; git status ; echo done", ";"), + // Or-chains + ("git pull || echo failed", "||"), + ]; + for (input, operator) in cases { + match check_for_hook(input, "claude") { + HookResult::Rewrite(cmd) => { + assert!(cmd.contains("rtk run"), "'{input}' should rewrite"); + assert!( + cmd.contains(operator), + "'{input}' must preserve '{operator}', got '{cmd}'" + ); + } + other => panic!("Expected Rewrite for '{input}', got {other:?}"), + } + } + } + + #[test] + fn test_compound_blocked_in_chain() { + // Safety rules catch dangerous commands even mid-chain + let cases = [ + ("cd /tmp && cat file.txt", "file-reading"), + ("echo start && sed -i 's/x/y/' f", "file-editing"), + ("git add . && head -5 f.txt", "file-reading"), + ]; + for (input, expected_msg) in cases { + assert_blocked(input, expected_msg); + } + } + + #[test] + fn test_compound_quoted_operators_not_split() { + // && inside quotes must NOT split the command + let input = r#"git commit -m "Fix && Bug""#; + match check_for_hook(input, "claude") { + HookResult::Rewrite(cmd) => { + assert!(cmd.contains("rtk run"), "Should rewrite, got '{cmd}'"); + } + other => panic!("Expected Rewrite for quoted &&, got {other:?}"), + } + } + + // === COMMANDS THAT SHOULD BLOCK (table-driven) === + + #[test] + fn test_blocked_commands() { + let cases = [ + ("cat file.txt", "file-reading"), + ("sed -i 's/old/new/' file.txt", "file-editing"), + ("head -n 10 file.txt", "file-reading"), + ("cd /tmp && cat file.txt", "file-reading"), // cat in chain + ]; + for (input, expected_msg) in cases { + assert_blocked(input, expected_msg); + } + } + + // === SHELLISM PASSTHROUGH: cat/sed/head allowed with pipe/redirect === + + #[test] + fn test_token_waste_allowed_in_pipelines() { + let cases = [ + "cat file.txt | grep pattern", + "cat file.txt > output.txt", + "sed 's/old/new/' file.txt > output.txt", + "head -n 10 file.txt | grep pattern", + "for f in *.txt; do cat \"$f\" | grep x; done", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === MULTI-AGENT === + + #[test] + fn test_different_agents_same_result() { + for agent in ["claude", "gemini"] { + match check_for_hook("git status", agent) { + HookResult::Rewrite(cmd) => assert!(cmd.contains("rtk run")), + _ => panic!("Expected Rewrite for agent '{}'", agent), + } + } + } + + // === FORMAT_FOR_CLAUDE === + + #[test] + fn test_format_for_claude() { + let (output, success, code) = + format_for_claude(HookResult::Rewrite("rtk run -c 'git status'".to_string())); + assert_eq!(output, "rtk run -c 'git status'"); + assert!(success); + assert_eq!(code, 0); + + let (output, success, code) = + format_for_claude(HookResult::Blocked("Error message".to_string())); + assert_eq!(output, "Error message"); + assert!(!success); + assert_eq!(code, 2); // Exit 2 = blocking error per Claude Code spec + } + + // === RECURSION DEPTH LIMIT === + + #[test] + fn test_rewrite_depth_limit() { + // At max depth → blocked + match check_for_hook_inner("echo hello", MAX_REWRITE_DEPTH) { + HookResult::Blocked(msg) => assert!(msg.contains("loop"), "msg: {}", msg), + _ => panic!("Expected Blocked at max depth"), + } + // At depth 0 → normal rewrite + match check_for_hook_inner("echo hello", 0) { + HookResult::Rewrite(cmd) => assert!(cmd.contains("rtk run")), + _ => panic!("Expected Rewrite at depth 0"), + } + } + + // ========================================================================= + // CLAUDE CODE WIRE FORMAT CONFORMANCE + // https://docs.anthropic.com/en/docs/claude-code/hooks + // + // Claude Code hook protocol: + // - Rewrite: command on stdout, exit code 0 + // - Block: message on stderr, exit code 2 + // - Other exit codes are non-blocking errors + // + // format_for_claude() is the boundary between HookResult and the wire. + // These tests verify it produces the exact contract Claude Code expects. + // ========================================================================= + + #[test] + fn test_claude_rewrite_exit_code_is_zero() { + let (_, _, code) = format_for_claude(HookResult::Rewrite("rtk run -c 'ls'".into())); + assert_eq!(code, 0, "Rewrite must exit 0 (success)"); + } + + #[test] + fn test_claude_block_exit_code_is_two() { + let (_, _, code) = format_for_claude(HookResult::Blocked("denied".into())); + assert_eq!( + code, 2, + "Block must exit 2 (blocking error per Claude Code spec)" + ); + } + + #[test] + fn test_claude_rewrite_output_is_command_text() { + // Claude Code reads stdout as the rewritten command — must be plain text, not JSON + let (output, success, _) = + format_for_claude(HookResult::Rewrite("rtk run -c 'git status'".into())); + assert_eq!(output, "rtk run -c 'git status'"); + assert!(success); + // Must NOT be JSON + assert!( + !output.starts_with('{'), + "Rewrite output must be plain text, not JSON" + ); + } + + #[test] + fn test_claude_block_output_is_human_message() { + // Claude Code reads stderr for the block reason + let (output, success, _) = + format_for_claude(HookResult::Blocked("Use Read tool instead".into())); + assert_eq!(output, "Use Read tool instead"); + assert!(!success); + // Must NOT be JSON + assert!( + !output.starts_with('{'), + "Block output must be plain text, not JSON" + ); + } + + #[test] + fn test_claude_rewrite_success_flag_true() { + let (_, success, _) = format_for_claude(HookResult::Rewrite("cmd".into())); + assert!(success, "Rewrite must set success=true"); + } + + #[test] + fn test_claude_block_success_flag_false() { + let (_, success, _) = format_for_claude(HookResult::Blocked("msg".into())); + assert!(!success, "Block must set success=false"); + } + + #[test] + fn test_claude_exit_codes_not_one() { + // Exit code 1 means non-blocking error in Claude Code — we must never use it + let (_, _, rewrite_code) = format_for_claude(HookResult::Rewrite("cmd".into())); + let (_, _, block_code) = format_for_claude(HookResult::Blocked("msg".into())); + assert_ne!( + rewrite_code, 1, + "Exit code 1 is non-blocking error, not valid for rewrite" + ); + assert_ne!( + block_code, 1, + "Exit code 1 is non-blocking error, not valid for block" + ); + } + + // === CROSS-PROTOCOL: Same decision for both agents === + + #[test] + fn test_cross_protocol_safe_command_allowed_by_both() { + // Both Claude and Gemini must allow the same safe commands + for cmd in ["git status", "cargo test", "ls -la", "echo hello"] { + let claude = check_for_hook(cmd, "claude"); + let gemini = check_for_hook(cmd, "gemini"); + match (&claude, &gemini) { + (HookResult::Rewrite(_), HookResult::Rewrite(_)) => {} + _ => panic!( + "'{}': Claude={:?}, Gemini={:?} — both should Rewrite", + cmd, claude, gemini + ), + } + } + } + + #[test] + fn test_cross_protocol_blocked_command_denied_by_both() { + // Both Claude and Gemini must block the same unsafe commands + for cmd in ["cat file.txt", "head -n 10 file.txt"] { + let claude = check_for_hook(cmd, "claude"); + let gemini = check_for_hook(cmd, "gemini"); + match (&claude, &gemini) { + (HookResult::Blocked(_), HookResult::Blocked(_)) => {} + _ => panic!( + "'{}': Claude={:?}, Gemini={:?} — both should Block", + cmd, claude, gemini + ), + } + } + } +} diff --git a/src/cmd/lexer.rs b/src/cmd/lexer.rs new file mode 100644 index 0000000..5f820bc --- /dev/null +++ b/src/cmd/lexer.rs @@ -0,0 +1,474 @@ +//! State-machine lexer that respects quotes and escapes. +//! Critical: `git commit -m "Fix && Bug"` must NOT split on && + +#[derive(Debug, PartialEq, Clone)] +pub enum TokenKind { + Arg, // Regular argument + Operator, // &&, ||, ; + Pipe, // | + Redirect, // >, >>, <, 2> + Shellism, // *, $, `, (, ), {, } - forces passthrough +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ParsedToken { + pub kind: TokenKind, + pub value: String, // The actual string value +} + +/// Tokenize input with quote awareness. +/// Returns Vec of parsed tokens. +pub fn tokenize(input: &str) -> Vec { + let mut tokens = Vec::new(); + let mut current = String::new(); + let mut chars = input.chars().peekable(); + + let mut quote: Option = None; // None, Some('\''), Some('"') + let mut escaped = false; + + while let Some(c) = chars.next() { + // Handle escape sequences (but NOT inside single quotes) + if escaped { + current.push(c); + escaped = false; + continue; + } + if c == '\\' && quote != Some('\'') { + escaped = true; + current.push(c); + continue; + } + + // Handle quotes + if let Some(q) = quote { + if c == q { + quote = None; // Close quote + } + current.push(c); + continue; + } + if c == '\'' || c == '"' { + quote = Some(c); + current.push(c); + continue; + } + + // Outside quotes - handle operators and shellisms + match c { + // Shellisms force passthrough (includes ! for history expansion/negation) + '*' | '?' | '$' | '`' | '(' | ')' | '{' | '}' | '!' => { + flush_arg(&mut tokens, &mut current); + tokens.push(ParsedToken { + kind: TokenKind::Shellism, + value: c.to_string(), + }); + } + // Operators + '&' | '|' | ';' | '>' | '<' => { + flush_arg(&mut tokens, &mut current); + + let mut op = c.to_string(); + // Lookahead for double-char operators + if let Some(&next) = chars.peek() { + if (next == c && c != ';' && c != '<') || (c == '>' && next == '>') { + op.push(chars.next().unwrap()); + } + } + + let kind = match op.as_str() { + "&&" | "||" | ";" => TokenKind::Operator, + "|" => TokenKind::Pipe, + "&" => TokenKind::Shellism, // Background job needs real shell + _ => TokenKind::Redirect, + }; + tokens.push(ParsedToken { kind, value: op }); + } + // Whitespace delimits arguments + c if c.is_whitespace() => { + flush_arg(&mut tokens, &mut current); + } + // Regular character + _ => current.push(c), + } + } + + // Handle unclosed quote (treat remaining as arg, don't panic) + flush_arg(&mut tokens, &mut current); + tokens +} + +fn flush_arg(tokens: &mut Vec, current: &mut String) { + let trimmed = current.trim(); + if !trimmed.is_empty() { + tokens.push(ParsedToken { + kind: TokenKind::Arg, + value: trimmed.to_string(), + }); + } + current.clear(); +} + +/// Strip quotes from a token value +pub fn strip_quotes(s: &str) -> String { + let chars: Vec = s.chars().collect(); + if chars.len() >= 2 + && ((chars[0] == '"' && chars[chars.len() - 1] == '"') + || (chars[0] == '\'' && chars[chars.len() - 1] == '\'')) + { + return chars[1..chars.len() - 1].iter().collect(); + } + s.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + // === BASIC FUNCTIONALITY TESTS === + + #[test] + fn test_simple_command() { + let tokens = tokenize("git status"); + assert_eq!(tokens.len(), 2); + assert_eq!(tokens[0].kind, TokenKind::Arg); + assert_eq!(tokens[0].value, "git"); + assert_eq!(tokens[1].value, "status"); + } + + #[test] + fn test_command_with_args() { + let tokens = tokenize("git commit -m message"); + assert_eq!(tokens.len(), 4); + assert_eq!(tokens[0].value, "git"); + assert_eq!(tokens[1].value, "commit"); + assert_eq!(tokens[2].value, "-m"); + assert_eq!(tokens[3].value, "message"); + } + + // === QUOTE HANDLING TESTS === + + #[test] + fn test_quoted_operator_not_split() { + let tokens = tokenize(r#"git commit -m "Fix && Bug""#); + // && inside quotes should NOT be an Operator token + assert!(!tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == "&&")); + assert!(tokens.iter().any(|t| t.value.contains("Fix && Bug"))); + } + + #[test] + fn test_single_quoted_string() { + let tokens = tokenize("echo 'hello world'"); + assert!(tokens.iter().any(|t| t.value == "'hello world'")); + } + + #[test] + fn test_double_quoted_string() { + let tokens = tokenize("echo \"hello world\""); + assert!(tokens.iter().any(|t| t.value == "\"hello world\"")); + } + + #[test] + fn test_empty_quoted_string() { + let tokens = tokenize("echo \"\""); + // Should have echo and "" + assert!(tokens.iter().any(|t| t.value == "\"\"")); + } + + #[test] + fn test_nested_quotes() { + let tokens = tokenize(r#"echo "outer 'inner' outer""#); + assert!(tokens.iter().any(|t| t.value.contains("'inner'"))); + } + + #[test] + fn test_strip_quotes_double() { + assert_eq!(strip_quotes("\"hello\""), "hello"); + } + + #[test] + fn test_strip_quotes_single() { + assert_eq!(strip_quotes("'hello'"), "hello"); + } + + #[test] + fn test_strip_quotes_none() { + assert_eq!(strip_quotes("hello"), "hello"); + } + + #[test] + fn test_strip_quotes_mismatched() { + assert_eq!(strip_quotes("\"hello'"), "\"hello'"); + } + + // === ESCAPE HANDLING TESTS === + + #[test] + fn test_escaped_space() { + let tokens = tokenize("echo hello\\ world"); + // Escaped space should be part of the arg + assert!(tokens.iter().any(|t| t.value.contains("hello"))); + } + + #[test] + fn test_backslash_in_single_quotes() { + // In single quotes, backslash is literal + let tokens = tokenize(r#"echo 'hello\nworld'"#); + assert!(tokens.iter().any(|t| t.value.contains(r#"\n"#))); + } + + #[test] + fn test_escaped_quote_in_double() { + let tokens = tokenize(r#"echo "hello\"world""#); + assert!(tokens.iter().any(|t| t.value.contains("hello"))); + } + + // === EDGE CASE TESTS === + + #[test] + fn test_empty_input() { + let tokens = tokenize(""); + assert!(tokens.is_empty()); + } + + #[test] + fn test_whitespace_only() { + let tokens = tokenize(" "); + assert!(tokens.is_empty()); + } + + #[test] + fn test_unclosed_single_quote() { + // Should not panic, treat remaining as part of arg + let tokens = tokenize("'unclosed"); + assert!(!tokens.is_empty()); + } + + #[test] + fn test_unclosed_double_quote() { + // Should not panic, treat remaining as part of arg + let tokens = tokenize("\"unclosed"); + assert!(!tokens.is_empty()); + } + + #[test] + fn test_unicode_preservation() { + let tokens = tokenize("echo \"héllo wörld\""); + assert!(tokens.iter().any(|t| t.value.contains("héllo"))); + } + + #[test] + fn test_multiple_spaces() { + let tokens = tokenize("git status"); + assert_eq!(tokens.len(), 2); + } + + #[test] + fn test_leading_trailing_spaces() { + let tokens = tokenize(" git status "); + assert_eq!(tokens.len(), 2); + } + + // === OPERATOR TESTS === + + #[test] + fn test_and_operator() { + let tokens = tokenize("cmd1 && cmd2"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == "&&")); + } + + #[test] + fn test_or_operator() { + let tokens = tokenize("cmd1 || cmd2"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == "||")); + } + + #[test] + fn test_semicolon() { + let tokens = tokenize("cmd1 ; cmd2"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == ";")); + } + + #[test] + fn test_multiple_and() { + let tokens = tokenize("a && b && c"); + let ops: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Operator)) + .collect(); + assert_eq!(ops.len(), 2); + } + + #[test] + fn test_mixed_operators() { + let tokens = tokenize("a && b || c"); + let ops: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Operator)) + .collect(); + assert_eq!(ops.len(), 2); + } + + #[test] + fn test_operator_at_start() { + let tokens = tokenize("&& cmd"); + // Should still parse, just with operator first + assert!(tokens.iter().any(|t| t.value == "&&")); + } + + #[test] + fn test_operator_at_end() { + let tokens = tokenize("cmd &&"); + assert!(tokens.iter().any(|t| t.value == "&&")); + } + + // === PIPE TESTS === + + #[test] + fn test_pipe_detection() { + let tokens = tokenize("cat file | grep pattern"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Pipe))); + } + + #[test] + fn test_quoted_pipe_not_pipe() { + let tokens = tokenize("\"a|b\""); + // Pipe inside quotes is not a Pipe token + assert!(!tokens.iter().any(|t| matches!(t.kind, TokenKind::Pipe))); + } + + #[test] + fn test_multiple_pipes() { + let tokens = tokenize("a | b | c"); + let pipes: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Pipe)) + .collect(); + assert_eq!(pipes.len(), 2); + } + + // === SHELLISM TESTS === + + #[test] + fn test_glob_detection() { + let tokens = tokenize("ls *.rs"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_quoted_glob_not_shellism() { + let tokens = tokenize("echo \"*.txt\""); + // Glob inside quotes is not a Shellism token + assert!(!tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_variable_detection() { + let tokens = tokenize("echo $HOME"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_quoted_variable_not_shellism() { + let tokens = tokenize("echo \"$HOME\""); + // $ inside double quotes is NOT detected as a Shellism token + // because the lexer respects quotes + // This is correct - the variable can't be expanded by us anyway + // so the whole command will need to passthrough to shell + // But at the tokenization level, it's not a Shellism + assert!(!tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_backtick_substitution() { + let tokens = tokenize("echo `date`"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_subshell_detection() { + let tokens = tokenize("echo $(date)"); + // Both $ and ( should be shellisms + let shellisms: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Shellism)) + .collect(); + assert!(!shellisms.is_empty()); + } + + #[test] + fn test_brace_expansion() { + let tokens = tokenize("echo {a,b}.txt"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_escaped_glob() { + let tokens = tokenize("echo \\*.txt"); + // Escaped glob should not be a shellism + assert!(!tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Shellism) && t.value == "*")); + } + + // === REDIRECT TESTS === + + #[test] + fn test_redirect_out() { + let tokens = tokenize("cmd > file"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Redirect))); + } + + #[test] + fn test_redirect_append() { + let tokens = tokenize("cmd >> file"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Redirect) && t.value == ">>")); + } + + #[test] + fn test_redirect_in() { + let tokens = tokenize("cmd < file"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Redirect))); + } + + #[test] + fn test_redirect_stderr() { + let tokens = tokenize("cmd 2> file"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Redirect))); + } + + // === EXCLAMATION / NEGATION TESTS === + + #[test] + fn test_exclamation_is_shellism() { + let tokens = tokenize("if ! grep -q pattern file; then echo missing; fi"); + assert!( + tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Shellism) && t.value == "!"), + "! (negation) must be Shellism" + ); + } + + // === BACKGROUND JOB TESTS === + + #[test] + fn test_background_job_is_shellism() { + let tokens = tokenize("sleep 10 &"); + assert!( + tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Shellism) && t.value == "&"), + "Single & (background job) must be Shellism, not Redirect" + ); + } +} diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs new file mode 100644 index 0000000..979821a --- /dev/null +++ b/src/cmd/mod.rs @@ -0,0 +1,27 @@ +//! RTK command interceptor — safety checks and token-optimized execution. +//! +//! This module provides: +//! - Quote-aware lexing for shell commands +//! - Native execution for simple chains +//! - Passthrough to /bin/sh for complex scripts +//! - Safety interception (rm -> trash, etc.) +//! - Token-optimized output filtering +//! - Hook protocol support (Claude/Gemini) + +pub(crate) mod analysis; +pub(crate) mod builtins; +pub mod claude_hook; +pub mod exec; +pub(crate) mod filters; +pub mod gemini_hook; +pub mod hook; +pub(crate) mod lexer; +pub(crate) mod predicates; +pub(crate) mod safety; +pub(crate) mod trash_cmd; + +#[cfg(test)] +pub(crate) mod test_helpers; + +pub use exec::execute; +pub use hook::check_for_hook; diff --git a/src/cmd/predicates.rs b/src/cmd/predicates.rs new file mode 100644 index 0000000..9bd5a6a --- /dev/null +++ b/src/cmd/predicates.rs @@ -0,0 +1,94 @@ +//! Context-aware predicates for conditional safety rules. +//! These give RTK "situational awareness" - checking git state, file existence, etc. + +use std::process::Command; + +/// Check if there are unstaged changes in the current git repo +pub(crate) fn has_unstaged_changes() -> bool { + Command::new("git") + .args(["diff", "--quiet"]) + .status() + .map(|s| !s.success()) // git diff --quiet returns 1 if changes exist + .unwrap_or(false) +} + +/// Critical for token reduction: detect if output goes to human or agent +pub(crate) fn is_interactive() -> bool { + use std::io::IsTerminal; + std::io::stderr().is_terminal() +} + +/// Expand ~ to $HOME, with fallback +pub(crate) fn expand_tilde(path: &str) -> String { + if path.starts_with("~") { + // Try HOME first, then USERPROFILE (Windows) + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "/".to_string()); + path.replacen("~", &home, 1) + } else { + path.to_string() + } +} + +/// Get HOME directory with fallback +pub(crate) fn get_home() -> String { + std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "/".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + // === PATH EXPANSION TESTS === + + #[test] + fn test_expand_tilde_simple() { + let home = env::var("HOME").unwrap_or("/".to_string()); + assert_eq!(expand_tilde("~/src"), format!("{}/src", home)); + } + + #[test] + fn test_expand_tilde_no_tilde() { + assert_eq!(expand_tilde("/absolute/path"), "/absolute/path"); + } + + #[test] + fn test_expand_tilde_only_tilde() { + let home = env::var("HOME").unwrap_or("/".to_string()); + assert_eq!(expand_tilde("~"), home); + } + + #[test] + fn test_expand_tilde_relative() { + assert_eq!(expand_tilde("relative/path"), "relative/path"); + } + + // === HOME DIRECTORY TESTS === + + #[test] + fn test_get_home_returns_something() { + let home = get_home(); + assert!(!home.is_empty()); + } + + // === INTERACTIVE TESTS === + + #[test] + fn test_is_interactive() { + // This will be false when running tests + // Just ensure it doesn't panic + let _ = is_interactive(); + } + + // === GIT PREDICATE TESTS === + + #[test] + fn test_has_unstaged_changes() { + // Just ensure it doesn't panic + let _ = has_unstaged_changes(); + } +} diff --git a/src/cmd/safety.rs b/src/cmd/safety.rs new file mode 100644 index 0000000..21ea981 --- /dev/null +++ b/src/cmd/safety.rs @@ -0,0 +1,538 @@ +//! Safety Policy Engine — unified rule-based implementation. +//! +//! All safety rules, remaps, and blocking rules are loaded from the unified +//! Rule system (`config::rules`). Rules are MD files with YAML frontmatter, +//! loaded from built-in defaults and user directories. + +use crate::config::rules::{self, Rule}; + +use super::predicates; + +/// Result of safety check +#[derive(Clone, Debug, PartialEq)] +pub enum SafetyResult { + /// Command is safe to execute as-is + Safe, + /// Command is blocked with error message + Blocked(String), + /// Command was rewritten to a new command string + Rewritten(String), + /// Request to move files to trash (built-in) + TrashRequested(Vec), +} + +/// Dispatch a matched rule into a SafetyResult. +fn dispatch(rule: &Rule, args: &str) -> SafetyResult { + match rule.action.as_str() { + "trash" => { + let paths: Vec = args + .split_whitespace() + .filter(|a| !a.starts_with('-')) + .map(String::from) + .collect(); + SafetyResult::TrashRequested(paths) + } + "rewrite" => { + let redirect = rule.redirect.as_deref().unwrap_or(args); + SafetyResult::Rewritten(redirect.replace("{args}", args)) + } + "suggest_tool" | "block" => { + // Use interactive-aware message (human vs agent) + let msg = if predicates::is_interactive() { + // For suggest_tool, human message references the tool name + if rule.action == "suggest_tool" { + // First line of message is typically the human-friendly version + rule.message + .lines() + .next() + .unwrap_or(&rule.message) + .to_string() + } else { + rule.message.clone() + } + } else { + // Agent: use the full message (contains BLOCK: prefix) + rule.message.clone() + }; + SafetyResult::Blocked(msg) + } + "warn" => { + eprintln!("{}", rule.message); + SafetyResult::Safe + } + _ => SafetyResult::Safe, + } +} + +/// Check a parsed command against all safety rules. +pub fn check(binary: &str, args: &[String]) -> SafetyResult { + let full_cmd = if args.is_empty() { + binary.to_string() + } else { + format!("{} {}", binary, args.join(" ")) + }; + + for rule in rules::load_all() { + if !rules::matches_rule(rule, Some(binary), &full_cmd) { + continue; + } + if !rule.should_apply() { + continue; + } + return dispatch(rule, &args.join(" ")); + } + SafetyResult::Safe +} + +/// Check raw command string (for passthrough mode). +/// Catches dangerous patterns even when we can't parse the command. +pub fn check_raw(raw: &str) -> SafetyResult { + for rule in rules::load_all() { + if !rules::matches_rule(rule, None, raw) { + continue; + } + if !rule.should_apply() { + continue; + } + // In passthrough, suggest_tool rules don't apply (cat in pipelines is valid) + if rule.action == "suggest_tool" { + continue; + } + // In passthrough, trash becomes block (can't extract paths reliably) + if rule.action == "trash" { + return SafetyResult::Blocked(format!( + "Passthrough blocked: '{}' detected. Use native mode for safe trash.", + rule.patterns.first().map(|s| s.as_str()).unwrap_or("rm") + )); + } + return dispatch(rule, raw); + } + SafetyResult::Safe +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cmd::test_helpers::EnvGuard; + use std::env; + + // === BASIC CHECK TESTS === + + #[test] + fn test_check_safe_command() { + let _guard = EnvGuard::new(); + let result = check("ls", &["-la".to_string()]); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_git_status() { + let _guard = EnvGuard::new(); + let result = check("git", &["status".to_string()]); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_empty_args() { + let _guard = EnvGuard::new(); + let result = check("pwd", &[]); + assert_eq!(result, SafetyResult::Safe); + } + + // === RM SAFETY TESTS (RTK_SAFE_COMMANDS) === + + #[test] + fn test_check_rm_blocked_when_env_set() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("rm", &["file.txt".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + assert_eq!(paths, vec!["file.txt"]); + } + _ => panic!("Expected TrashRequested, got {:?}", result), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_blocked_by_default() { + let _guard = EnvGuard::new(); + // rm should be redirected to trash by default now + let result = check("rm", &["file.txt".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + assert_eq!(paths, vec!["file.txt"]); + } + _ => panic!("Expected TrashRequested by default, got {:?}", result), + } + } + + #[test] + fn test_check_rm_passes_when_disabled() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "0"); + let result = check("rm", &["file.txt".to_string()]); + assert_eq!(result, SafetyResult::Safe); + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_with_flags() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("rm", &["-rf".to_string(), "dir".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + // Flags should be filtered out + assert_eq!(paths, vec!["dir"]); + } + _ => panic!("Expected TrashRequested"), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_multiple_files() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check( + "rm", + &[ + "a.txt".to_string(), + "b.txt".to_string(), + "c.txt".to_string(), + ], + ); + match result { + SafetyResult::TrashRequested(paths) => { + assert_eq!(paths, vec!["a.txt", "b.txt", "c.txt"]); + } + _ => panic!("Expected TrashRequested"), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_no_files() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("rm", &["-rf".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + assert!(paths.is_empty()); + } + _ => panic!("Expected TrashRequested, got {:?}", result), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + // === CAT/SED/HEAD TESTS (blocked by default, opt-out with RTK_BLOCK_TOKEN_WASTE=0) === + + #[test] + fn test_check_cat_blocked() { + let _guard = EnvGuard::new(); + let result = check("cat", &["file.txt".to_string()]); + match result { + SafetyResult::Blocked(msg) => { + assert!(msg.contains("file-reading"), "msg: {}", msg); + } + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_cat_passes_when_disabled() { + let _guard = EnvGuard::new(); + env::set_var("RTK_BLOCK_TOKEN_WASTE", "0"); + let result = check("cat", &["file.txt".to_string()]); + env::remove_var("RTK_BLOCK_TOKEN_WASTE"); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_sed_blocked() { + let _guard = EnvGuard::new(); + let result = check("sed", &["-i".to_string(), "s/old/new/g".to_string()]); + match result { + SafetyResult::Blocked(msg) => { + assert!(msg.contains("file-editing"), "msg: {}", msg); + } + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_head_blocked() { + let _guard = EnvGuard::new(); + let result = check( + "head", + &["-n".to_string(), "10".to_string(), "file.txt".to_string()], + ); + match result { + SafetyResult::Blocked(msg) => { + assert!(msg.contains("file-reading"), "msg: {}", msg); + } + _ => panic!("Expected Blocked"), + } + } + + // === GIT SAFETY TESTS (RTK_SAFE_COMMANDS) === + + #[test] + fn test_check_git_reset_hard_blocked_when_env_set() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + // This test may or may not trigger depending on git state + // Just ensure it doesn't panic + let _ = check("git", &["reset".to_string(), "--hard".to_string()]); + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_git_clean_fd_rewritten() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("git", &["clean".to_string(), "-fd".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash -u")); + assert!(cmd.contains("clean")); + } + _ => panic!("Expected Rewritten, got {:?}", result), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_git_clean_rewritten_by_default() { + let _guard = EnvGuard::new(); + // git clean should be rewritten with stash by default + let result = check("git", &["clean".to_string(), "-fd".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash -u")); + } + _ => panic!("Expected Rewritten by default, got {:?}", result), + } + } + + #[test] + fn test_check_git_clean_passes_when_disabled() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "0"); + let result = check("git", &["clean".to_string(), "-fd".to_string()]); + assert_eq!(result, SafetyResult::Safe); + env::remove_var("RTK_SAFE_COMMANDS"); + } + + // === CHECK_RAW TESTS === + + #[test] + fn test_check_raw_rm_detected() { + let _guard = EnvGuard::new(); + // RTK_SAFE_COMMANDS is enabled by default, so rm should be blocked + let result = check_raw("rm file.txt"); + match result { + SafetyResult::Blocked(_) => {} + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_raw_sudo_rm_detected() { + let _guard = EnvGuard::new(); + // RTK_SAFE_COMMANDS is enabled by default, so sudo rm should be blocked + let result = check_raw("sudo rm file.txt"); + match result { + SafetyResult::Blocked(_) => {} + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_raw_sudo_flags_rm_detected() { + let _guard = EnvGuard::new(); + let result = check_raw("sudo -u root rm file.txt"); + match result { + SafetyResult::Blocked(_) => {} + _ => panic!("Expected Blocked for sudo -u root rm"), + } + } + + #[test] + fn test_check_raw_safe_command() { + let _guard = EnvGuard::new(); + let result = check_raw("ls -la"); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_raw_rm_in_quoted_string() { + let _guard = EnvGuard::new(); + let result = check_raw("echo \"rm file\""); + // This will be blocked because we can't distinguish quoted rm + // That's intentional - better safe than sorry + match result { + SafetyResult::Blocked(_) => {} + SafetyResult::Safe => {} // Either is acceptable + SafetyResult::Rewritten(_) => {} + SafetyResult::TrashRequested(_) => {} + } + } + + // === NEW GIT SAFETY TESTS === + + #[test] + fn test_git_checkout_dot_stash_prepended() { + let _guard = EnvGuard::new(); + let result = check("git", &["checkout".to_string(), ".".to_string()]); + // May or may not trigger based on predicate, just ensure no panic + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash")); + assert!(cmd.contains("checkout")); + } + SafetyResult::Safe => {} // Predicate returned false (no changes) + _ => {} + } + } + + #[test] + fn test_git_checkout_dashdash_stash_prepended() { + let _guard = EnvGuard::new(); + let result = check( + "git", + &[ + "checkout".to_string(), + "--".to_string(), + "file.txt".to_string(), + ], + ); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash")); + assert!(cmd.contains("checkout")); + } + SafetyResult::Safe => {} + _ => {} + } + } + + #[test] + fn test_git_stash_drop_rewritten_to_pop() { + let _guard = EnvGuard::new(); + let result = check("git", &["stash".to_string(), "drop".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash pop")); + } + _ => panic!("Expected Rewritten to stash pop"), + } + } + + #[test] + fn test_git_clean_f_rewritten() { + let _guard = EnvGuard::new(); + let result = check("git", &["clean".to_string(), "-f".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash -u")); + assert!(cmd.contains("clean")); + } + _ => panic!("Expected Rewritten with stash -u"), + } + } + + #[test] + fn test_git_branch_checkout_safe() { + // git checkout should be safe (not matched by checkout . or checkout --) + let _guard = EnvGuard::new(); + let result = check("git", &["checkout".to_string(), "main".to_string()]); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_git_checkout_new_branch_safe() { + let _guard = EnvGuard::new(); + let result = check( + "git", + &[ + "checkout".to_string(), + "-b".to_string(), + "feature".to_string(), + ], + ); + assert_eq!(result, SafetyResult::Safe); + } + + // === PATTERN MATCHING FALSE POSITIVE TESTS === + + #[test] + fn test_no_false_positive_catalog() { + let _guard = EnvGuard::new(); + let result = check("catalog", &["show".to_string()]); + assert_eq!( + result, + SafetyResult::Safe, + "catalog must not match cat rule" + ); + } + + #[test] + fn test_no_false_positive_sedan() { + let _guard = EnvGuard::new(); + let result = check("sedan", &[]); + assert_eq!(result, SafetyResult::Safe, "sedan must not match sed rule"); + } + + #[test] + fn test_no_false_positive_headless() { + let _guard = EnvGuard::new(); + let result = check("headless", &["chrome".to_string()]); + assert_eq!( + result, + SafetyResult::Safe, + "headless must not match head rule" + ); + } + + #[test] + fn test_no_false_positive_rmdir() { + let _guard = EnvGuard::new(); + let result = check("rmdir", &["empty_dir".to_string()]); + assert_eq!(result, SafetyResult::Safe, "rmdir must not match rm rule"); + } + + // === CHECK_RAW WORD BOUNDARY TESTS === + + #[test] + fn test_check_raw_no_false_positive_trim() { + let _guard = EnvGuard::new(); + std::env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check_raw("trim file.txt"); + assert_eq!(result, SafetyResult::Safe, "trim must not match rm pattern"); + std::env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_raw_no_false_positive_farm() { + let _guard = EnvGuard::new(); + std::env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check_raw("farm --harvest"); + assert_eq!(result, SafetyResult::Safe, "farm must not match rm pattern"); + std::env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_raw_catches_standalone_rm() { + let _guard = EnvGuard::new(); + std::env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check_raw("rm file.txt"); + assert!( + matches!(result, SafetyResult::Blocked(_)), + "standalone rm must be caught" + ); + std::env::remove_var("RTK_SAFE_COMMANDS"); + } +} diff --git a/src/cmd/test_helpers.rs b/src/cmd/test_helpers.rs new file mode 100644 index 0000000..06f929a --- /dev/null +++ b/src/cmd/test_helpers.rs @@ -0,0 +1,35 @@ +//! Shared test utilities for the cmd module. + +use std::sync::{Mutex, MutexGuard, OnceLock}; + +static ENV_LOCK: OnceLock> = OnceLock::new(); + +/// RAII guard that serializes env-var-mutating tests and auto-cleans on drop. +/// Prevents race conditions between parallel test threads and ensures cleanup +/// even if a test panics. +pub struct EnvGuard { + _lock: MutexGuard<'static, ()>, +} + +impl EnvGuard { + pub fn new() -> Self { + let lock = ENV_LOCK + .get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|e| e.into_inner()); + Self::cleanup(); + Self { _lock: lock } + } + + fn cleanup() { + std::env::remove_var("RTK_SAFE_COMMANDS"); + std::env::remove_var("RTK_BLOCK_TOKEN_WASTE"); + std::env::remove_var("RTK_ACTIVE"); + } +} + +impl Drop for EnvGuard { + fn drop(&mut self) { + Self::cleanup(); + } +} diff --git a/src/cmd/trash_cmd.rs b/src/cmd/trash_cmd.rs new file mode 100644 index 0000000..70ac9ae --- /dev/null +++ b/src/cmd/trash_cmd.rs @@ -0,0 +1,76 @@ +//! Built-in trash - mirrors rm behavior: silent on success, error on failure. + +use anyhow::Result; +use std::path::Path; + +pub fn execute(paths: &[String]) -> Result { + let expanded: Vec = paths + .iter() + .filter(|p| !p.is_empty()) + .map(|p| super::predicates::expand_tilde(p)) + .collect(); + + if expanded.is_empty() { + eprintln!("trash: no paths specified"); + return Ok(false); + } + + let (existing, missing): (Vec<_>, Vec<_>) = + expanded.iter().partition(|p| Path::new(p).exists()); + + // Report missing like rm does + for p in &missing { + eprintln!("trash: cannot remove '{}': No such path", p); + } + + if existing.is_empty() { + return Ok(false); + } + + let refs: Vec<&str> = existing.iter().map(|s| s.as_str()).collect(); + match trash::delete_all(&refs) { + Ok(_) => Ok(true), + Err(e) => { + eprintln!("trash: {}", e); + Ok(false) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::path::PathBuf; + + fn tmp(name: &str) -> PathBuf { + let p = std::env::temp_dir().join(format!("rtk_{}", name)); + fs::write(&p, "x").unwrap(); + p + } + fn rm(p: &PathBuf) { + let _ = fs::remove_file(p); + } + + #[test] + fn t_empty() { + assert!(!execute(&[]).unwrap()); + } + #[test] + fn t_missing() { + assert!(!execute(&["/nope".into()]).unwrap()); + } + #[test] + fn t_single() { + let p = tmp("s"); + assert!(execute(&[p.to_string_lossy().into()]).unwrap()); + rm(&p); + } + #[test] + fn t_multi() { + let (a, b) = (tmp("a"), tmp("b")); + assert!(execute(&[a.to_string_lossy().into(), b.to_string_lossy().into()]).unwrap()); + rm(&a); + rm(&b); + } +} diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index 830ec32..0000000 --- a/src/config.rs +++ /dev/null @@ -1,125 +0,0 @@ -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::path::PathBuf; - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct Config { - #[serde(default)] - pub tracking: TrackingConfig, - #[serde(default)] - pub display: DisplayConfig, - #[serde(default)] - pub filters: FilterConfig, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TrackingConfig { - pub enabled: bool, - pub history_days: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub database_path: Option, -} - -impl Default for TrackingConfig { - fn default() -> Self { - Self { - enabled: true, - history_days: 90, - database_path: None, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct DisplayConfig { - pub colors: bool, - pub emoji: bool, - pub max_width: usize, -} - -impl Default for DisplayConfig { - fn default() -> Self { - Self { - colors: true, - emoji: true, - max_width: 120, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct FilterConfig { - pub ignore_dirs: Vec, - pub ignore_files: Vec, -} - -impl Default for FilterConfig { - fn default() -> Self { - Self { - ignore_dirs: vec![ - ".git".into(), - "node_modules".into(), - "target".into(), - "__pycache__".into(), - ".venv".into(), - "vendor".into(), - ], - ignore_files: vec!["*.lock".into(), "*.min.js".into(), "*.min.css".into()], - } - } -} - -impl Config { - pub fn load() -> Result { - let path = get_config_path()?; - - if path.exists() { - let content = std::fs::read_to_string(&path)?; - let config: Config = toml::from_str(&content)?; - Ok(config) - } else { - Ok(Config::default()) - } - } - - pub fn save(&self) -> Result<()> { - let path = get_config_path()?; - - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent)?; - } - - let content = toml::to_string_pretty(self)?; - std::fs::write(&path, content)?; - Ok(()) - } - - pub fn create_default() -> Result { - let config = Config::default(); - config.save()?; - get_config_path() - } -} - -fn get_config_path() -> Result { - let config_dir = dirs::config_dir().unwrap_or_else(|| PathBuf::from(".")); - Ok(config_dir.join("rtk").join("config.toml")) -} - -pub fn show_config() -> Result<()> { - let path = get_config_path()?; - println!("Config: {}", path.display()); - println!(); - - if path.exists() { - let config = Config::load()?; - println!("{}", toml::to_string_pretty(&config)?); - } else { - println!("(default config, file not created)"); - println!(); - let config = Config::default(); - println!("{}", toml::to_string_pretty(&config)?); - } - - Ok(()) -} diff --git a/src/config/discovery.rs b/src/config/discovery.rs new file mode 100644 index 0000000..edd5112 --- /dev/null +++ b/src/config/discovery.rs @@ -0,0 +1,330 @@ +//! Directory walk-up discovery for `rtk.*.md` rule files. +//! +//! Walks from cwd to home, scanning configurable dirs in each ancestor. +//! Search dirs, global dirs, and extra rules_dirs are read from config. +//! Results cached via `OnceLock` — zero cost after first call. + +use std::collections::HashSet; +use std::path::{Path, PathBuf}; +use std::sync::OnceLock; + +static DISCOVERED: OnceLock> = OnceLock::new(); + +/// Return all `rtk.*.md` files ordered lowest→highest priority. +/// +/// Precedence (highest wins): +/// 0 (lowest). Compiled `include_str!()` defaults (handled in rules.rs, not here) +/// 1. Platform config dir + `~/.config/rtk/` (global RTK config) +/// 2. Config `discovery.rules_dirs` (explicit extra dirs) +/// 3. Config `discovery.global_dirs` under $HOME (default: `.claude/`, `.gemini/`) +/// 4. Walk up from cwd using config `discovery.search_dirs` +/// (default: `.claude/`, `.gemini/`, `.rtk/` — furthest from cwd first, cwd last) +/// 5. CLI `--rules-add` paths (highest file priority) +/// +/// If `--rules-path` is set, ONLY those paths are searched (skips all discovery). +/// All dirs configurable via `[discovery]` section in config.toml or env vars. +pub fn discover_rtk_files() -> &'static [PathBuf] { + DISCOVERED.get_or_init(discover_impl) +} + +fn discover_impl() -> Vec { + let mut seen = HashSet::new(); + let mut files = Vec::new(); + let overrides = super::cli_overrides(); + + // If --rules-path is set, use ONLY those paths (exclusive mode) + if let Some(ref exclusive_paths) = overrides.rules_path { + for dir in exclusive_paths { + collect_from_dir(dir, &mut files, &mut seen); + } + return files; + } + + let config = super::get_merged(); + + // Normal discovery + let home = match dirs::home_dir() { + Some(h) => h, + None => return files, + }; + + // 1. Platform-specific config dir (macOS: ~/Library/Application Support/rtk/) + if let Some(config_dir) = dirs::config_dir() { + let platform_rtk = config_dir.join("rtk"); + collect_from_dir(&platform_rtk, &mut files, &mut seen); + } + + // 2. Canonical RTK config dir: ~/.config/rtk/ + let canonical_rtk = home.join(".config").join("rtk"); + collect_from_dir(&canonical_rtk, &mut files, &mut seen); + + // 3. Config discovery.rules_dirs (explicit extra directories) + for dir in &config.discovery.rules_dirs { + collect_from_dir(dir, &mut files, &mut seen); + } + + // 4. Global dirs under $HOME (from config discovery.global_dirs) + for name in &config.discovery.global_dirs { + collect_from_dir(&home.join(name), &mut files, &mut seen); + } + + // 5. Walk up from cwd to home using config discovery.search_dirs + let cwd = match std::env::current_dir() { + Ok(c) => c, + Err(_) => return files, + }; + + let mut ancestors: Vec = Vec::new(); + let mut current = cwd.as_path(); + loop { + ancestors.push(current.to_path_buf()); + if current == home { + break; + } + match current.parent() { + Some(p) if p != current => current = p, + _ => break, + } + } + // Reverse: furthest ancestor first (lowest priority), cwd last (highest) + ancestors.reverse(); + + for ancestor in &ancestors { + for search_dir in &config.discovery.search_dirs { + let dir = ancestor.join(search_dir); + collect_from_dir(&dir, &mut files, &mut seen); + } + } + + // 6. --rules-add paths (highest file priority, after all discovery) + for dir in &overrides.rules_add { + collect_from_dir(dir, &mut files, &mut seen); + } + + files +} + +/// Collect `rtk.*.md` files from a directory, deduplicating by canonical path. +fn collect_from_dir(dir: &Path, files: &mut Vec, seen: &mut HashSet) { + let entries = match std::fs::read_dir(dir) { + Ok(e) => e, + Err(_) => return, // Silently skip unreadable dirs + }; + + let mut dir_files: Vec = Vec::new(); + for entry in entries.flatten() { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + if is_rtk_rule_file(&name_str) { + let path = entry.path(); + // Canonicalize for dedup: detects symlink loops and duplicate real paths + let canon = match path.canonicalize() { + Ok(c) => c, + Err(_) => continue, // Broken symlink or unreadable + }; + if seen.insert(canon) { + dir_files.push(path); + } + } + } + // Sort within directory for deterministic ordering + dir_files.sort(); + files.extend(dir_files); +} + +/// Match `rtk.*.md` pattern: starts with "rtk.", ends with ".md", has content between. +fn is_rtk_rule_file(name: &str) -> bool { + name.starts_with("rtk.") && name.ends_with(".md") && name.len() > 7 +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + #[test] + fn test_is_rtk_rule_file_valid() { + assert!(is_rtk_rule_file("rtk.safety.rm-to-trash.md")); + assert!(is_rtk_rule_file("rtk.remap.t.md")); + assert!(is_rtk_rule_file("rtk.x.md")); // minimal valid: 8 chars + } + + #[test] + fn test_is_rtk_rule_file_invalid() { + assert!(!is_rtk_rule_file("rtk.md")); // too short (7 chars, not > 7) + assert!(!is_rtk_rule_file("foo.md")); + assert!(!is_rtk_rule_file("rtk.safety.txt")); + assert!(!is_rtk_rule_file("")); + } + + #[test] + fn test_collect_from_empty_dir() { + let tmp = tempfile::tempdir().unwrap(); + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + assert!(files.is_empty()); + } + + #[test] + fn test_collect_from_dir_with_rules() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join("rtk.test.md"), "---\nname: test\n---\n").unwrap(); + fs::write(tmp.path().join("not-a-rule.md"), "ignored").unwrap(); + fs::write(tmp.path().join("rtk.md"), "too short name").unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + assert_eq!(files.len(), 1); + assert!(files[0].file_name().unwrap().to_str().unwrap() == "rtk.test.md"); + } + + #[test] + fn test_collect_deduplicates_symlinks() { + let tmp = tempfile::tempdir().unwrap(); + let real = tmp.path().join("rtk.test.md"); + fs::write(&real, "---\nname: test\n---\n").unwrap(); + + // Create a subdirectory with a symlink to the same file + let subdir = tmp.path().join("sub"); + fs::create_dir(&subdir).unwrap(); + #[cfg(unix)] + std::os::unix::fs::symlink(&real, subdir.join("rtk.test.md")).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + collect_from_dir(&subdir, &mut files, &mut seen); + + #[cfg(unix)] + assert_eq!(files.len(), 1, "Symlink should be deduplicated"); + } + + #[test] + fn test_collect_skips_unreadable_dir() { + let mut files = Vec::new(); + let mut seen = HashSet::new(); + // Non-existent directory should be silently skipped + collect_from_dir(Path::new("/nonexistent/path"), &mut files, &mut seen); + assert!(files.is_empty()); + } + + #[test] + fn test_collect_skips_file_as_dir() { + // If a file is passed instead of a directory, read_dir will fail — should be skipped + let tmp = tempfile::tempdir().unwrap(); + let file_path = tmp.path().join("not_a_dir"); + fs::write(&file_path, "i am a file").unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(&file_path, &mut files, &mut seen); + assert!(files.is_empty()); // Not a dir, silently skipped + } + + #[test] + fn test_collect_skips_broken_symlinks() { + let tmp = tempfile::tempdir().unwrap(); + + #[cfg(unix)] + { + // Create a broken symlink (target doesn't exist) + let broken_link = tmp.path().join("rtk.broken.md"); + std::os::unix::fs::symlink("/nonexistent/target", &broken_link).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + // Broken symlink: canonicalize fails → continue (skipped) + assert!(files.is_empty()); + } + } + + #[test] + fn test_collect_handles_non_utf8_filenames() { + // Files with non-UTF8 names should be handled via to_string_lossy + let tmp = tempfile::tempdir().unwrap(); + // Create a normal rtk rule file alongside a non-matching file + fs::write(tmp.path().join("rtk.valid.md"), "---\nname: v\n---\n").unwrap(); + fs::write(tmp.path().join("other.txt"), "not a rule").unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + assert_eq!(files.len(), 1); + } + + #[test] + fn test_collect_multiple_dirs_deduplicates() { + let tmp = tempfile::tempdir().unwrap(); + let dir_a = tmp.path().join("a"); + let dir_b = tmp.path().join("b"); + fs::create_dir_all(&dir_a).unwrap(); + fs::create_dir_all(&dir_b).unwrap(); + + let real_file = dir_a.join("rtk.test.md"); + fs::write(&real_file, "---\nname: test\n---\n").unwrap(); + + #[cfg(unix)] + { + // Symlink from dir_b to same real file + std::os::unix::fs::symlink(&real_file, dir_b.join("rtk.test.md")).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(&dir_a, &mut files, &mut seen); + collect_from_dir(&dir_b, &mut files, &mut seen); + assert_eq!( + files.len(), + 1, + "Same file via symlink should be deduplicated" + ); + } + } + + #[cfg(unix)] + #[test] + fn test_collect_permission_denied_dir_skipped() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::tempdir().unwrap(); + let restricted = tmp.path().join("restricted"); + fs::create_dir(&restricted).unwrap(); + fs::write(restricted.join("rtk.test.md"), "---\nname: t\n---\n").unwrap(); + + // Remove read permission + fs::set_permissions(&restricted, fs::Permissions::from_mode(0o000)).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(&restricted, &mut files, &mut seen); + // Permission denied → silently skipped + assert!(files.is_empty()); + + // Restore permissions for cleanup + fs::set_permissions(&restricted, fs::Permissions::from_mode(0o755)).unwrap(); + } + + #[test] + fn test_default_search_dirs_match_expected() { + // Verify defaults match the previously hardcoded values + let config = crate::config::DiscoveryConfig::default(); + assert_eq!(config.search_dirs, vec![".claude", ".gemini", ".rtk"]); + } + + #[test] + fn test_default_global_dirs_match_expected() { + let config = crate::config::DiscoveryConfig::default(); + assert_eq!(config.global_dirs, vec![".claude", ".gemini"]); + } + + #[test] + fn test_default_rules_dirs_empty() { + let config = crate::config::DiscoveryConfig::default(); + assert!( + config.rules_dirs.is_empty(), + "Default rules_dirs should be empty (uses ~/.config/rtk/ implicitly)" + ); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..1e75105 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,1208 @@ +//! Configuration system: scalar config (TOML) + unified rules (MD with YAML frontmatter). +//! +//! Two config layers: +//! 1. Scalar config (`config.toml`): tracking, display, filters +//! 2. Rules (`rtk.*.md`): safety, remaps, warnings — via `rules` submodule + +pub mod discovery; +pub mod rules; + +use anyhow::{anyhow, Result}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::OnceLock; + +/// CLI overrides for config paths. Set from main.rs before any config loading. +#[derive(Debug, Default)] +pub struct CliConfigOverrides { + /// Exclusive config paths — replaces all discovery. Multiple files merged in order. + pub config_path: Option>, + /// Additional config paths — loaded with highest priority (after env vars). + pub config_add: Vec, + /// Exclusive rule discovery paths — replaces walk-up discovery. + pub rules_path: Option>, + /// Additional rule discovery paths — loaded with highest priority. + pub rules_add: Vec, +} + +static CLI_OVERRIDES: OnceLock = OnceLock::new(); + +/// Set CLI config overrides. Must be called before any config loading. +pub fn set_cli_overrides(overrides: CliConfigOverrides) { + let _ = CLI_OVERRIDES.set(overrides); +} + +/// Get CLI config overrides (or defaults if never set). +pub fn cli_overrides() -> &'static CliConfigOverrides { + CLI_OVERRIDES.get_or_init(CliConfigOverrides::default) +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone)] +pub struct Config { + #[serde(default)] + pub tracking: TrackingConfig, + #[serde(default)] + pub display: DisplayConfig, + #[serde(default)] + pub filters: FilterConfig, + #[serde(default)] + pub discovery: DiscoveryConfig, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct TrackingConfig { + pub enabled: bool, + pub history_days: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub database_path: Option, +} + +impl Default for TrackingConfig { + fn default() -> Self { + Self { + enabled: true, + history_days: 90, + database_path: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct DisplayConfig { + pub colors: bool, + pub emoji: bool, + pub max_width: usize, +} + +impl Default for DisplayConfig { + fn default() -> Self { + Self { + colors: true, + emoji: true, + max_width: 120, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FilterConfig { + pub ignore_dirs: Vec, + pub ignore_files: Vec, +} + +impl Default for FilterConfig { + fn default() -> Self { + Self { + ignore_dirs: vec![ + ".git".into(), + "node_modules".into(), + "target".into(), + "__pycache__".into(), + ".venv".into(), + "vendor".into(), + ], + ignore_files: vec!["*.lock".into(), "*.min.js".into(), "*.min.css".into()], + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct DiscoveryConfig { + /// Dirs to search in each ancestor during walk-up (e.g. [".claude", ".gemini", ".rtk"]). + pub search_dirs: Vec, + /// Global dirs under $HOME to check before walk-up (e.g. [".claude", ".gemini"]). + pub global_dirs: Vec, + /// Additional rule directories to search. First entry is also the export/write target. + /// Default: [] (uses ~/.config/rtk/ as the implicit primary). + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub rules_dirs: Vec, +} + +impl Default for DiscoveryConfig { + fn default() -> Self { + Self { + search_dirs: vec![".claude".into(), ".gemini".into(), ".rtk".into()], + global_dirs: vec![".claude".into(), ".gemini".into()], + rules_dirs: vec![], + } + } +} + +impl Config { + /// Load global config from `~/.config/rtk/config.toml`. + /// Falls back to defaults if file is missing or unreadable. + pub fn load() -> Result { + let path = match get_config_path() { + Ok(p) => p, + Err(_) => return Ok(Config::default()), + }; + + if path.exists() { + match std::fs::read_to_string(&path) { + Ok(content) => match toml::from_str(&content) { + Ok(config) => Ok(config), + Err(_) => Ok(Config::default()), // Malformed config → defaults + }, + Err(_) => Ok(Config::default()), // Unreadable → defaults + } + } else { + Ok(Config::default()) + } + } + + /// Load merged config with full precedence chain. + /// + /// Precedence (highest wins): + /// 0. CLI params: `--config-path` (exclusive) or `--config-add` (additive) + /// 1. Environment variables (RTK_*) + /// 2. Project-local `.rtk/config.toml` (nearest ancestor) + /// 3. Global `~/.config/rtk/config.toml` (or platform config dir) + /// 4. Compiled defaults + pub fn load_merged() -> Result { + let overrides = cli_overrides(); + + // If --config-path is set, use ONLY those files (skip global + walk-up) + let mut config = if let Some(ref exclusive_paths) = overrides.config_path { + let mut cfg = Config::default(); + for path in exclusive_paths { + if path.exists() { + if let Ok(content) = std::fs::read_to_string(path) { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut cfg); + } + } + } + } + cfg + } else { + // Normal: start with global config + let mut cfg = Self::load()?; + + // Layer 3: Walk up from cwd looking for .rtk/config.toml + if let Ok(cwd) = std::env::current_dir() { + let mut current = cwd.as_path(); + loop { + let project_config = current.join(".rtk").join("config.toml"); + if project_config.exists() { + match std::fs::read_to_string(&project_config) { + Ok(content) => { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut cfg); + } + } + Err(_) => {} // Silently skip unreadable project config + } + break; + } + match current.parent() { + Some(p) if p != current => current = p, + _ => break, + } + } + } + cfg + }; + + // Layer 1.5: --config-add paths (higher than project-local, lower than env vars) + for add_path in &overrides.config_add { + if add_path.exists() { + if let Ok(content) = std::fs::read_to_string(add_path) { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut config); + } + } + } + } + + // Layer 1 (highest priority): Environment variable overrides + if let Ok(val) = std::env::var("RTK_TRACKING_ENABLED") { + if let Ok(b) = val.parse::() { + config.tracking.enabled = b; + } else if val == "0" { + config.tracking.enabled = false; + } else if val == "1" { + config.tracking.enabled = true; + } + } + if let Ok(val) = std::env::var("RTK_HISTORY_DAYS") { + if let Ok(days) = val.parse::() { + config.tracking.history_days = days; + } + } + if let Ok(path) = std::env::var("RTK_DB_PATH") { + config.tracking.database_path = Some(PathBuf::from(path)); + } + if let Ok(val) = std::env::var("RTK_DISPLAY_COLORS") { + if let Ok(b) = val.parse::() { + config.display.colors = b; + } + } + if let Ok(val) = std::env::var("RTK_DISPLAY_EMOJI") { + if let Ok(b) = val.parse::() { + config.display.emoji = b; + } + } + if let Ok(val) = std::env::var("RTK_MAX_WIDTH") { + if let Ok(w) = val.parse::() { + config.display.max_width = w; + } + } + if let Ok(val) = std::env::var("RTK_SEARCH_DIRS") { + config.discovery.search_dirs = val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Ok(val) = std::env::var("RTK_RULES_DIRS") { + config.discovery.rules_dirs = val.split(',').map(|s| PathBuf::from(s.trim())).collect(); + } + + Ok(config) + } + + pub fn save(&self) -> Result<()> { + let path = get_config_path()?; + + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + + let content = toml::to_string_pretty(self)?; + std::fs::write(&path, content)?; + Ok(()) + } + + /// Save to a specific path (for --local support). + pub fn save_to(&self, path: &std::path::Path) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let content = toml::to_string_pretty(self)?; + std::fs::write(path, content)?; + Ok(()) + } + + pub fn create_default() -> Result { + let config = Config::default(); + config.save()?; + get_config_path() + } +} + +/// Overlay config for merging project config onto global config. +/// All fields are Option — only present fields override. +#[derive(Debug, Deserialize, Default)] +pub struct ConfigOverlay { + pub tracking: Option, + pub display: Option, + pub filters: Option, + pub discovery: Option, +} + +#[derive(Debug, Deserialize)] +pub struct TrackingOverlay { + pub enabled: Option, + pub history_days: Option, + pub database_path: Option, +} + +#[derive(Debug, Deserialize)] +pub struct DisplayOverlay { + pub colors: Option, + pub emoji: Option, + pub max_width: Option, +} + +#[derive(Debug, Deserialize)] +pub struct FilterOverlay { + pub ignore_dirs: Option>, + pub ignore_files: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct DiscoveryOverlay { + pub search_dirs: Option>, + pub global_dirs: Option>, + pub rules_dirs: Option>, +} + +impl ConfigOverlay { + fn apply(&self, config: &mut Config) { + if let Some(ref t) = self.tracking { + if let Some(v) = t.enabled { + config.tracking.enabled = v; + } + if let Some(v) = t.history_days { + config.tracking.history_days = v; + } + if let Some(ref v) = t.database_path { + config.tracking.database_path = Some(v.clone()); + } + } + if let Some(ref d) = self.display { + if let Some(v) = d.colors { + config.display.colors = v; + } + if let Some(v) = d.emoji { + config.display.emoji = v; + } + if let Some(v) = d.max_width { + config.display.max_width = v; + } + } + if let Some(ref f) = self.filters { + if let Some(ref v) = f.ignore_dirs { + config.filters.ignore_dirs = v.clone(); + } + if let Some(ref v) = f.ignore_files { + config.filters.ignore_files = v.clone(); + } + } + if let Some(ref d) = self.discovery { + if let Some(ref v) = d.search_dirs { + config.discovery.search_dirs = v.clone(); + } + if let Some(ref v) = d.global_dirs { + config.discovery.global_dirs = v.clone(); + } + if let Some(ref v) = d.rules_dirs { + config.discovery.rules_dirs = v.clone(); + } + } + } +} + +/// Global config path: `~/.config/rtk/config.toml` +pub fn get_config_path() -> Result { + let config_dir = dirs::config_dir().unwrap_or_else(|| PathBuf::from(".")); + Ok(config_dir.join("rtk").join("config.toml")) +} + +/// Canonical RTK rules directory: `~/.config/rtk/` +/// +/// This is distinct from `dirs::config_dir()` which on macOS returns +/// `~/Library/Application Support/` — not appropriate for a CLI tool's +/// user-facing rule files. We use `~/.config/rtk/` on all platforms. +/// Primary rules directory (for writes/exports). First entry of rules_dirs, or ~/.config/rtk/. +pub fn get_rules_dir() -> Result { + let config = get_merged(); + if let Some(first) = config.discovery.rules_dirs.first() { + return Ok(first.clone()); + } + let home = dirs::home_dir().ok_or_else(|| anyhow!("Cannot determine home directory"))?; + Ok(home.join(".config").join("rtk")) +} + +/// Project-local config path: `.rtk/config.toml` in cwd +pub fn get_local_config_path() -> Result { + let cwd = std::env::current_dir()?; + Ok(cwd.join(".rtk").join("config.toml")) +} + +/// Cached merged config (loaded once per process). +static MERGED_CONFIG: OnceLock = OnceLock::new(); + +/// Get the merged config (cached). For use by tracking, display, etc. +pub fn get_merged() -> &'static Config { + MERGED_CONFIG.get_or_init(|| Config::load_merged().unwrap_or_default()) +} + +pub fn show_config() -> Result<()> { + let path = get_config_path()?; + if path.exists() { + println!("# {}", path.display()); + let config = Config::load()?; + println!("{}", toml::to_string_pretty(&config)?); + } else { + println!("# (defaults, no config file)"); + println!("{}", toml::to_string_pretty(&Config::default())?); + } + Ok(()) +} + +// === Config CRUD === + +/// Get a config value by dotted key (e.g., "tracking.enabled"). +pub fn get_value(key: &str) -> Result { + let config = Config::load_merged()?; + let toml_val = toml::Value::try_from(&config)?; + + let parts: Vec<&str> = key.split('.').collect(); + let mut current = &toml_val; + for part in &parts { + current = current + .get(part) + .ok_or_else(|| anyhow!("Unknown config key: {key}"))?; + } + + match current { + toml::Value::String(s) => Ok(s.clone()), + toml::Value::Boolean(b) => Ok(b.to_string()), + toml::Value::Integer(i) => Ok(i.to_string()), + toml::Value::Float(f) => Ok(f.to_string()), + toml::Value::Array(a) => Ok(format!("{:?}", a)), + other => Ok(other.to_string()), + } +} + +/// Set a config value by dotted key. +pub fn set_value(key: &str, value: &str, local: bool) -> Result<()> { + let path = if local { + get_local_config_path()? + } else { + get_config_path()? + }; + + let mut config = if path.exists() { + let content = std::fs::read_to_string(&path)?; + toml::from_str(&content)? + } else { + Config::default() + }; + + apply_value(&mut config, key, value)?; + + if local { + config.save_to(&path)?; + } else { + config.save()?; + } + Ok(()) +} + +/// Unset a config value (reset to default). +pub fn unset_value(key: &str, local: bool) -> Result<()> { + let path = if local { + get_local_config_path()? + } else { + get_config_path()? + }; + + if !path.exists() { + return Err(anyhow!("Config file not found: {}", path.display())); + } + + let content = std::fs::read_to_string(&path)?; + let mut toml_val: toml::Value = toml::from_str(&content)?; + + let parts: Vec<&str> = key.split('.').collect(); + if parts.len() == 2 { + if let Some(table) = toml_val.get_mut(parts[0]).and_then(|v| v.as_table_mut()) { + table.remove(parts[1]); + } + } else { + return Err(anyhow!("Invalid key format: {key}. Use section.field")); + } + + let content = toml::to_string_pretty(&toml_val)?; + std::fs::write(&path, content)?; + Ok(()) +} + +/// List all config values with optional origin info. +pub fn list_values(origin: bool) -> Result<()> { + let config = Config::load_merged()?; + let toml_str = toml::to_string_pretty(&config)?; + + if origin { + let global_path = get_config_path()?; + let has_global = global_path.exists(); + + // Check for project config + let mut has_project = false; + if let Ok(cwd) = std::env::current_dir() { + let mut current = cwd.as_path(); + loop { + if current.join(".rtk").join("config.toml").exists() { + has_project = true; + break; + } + match current.parent() { + Some(p) if p != current => current = p, + _ => break, + } + } + } + + println!("# Sources:"); + if has_global { + println!("# global: {}", global_path.display()); + } + if has_project { + println!("# project: .rtk/config.toml"); + } + if !has_global && !has_project { + println!("# (all defaults)"); + } + println!(); + } + + println!("{toml_str}"); + + // Show rules summary only with --origin flag + if origin { + let rules = rules::load_all(); + if !rules.is_empty() { + println!("# Rules ({} loaded):", rules.len()); + for rule in rules { + println!("# {} [{}] — {}", rule.name, rule.action, rule.source); + } + } + } + + Ok(()) +} + +/// Apply a string value to a config struct by dotted key. +fn apply_value(config: &mut Config, key: &str, value: &str) -> Result<()> { + match key { + "tracking.enabled" => config.tracking.enabled = value.parse()?, + "tracking.history_days" => config.tracking.history_days = value.parse()?, + "tracking.database_path" => { + config.tracking.database_path = Some(PathBuf::from(value)); + } + "display.colors" => config.display.colors = value.parse()?, + "display.emoji" => config.display.emoji = value.parse()?, + "display.max_width" => config.display.max_width = value.parse()?, + "discovery.search_dirs" => { + config.discovery.search_dirs = value.split(',').map(|s| s.trim().to_string()).collect(); + } + "discovery.global_dirs" => { + config.discovery.global_dirs = value.split(',').map(|s| s.trim().to_string()).collect(); + } + "discovery.rules_dirs" => { + config.discovery.rules_dirs = + value.split(',').map(|s| PathBuf::from(s.trim())).collect(); + } + _ => return Err(anyhow!("Unknown config key: {key}")), + } + Ok(()) +} + +/// Create or update a rule MD file. +pub fn set_rule( + name: &str, + pattern: Option<&str>, + action: Option<&str>, + redirect: Option<&str>, + local: bool, +) -> Result<()> { + let dir = if local { + let cwd = std::env::current_dir()?; + cwd.join(".rtk") + } else { + get_rules_dir()? + }; + std::fs::create_dir_all(&dir)?; + + let action_str = action.unwrap_or("rewrite"); + let filename = format!("rtk.{name}.md"); + let path = dir.join(&filename); + + let mut content = String::from("---\n"); + content.push_str(&format!("name: {name}\n")); + if let Some(pat) = pattern { + // Single pattern without quotes for simple, quoted for multi-word + if pat.contains(' ') { + content.push_str(&format!("patterns: [\"{pat}\"]\n")); + } else { + content.push_str(&format!("patterns: [{pat}]\n")); + } + } + content.push_str(&format!("action: {action_str}\n")); + if let Some(redir) = redirect { + content.push_str(&format!("redirect: \"{redir}\"\n")); + } + content.push_str("---\n\nUser-defined rule.\n"); + + std::fs::write(&path, &content)?; + println!("Created rule: {}", path.display()); + Ok(()) +} + +/// Delete a rule MD file. +pub fn unset_rule(name: &str, local: bool) -> Result<()> { + let dir = if local { + let cwd = std::env::current_dir()?; + cwd.join(".rtk") + } else { + get_rules_dir()? + }; + + let filename = format!("rtk.{name}.md"); + let path = dir.join(&filename); + + if path.exists() { + std::fs::remove_file(&path)?; + println!("Removed rule: {}", path.display()); + } else { + // If it's a built-in rule, create a disabled override + let is_builtin = rules::DEFAULT_RULES.iter().any(|content| { + rules::parse_rule(content, "builtin") + .map(|r| r.name == name) + .unwrap_or(false) + }); + if is_builtin { + std::fs::create_dir_all(&dir)?; + let content = format!("---\nname: {name}\nenabled: false\n---\n\nDisabled by user.\n"); + std::fs::write(&path, content)?; + println!("Disabled built-in rule: {}", path.display()); + } else { + return Err(anyhow!("Rule file not found: {}", path.display())); + } + } + Ok(()) +} + +/// Export built-in rules to a directory. +pub fn export_rules(claude: bool) -> Result<()> { + let dir = if claude { + crate::init::resolve_claude_dir()? + } else { + get_rules_dir()? + }; + std::fs::create_dir_all(&dir)?; + + let mut count = 0; + for content in rules::DEFAULT_RULES { + let rule = rules::parse_rule(content, "builtin")?; + let filename = format!("rtk.{}.md", rule.name); + let path = dir.join(&filename); + // Skip if content unchanged; tolerate unreadable existing files + if path.exists() { + if let Ok(existing) = std::fs::read_to_string(&path) { + if existing.trim() == content.trim() { + continue; + } + } + // If unreadable, overwrite anyway + } + std::fs::write(&path, content)?; + count += 1; + } + + println!("Exported {} rules to {}", count, dir.display()); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_default() { + let config = Config::default(); + assert!(config.tracking.enabled); + assert_eq!(config.tracking.history_days, 90); + assert!(config.display.colors); + assert_eq!(config.display.max_width, 120); + } + + #[test] + fn test_config_overlay_none_fields_dont_override() { + let mut config = Config::default(); + config.tracking.history_days = 30; + config.display.max_width = 80; + + let overlay = ConfigOverlay::default(); + overlay.apply(&mut config); + + // None fields should not override + assert_eq!(config.tracking.history_days, 30); + assert_eq!(config.display.max_width, 80); + } + + #[test] + fn test_config_overlay_applies() { + let mut config = Config::default(); + + let overlay_toml = r#" +[tracking] +history_days = 30 + +[display] +max_width = 80 +"#; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.tracking.history_days, 30); + assert_eq!(config.display.max_width, 80); + // Unmentioned fields unchanged + assert!(config.tracking.enabled); + assert!(config.display.colors); + } + + #[test] + fn test_apply_value_tracking() { + let mut config = Config::default(); + apply_value(&mut config, "tracking.enabled", "false").unwrap(); + assert!(!config.tracking.enabled); + + apply_value(&mut config, "tracking.history_days", "30").unwrap(); + assert_eq!(config.tracking.history_days, 30); + } + + #[test] + fn test_apply_value_display() { + let mut config = Config::default(); + apply_value(&mut config, "display.max_width", "80").unwrap(); + assert_eq!(config.display.max_width, 80); + + apply_value(&mut config, "display.colors", "false").unwrap(); + assert!(!config.display.colors); + } + + #[test] + fn test_apply_value_unknown_key() { + let mut config = Config::default(); + assert!(apply_value(&mut config, "unknown.key", "value").is_err()); + } + + #[test] + fn test_get_value_existing() { + // This uses load_merged which reads from disk, so just test the happy path + let result = get_value("tracking.enabled"); + assert!(result.is_ok()); + let val = result.unwrap(); + assert!(val == "true" || val == "false"); + } + + #[test] + fn test_get_value_unknown() { + let result = get_value("nonexistent.key"); + assert!(result.is_err()); + } + + #[test] + fn test_load_merged_env_override() { + std::env::set_var("RTK_DB_PATH", "/tmp/test.db"); + let config = Config::load_merged().unwrap(); + assert_eq!( + config.tracking.database_path, + Some(PathBuf::from("/tmp/test.db")) + ); + std::env::remove_var("RTK_DB_PATH"); + } + + #[test] + fn test_env_overrides_all_fields() { + // Single test to avoid parallel env var interference. + // Tests all RTK_* env var overrides sequentially. + + // tracking.enabled: "false" overrides default true + std::env::set_var("RTK_TRACKING_ENABLED", "false"); + let config = Config::load_merged().unwrap(); + assert!(!config.tracking.enabled); + std::env::remove_var("RTK_TRACKING_ENABLED"); + + // tracking.enabled: "0" also disables + std::env::set_var("RTK_TRACKING_ENABLED", "0"); + let config = Config::load_merged().unwrap(); + assert!(!config.tracking.enabled); + std::env::remove_var("RTK_TRACKING_ENABLED"); + + // tracking.enabled: "1" enables + std::env::set_var("RTK_TRACKING_ENABLED", "1"); + let config = Config::load_merged().unwrap(); + assert!(config.tracking.enabled); + std::env::remove_var("RTK_TRACKING_ENABLED"); + + // tracking.history_days + std::env::set_var("RTK_HISTORY_DAYS", "7"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.tracking.history_days, 7); + std::env::remove_var("RTK_HISTORY_DAYS"); + + // display.colors + std::env::set_var("RTK_DISPLAY_COLORS", "false"); + let config = Config::load_merged().unwrap(); + assert!(!config.display.colors); + std::env::remove_var("RTK_DISPLAY_COLORS"); + + // display.emoji + std::env::set_var("RTK_DISPLAY_EMOJI", "false"); + let config = Config::load_merged().unwrap(); + assert!(!config.display.emoji); + std::env::remove_var("RTK_DISPLAY_EMOJI"); + + // display.max_width + std::env::set_var("RTK_MAX_WIDTH", "200"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.display.max_width, 200); + std::env::remove_var("RTK_MAX_WIDTH"); + } + + #[test] + fn test_project_local_overlay_overrides_global() { + let tmp = tempfile::tempdir().unwrap(); + let rtk_dir = tmp.path().join(".rtk"); + std::fs::create_dir_all(&rtk_dir).unwrap(); + std::fs::write( + rtk_dir.join("config.toml"), + "[tracking]\nhistory_days = 14\n", + ) + .unwrap(); + + // Simulate being in a project with .rtk/config.toml + let mut config = Config::default(); + assert_eq!(config.tracking.history_days, 90); // default + + let overlay_toml = "[tracking]\nhistory_days = 14\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + assert_eq!(config.tracking.history_days, 14); // project-local overrides + } + + #[test] + fn test_env_overrides_project_local_overlay() { + // Env vars have highest priority — even over project-local config. + // Tests overlay application directly (no env var race). + let mut config = Config::default(); + let overlay_toml = "[tracking]\nhistory_days = 14\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + assert_eq!(config.tracking.history_days, 14); // overlay applied + + // In load_merged, env vars are applied AFTER project overlay, + // so env vars always win. Tested via test_env_overrides_all_fields. + } + + #[test] + fn test_load_robust_to_missing_config() { + // Config::load() should fall back to defaults when config doesn't exist + let config = Config::load().unwrap(); + // Should have defaults — no crash + assert!(config.tracking.enabled); + assert_eq!(config.tracking.history_days, 90); + } + + #[test] + fn test_overlay_partial_sections() { + // Only display section in overlay — tracking should be untouched + let mut config = Config::default(); + config.tracking.history_days = 45; + + let overlay_toml = "[display]\nmax_width = 60\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.display.max_width, 60); // overridden + assert_eq!(config.tracking.history_days, 45); // untouched + } + + #[test] + fn test_overlay_partial_fields_within_section() { + // Only one field in tracking overlay — others untouched + let mut config = Config::default(); + config.tracking.enabled = false; + + let overlay_toml = "[tracking]\nhistory_days = 7\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.tracking.history_days, 7); // overridden + assert!(!config.tracking.enabled); // untouched (was false) + } + + #[test] + fn test_get_rules_dir_returns_dot_config_rtk() { + let dir = get_rules_dir().unwrap(); + let home = dirs::home_dir().unwrap(); + assert_eq!(dir, home.join(".config").join("rtk")); + } + + #[test] + fn test_env_override_invalid_value_ignored() { + // Invalid env values should be silently ignored, keeping the default + std::env::set_var("RTK_HISTORY_DAYS", "not_a_number"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.tracking.history_days, 90); // default kept + std::env::remove_var("RTK_HISTORY_DAYS"); + + std::env::set_var("RTK_MAX_WIDTH", "abc"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.display.max_width, 120); // default kept + std::env::remove_var("RTK_MAX_WIDTH"); + } + + #[test] + fn test_cli_overrides_default() { + // Default CLI overrides should not change behavior + let overrides = CliConfigOverrides::default(); + assert!(overrides.config_path.is_none()); // None = use normal discovery + assert!(overrides.config_add.is_empty()); + assert!(overrides.rules_path.is_none()); // None = use normal discovery + assert!(overrides.rules_add.is_empty()); + } + + #[test] + fn test_cli_config_path_multiple_files_merged() { + // --config-path a.toml --config-path b.toml merges both in order + let tmp = tempfile::tempdir().unwrap(); + let file_a = tmp.path().join("a.toml"); + let file_b = tmp.path().join("b.toml"); + std::fs::write(&file_a, "[tracking]\nhistory_days = 5\n").unwrap(); + std::fs::write(&file_b, "[display]\nmax_width = 60\n").unwrap(); + + // Simulate load_merged with exclusive paths + let mut cfg = Config::default(); + for path in &[&file_a, &file_b] { + if let Ok(content) = std::fs::read_to_string(path) { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut cfg); + } + } + } + + assert_eq!(cfg.tracking.history_days, 5); // from a.toml + assert_eq!(cfg.display.max_width, 60); // from b.toml + assert!(cfg.tracking.enabled); // default (not in either file) + } + + #[test] + fn test_cli_config_path_exclusive() { + // --config-path loads ONLY from that file + let tmp = tempfile::tempdir().unwrap(); + let config_file = tmp.path().join("custom.toml"); + std::fs::write( + &config_file, + "[tracking]\nhistory_days = 5\nenabled = false\n", + ) + .unwrap(); + + // Simulate what load_merged does with exclusive path + let path = &config_file; + let config: Config = if path.exists() { + let content = std::fs::read_to_string(path).unwrap(); + toml::from_str(&content).unwrap() + } else { + Config::default() + }; + + assert_eq!(config.tracking.history_days, 5); + assert!(!config.tracking.enabled); + // Other fields get defaults since only tracking was specified + assert!(config.display.colors); + } + + #[test] + fn test_cli_config_add_overlay() { + // --config-add applies as high-priority overlay + let mut config = Config::default(); + assert_eq!(config.display.max_width, 120); + + let add_toml = "[display]\nmax_width = 60\n"; + let overlay: ConfigOverlay = toml::from_str(add_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.display.max_width, 60); // overridden by --config-add + assert!(config.tracking.enabled); // untouched + } + + // === Error Robustness Tests === + + #[test] + fn test_load_robust_to_malformed_toml() { + let tmp = tempfile::tempdir().unwrap(); + let bad_config = tmp.path().join("config.toml"); + std::fs::write(&bad_config, "this is not valid toml {{{{").unwrap(); + + // Malformed TOML should parse to default (not crash) + let result: Result = toml::from_str("this is not valid toml {{{{"); + assert!(result.is_err()); + + // Config::load falls back to defaults for malformed content + let config = Config::load().unwrap(); + assert!(config.tracking.enabled); // defaults + } + + #[test] + fn test_load_robust_to_empty_config_file() { + // Empty string is valid TOML (all defaults) + let config: Config = toml::from_str("").unwrap(); + assert!(config.tracking.enabled); + assert_eq!(config.tracking.history_days, 90); + assert_eq!(config.display.max_width, 120); + } + + #[test] + fn test_load_robust_to_binary_garbage_config() { + let garbage = "\x00\x01\x02 binary garbage"; + let result: Result = toml::from_str(garbage); + assert!(result.is_err()); // Should error, not panic + } + + #[test] + fn test_overlay_robust_to_malformed_toml() { + let result: Result = toml::from_str("not valid {{{"); + assert!(result.is_err()); // Should error, not panic + } + + #[test] + fn test_overlay_from_empty_string() { + // Empty overlay should be all-None (no overrides) + let overlay: ConfigOverlay = toml::from_str("").unwrap(); + assert!(overlay.tracking.is_none()); + assert!(overlay.display.is_none()); + assert!(overlay.filters.is_none()); + } + + #[test] + fn test_config_path_exclusive_nonexistent_falls_back() { + // If --config-path points to non-existent file, use defaults + let path = PathBuf::from("/nonexistent/config.toml"); + assert!(!path.exists()); + // Simulates load_merged logic: non-existent → Config::default() + let config = Config::default(); + assert!(config.tracking.enabled); + } + + #[test] + fn test_config_add_nonexistent_path_skipped() { + // --config-add with non-existent path should be silently skipped + let path = PathBuf::from("/nonexistent/overlay.toml"); + assert!(!path.exists()); + // The load_merged code does `if add_path.exists()` — non-existent skipped + let mut config = Config::default(); + config.tracking.history_days = 42; + // Config unchanged because path doesn't exist + assert_eq!(config.tracking.history_days, 42); + } + + #[test] + fn test_config_add_malformed_file_skipped() { + let tmp = tempfile::tempdir().unwrap(); + let bad_file = tmp.path().join("bad.toml"); + std::fs::write(&bad_file, "not valid {{{{ toml").unwrap(); + + // Simulates load_merged: if let Ok(overlay) = toml::from_str(...) + let content = std::fs::read_to_string(&bad_file).unwrap(); + let result = toml::from_str::(&content); + assert!(result.is_err()); // Bad TOML → no overlay applied + + // Config should remain at defaults + let config = Config::default(); + assert!(config.tracking.enabled); + } + + #[test] + fn test_set_value_creates_parent_dirs() { + let tmp = tempfile::tempdir().unwrap(); + let config_path = tmp.path().join("nested").join("deep").join("config.toml"); + + // save_to should create parent dirs + let config = Config::default(); + let result = config.save_to(&config_path); + assert!(result.is_ok()); + assert!(config_path.exists()); + } + + // === DiscoveryConfig tests === + + #[test] + fn test_default_discovery_config() { + let config = DiscoveryConfig::default(); + assert_eq!(config.search_dirs, vec![".claude", ".gemini", ".rtk"]); + assert_eq!(config.global_dirs, vec![".claude", ".gemini"]); + assert!(config.rules_dirs.is_empty()); + } + + #[test] + fn test_discovery_config_roundtrip_toml() { + let config = Config::default(); + let toml_str = toml::to_string_pretty(&config).unwrap(); + let parsed: Config = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.discovery.search_dirs, config.discovery.search_dirs); + assert_eq!(parsed.discovery.global_dirs, config.discovery.global_dirs); + assert_eq!(parsed.discovery.rules_dirs, config.discovery.rules_dirs); + } + + #[test] + fn test_discovery_config_from_toml_custom() { + let toml_str = r#" +[discovery] +search_dirs = [".rtk", ".custom"] +global_dirs = [".mytools"] +rules_dirs = ["/opt/rtk/rules", "/home/user/rules"] +"#; + let config: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(config.discovery.search_dirs, vec![".rtk", ".custom"]); + assert_eq!(config.discovery.global_dirs, vec![".mytools"]); + assert_eq!( + config.discovery.rules_dirs, + vec![ + PathBuf::from("/opt/rtk/rules"), + PathBuf::from("/home/user/rules") + ] + ); + } + + #[test] + fn test_discovery_overlay_applies() { + let mut config = Config::default(); + let overlay: ConfigOverlay = toml::from_str( + r#" +[discovery] +search_dirs = [".only-rtk"] +rules_dirs = ["/custom/rules"] +"#, + ) + .unwrap(); + overlay.apply(&mut config); + assert_eq!(config.discovery.search_dirs, vec![".only-rtk"]); + // global_dirs unchanged (not in overlay) + assert_eq!(config.discovery.global_dirs, vec![".claude", ".gemini"]); + assert_eq!( + config.discovery.rules_dirs, + vec![PathBuf::from("/custom/rules")] + ); + } + + #[test] + fn test_apply_value_discovery_search_dirs() { + let mut config = Config::default(); + apply_value(&mut config, "discovery.search_dirs", ".rtk,.custom").unwrap(); + assert_eq!(config.discovery.search_dirs, vec![".rtk", ".custom"]); + } + + #[test] + fn test_apply_value_discovery_global_dirs() { + let mut config = Config::default(); + apply_value(&mut config, "discovery.global_dirs", ".claude").unwrap(); + assert_eq!(config.discovery.global_dirs, vec![".claude"]); + } + + #[test] + fn test_apply_value_discovery_rules_dirs() { + let mut config = Config::default(); + apply_value(&mut config, "discovery.rules_dirs", "/a,/b,/c").unwrap(); + assert_eq!( + config.discovery.rules_dirs, + vec![ + PathBuf::from("/a"), + PathBuf::from("/b"), + PathBuf::from("/c") + ] + ); + } + + #[test] + fn test_get_rules_dir_default() { + // Without any config override, get_rules_dir returns ~/.config/rtk/ + let dir = get_rules_dir().unwrap(); + assert!( + dir.to_string_lossy().contains("rtk"), + "Default rules dir should contain 'rtk': {}", + dir.display() + ); + } + + #[test] + fn test_discovery_config_empty_rules_dirs_not_serialized() { + // Empty rules_dirs should be omitted from TOML output (skip_serializing_if) + let config = Config::default(); + let toml_str = toml::to_string_pretty(&config).unwrap(); + assert!( + !toml_str.contains("rules_dirs"), + "Empty rules_dirs should be omitted from serialization" + ); + } +} diff --git a/src/config/rules.rs b/src/config/rules.rs new file mode 100644 index 0000000..1363474 --- /dev/null +++ b/src/config/rules.rs @@ -0,0 +1,881 @@ +//! Unified Rule system: safety rules, remaps, and warnings as data-driven MD files. +//! +//! Replaces `SafetyAction`, `SafetyRule`, `rule!()` macro, and `get_rules()` from safety.rs. +//! Rules are MD files with YAML frontmatter, loaded from built-in defaults and user directories. + +use anyhow::{anyhow, Result}; +use std::collections::{BTreeMap, HashMap}; +use std::sync::OnceLock; + +/// A unified rule: safety, remap, warning, or block. +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Rule { + pub name: String, + #[serde(default)] + pub patterns: Vec, + #[serde(default = "default_block")] + pub action: String, + #[serde(default)] + pub redirect: Option, + #[serde(default = "default_always")] + pub when: String, + #[serde(default)] + pub env_var: Option, + #[serde(default = "default_true")] + pub enabled: bool, + #[serde(skip)] + pub message: String, + #[serde(skip)] + pub source: String, +} + +fn default_block() -> String { + "block".into() +} +fn default_always() -> String { + "always".into() +} +fn default_true() -> bool { + true +} + +impl Rule { + /// Check if rule should apply given current env + predicates. + pub fn should_apply(&self) -> bool { + // Env var opt-out check + if let Some(ref env) = self.env_var { + if let Ok(val) = std::env::var(env) { + if val == "0" || val == "false" { + return false; + } + } + } + // When predicate + check_when(&self.when) + } +} + +// === Predicate Registry === + +type PredicateFn = fn() -> bool; + +fn predicate_registry() -> &'static HashMap<&'static str, PredicateFn> { + static REGISTRY: OnceLock> = OnceLock::new(); + REGISTRY.get_or_init(|| { + let mut m = HashMap::new(); + m.insert("always", (|| true) as PredicateFn); + m.insert( + "has_unstaged_changes", + crate::cmd::predicates::has_unstaged_changes as PredicateFn, + ); + m + }) +} + +pub fn check_when(when: &str) -> bool { + if when == "always" || when.is_empty() { + return true; + } + if let Some(func) = predicate_registry().get(when) { + return func(); + } + // Bash fallback (matches clautorun behavior) + std::process::Command::new("sh") + .args(["-c", when]) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +// === Parse & Load === + +/// Parse a rule from MD content with YAML frontmatter. +pub fn parse_rule(content: &str, source: &str) -> Result { + let trimmed = content.trim(); + let rest = trimmed + .strip_prefix("---") + .ok_or_else(|| anyhow!("No frontmatter: missing opening ---"))?; + let end = rest + .find("\n---") + .ok_or_else(|| anyhow!("Unclosed frontmatter: missing closing ---"))?; + let yaml = &rest[..end]; + let body = rest[end + 4..].trim(); + let mut rule: Rule = serde_yaml::from_str(yaml)?; + rule.message = body.to_string(); + rule.source = source.to_string(); + Ok(rule) +} + +/// Embedded default rules (compiled into binary). +pub const DEFAULT_RULES: &[&str] = &[ + include_str!("../rules/rtk.safety.rm-to-trash.md"), + include_str!("../rules/rtk.safety.git-reset-hard.md"), + include_str!("../rules/rtk.safety.git-checkout-dashdash.md"), + include_str!("../rules/rtk.safety.git-checkout-dot.md"), + include_str!("../rules/rtk.safety.git-stash-drop.md"), + include_str!("../rules/rtk.safety.git-clean-fd.md"), + include_str!("../rules/rtk.safety.git-clean-df.md"), + include_str!("../rules/rtk.safety.git-clean-f.md"), + include_str!("../rules/rtk.safety.block-cat.md"), + include_str!("../rules/rtk.safety.block-sed.md"), + include_str!("../rules/rtk.safety.block-head.md"), +]; + +static RULES_CACHE: OnceLock> = OnceLock::new(); + +/// Load all rules: embedded defaults + user overrides. Cached via OnceLock. +pub fn load_all() -> &'static [Rule] { + RULES_CACHE.get_or_init(|| { + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // 1. Embedded defaults (lowest priority) + for content in DEFAULT_RULES { + match parse_rule(content, "builtin") { + Ok(rule) if rule.enabled => { + rules_by_name.insert(rule.name.clone(), rule); + } + Ok(rule) => { + rules_by_name.remove(&rule.name); + } + Err(e) => eprintln!("rtk: bad builtin rule: {e}"), + } + } + + // 2. User files (higher priority overrides by name) + for path in super::discovery::discover_rtk_files() { + let content = match std::fs::read_to_string(path) { + Ok(c) => c, + Err(_) => continue, + }; + match parse_rule(&content, &path.display().to_string()) { + Ok(rule) if rule.enabled => { + rules_by_name.insert(rule.name.clone(), rule); + } + Ok(rule) => { + rules_by_name.remove(&rule.name); + } + Err(_) => continue, + } + } + + rules_by_name.into_values().collect() + }) +} + +// === Global Option Stripping === + +/// Strip global options that appear between a command and its subcommand. +/// +/// Tools like git, cargo, docker, and kubectl accept global options before +/// the subcommand (e.g., `git -C /path --no-pager status`). These must be +/// stripped before pattern matching so that safety rules like `"git reset --hard"` +/// still match `git --no-pager reset --hard`. +/// +/// Based on the patterns from upstream PR #99 (hooks/rtk-rewrite.sh). +fn strip_global_options(full_cmd: &str) -> String { + let words: Vec<&str> = full_cmd.split_whitespace().collect(); + if words.is_empty() { + return full_cmd.to_string(); + } + + let binary = words[0]; + let rest = &words[1..]; + + match binary { + "git" => { + // Strip: -C , -c , --no-pager, --no-optional-locks, + // --bare, --literal-pathspecs, --key=value + let mut result = vec!["git"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if (w == "-C" || w == "-c") && i + 1 < rest.len() { + i += 2; // skip flag + argument + } else if w.starts_with("--") + && w.contains('=') + && !w.starts_with("--hard") + && !w.starts_with("--force") + { + i += 1; // skip --key=value global options + } else if matches!( + w, + "--no-pager" + | "--no-optional-locks" + | "--bare" + | "--literal-pathspecs" + | "--paginate" + | "--git-dir" + ) { + i += 1; // skip standalone boolean global options + } else { + // First non-global-option word is the subcommand; keep everything from here + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + "cargo" => { + // Strip: +toolchain (e.g., cargo +nightly test) + let mut result = vec!["cargo"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if w.starts_with('+') { + i += 1; // skip +toolchain + } else { + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + "docker" => { + // Strip: -H , --context , --config , --key=value + let mut result = vec!["docker"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if matches!(w, "-H" | "--context" | "--config") && i + 1 < rest.len() { + i += 2; // skip flag + argument + } else if w.starts_with("--") && w.contains('=') { + i += 1; // skip --key=value + } else { + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + "kubectl" => { + // Strip: --context , --kubeconfig , --namespace , -n , --key=value + let mut result = vec!["kubectl"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if matches!(w, "--context" | "--kubeconfig" | "--namespace" | "-n") + && i + 1 < rest.len() + { + i += 2; // skip flag + argument + } else if w.starts_with("--") && w.contains('=') { + i += 1; // skip --key=value + } else { + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + _ => full_cmd.to_string(), + } +} + +// === Pattern Matching === + +/// Check if a rule matches a command. +/// +/// - Single-word pattern: exact binary match (avoids "cat" matching "catalog") +/// - Multi-word pattern: prefix match on full command string (with global option stripping) +/// - Raw mode (binary=None): word-boundary search (handles "sudo rm") +pub fn matches_rule(rule: &Rule, binary: Option<&str>, full_cmd: &str) -> bool { + rule.patterns.iter().any(|pat| { + if pat.contains(' ') { + // Multi-word: prefix match, also try with global options stripped + let normalized = strip_global_options(full_cmd); + full_cmd.starts_with(pat.as_str()) || normalized.starts_with(pat.as_str()) + } else if let Some(bin) = binary { + // Parsed mode: exact binary + bin == pat + } else { + // Raw mode: word-boundary (handles "sudo rm", "/usr/bin/rm") + full_cmd + .split_whitespace() + .any(|w| w == pat || w.ends_with(&format!("/{pat}"))) + } + }) +} + +// === Remap Helper === + +/// Try to expand a single-word remap alias (e.g., "t --lib" → "cargo test --lib"). +/// +/// Only matches single-word patterns with `action: "rewrite"`. Multi-word rewrites +/// are safety rules handled by `check()`. Order: remap → safety → execute. +pub fn try_remap(raw: &str) -> Option { + let first_word = raw.split_whitespace().next()?; + for rule in load_all() { + if rule.action != "rewrite" { + continue; + } + // Only remap single-word pattern matches (aliases like "t" → "cargo test") + if !rule + .patterns + .iter() + .any(|p| !p.contains(' ') && p == first_word) + { + continue; + } + if !rule.should_apply() { + continue; + } + if let Some(ref redirect) = rule.redirect { + let rest = raw[first_word.len()..].trim(); + return Some(redirect.replace("{args}", rest)); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_rule_valid() { + let content = "---\nname: test-rule\npatterns: [rm]\naction: trash\n---\nSafety message."; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.name, "test-rule"); + assert_eq!(rule.patterns, vec!["rm"]); + assert_eq!(rule.action, "trash"); + assert_eq!(rule.message, "Safety message."); + assert_eq!(rule.source, "test"); + } + + #[test] + fn test_parse_rule_no_frontmatter() { + let content = "No frontmatter here"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_unclosed_frontmatter() { + let content = "---\nname: broken\n"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_message_body() { + let content = "---\nname: test\n---\n\nLine 1\n\nLine 2"; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.message, "Line 1\n\nLine 2"); + } + + #[test] + fn test_parse_rule_defaults() { + let content = "---\nname: minimal\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.action, "block"); // default + assert_eq!(rule.when, "always"); // default + assert!(rule.enabled); // default true + assert!(rule.patterns.is_empty()); // default empty + } + + #[test] + fn test_parse_rule_all_fields() { + let content = r#"--- +name: full +patterns: ["git reset --hard"] +action: rewrite +redirect: "git stash && git reset --hard {args}" +when: has_unstaged_changes +env_var: RTK_SAFE_COMMANDS +enabled: true +--- +Full message."#; + let rule = parse_rule(content, "builtin").unwrap(); + assert_eq!(rule.name, "full"); + assert_eq!(rule.patterns, vec!["git reset --hard"]); + assert_eq!(rule.action, "rewrite"); + assert_eq!( + rule.redirect.as_deref(), + Some("git stash && git reset --hard {args}") + ); + assert_eq!(rule.when, "has_unstaged_changes"); + assert_eq!(rule.env_var.as_deref(), Some("RTK_SAFE_COMMANDS")); + assert!(rule.enabled); + assert_eq!(rule.message, "Full message."); + } + + #[test] + fn test_matches_rule_single_word_binary() { + let content = "---\nname: test\npatterns: [rm]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(matches_rule(&rule, Some("rm"), "rm file.txt")); + assert!(!matches_rule(&rule, Some("rmdir"), "rmdir empty")); + } + + #[test] + fn test_matches_rule_multiple_patterns_in_one_rule() { + let content = + "---\nname: test\npatterns: [\"chmod -R 777\", \"chmod 777\"]\naction: warn\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.patterns.len(), 2); + assert!(matches_rule(&rule, Some("chmod"), "chmod -R 777 /tmp")); + assert!(matches_rule(&rule, Some("chmod"), "chmod 777 /tmp")); + assert!(!matches_rule(&rule, Some("chmod"), "chmod 755 /tmp")); + } + + #[test] + fn test_matches_rule_multi_word_prefix() { + let content = "---\nname: test\npatterns: [\"git reset --hard\"]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(matches_rule(&rule, Some("git"), "git reset --hard HEAD~1")); + assert!(!matches_rule(&rule, Some("git"), "git reset --soft HEAD")); + } + + #[test] + fn test_matches_rule_raw_mode_word_boundary() { + let content = "---\nname: test\npatterns: [rm]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + // Raw mode: None for binary + assert!(matches_rule(&rule, None, "rm file.txt")); + assert!(matches_rule(&rule, None, "sudo rm file.txt")); + assert!(matches_rule(&rule, None, "/usr/bin/rm file.txt")); + // Should NOT match substrings + assert!(!matches_rule(&rule, None, "trim file.txt")); + assert!(!matches_rule(&rule, None, "farm --harvest")); + } + + #[test] + fn test_should_apply_env_var_opt_out() { + let content = "---\nname: test\npatterns: [rm]\nenv_var: RTK_TEST_VAR\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + + // No env var set → applies (opt-out model) + assert!(rule.should_apply()); + + // Set to "0" → disabled + std::env::set_var("RTK_TEST_VAR", "0"); + assert!(!rule.should_apply()); + + // Set to "false" → disabled + std::env::set_var("RTK_TEST_VAR", "false"); + assert!(!rule.should_apply()); + + // Set to "1" → enabled + std::env::set_var("RTK_TEST_VAR", "1"); + assert!(rule.should_apply()); + + std::env::remove_var("RTK_TEST_VAR"); + } + + #[test] + fn test_should_apply_when_always() { + let content = "---\nname: test\nwhen: always\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(rule.should_apply()); + } + + #[test] + fn test_load_all_includes_builtins() { + let rules = load_all(); + assert!( + rules.len() >= 11, + "Should have at least 11 built-in rules, got {}", + rules.len() + ); + // Check specific built-in names + let names: Vec<&str> = rules.iter().map(|r| r.name.as_str()).collect(); + assert!(names.contains(&"rm-to-trash")); + assert!(names.contains(&"block-cat")); + assert!(names.contains(&"git-reset-hard")); + } + + #[test] + fn test_check_when_always() { + assert!(check_when("always")); + assert!(check_when("")); + } + + #[test] + fn test_check_when_builtin_predicate() { + // has_unstaged_changes is registered - should not panic + let _ = check_when("has_unstaged_changes"); + } + + #[test] + fn test_check_when_bash_fallback() { + assert!(check_when("true")); + assert!(!check_when("false")); + } + + #[test] + fn test_try_remap_no_match() { + // "ls" is not a registered remap alias + assert!(try_remap("ls -la").is_none()); + } + + // Note: try_remap with a match requires user-defined rules in discovery dirs, + // which is tested in E2E tests rather than unit tests. + + #[test] + fn test_rule_override_by_name() { + // Simulate: builtin rule overridden by user rule with same name + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + let builtin = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: trash\n---\nBuiltin message.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // User override: same name, different action + let user_rule = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: block\n---\nUser blocked rm.", + "~/.config/rtk/rtk.safety.rm-to-trash.md", + ) + .unwrap(); + rules_by_name.insert(user_rule.name.clone(), user_rule); + + let rules: Vec = rules_by_name.into_values().collect(); + assert_eq!(rules.len(), 1); // Overridden, not duplicated + assert_eq!(rules[0].action, "block"); // User's action wins + assert_eq!(rules[0].message, "User blocked rm."); // User's message wins + } + + #[test] + fn test_rule_disabled_override_removes() { + // Simulate: user disables a builtin rule + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + let builtin = parse_rule( + "---\nname: block-cat\npatterns: [cat]\naction: suggest_tool\n---\nUse Read.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // User disables it + let disabled = parse_rule( + "---\nname: block-cat\nenabled: false\n---\nDisabled by user.", + "~/.config/rtk/rtk.safety.block-cat.md", + ) + .unwrap(); + assert!(!disabled.enabled); + + // The load_all logic: enabled=false removes from map + if !disabled.enabled { + rules_by_name.remove(&disabled.name); + } + + assert!(rules_by_name.is_empty()); // Rule removed + } + + #[test] + fn test_all_builtin_rules_parse_successfully() { + for (i, content) in DEFAULT_RULES.iter().enumerate() { + let result = parse_rule(content, "builtin"); + assert!( + result.is_ok(), + "Built-in rule #{} failed to parse: {:?}", + i, + result.err() + ); + let rule = result.unwrap(); + assert!(!rule.name.is_empty(), "Rule #{} has empty name", i); + assert!( + rule.enabled, + "Rule #{} ({}) should be enabled", + i, rule.name + ); + } + } + + #[test] + fn test_all_builtin_rules_have_patterns() { + for content in DEFAULT_RULES { + let rule = parse_rule(content, "builtin").unwrap(); + assert!( + !rule.patterns.is_empty(), + "Rule '{}' has no patterns", + rule.name + ); + } + } + + // === Error Robustness Tests === + + #[test] + fn test_parse_rule_empty_string() { + assert!(parse_rule("", "test").is_err()); + } + + #[test] + fn test_parse_rule_binary_garbage() { + assert!(parse_rule("\x00\x01\x02 garbage", "test").is_err()); + } + + #[test] + fn test_parse_rule_valid_frontmatter_invalid_yaml() { + let content = "---\n: : : not valid yaml\n---\nbody"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_missing_name_field() { + // YAML without required 'name' field + let content = "---\npatterns: [rm]\n---\nbody"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_only_frontmatter_delimiters() { + let content = "---\n---\n"; + // Empty YAML → missing name → error + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_extra_fields_ignored() { + // Unknown fields in YAML should be silently ignored (serde default) + let content = "---\nname: test\nunknown_field: 42\nextra: true\n---\nbody"; + let rule = parse_rule(content, "test"); + assert!( + rule.is_ok(), + "Unknown fields should be ignored, got: {:?}", + rule.err() + ); + assert_eq!(rule.unwrap().name, "test"); + } + + #[test] + fn test_check_when_nonexistent_command() { + // A nonsense bash command should return false (not panic) + assert!(!check_when("totally_nonexistent_command_xyz_12345")); + } + + #[test] + fn test_try_remap_empty_string() { + assert!(try_remap("").is_none()); + } + + #[test] + fn test_try_remap_whitespace_only() { + assert!(try_remap(" ").is_none()); + } + + #[test] + fn test_matches_rule_empty_patterns() { + let content = "---\nname: no-patterns\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(!matches_rule(&rule, Some("rm"), "rm file")); + assert!(!matches_rule(&rule, None, "rm file")); + } + + // === Precedence Chain Tests === + + #[test] + fn test_full_precedence_chain_builtin_global_project() { + // Simulates the full load_all() precedence: builtin → global → project + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // 1. Builtin (lowest priority): action=trash + let builtin = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: trash\n---\nBuiltin.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // 2. Global user file (~/.config/rtk/): action=warn (user edited the exported file) + let global = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: warn\n---\nGlobal user override.", + "~/.config/rtk/rtk.safety.rm-to-trash.md", + ) + .unwrap(); + rules_by_name.insert(global.name.clone(), global); + + // 3. Project-local (.rtk/): action=block (project-specific) + let project = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: block\n---\nProject override.", + "/project/.rtk/rtk.safety.rm-to-trash.md", + ) + .unwrap(); + rules_by_name.insert(project.name.clone(), project); + + let rules: Vec = rules_by_name.into_values().collect(); + assert_eq!(rules.len(), 1, "Should be 1 rule after all overrides"); + assert_eq!(rules[0].action, "block", "Project-local should win"); + assert_eq!(rules[0].source, "/project/.rtk/rtk.safety.rm-to-trash.md"); + } + + #[test] + fn test_user_edited_export_overrides_builtin() { + // User exports builtins then edits one: edited file should override compiled builtin + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // Compiled builtin + let builtin = parse_rule( + "---\nname: block-cat\npatterns: [cat]\naction: suggest_tool\nredirect: Read\n---\nBuiltin.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // User-edited export: changed redirect + let edited = parse_rule( + "---\nname: block-cat\npatterns: [cat]\naction: suggest_tool\nredirect: \"Read (with limit=50)\"\n---\nUser customized.", + "~/.config/rtk/rtk.safety.block-cat.md", + ) + .unwrap(); + rules_by_name.insert(edited.name.clone(), edited); + + let rules: Vec = rules_by_name.into_values().collect(); + assert_eq!(rules.len(), 1); + assert_eq!( + rules[0].redirect.as_deref(), + Some("Read (with limit=50)"), + "User-edited redirect should win" + ); + assert!(rules[0].source.contains(".config/rtk/")); + } + + #[test] + fn test_project_local_disable_overrides_global_and_builtin() { + // Project disables a rule that exists both in builtins and global + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // Builtin + let builtin = parse_rule( + "---\nname: block-sed\npatterns: [sed]\naction: suggest_tool\n---\nBuiltin.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // Global user file (same as builtin, maybe exported) + let global = parse_rule( + "---\nname: block-sed\npatterns: [sed]\naction: suggest_tool\n---\nGlobal.", + "~/.config/rtk/rtk.safety.block-sed.md", + ) + .unwrap(); + rules_by_name.insert(global.name.clone(), global); + + // Project-local disables it + let disabled = parse_rule( + "---\nname: block-sed\nenabled: false\n---\nDisabled for this project.", + "/project/.rtk/rtk.safety.block-sed.md", + ) + .unwrap(); + if !disabled.enabled { + rules_by_name.remove(&disabled.name); + } + + assert!( + rules_by_name.is_empty(), + "Project-local disable should remove rule entirely" + ); + } + + // === Global Option Stripping (PR #99 parity) === + // Table-driven: (input, expected_output) pairs covering git, cargo, docker, kubectl. + + #[test] + fn test_strip_global_options() { + let cases: &[(&str, &str)] = &[ + // Git: single flags + ("git --no-pager status", "git status"), + ("git -C /path/to/project status", "git status"), + ("git -c core.autocrlf=true diff", "git diff"), + ("git --git-dir=/path/.git status", "git status"), + ("git --no-optional-locks status", "git status"), + ("git --bare log --oneline", "git log --oneline"), + ("git --literal-pathspecs add .", "git add ."), + // Git: multiple globals stacked + ( + "git -C /path --no-pager --no-optional-locks reset --hard", + "git reset --hard", + ), + // Git: subcommand flags preserved (not stripped) + ("git reset --hard HEAD~1", "git reset --hard HEAD~1"), + ("git checkout --force main", "git checkout --force main"), + // Git: no globals (identity) + ("git status", "git status"), + ("git log --oneline -10", "git log --oneline -10"), + // Cargo: toolchain prefix + ("cargo +nightly test", "cargo test"), + ("cargo +stable build --release", "cargo build --release"), + ("cargo test", "cargo test"), // no prefix (identity) + // Docker: global flags + ("docker --context prod ps", "docker ps"), + ("docker -H tcp://host:2375 images", "docker images"), + ("docker --config /tmp/.docker run hello", "docker run hello"), + ("docker ps", "docker ps"), // no globals (identity) + // Kubectl: global flags + ("kubectl -n kube-system get pods", "kubectl get pods"), + ( + "kubectl --context prod --namespace default describe pod foo", + "kubectl describe pod foo", + ), + ("kubectl --kubeconfig=/path get svc", "kubectl get svc"), + ("kubectl get pods", "kubectl get pods"), // no globals (identity) + // Non-matching commands (identity) + ("rm -rf /tmp/foo", "rm -rf /tmp/foo"), + ("cat file.txt", "cat file.txt"), + ("echo hello", "echo hello"), + ]; + for (input, expected) in cases { + assert_eq!( + strip_global_options(input), + *expected, + "strip_global_options({input:?})" + ); + } + } + + // === Rule Matching with Global Options (PR #99 parity) === + // Multi-word safety patterns must match even with global options inserted. + + #[test] + fn test_matches_rule_with_global_options() { + let cases: &[(&str, &str, bool)] = &[ + // (pattern, full_cmd, expected_match) + ("git reset --hard", "git --no-pager reset --hard HEAD", true), + ("git reset --hard", "git -C /path reset --hard", true), + ( + "git reset --hard", + "git -C /p --no-pager --no-optional-locks reset --hard", + true, + ), + ("git checkout .", "git -C /project checkout .", true), + ( + "git checkout --", + "git --no-pager checkout -- file.txt", + true, + ), + ( + "git clean -fd", + "git -C /path --no-pager --no-optional-locks clean -fd", + true, + ), + ("git stash drop", "git --no-pager stash drop", true), + // No globals: direct match still works + ("git reset --hard", "git reset --hard HEAD~1", true), + ("git checkout .", "git checkout .", true), + // Non-matching + ("git reset --hard", "git reset --soft HEAD", false), + ("git checkout .", "git checkout main", false), + ]; + for (pattern, full_cmd, expected) in cases { + let yaml = format!("---\nname: test\npatterns: [\"{pattern}\"]\n---\n"); + let rule = parse_rule(&yaml, "test").unwrap(); + let binary = full_cmd.split_whitespace().next(); + assert_eq!( + matches_rule(&rule, binary, full_cmd), + *expected, + "matches_rule(pat={pattern:?}, cmd={full_cmd:?})" + ); + } + } + + #[test] + fn test_matches_rule_empty_command() { + let content = "---\nname: test\npatterns: [rm]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + // Parsed mode: binary match is independent of full_cmd + assert!(matches_rule(&rule, Some("rm"), "")); + // Raw mode: empty string has no words → no match + assert!(!matches_rule(&rule, None, "")); + } +} diff --git a/src/container.rs b/src/container.rs index c017b43..b6cb0cf 100644 --- a/src/container.rs +++ b/src/container.rs @@ -60,7 +60,12 @@ fn docker_ps(_verbose: u8) -> Result<()> { if parts.len() >= 4 { let id = &parts[0][..12.min(parts[0].len())]; let name = parts[1]; - let short_image = parts.get(3).unwrap_or(&"").split('/').last().unwrap_or(""); + let short_image = parts + .get(3) + .unwrap_or(&"") + .split('/') + .next_back() + .unwrap_or(""); let ports = compact_ports(parts.get(4).unwrap_or(&"")); if ports == "-" { rtk.push_str(&format!(" {} {} ({})\n", id, name, short_image)); @@ -393,7 +398,7 @@ fn compact_ports(ports: &str) -> String { // Extract just the port numbers let port_nums: Vec<&str> = ports .split(',') - .filter_map(|p| p.split("->").next().and_then(|s| s.split(':').last())) + .filter_map(|p| p.split("->").next().and_then(|s| s.split(':').next_back())) .collect(); if port_nums.len() <= 3 { diff --git a/src/find_cmd.rs b/src/find_cmd.rs index 679288e..a56a73a 100644 --- a/src/find_cmd.rs +++ b/src/find_cmd.rs @@ -57,7 +57,7 @@ pub fn run( }; let ft = entry.file_type(); - let is_dir = ft.as_ref().map_or(false, |t| t.is_dir()); + let is_dir = ft.as_ref().is_some_and(|t| t.is_dir()); // Filter by type if want_dirs && !is_dir { diff --git a/src/gh_cmd.rs b/src/gh_cmd.rs index 1e32fad..b0aafe4 100644 --- a/src/gh_cmd.rs +++ b/src/gh_cmd.rs @@ -488,12 +488,10 @@ fn list_issues(args: &[String], _verbose: u8, ultra_compact: bool) -> Result<()> } else { "C" } + } else if state == "OPEN" { + "🟢" } else { - if state == "OPEN" { - "🟢" - } else { - "🔴" - } + "🔴" }; let line = format!(" {} #{} {}\n", icon, number, truncate(title, 60)); filtered.push_str(&line); diff --git a/src/git.rs b/src/git.rs index 8f8d891..86e9c47 100644 --- a/src/git.rs +++ b/src/git.rs @@ -307,9 +307,9 @@ fn run_log(args: &[String], _max_lines: Option, verbose: u8) -> Result<() }); // Check if user provided limit flag - let has_limit_flag = args.iter().any(|arg| { - arg.starts_with('-') && arg.chars().nth(1).map_or(false, |c| c.is_ascii_digit()) - }); + let has_limit_flag = args + .iter() + .any(|arg| arg.starts_with('-') && arg.chars().nth(1).is_some_and(|c| c.is_ascii_digit())); // Apply RTK defaults only if user didn't specify them if !has_format_flag { @@ -323,7 +323,7 @@ fn run_log(args: &[String], _max_lines: Option, verbose: u8) -> Result<() // Extract limit from args if provided args.iter() .find(|arg| { - arg.starts_with('-') && arg.chars().nth(1).map_or(false, |c| c.is_ascii_digit()) + arg.starts_with('-') && arg.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) }) .and_then(|arg| arg[1..].parse::().ok()) .unwrap_or(10) @@ -677,7 +677,7 @@ fn run_commit(message: &str, verbose: u8) -> Result<()> { // Extract commit hash from output like "[main abc1234] message" let compact = if let Some(line) = stdout.lines().next() { if let Some(hash_start) = line.find(' ') { - let hash = line[1..hash_start].split(' ').last().unwrap_or(""); + let hash = line[1..hash_start].split(' ').next_back().unwrap_or(""); if !hash.is_empty() && hash.len() >= 7 { format!("ok ✓ {}", &hash[..7.min(hash.len())]) } else { @@ -698,23 +698,21 @@ fn run_commit(message: &str, verbose: u8) -> Result<()> { &raw_output, &compact, ); + } else if stderr.contains("nothing to commit") || stdout.contains("nothing to commit") { + println!("ok (nothing to commit)"); + timer.track( + &format!("git commit -m \"{}\"", message), + "rtk git commit", + &raw_output, + "ok (nothing to commit)", + ); } else { - if stderr.contains("nothing to commit") || stdout.contains("nothing to commit") { - println!("ok (nothing to commit)"); - timer.track( - &format!("git commit -m \"{}\"", message), - "rtk git commit", - &raw_output, - "ok (nothing to commit)", - ); - } else { - eprintln!("FAILED: git commit"); - if !stderr.trim().is_empty() { - eprintln!("{}", stderr); - } - if !stdout.trim().is_empty() { - eprintln!("{}", stdout); - } + eprintln!("FAILED: git commit"); + if !stderr.trim().is_empty() { + eprintln!("{}", stderr); + } + if !stdout.trim().is_empty() { + eprintln!("{}", stdout); } } diff --git a/src/golangci_cmd.rs b/src/golangci_cmd.rs index ca55a0c..6e4b3f0 100644 --- a/src/golangci_cmd.rs +++ b/src/golangci_cmd.rs @@ -9,8 +9,10 @@ use std::process::Command; struct Position { #[serde(rename = "Filename")] filename: String, + #[serde(default)] #[serde(rename = "Line")] line: usize, + #[serde(default)] #[serde(rename = "Column")] column: usize, } @@ -19,6 +21,7 @@ struct Position { struct Issue { #[serde(rename = "FromLinter")] from_linter: String, + #[serde(default)] #[serde(rename = "Text")] text: String, #[serde(rename = "Pos")] diff --git a/src/init.rs b/src/init.rs index 961e4ac..f575e62 100644 --- a/src/init.rs +++ b/src/init.rs @@ -4,9 +4,6 @@ use std::io::Write; use std::path::{Path, PathBuf}; use tempfile::NamedTempFile; -// Embedded hook script (guards before set -euo pipefail) -const REWRITE_HOOK: &str = include_str!("../hooks/rtk-rewrite.sh"); - // Embedded slim RTK awareness instructions const RTK_SLIM: &str = include_str!("../hooks/rtk-awareness.md"); @@ -179,55 +176,13 @@ pub fn run( } } -/// Prepare hook directory and return paths (hook_dir, hook_path) -fn prepare_hook_paths() -> Result<(PathBuf, PathBuf)> { - let claude_dir = resolve_claude_dir()?; - let hook_dir = claude_dir.join("hooks"); - fs::create_dir_all(&hook_dir) - .with_context(|| format!("Failed to create hook directory: {}", hook_dir.display()))?; - let hook_path = hook_dir.join("rtk-rewrite.sh"); - Ok((hook_dir, hook_path)) -} - -/// Write hook file if missing or outdated, return true if changed -#[cfg(unix)] -fn ensure_hook_installed(hook_path: &Path, verbose: u8) -> Result { - let changed = if hook_path.exists() { - let existing = fs::read_to_string(hook_path) - .with_context(|| format!("Failed to read existing hook: {}", hook_path.display()))?; - - if existing == REWRITE_HOOK { - if verbose > 0 { - eprintln!("Hook already up to date: {}", hook_path.display()); - } - false - } else { - fs::write(hook_path, REWRITE_HOOK) - .with_context(|| format!("Failed to write hook to {}", hook_path.display()))?; - if verbose > 0 { - eprintln!("Updated hook: {}", hook_path.display()); - } - true - } - } else { - fs::write(hook_path, REWRITE_HOOK) - .with_context(|| format!("Failed to write hook to {}", hook_path.display()))?; - if verbose > 0 { - eprintln!("Created hook: {}", hook_path.display()); - } - true - }; - - // Set executable permissions - use std::os::unix::fs::PermissionsExt; - fs::set_permissions(hook_path, fs::Permissions::from_mode(0o755)) - .with_context(|| format!("Failed to set hook permissions: {}", hook_path.display()))?; - - Ok(changed) -} - /// Idempotent file write: create or update if content differs -fn write_if_changed(path: &Path, content: &str, name: &str, verbose: u8) -> Result { +pub(crate) fn write_if_changed( + path: &Path, + content: &str, + name: &str, + verbose: u8, +) -> Result { if path.exists() { let existing = fs::read_to_string(path) .with_context(|| format!("Failed to read {}: {}", name, path.display()))?; @@ -257,7 +212,7 @@ fn write_if_changed(path: &Path, content: &str, name: &str, verbose: u8) -> Resu /// Atomic write using tempfile + rename /// Prevents corruption on crash/interrupt -fn atomic_write(path: &Path, content: &str) -> Result<()> { +pub(crate) fn atomic_write(path: &Path, content: &str) -> Result<()> { let parent = path.parent().with_context(|| { format!( "Cannot write to {}: path has no parent directory", @@ -311,13 +266,13 @@ fn prompt_user_consent(settings_path: &Path) -> Result { } /// Print manual instructions for settings.json patching -fn print_manual_instructions(hook_path: &Path) { +fn print_manual_instructions(hook_command: &str) { println!("\n MANUAL STEP: Add this to ~/.claude/settings.json:"); println!(" {{"); println!(" \"hooks\": {{ \"PreToolUse\": [{{"); println!(" \"matcher\": \"Bash\","); println!(" \"hooks\": [{{ \"type\": \"command\","); - println!(" \"command\": \"{}\"", hook_path.display()); + println!(" \"command\": \"{}\"", hook_command); println!(" }}]"); println!(" }}]}}"); println!(" }}"); @@ -343,7 +298,7 @@ fn remove_hook_from_json(root: &mut serde_json::Value) -> bool { if let Some(hooks_array) = entry.get("hooks").and_then(|h| h.as_array()) { for hook in hooks_array { if let Some(command) = hook.get("command").and_then(|c| c.as_str()) { - if command.contains("rtk-rewrite.sh") { + if command.contains("rtk-rewrite.sh") || command.contains("rtk hook claude") { return false; // Remove this entry } } @@ -355,20 +310,21 @@ fn remove_hook_from_json(root: &mut serde_json::Value) -> bool { pre_tool_use_array.len() < original_len } -/// Remove RTK hook from settings.json file -/// Backs up before modification, returns true if hook was found and removed -fn remove_hook_from_settings(verbose: u8) -> Result { - let claude_dir = resolve_claude_dir()?; - let settings_path = claude_dir.join("settings.json"); - +/// Shared: remove a hook from a settings.json file +/// Reads, parses, applies `remover`, backs up, and atomically writes if changed. +fn remove_hook_from_settings_file( + settings_path: &Path, + remover: impl FnOnce(&mut serde_json::Value) -> bool, + verbose: u8, +) -> Result { if !settings_path.exists() { if verbose > 0 { - eprintln!("settings.json not found, nothing to remove"); + eprintln!("{} not found, nothing to remove", settings_path.display()); } return Ok(false); } - let content = fs::read_to_string(&settings_path) + let content = fs::read_to_string(settings_path) .with_context(|| format!("Failed to read {}", settings_path.display()))?; if content.trim().is_empty() { @@ -378,27 +334,31 @@ fn remove_hook_from_settings(verbose: u8) -> Result { let mut root: serde_json::Value = serde_json::from_str(&content) .with_context(|| format!("Failed to parse {} as JSON", settings_path.display()))?; - let removed = remove_hook_from_json(&mut root); + let removed = remover(&mut root); if removed { - // Backup original let backup_path = settings_path.with_extension("json.bak"); - fs::copy(&settings_path, &backup_path) + fs::copy(settings_path, &backup_path) .with_context(|| format!("Failed to backup to {}", backup_path.display()))?; - // Atomic write let serialized = serde_json::to_string_pretty(&root).context("Failed to serialize settings.json")?; - atomic_write(&settings_path, &serialized)?; + atomic_write(settings_path, &serialized)?; if verbose > 0 { - eprintln!("Removed RTK hook from settings.json"); + eprintln!("Removed RTK hook from {}", settings_path.display()); } } Ok(removed) } +/// Remove RTK hook from Claude settings.json +fn remove_hook_from_settings(verbose: u8) -> Result { + let settings_path = resolve_claude_dir()?.join("settings.json"); + remove_hook_from_settings_file(&settings_path, remove_hook_from_json, verbose) +} + /// Full uninstall: remove hook, RTK.md, @RTK.md reference, settings.json entry pub fn uninstall(global: bool, verbose: u8) -> Result<()> { if !global { @@ -408,12 +368,12 @@ pub fn uninstall(global: bool, verbose: u8) -> Result<()> { let claude_dir = resolve_claude_dir()?; let mut removed = Vec::new(); - // 1. Remove hook file + // 1. Remove legacy hook file (if present from old installs) let hook_path = claude_dir.join("hooks").join("rtk-rewrite.sh"); if hook_path.exists() { fs::remove_file(&hook_path) .with_context(|| format!("Failed to remove hook: {}", hook_path.display()))?; - removed.push(format!("Hook: {}", hook_path.display())); + removed.push(format!("Legacy hook: {}", hook_path.display())); } // 2. Remove RTK.md @@ -426,30 +386,18 @@ pub fn uninstall(global: bool, verbose: u8) -> Result<()> { // 3. Remove @RTK.md reference from CLAUDE.md let claude_md_path = claude_dir.join("CLAUDE.md"); - if claude_md_path.exists() { - let content = fs::read_to_string(&claude_md_path) - .with_context(|| format!("Failed to read CLAUDE.md: {}", claude_md_path.display()))?; - - if content.contains("@RTK.md") { - let new_content = content - .lines() - .filter(|line| !line.trim().starts_with("@RTK.md")) - .collect::>() - .join("\n"); - - // Clean up double blanks - let cleaned = clean_double_blanks(&new_content); - - fs::write(&claude_md_path, cleaned).with_context(|| { - format!("Failed to write CLAUDE.md: {}", claude_md_path.display()) - })?; - removed.push(format!("CLAUDE.md: removed @RTK.md reference")); - } + if remove_rtk_reference_from_file(&claude_md_path, "CLAUDE.md")? { + removed.push("CLAUDE.md: removed @RTK.md reference".to_string()); } - // 4. Remove hook entry from settings.json + // 4. Remove hook entry from Claude Code settings.json if remove_hook_from_settings(verbose)? { - removed.push("settings.json: removed RTK hook entry".to_string()); + removed.push("Claude settings.json: removed RTK hook entry".to_string()); + } + + // 5. Remove hook entry from Gemini settings.json + if remove_gemini_hook_from_settings(verbose)? { + removed.push("Gemini settings.json: removed RTK hook entry".to_string()); } // Report results @@ -466,18 +414,111 @@ pub fn uninstall(global: bool, verbose: u8) -> Result<()> { Ok(()) } -/// Orchestrator: patch settings.json with RTK hook -/// Handles reading, checking, prompting, merging, backing up, and atomic writing -fn patch_settings_json(hook_path: &Path, mode: PatchMode, verbose: u8) -> Result { - let claude_dir = resolve_claude_dir()?; - let settings_path = claude_dir.join("settings.json"); - let hook_command = hook_path - .to_str() - .context("Hook path contains invalid UTF-8")?; +/// Uninstall RTK Gemini CLI integration. +/// Mirrors Claude uninstall: removes RTK.md, @RTK.md reference, and settings.json hook. +pub fn uninstall_gemini(verbose: u8) -> Result<()> { + let gemini_dir = resolve_gemini_dir()?; + let mut removed = Vec::new(); + // 1. Remove RTK.md + let rtk_md_path = gemini_dir.join("RTK.md"); + if rtk_md_path.exists() { + fs::remove_file(&rtk_md_path) + .with_context(|| format!("Failed to remove RTK.md: {}", rtk_md_path.display()))?; + removed.push(format!("RTK.md: {}", rtk_md_path.display())); + } + + // 2. Remove @RTK.md reference from GEMINI.md + let gemini_md_path = gemini_dir.join("GEMINI.md"); + if remove_rtk_reference_from_file(&gemini_md_path, "GEMINI.md")? { + removed.push("GEMINI.md: removed @RTK.md reference".to_string()); + } + + // 3. Remove hook entry from Gemini settings.json + if remove_gemini_hook_from_settings(verbose)? { + removed.push("Gemini settings.json: removed RTK hook entry".to_string()); + } + + // Report results + if removed.is_empty() { + println!("RTK Gemini integration was not installed (nothing to remove)"); + } else { + println!("RTK Gemini integration uninstalled:"); + for item in removed { + println!(" - {}", item); + } + println!("\nRestart Gemini CLI to apply changes."); + } + + Ok(()) +} + +// ============================================================================ +// MULTI-PLATFORM ARCHITECTURE (Claude Code + Gemini CLI) +// ============================================================================ +// +// RTK supports both Claude Code and Gemini CLI via a DRY architecture that +// shares common logic while respecting protocol-specific differences. +// +// ## Shared Infrastructure (DRY) +// +// 1. patch_settings_shared() - Core settings.json patching logic +// 2. patch_instruction_file() - Add @RTK.md to CLAUDE.md / GEMINI.md +// 3. remove_rtk_reference_from_file() - Remove @RTK.md (for uninstall) +// 4. show_agent_hook_status() - Hook status verification +// 5. prompt_user_consent() - User confirmation prompt +// 6. atomic_write() / write_if_changed() - Safe file I/O +// 7. PatchMode / PatchResult enums - Behavior control and outcome reporting +// +// ## Symmetric Installation Workflow (Both Platforms) +// +// ### Claude Code +// - Create: ~/.claude/RTK.md +// - Patch: ~/.claude/CLAUDE.md (add @RTK.md) +// - Patch: ~/.claude/settings.json (PreToolUse hook) +// - Uninstall: Removes all 3 artifacts +// +// ### Gemini CLI +// - Create: ~/.gemini/RTK.md (same content) +// - Patch: ~/.gemini/GEMINI.md (add @RTK.md) +// - Patch: ~/.gemini/settings.json (BeforeTool hook) +// - Uninstall: Removes all 3 artifacts +// +// ## Protocol-Specific Differences (Settings.json Only) +// +// - Claude: Event=PreToolUse, Matcher=Bash, Command="rtk hook claude" +// - Gemini: Event=BeforeTool, Matcher=run_shell_command, Command="rtk hook gemini" +// +// These reflect API differences and cannot be unified. +// +// ## Default Behavior (as of v0.15.3) +// +// `rtk init` (no platform flags) → Sets up BOTH Claude and Gemini +// `rtk init --claude` → Claude only +// `rtk init --gemini` → Gemini only +// `rtk init --uninstall` → Remove both +// `rtk init --uninstall --claude` → Remove Claude only +// `rtk init --uninstall --gemini` → Remove Gemini only +// ============================================================================ + +/// Shared: patch a settings.json with an agent hook. +/// Reads/creates JSON, checks idempotency, handles PatchMode, inserts hook, +/// backs up, and atomically writes. +/// +/// Used by both Claude Code (patch_settings_json) and Gemini CLI (patch_gemini_settings). +fn patch_settings_shared( + settings_path: &Path, + is_present: impl Fn(&serde_json::Value) -> bool, + insert_hook: impl FnOnce(&mut serde_json::Value), + print_manual: impl Fn(), + mode: PatchMode, + label: &str, + restart_msg: &str, + verbose: u8, +) -> Result { // Read or create settings.json let mut root = if settings_path.exists() { - let content = fs::read_to_string(&settings_path) + let content = fs::read_to_string(settings_path) .with_context(|| format!("Failed to read {}", settings_path.display()))?; if content.trim().is_empty() { @@ -491,9 +532,9 @@ fn patch_settings_json(hook_path: &Path, mode: PatchMode, verbose: u8) -> Result }; // Check idempotency - if hook_already_present(&root, &hook_command) { + if is_present(&root) { if verbose > 0 { - eprintln!("settings.json: hook already present"); + eprintln!("{}: hook already present", label); } return Ok(PatchResult::AlreadyPresent); } @@ -501,27 +542,25 @@ fn patch_settings_json(hook_path: &Path, mode: PatchMode, verbose: u8) -> Result // Handle mode match mode { PatchMode::Skip => { - print_manual_instructions(hook_path); + print_manual(); return Ok(PatchResult::Skipped); } PatchMode::Ask => { - if !prompt_user_consent(&settings_path)? { - print_manual_instructions(hook_path); + if !prompt_user_consent(settings_path)? { + print_manual(); return Ok(PatchResult::Declined); } } - PatchMode::Auto => { - // Proceed without prompting - } + PatchMode::Auto => {} } // Deep-merge hook - insert_hook_entry(&mut root, &hook_command); + insert_hook(&mut root); // Backup original if settings_path.exists() { let backup_path = settings_path.with_extension("json.bak"); - fs::copy(&settings_path, &backup_path) + fs::copy(settings_path, &backup_path) .with_context(|| format!("Failed to backup to {}", backup_path.display()))?; if verbose > 0 { eprintln!("Backup: {}", backup_path.display()); @@ -531,20 +570,37 @@ fn patch_settings_json(hook_path: &Path, mode: PatchMode, verbose: u8) -> Result // Atomic write let serialized = serde_json::to_string_pretty(&root).context("Failed to serialize settings.json")?; - atomic_write(&settings_path, &serialized)?; + atomic_write(settings_path, &serialized)?; - println!("\n settings.json: hook added"); + println!("\n {}: hook added", label); if settings_path.with_extension("json.bak").exists() { println!( " Backup: {}", settings_path.with_extension("json.bak").display() ); } - println!(" Restart Claude Code. Test with: git status"); + println!(" {}", restart_msg); Ok(PatchResult::Patched) } +/// Patch Claude settings.json with RTK hook +fn patch_settings_json(mode: PatchMode, verbose: u8) -> Result { + let settings_path = resolve_claude_dir()?.join("settings.json"); + let hook_command = "rtk hook claude"; + + patch_settings_shared( + &settings_path, + |root| hook_already_present(root, hook_command), + |root| insert_hook_entry(root, hook_command), + || print_manual_instructions(hook_command), + mode, + "settings.json", + "Restart Claude Code. Test with: git status", + verbose, + ) +} + /// Clean up consecutive blank lines (collapse 3+ to 2) /// Used when removing @RTK.md line from CLAUDE.md fn clean_double_blanks(content: &str) -> String { @@ -558,7 +614,6 @@ fn clean_double_blanks(content: &str) -> String { if line.trim().is_empty() { // Count consecutive blank lines let mut blank_count = 0; - let start = i; while i < lines.len() && lines[i].trim().is_empty() { blank_count += 1; i += 1; @@ -566,9 +621,7 @@ fn clean_double_blanks(content: &str) -> String { // Keep at most 2 blank lines let keep = blank_count.min(2); - for _ in 0..keep { - result.push(""); - } + result.extend(std::iter::repeat_n("", keep)); } else { result.push(line); i += 1; @@ -615,7 +668,7 @@ fn insert_hook_entry(root: &mut serde_json::Value, hook_command: &str) { } /// Check if RTK hook is already present in settings.json -/// Matches on rtk-rewrite.sh substring to handle different path formats +/// Matches on rtk-rewrite.sh (legacy) or rtk hook claude (current) fn hook_already_present(root: &serde_json::Value, hook_command: &str) -> bool { let pre_tool_use_array = match root .get("hooks") @@ -632,9 +685,9 @@ fn hook_already_present(root: &serde_json::Value, hook_command: &str) -> bool { .flatten() .filter_map(|hook| hook.get("command")?.as_str()) .any(|cmd| { - // Exact match OR both contain rtk-rewrite.sh cmd == hook_command - || (cmd.contains("rtk-rewrite.sh") && hook_command.contains("rtk-rewrite.sh")) + || cmd.contains("rtk-rewrite.sh") // Legacy match for migration + || cmd.contains("rtk hook claude") // New direct binary invocation }) } @@ -658,29 +711,34 @@ fn run_default_mode(global: bool, patch_mode: PatchMode, verbose: u8) -> Result< let rtk_md_path = claude_dir.join("RTK.md"); let claude_md_path = claude_dir.join("CLAUDE.md"); - // 1. Prepare hook directory and install hook - let (_hook_dir, hook_path) = prepare_hook_paths()?; - ensure_hook_installed(&hook_path, verbose)?; - - // 2. Write RTK.md + // 1. Write RTK.md write_if_changed(&rtk_md_path, RTK_SLIM, "RTK.md", verbose)?; - // 3. Patch CLAUDE.md (add @RTK.md, migrate if needed) + // 2. Patch CLAUDE.md (add @RTK.md, migrate if needed) let migrated = patch_claude_md(&claude_md_path, verbose)?; - // 4. Print success message + // 3. Print success message println!("\nRTK hook installed (global).\n"); - println!(" Hook: {}", hook_path.display()); + println!(" Hook: rtk hook claude (direct binary)"); println!(" RTK.md: {} (10 lines)", rtk_md_path.display()); println!(" CLAUDE.md: @RTK.md reference added"); if migrated { - println!("\n ✅ Migrated: removed 137-line RTK block from CLAUDE.md"); - println!(" replaced with @RTK.md (10 lines)"); + println!("\n Migrated: removed 137-line RTK block from CLAUDE.md"); + println!(" replaced with @RTK.md (10 lines)"); + } + + // 4. Export default rules to ~/.config/rtk/ for discoverability + if let Err(e) = crate::config::export_rules(false) { + if verbose > 0 { + eprintln!(" Note: could not export default rules: {e}"); + } + } else { + println!(" Rules: ~/.config/rtk/rtk.*.md (customizable)"); } // 5. Patch settings.json - let patch_result = patch_settings_json(&hook_path, patch_mode, verbose)?; + let patch_result = patch_settings_json(patch_mode, verbose)?; // Report result match patch_result { @@ -715,18 +773,14 @@ fn run_hook_only_mode(global: bool, patch_mode: PatchMode, verbose: u8) -> Resul return Ok(()); } - // Prepare and install hook - let (_hook_dir, hook_path) = prepare_hook_paths()?; - ensure_hook_installed(&hook_path, verbose)?; - println!("\nRTK hook installed (hook-only mode).\n"); - println!(" Hook: {}", hook_path.display()); + println!(" Hook: rtk hook claude (direct binary)"); println!( " Note: No RTK.md created. Claude won't know about meta commands (gain, discover, proxy)." ); // Patch settings.json - let patch_result = patch_settings_json(&hook_path, patch_mode, verbose)?; + let patch_result = patch_settings_json(patch_mode, verbose)?; // Report result match patch_result { @@ -886,8 +940,11 @@ fn upsert_rtk_block(content: &str, block: &str) -> (String, RtkBlockUpsert) { } } -/// Patch CLAUDE.md: add @RTK.md, migrate if old block exists -fn patch_claude_md(path: &Path, verbose: u8) -> Result { +// --- patch_instruction_file: @RTK.md reference management --- + +/// Shared: Patch instruction file (CLAUDE.md or GEMINI.md) to add @RTK.md reference. +/// Migrates old RTK blocks if present. Returns true if migration occurred. +fn patch_instruction_file(path: &Path, file_label: &str, verbose: u8) -> Result { let mut content = if path.exists() { fs::read_to_string(path)? } else { @@ -903,7 +960,7 @@ fn patch_claude_md(path: &Path, verbose: u8) -> Result { content = new_content; migrated = true; if verbose > 0 { - eprintln!("Migrated: removed old RTK block from CLAUDE.md"); + eprintln!("Migrated: removed old RTK block from {}", file_label); } } } @@ -911,7 +968,7 @@ fn patch_claude_md(path: &Path, verbose: u8) -> Result { // Check if @RTK.md already present if content.contains("@RTK.md") { if verbose > 0 { - eprintln!("@RTK.md reference already present in CLAUDE.md"); + eprintln!("@RTK.md reference already present in {}", file_label); } if migrated { fs::write(path, content)?; @@ -929,13 +986,51 @@ fn patch_claude_md(path: &Path, verbose: u8) -> Result { fs::write(path, new_content)?; if verbose > 0 { - eprintln!("Added @RTK.md reference to CLAUDE.md"); + eprintln!("Added @RTK.md reference to {}", file_label); } Ok(migrated) } -/// Remove old RTK block from CLAUDE.md (migration helper) +/// Shared: Remove @RTK.md reference from an instruction file (CLAUDE.md or GEMINI.md). +/// Returns true if the reference was found and removed. +fn remove_rtk_reference_from_file(path: &Path, file_label: &str) -> Result { + if !path.exists() { + return Ok(false); + } + + let content = fs::read_to_string(path) + .with_context(|| format!("Failed to read {}: {}", file_label, path.display()))?; + + if !content.contains("@RTK.md") { + return Ok(false); + } + + let new_content = content + .lines() + .filter(|line| !line.trim().starts_with("@RTK.md")) + .collect::>() + .join("\n"); + + let cleaned = clean_double_blanks(&new_content); + + fs::write(path, cleaned) + .with_context(|| format!("Failed to write {}: {}", file_label, path.display()))?; + + Ok(true) +} + +/// Patch CLAUDE.md: add @RTK.md, migrate if old block exists +fn patch_claude_md(path: &Path, verbose: u8) -> Result { + patch_instruction_file(path, "CLAUDE.md", verbose) +} + +/// Patch GEMINI.md: add @RTK.md, migrate if old block exists +fn patch_gemini_md(path: &Path, verbose: u8) -> Result { + patch_instruction_file(path, "GEMINI.md", verbose) +} + +/// Remove old RTK block from CLAUDE.md or GEMINI.md (migration helper) fn remove_rtk_block(content: &str) -> (String, bool) { if let (Some(start), Some(end)) = ( content.find("