From 730f31c97777de3e8f80217a4f8a84f675305fc9 Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Mon, 16 Feb 2026 18:58:45 -0500 Subject: [PATCH 1/5] feat: add Rust-based hook engine with chained command rewriting Rust binary replaces 204-line bash script as Claude Code PreToolUse hook. Adds rtk hook claude, rtk run -c, and Windows support via cfg!(windows). Closes #112 (chained commands missed). Based on updated master (70c3786) which includes: - Hook audit mode (#151) - Claude Code agents and skills (d8f4659) - tee raw output feature (#134) Migrated from feat/rust-hooks (571bd86) with conflict resolution for: - src/main.rs: Commands enum (preserved both hook audit + our hook commands) - src/init.rs: Hook registration (integrated both approaches) New files (src/cmd/ module): - mod.rs: Module declarations (10 modules, excluding safety/trash/gemini for PR 1) - hook.rs: Shared hook decision logic (21 tests, 3 safety tests removed for PR 2) - claude_hook.rs: Claude Code JSON protocol handler (18 tests) - lexer.rs: Quote-aware tokenizer (28 tests) - analysis.rs: Chain parsing and shellism detection (10 tests) - builtins.rs: cd/export/pwd/echo/true/false (8 tests) - exec.rs: Command executor with recursion guard (22 tests, safety dispatch removed for PR 2) - filters.rs: Output filter registry (5 tests) - predicates.rs: Context predicates (4 tests) - test_helpers.rs: Test utilities Modified files: - src/main.rs: Added Commands::Run, Commands::Hook, HookCommands enum, routing - src/init.rs: Changed patch_settings_json to use rtk hook claude binary command - hooks/rtk-rewrite.sh: Replaced 204-line bash script with 4-line shim (exec rtk hook claude) - Cargo.toml: Added which = 7 for PATH resolution - INSTALL.md: Added Windows installation section Windows support: - exec.rs:175-176: cfg!(windows) selects cmd /C vs sh -c for shell passthrough - predicates.rs:26: USERPROFILE fallback for Windows home directory - No bash, node, or bun dependency - rtk hook claude is a compiled Rust binary Tests: All 541 tests pass --- Cargo.lock | 31 +++ Cargo.toml | 1 + INSTALL.md | 28 ++ hooks/rtk-rewrite.sh | 213 +-------------- src/cmd/analysis.rs | 249 ++++++++++++++++++ src/cmd/builtins.rs | 246 +++++++++++++++++ src/cmd/claude_hook.rs | 506 +++++++++++++++++++++++++++++++++++ src/cmd/exec.rs | 426 ++++++++++++++++++++++++++++++ src/cmd/filters.rs | 212 +++++++++++++++ src/cmd/hook.rs | 569 ++++++++++++++++++++++++++++++++++++++++ src/cmd/lexer.rs | 474 +++++++++++++++++++++++++++++++++ src/cmd/mod.rs | 33 +++ src/cmd/predicates.rs | 94 +++++++ src/cmd/test_helpers.rs | 35 +++ src/init.rs | 20 +- src/main.rs | 53 ++++ 16 files changed, 2969 insertions(+), 221 deletions(-) create mode 100644 src/cmd/analysis.rs create mode 100644 src/cmd/builtins.rs create mode 100644 src/cmd/claude_hook.rs create mode 100644 src/cmd/exec.rs create mode 100644 src/cmd/filters.rs create mode 100644 src/cmd/hook.rs create mode 100644 src/cmd/lexer.rs create mode 100644 src/cmd/mod.rs create mode 100644 src/cmd/predicates.rs create mode 100644 src/cmd/test_helpers.rs diff --git a/Cargo.lock b/Cargo.lock index 6fa4eb0..f9baf4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,6 +253,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "env_home" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" + [[package]] name = "equivalent" version = "1.0.2" @@ -598,6 +610,7 @@ dependencies = [ "thiserror", "toml", "walkdir", + "which", ] [[package]] @@ -892,6 +905,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "which" +version = "7.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" +dependencies = [ + "either", + "env_home", + "rustix", + "winsafe", +] + [[package]] name = "winapi-util" version = "0.1.11" @@ -1117,6 +1142,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "winsafe" +version = "0.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index af6896b..4f6809f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ rusqlite = { version = "0.31", features = ["bundled"] } toml = "0.8" chrono = "0.4" thiserror = "1.0" +which = "7" tempfile = "3" [dev-dependencies] diff --git a/INSTALL.md b/INSTALL.md index 55b32fd..6de2e31 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -67,6 +67,34 @@ rtk gain # MUST show token savings, not "command not found" ⚠️ **WARNING**: `cargo install rtk` from crates.io might install the wrong package. Always verify with `rtk gain`. +### Windows Installation + +RTK compiles natively for Windows. Download from GitHub Releases or: + +```bash +cargo install --path . +``` + +Configure Claude Code `settings.json`: + +```json +{ + "hooks": { + "PreToolUse": [{ + "matcher": "Bash", + "hooks": [{ + "type": "command", + "command": "rtk hook claude" + }] + }] + } +} +``` + +No bash, node, or bun required — `rtk hook claude` is a native Windows binary. +The exec.rs shell selection (`cfg!(windows)`) automatically uses `cmd /C` on Windows +and `sh -c` on Unix for passthrough commands. + ## Project Initialization ### Recommended: Global Hook-First Setup diff --git a/hooks/rtk-rewrite.sh b/hooks/rtk-rewrite.sh index 59e02ca..f895725 100644 --- a/hooks/rtk-rewrite.sh +++ b/hooks/rtk-rewrite.sh @@ -1,209 +1,4 @@ -#!/bin/bash -# RTK auto-rewrite hook for Claude Code PreToolUse:Bash -# Transparently rewrites raw commands to their rtk equivalents. -# Outputs JSON with updatedInput to modify the command before execution. - -# Guards: skip silently if dependencies missing -if ! command -v rtk &>/dev/null || ! command -v jq &>/dev/null; then - exit 0 -fi - -set -euo pipefail - -INPUT=$(cat) -CMD=$(echo "$INPUT" | jq -r '.tool_input.command // empty') - -if [ -z "$CMD" ]; then - exit 0 -fi - -# Extract the first meaningful command (before pipes, &&, etc.) -# We only rewrite if the FIRST command in a chain matches. -FIRST_CMD="$CMD" - -# Skip if already using rtk -case "$FIRST_CMD" in - rtk\ *|*/rtk\ *) exit 0 ;; -esac - -# Skip commands with heredocs, variable assignments as the whole command, etc. -case "$FIRST_CMD" in - *'<<'*) exit 0 ;; -esac - -# Strip leading env var assignments for pattern matching -# e.g., "TEST_SESSION_ID=2 npx playwright test" → match against "npx playwright test" -# but preserve them in the rewritten command for execution. -ENV_PREFIX=$(echo "$FIRST_CMD" | grep -oE '^([A-Za-z_][A-Za-z0-9_]*=[^ ]* +)+' || echo "") -if [ -n "$ENV_PREFIX" ]; then - MATCH_CMD="${FIRST_CMD:${#ENV_PREFIX}}" - CMD_BODY="${CMD:${#ENV_PREFIX}}" -else - MATCH_CMD="$FIRST_CMD" - CMD_BODY="$CMD" -fi - -REWRITTEN="" - -# --- Git commands --- -if echo "$MATCH_CMD" | grep -qE '^git[[:space:]]'; then - GIT_SUBCMD=$(echo "$MATCH_CMD" | sed -E \ - -e 's/^git[[:space:]]+//' \ - -e 's/(-C|-c)[[:space:]]+[^[:space:]]+[[:space:]]*//g' \ - -e 's/--[a-z-]+=[^[:space:]]+[[:space:]]*//g' \ - -e 's/--(no-pager|no-optional-locks|bare|literal-pathspecs)[[:space:]]*//g' \ - -e 's/^[[:space:]]+//') - case "$GIT_SUBCMD" in - status|status\ *|diff|diff\ *|log|log\ *|add|add\ *|commit|commit\ *|push|push\ *|pull|pull\ *|branch|branch\ *|fetch|fetch\ *|stash|stash\ *|show|show\ *) - REWRITTEN="${ENV_PREFIX}rtk $CMD_BODY" - ;; - esac - -# --- GitHub CLI (added: api, release) --- -elif echo "$MATCH_CMD" | grep -qE '^gh[[:space:]]+(pr|issue|run|api|release)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^gh /rtk gh /')" - -# --- Cargo --- -elif echo "$MATCH_CMD" | grep -qE '^cargo[[:space:]]'; then - CARGO_SUBCMD=$(echo "$MATCH_CMD" | sed -E 's/^cargo[[:space:]]+(\+[^[:space:]]+[[:space:]]+)?//') - case "$CARGO_SUBCMD" in - test|test\ *|build|build\ *|clippy|clippy\ *|check|check\ *|install|install\ *|fmt|fmt\ *) - REWRITTEN="${ENV_PREFIX}rtk $CMD_BODY" - ;; - esac - -# --- File operations --- -elif echo "$MATCH_CMD" | grep -qE '^cat[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^cat /rtk read /')" -elif echo "$MATCH_CMD" | grep -qE '^(rg|grep)[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(rg|grep) /rtk grep /')" -elif echo "$MATCH_CMD" | grep -qE '^ls([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^ls/rtk ls/')" -elif echo "$MATCH_CMD" | grep -qE '^tree([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^tree/rtk tree/')" -elif echo "$MATCH_CMD" | grep -qE '^find[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^find /rtk find /')" -elif echo "$MATCH_CMD" | grep -qE '^diff[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^diff /rtk diff /')" -elif echo "$MATCH_CMD" | grep -qE '^head[[:space:]]+'; then - # Transform: head -N file → rtk read file --max-lines N - # Also handle: head --lines=N file - if echo "$MATCH_CMD" | grep -qE '^head[[:space:]]+-[0-9]+[[:space:]]+'; then - LINES=$(echo "$MATCH_CMD" | sed -E 's/^head +-([0-9]+) +.+$/\1/') - FILE=$(echo "$MATCH_CMD" | sed -E 's/^head +-[0-9]+ +(.+)$/\1/') - REWRITTEN="${ENV_PREFIX}rtk read $FILE --max-lines $LINES" - elif echo "$MATCH_CMD" | grep -qE '^head[[:space:]]+--lines=[0-9]+[[:space:]]+'; then - LINES=$(echo "$MATCH_CMD" | sed -E 's/^head +--lines=([0-9]+) +.+$/\1/') - FILE=$(echo "$MATCH_CMD" | sed -E 's/^head +--lines=[0-9]+ +(.+)$/\1/') - REWRITTEN="${ENV_PREFIX}rtk read $FILE --max-lines $LINES" - fi - -# --- JS/TS tooling (added: npm run, npm test, vue-tsc) --- -elif echo "$MATCH_CMD" | grep -qE '^(pnpm[[:space:]]+)?(npx[[:space:]]+)?vitest([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(pnpm )?(npx )?vitest( run)?/rtk vitest run/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+test([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm test/rtk vitest run/')" -elif echo "$MATCH_CMD" | grep -qE '^npm[[:space:]]+test([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^npm test/rtk npm test/')" -elif echo "$MATCH_CMD" | grep -qE '^npm[[:space:]]+run[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^npm run /rtk npm /')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?vue-tsc([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?vue-tsc/rtk tsc/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+tsc([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm tsc/rtk tsc/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?tsc([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?tsc/rtk tsc/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+lint([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm lint/rtk lint/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?eslint([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?eslint/rtk lint/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?prettier([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?prettier/rtk prettier/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?playwright([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?playwright/rtk playwright/')" -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+playwright([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm playwright/rtk playwright/')" -elif echo "$MATCH_CMD" | grep -qE '^(npx[[:space:]]+)?prisma([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed -E 's/^(npx )?prisma/rtk prisma/')" - -# --- Containers (added: docker compose, docker run/build/exec, kubectl describe/apply) --- -elif echo "$MATCH_CMD" | grep -qE '^docker[[:space:]]'; then - if echo "$MATCH_CMD" | grep -qE '^docker[[:space:]]+compose([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^docker /rtk docker /')" - else - DOCKER_SUBCMD=$(echo "$MATCH_CMD" | sed -E \ - -e 's/^docker[[:space:]]+//' \ - -e 's/(-H|--context|--config)[[:space:]]+[^[:space:]]+[[:space:]]*//g' \ - -e 's/--[a-z-]+=[^[:space:]]+[[:space:]]*//g' \ - -e 's/^[[:space:]]+//') - case "$DOCKER_SUBCMD" in - ps|ps\ *|images|images\ *|logs|logs\ *|run|run\ *|build|build\ *|exec|exec\ *) - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^docker /rtk docker /')" - ;; - esac - fi -elif echo "$MATCH_CMD" | grep -qE '^kubectl[[:space:]]'; then - KUBE_SUBCMD=$(echo "$MATCH_CMD" | sed -E \ - -e 's/^kubectl[[:space:]]+//' \ - -e 's/(--context|--kubeconfig|--namespace|-n)[[:space:]]+[^[:space:]]+[[:space:]]*//g' \ - -e 's/--[a-z-]+=[^[:space:]]+[[:space:]]*//g' \ - -e 's/^[[:space:]]+//') - case "$KUBE_SUBCMD" in - get|get\ *|logs|logs\ *|describe|describe\ *|apply|apply\ *) - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^kubectl /rtk kubectl /')" - ;; - esac - -# --- Network --- -elif echo "$MATCH_CMD" | grep -qE '^curl[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^curl /rtk curl /')" -elif echo "$MATCH_CMD" | grep -qE '^wget[[:space:]]+'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^wget /rtk wget /')" - -# --- pnpm package management --- -elif echo "$MATCH_CMD" | grep -qE '^pnpm[[:space:]]+(list|ls|outdated)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pnpm /rtk pnpm /')" - -# --- Python tooling --- -elif echo "$MATCH_CMD" | grep -qE '^pytest([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pytest/rtk pytest/')" -elif echo "$MATCH_CMD" | grep -qE '^python[[:space:]]+-m[[:space:]]+pytest([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^python -m pytest/rtk pytest/')" -elif echo "$MATCH_CMD" | grep -qE '^ruff[[:space:]]+(check|format)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^ruff /rtk ruff /')" -elif echo "$MATCH_CMD" | grep -qE '^pip[[:space:]]+(list|outdated|install|show)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^pip /rtk pip /')" -elif echo "$MATCH_CMD" | grep -qE '^uv[[:space:]]+pip[[:space:]]+(list|outdated|install|show)([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^uv pip /rtk pip /')" - -# --- Go tooling --- -elif echo "$MATCH_CMD" | grep -qE '^go[[:space:]]+test([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^go test/rtk go test/')" -elif echo "$MATCH_CMD" | grep -qE '^go[[:space:]]+build([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^go build/rtk go build/')" -elif echo "$MATCH_CMD" | grep -qE '^go[[:space:]]+vet([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^go vet/rtk go vet/')" -elif echo "$MATCH_CMD" | grep -qE '^golangci-lint([[:space:]]|$)'; then - REWRITTEN="${ENV_PREFIX}$(echo "$CMD_BODY" | sed 's/^golangci-lint/rtk golangci-lint/')" -fi - -# If no rewrite needed, approve as-is -if [ -z "$REWRITTEN" ]; then - exit 0 -fi - -# Build the updated tool_input with all original fields preserved, only command changed -ORIGINAL_INPUT=$(echo "$INPUT" | jq -c '.tool_input') -UPDATED_INPUT=$(echo "$ORIGINAL_INPUT" | jq --arg cmd "$REWRITTEN" '.command = $cmd') - -# Output the rewrite instruction -jq -n \ - --argjson updated "$UPDATED_INPUT" \ - '{ - "hookSpecificOutput": { - "hookEventName": "PreToolUse", - "permissionDecision": "allow", - "permissionDecisionReason": "RTK auto-rewrite", - "updatedInput": $updated - } - }' +#!/bin/sh +# Legacy shim — actual hook logic is in the rtk binary. +# Direct usage: rtk hook claude (reads JSON from stdin) +exec rtk hook claude diff --git a/src/cmd/analysis.rs b/src/cmd/analysis.rs new file mode 100644 index 0000000..1d816e2 --- /dev/null +++ b/src/cmd/analysis.rs @@ -0,0 +1,249 @@ +//! Analyzes tokens to decide: Native execution or Passthrough? + +use super::lexer::{strip_quotes, ParsedToken, TokenKind}; + +/// Represents a single command in a chain +#[derive(Debug, Clone, PartialEq)] +pub struct NativeCommand { + pub binary: String, + pub args: Vec, + pub operator: Option, // &&, ||, ;, or None for last command +} + +/// Check if command needs real shell (has shellisms, pipes, redirects) +pub fn needs_shell(tokens: &[ParsedToken]) -> bool { + tokens.iter().any(|t| { + matches!( + t.kind, + TokenKind::Shellism | TokenKind::Pipe | TokenKind::Redirect + ) + }) +} + +/// Parse tokens into native command chain +/// Returns error if syntax is invalid (e.g., operator with no preceding command) +pub fn parse_chain(tokens: Vec) -> Result, String> { + let mut commands = Vec::new(); + let mut current_args = Vec::new(); + + for token in tokens { + match token.kind { + TokenKind::Arg => { + // Strip quotes from the argument + current_args.push(strip_quotes(&token.value)); + } + TokenKind::Operator => { + if current_args.is_empty() { + return Err(format!( + "Syntax error: operator {} with no command", + token.value + )); + } + // First arg is the binary, rest are args + let binary = current_args.remove(0); + commands.push(NativeCommand { + binary, + args: current_args.clone(), + operator: Some(token.value.clone()), + }); + current_args.clear(); + } + TokenKind::Pipe | TokenKind::Redirect | TokenKind::Shellism => { + // Should not reach here if needs_shell() was checked first + // But handle gracefully + return Err(format!( + "Unexpected {:?} in native mode - use passthrough", + token.kind + )); + } + } + } + + // Handle last command (no trailing operator) + if !current_args.is_empty() { + let binary = current_args.remove(0); + commands.push(NativeCommand { + binary, + args: current_args, + operator: None, + }); + } + + Ok(commands) +} + +/// Should the next command run based on operator and last result? +pub fn should_run(operator: Option<&str>, last_success: bool) -> bool { + match operator { + Some("&&") => last_success, + Some("||") => !last_success, + Some(";") | None => true, + _ => true, // Unknown operator, just run + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cmd::lexer::tokenize; + + // === NEEDS_SHELL TESTS === + + #[test] + fn test_needs_shell_simple() { + let tokens = tokenize("git status"); + assert!(!needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_glob() { + let tokens = tokenize("ls *.rs"); + assert!(needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_pipe() { + let tokens = tokenize("cat file | grep x"); + assert!(needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_redirect() { + let tokens = tokenize("cmd > file"); + assert!(needs_shell(&tokens)); + } + + #[test] + fn test_needs_shell_with_chain() { + let tokens = tokenize("cd dir && git status"); + // && is an Operator, not a Shellism - should NOT need shell + assert!(!needs_shell(&tokens)); + } + + // === PARSE_CHAIN TESTS === + + #[test] + fn test_parse_simple_command() { + let tokens = tokenize("git status"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].binary, "git"); + assert_eq!(cmds[0].args, vec!["status"]); + assert_eq!(cmds[0].operator, None); + } + + #[test] + fn test_parse_command_with_multiple_args() { + let tokens = tokenize("git commit -m message"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].binary, "git"); + assert_eq!(cmds[0].args, vec!["commit", "-m", "message"]); + } + + #[test] + fn test_parse_chained_and() { + let tokens = tokenize("cd dir && git status"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 2); + assert_eq!(cmds[0].binary, "cd"); + assert_eq!(cmds[0].args, vec!["dir"]); + assert_eq!(cmds[0].operator, Some("&&".to_string())); + assert_eq!(cmds[1].binary, "git"); + assert_eq!(cmds[1].args, vec!["status"]); + assert_eq!(cmds[1].operator, None); + } + + #[test] + fn test_parse_chained_or() { + let tokens = tokenize("cmd1 || cmd2"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 2); + assert_eq!(cmds[0].operator, Some("||".to_string())); + } + + #[test] + fn test_parse_chained_semicolon() { + let tokens = tokenize("cmd1 ; cmd2 ; cmd3"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 3); + assert_eq!(cmds[0].operator, Some(";".to_string())); + assert_eq!(cmds[1].operator, Some(";".to_string())); + assert_eq!(cmds[2].operator, None); + } + + #[test] + fn test_parse_triple_chain() { + let tokens = tokenize("a && b && c"); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 3); + } + + #[test] + fn test_parse_operator_at_start() { + let tokens = tokenize("&& cmd"); + let result = parse_chain(tokens); + assert!(result.is_err()); + } + + #[test] + fn test_parse_operator_at_end() { + let tokens = tokenize("cmd &&"); + let cmds = parse_chain(tokens).unwrap(); + // cmd is parsed, && triggers flush but no second command + assert_eq!(cmds.len(), 1); + assert_eq!(cmds[0].operator, Some("&&".to_string())); + } + + #[test] + fn test_parse_quoted_arg() { + let tokens = tokenize("git commit -m \"Fix && Bug\""); + let cmds = parse_chain(tokens).unwrap(); + assert_eq!(cmds.len(), 1); + // The && inside quotes should be in the arg, not an operator + // args are: commit, -m, "Fix && Bug" + assert_eq!(cmds[0].args.len(), 3); + assert_eq!(cmds[0].args[2], "Fix && Bug"); + } + + #[test] + fn test_parse_empty() { + let tokens = tokenize(""); + let cmds = parse_chain(tokens).unwrap(); + assert!(cmds.is_empty()); + } + + // === SHOULD_RUN TESTS === + + #[test] + fn test_should_run_and_success() { + assert!(should_run(Some("&&"), true)); + } + + #[test] + fn test_should_run_and_failure() { + assert!(!should_run(Some("&&"), false)); + } + + #[test] + fn test_should_run_or_success() { + assert!(!should_run(Some("||"), true)); + } + + #[test] + fn test_should_run_or_failure() { + assert!(should_run(Some("||"), false)); + } + + #[test] + fn test_should_run_semicolon() { + assert!(should_run(Some(";"), true)); + assert!(should_run(Some(";"), false)); + } + + #[test] + fn test_should_run_none() { + assert!(should_run(None, true)); + assert!(should_run(None, false)); + } +} diff --git a/src/cmd/builtins.rs b/src/cmd/builtins.rs new file mode 100644 index 0000000..fa11be6 --- /dev/null +++ b/src/cmd/builtins.rs @@ -0,0 +1,246 @@ +//! Built-in commands that RTK handles natively. +//! These maintain session state across hook calls. + +use super::predicates::{expand_tilde, get_home}; +use anyhow::{Context, Result}; + +/// Change directory (persists in RTK process) +pub fn builtin_cd(args: &[String]) -> Result { + let target = args + .first() + .map(|s| expand_tilde(s)) + .unwrap_or_else(get_home); + + std::env::set_current_dir(&target) + .with_context(|| format!("cd: {}: No such file or directory", target))?; + + Ok(true) +} + +/// Export environment variable +pub fn builtin_export(args: &[String]) -> Result { + for arg in args { + if let Some((key, value)) = arg.split_once('=') { + // Handle quoted values: export FOO="bar baz" + let clean_value = value + .strip_prefix('"') + .and_then(|v| v.strip_suffix('"')) + .or_else(|| value.strip_prefix('\'').and_then(|v| v.strip_suffix('\''))) + .unwrap_or(value); + std::env::set_var(key, clean_value); + } + } + Ok(true) +} + +/// Check if a binary is a builtin +pub fn is_builtin(binary: &str) -> bool { + matches!( + binary, + "cd" | "export" | "pwd" | "echo" | "true" | "false" | ":" + ) +} + +/// Execute a builtin command +pub fn execute(binary: &str, args: &[String]) -> Result { + match binary { + "cd" => builtin_cd(args), + "export" => builtin_export(args), + "pwd" => { + println!("{}", std::env::current_dir()?.display()); + Ok(true) + } + "echo" => { + let (print_args, no_newline) = if args.first().map(|s| s.as_str()) == Some("-n") { + (&args[1..], true) + } else { + (args, false) + }; + print!("{}", print_args.join(" ")); + if !no_newline { + println!(); + } + Ok(true) + } + "true" | ":" => Ok(true), + "false" => Ok(false), + _ => anyhow::bail!("Unknown builtin: {}", binary), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + // === CD TESTS === + // Consolidated into one test: cwd is process-global, so parallel tests race. + + #[test] + fn test_cd_all_cases() { + let original = env::current_dir().unwrap(); + let home = get_home(); + + // 1. cd to existing dir + let result = builtin_cd(&["/tmp".to_string()]).unwrap(); + assert!(result); + let new_dir = env::current_dir().unwrap(); + // On macOS /tmp symlinks to /private/tmp — canonicalize both sides + let canon_tmp = std::fs::canonicalize("/tmp").unwrap(); + let canon_new = std::fs::canonicalize(&new_dir).unwrap(); + assert_eq!(canon_new, canon_tmp, "cd /tmp should land in /tmp"); + + // 2. cd to nonexistent dir + let result = builtin_cd(&["/nonexistent/path/xyz".to_string()]); + assert!(result.is_err()); + // cwd unchanged after failed cd + assert_eq!( + std::fs::canonicalize(env::current_dir().unwrap()).unwrap(), + canon_tmp + ); + + // 3. cd with no args → home + let result = builtin_cd(&[]).unwrap(); + assert!(result); + let cwd = env::current_dir().unwrap(); + let canon_home = std::fs::canonicalize(&home).unwrap(); + let canon_cwd = std::fs::canonicalize(&cwd).unwrap(); + assert_eq!(canon_cwd, canon_home, "cd with no args should go home"); + + // 4. cd ~ → home + let _ = env::set_current_dir("/tmp"); + let result = builtin_cd(&["~".to_string()]).unwrap(); + assert!(result); + let cwd = std::fs::canonicalize(env::current_dir().unwrap()).unwrap(); + assert_eq!(cwd, canon_home, "cd ~ should go home"); + + // 5. cd ~/nonexistent-subpath — may fail, just verify no panic + let _ = builtin_cd(&["~/nonexistent_rtk_test_subpath_xyz".to_string()]); + + // Restore original cwd + let _ = env::set_current_dir(&original); + } + + // === EXPORT TESTS === + + #[test] + fn test_export_simple() { + builtin_export(&["RTK_TEST_SIMPLE=value".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_SIMPLE").unwrap(), "value"); + env::remove_var("RTK_TEST_SIMPLE"); + } + + #[test] + fn test_export_with_equals_in_value() { + builtin_export(&["RTK_TEST_EQUALS=key=value".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_EQUALS").unwrap(), "key=value"); + env::remove_var("RTK_TEST_EQUALS"); + } + + #[test] + fn test_export_quoted_value() { + builtin_export(&["RTK_TEST_QUOTED=\"hello world\"".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_QUOTED").unwrap(), "hello world"); + env::remove_var("RTK_TEST_QUOTED"); + } + + #[test] + fn test_export_multiple() { + builtin_export(&["RTK_TEST_A=1".to_string(), "RTK_TEST_B=2".to_string()]).unwrap(); + assert_eq!(env::var("RTK_TEST_A").unwrap(), "1"); + assert_eq!(env::var("RTK_TEST_B").unwrap(), "2"); + env::remove_var("RTK_TEST_A"); + env::remove_var("RTK_TEST_B"); + } + + #[test] + fn test_export_no_equals() { + // Should be silently ignored (like bash) + let result = builtin_export(&["NO_EQUALS_HERE".to_string()]).unwrap(); + assert!(result); + } + + // === IS_BUILTIN TESTS === + + #[test] + fn test_is_builtin_cd() { + assert!(is_builtin("cd")); + } + + #[test] + fn test_is_builtin_export() { + assert!(is_builtin("export")); + } + + #[test] + fn test_is_builtin_pwd() { + assert!(is_builtin("pwd")); + } + + #[test] + fn test_is_builtin_echo() { + assert!(is_builtin("echo")); + } + + #[test] + fn test_is_builtin_true() { + assert!(is_builtin("true")); + } + + #[test] + fn test_is_builtin_false() { + assert!(is_builtin("false")); + } + + #[test] + fn test_is_builtin_external() { + assert!(!is_builtin("git")); + assert!(!is_builtin("ls")); + assert!(!is_builtin("cargo")); + } + + // === EXECUTE TESTS === + + #[test] + fn test_execute_pwd() { + let result = execute("pwd", &[]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_echo() { + let result = execute("echo", &["hello".to_string(), "world".to_string()]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_true() { + let result = execute("true", &[]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_false() { + let result = execute("false", &[]).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_unknown_builtin() { + let result = execute("notabuiltin", &[]); + assert!(result.is_err()); + } + + #[test] + fn test_execute_echo_n_flag() { + // echo -n should succeed (prints without newline) + let result = execute("echo", &["-n".to_string(), "hello".to_string()]).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_echo_empty_args() { + let result = execute("echo", &[]).unwrap(); + assert!(result); + } +} diff --git a/src/cmd/claude_hook.rs b/src/cmd/claude_hook.rs new file mode 100644 index 0000000..1918c16 --- /dev/null +++ b/src/cmd/claude_hook.rs @@ -0,0 +1,506 @@ +//! Claude Code PreToolUse hook protocol handler. +//! +//! Reads JSON from stdin, applies safety checks and rewrites, +//! outputs JSON to stdout. +//! +//! Protocol: https://docs.anthropic.com/en/docs/claude-code/hooks +//! +//! ## Exit Code Behavior +//! +//! - Exit 0 = success (allow/rewrite) — tool proceeds +//! - Exit 2 = blocking error (deny) — tool rejected +//! +//! ## Claude Code Stderr Rule (CRITICAL) +//! +//! **Source:** See `/Users/athundt/.claude/clautorun/.worktrees/claude-stable-pre-v0.8.0/notes/hooks_api_reference.md:720-728` +//! +//! ```text +//! CRITICAL: ANY stderr output at exit 0 = hook error = fail-open +//! ``` +//! +//! **Implication:** +//! - Exit 0 + ANY stderr → Claude Code treats hook as FAILED → tool executes anyway (fail-open) +//! - Exit 2 + stderr → Claude Code treats stderr as the block reason → tool blocked, AI sees reason +//! +//! **This module's stderr usage:** +//! - ✅ Exit 0 paths (NoOpinion, Allow): **NEVER write to stderr** +//! - ✅ Exit 2 path (Deny): **stderr ONLY** for bug #4669 workaround (see below) +//! +//! ## Bug #4669 Workaround (Dual-Path Deny) +//! +//! **Issue:** https://github.com/anthropics/claude-code/issues/4669 +//! **Versions:** v1.0.62+ through current (not fixed) +//! **Problem:** `permissionDecision: "deny"` at exit 0 is IGNORED — tool executes anyway +//! +//! **Workaround:** +//! ```text +//! stdout: JSON with permissionDecision "deny" (documented main path, but broken) +//! stderr: plain text reason (fallback path that actually works) +//! exit code: 2 (triggers Claude Code to read stderr as error) +//! ``` +//! +//! This ensures deny works regardless of which path Claude Code processes. +//! +//! ## I/O Enforcement (Module-Specific) +//! +//! **This restriction applies ONLY to claude_hook.rs and gemini_hook.rs.** +//! All other RTK modules (main.rs, git.rs, etc.) use `println!`/`eprintln!` normally. +//! +//! **Why restricted here:** +//! - Hook protocol requires JSON-only stdout +//! - Claude Code's "ANY stderr = hook error" rule (see above) +//! - Accidental prints corrupt the JSON protocol +//! +//! **Enforcement mechanism:** +//! - `#![deny(clippy::print_stdout, clippy::print_stderr)]` at module level (line 52) +//! - `run_inner()` returns `HookResponse` enum — pure logic, no I/O +//! - `run()` is the ONLY function that writes output — single I/O point +//! - Uses `write!`/`writeln!` which are NOT caught by the clippy lint +//! +//! **Pathway:** main.rs → Commands::Hook → claude_hook::run() [DENY ENFORCED HERE] +//! +//! Fail-open: Any parse error or unexpected input → exit 0, no output. + +// Compile-time I/O enforcement for THIS MODULE ONLY. +// Other RTK modules (main.rs, git.rs, etc.) use println!/eprintln! normally. +// +// Why restrict here: +// - Claude Code hook protocol requires JSON-only stdout +// - Claude Code rule: "ANY stderr at exit 0 = hook error = fail-open" +// (Source: clautorun hooks_api_reference.md:720-728) +// - Accidental prints would corrupt the JSON response +// +// Mechanism: +// - Denies println!/eprintln! at compile-time +// - Allows write!/writeln! (used only in run() for controlled output) +// - run_inner() returns HookResponse (no I/O) +// - run() is the single I/O point +#![deny(clippy::print_stdout, clippy::print_stderr)] + +use super::hook::{ + check_for_hook, is_hook_disabled, should_passthrough, update_command_in_tool_input, + HookResponse, HookResult, +}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::io::{self, Read, Write}; + +// --- Wire format structs (field names must match Claude Code spec exactly) --- + +#[derive(Deserialize)] +pub(crate) struct ClaudePayload { + tool_input: Option, + // Claude Code also sends: tool_name, session_id, session_cwd, + // transcript_path — serde silently ignores unknown fields. + // The settings.json matcher already filters to Bash-only events. +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ClaudeResponse { + hook_specific_output: HookOutput, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct HookOutput { + hook_event_name: &'static str, + permission_decision: &'static str, + permission_decision_reason: String, + #[serde(skip_serializing_if = "Option::is_none")] + updated_input: Option, +} + +// --- Guard logic (extracted for testability) --- + +/// Extract the command string from a parsed payload. +/// Returns None if payload has no tool_input or no command field. +pub(crate) fn extract_command(payload: &ClaudePayload) -> Option<&str> { + payload + .tool_input + .as_ref()? + .get("command")? + .as_str() + .filter(|s| !s.is_empty()) +} + +// Guard functions `is_hook_disabled()` and `should_passthrough()` are shared +// with gemini_hook.rs via hook.rs to avoid duplication (DRY). + +/// Build a ClaudeResponse for an allowed/rewritten command. +pub(crate) fn allow_response(reason: String, updated_input: Option) -> ClaudeResponse { + ClaudeResponse { + hook_specific_output: HookOutput { + hook_event_name: "PreToolUse", + permission_decision: "allow", + permission_decision_reason: reason, + updated_input, + }, + } +} + +/// Build a ClaudeResponse for a blocked command. +pub(crate) fn deny_response(reason: String) -> ClaudeResponse { + ClaudeResponse { + hook_specific_output: HookOutput { + hook_event_name: "PreToolUse", + permission_decision: "deny", + permission_decision_reason: reason, + updated_input: None, + }, + } +} + +// --- Entry point --- + +/// Run the Claude Code hook handler. +/// +/// This is the ONLY function that performs I/O (stdout/stderr). +/// `run_inner()` returns a `HookResponse` enum — pure logic, no I/O. +/// Combined with `#![deny(clippy::print_stdout, clippy::print_stderr)]`, +/// this ensures no stray output corrupts the JSON hook protocol. +/// +/// Fail-open design: malformed input → exit 0, no output. +/// Claude Code interprets this as "no opinion" and proceeds normally. +pub fn run() -> anyhow::Result<()> { + // Fail-open: wrap entire handler so ANY error → exit 0 (no opinion). + let response = match run_inner() { + Ok(r) => r, + Err(_) => HookResponse::NoOpinion, // Fail-open: swallow errors + }; + + // ┌────────────────────────────────────────────────────────────────┐ + // │ SINGLE I/O POINT - All stdout/stderr output happens here only │ + // │ │ + // │ Why: Claude Code rule "ANY stderr at exit 0 = hook error" │ + // │ (Source: hooks_api_reference.md:720-728) │ + // │ │ + // │ Enforcement: #![deny(...)] at line 52 prevents println!/eprintln! │ + // │ write!/writeln! are not caught by lint (allowed) │ + // └────────────────────────────────────────────────────────────────┘ + match response { + HookResponse::NoOpinion => { + // Exit 0, NO stdout, NO stderr + // Claude Code sees no output → proceeds with original command + } + HookResponse::Allow(json) => { + // Exit 0, JSON to stdout, NO stderr + // CRITICAL: No stderr at exit 0 (would cause fail-open) + writeln!(io::stdout(), "{json}")?; + } + HookResponse::Deny(json, reason) => { + // Exit 2, JSON to stdout, reason to stderr + // This is the ONLY path that writes to stderr (valid at exit 2 only) + // + // Dual-path deny for bug #4669 workaround: + // - stdout: JSON with permissionDecision "deny" (documented path, but ignored) + // - stderr: plain text reason (actual blocking mechanism via exit 2) + // - exit 2: Triggers Claude Code to read stderr and block tool + writeln!(io::stdout(), "{json}")?; + writeln!(io::stderr(), "{reason}")?; + std::process::exit(2); + } + } + Ok(()) +} + +/// Inner handler: pure decision logic, no I/O. +/// Returns `HookResponse` for `run()` to output. +fn run_inner() -> anyhow::Result { + let mut buffer = String::new(); + io::stdin().read_to_string(&mut buffer)?; + + let payload: ClaudePayload = match serde_json::from_str(&buffer) { + Ok(p) => p, + Err(_) => return Ok(HookResponse::NoOpinion), + }; + + let cmd = match extract_command(&payload) { + Some(c) => c, + None => return Ok(HookResponse::NoOpinion), + }; + + if is_hook_disabled() || should_passthrough(cmd) { + return Ok(HookResponse::NoOpinion); + } + + let result = check_for_hook(cmd, "claude"); + + match result { + HookResult::Rewrite(new_cmd) => { + // Preserve all original tool_input fields, only replace "command" + // Shared helper (DRY with gemini_hook.rs via hook.rs) + let updated = update_command_in_tool_input(payload.tool_input, new_cmd); + + let response = allow_response("RTK safety rewrite applied".into(), Some(updated)); + let json = serde_json::to_string(&response)?; + Ok(HookResponse::Allow(json)) + } + HookResult::Blocked(msg) => { + let response = deny_response(msg.clone()); + let json = serde_json::to_string(&response)?; + Ok(HookResponse::Deny(json, msg)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ========================================================================= + // CLAUDE CODE WIRE FORMAT CONFORMANCE + // https://docs.anthropic.com/en/docs/claude-code/hooks + // + // These tests verify exact JSON field names per the Claude Code spec. + // A wrong field name means Claude Code silently ignores the response. + // ========================================================================= + + // --- Output: field name conformance --- + + #[test] + fn test_output_uses_hook_specific_output() { + // Claude expects "hookSpecificOutput" (camelCase), NOT "hook_specific_output" + let response = allow_response("test".into(), None); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert!( + parsed.get("hookSpecificOutput").is_some(), + "must have 'hookSpecificOutput' field" + ); + assert!( + parsed.get("hook_specific_output").is_none(), + "must NOT have snake_case field" + ); + } + + #[test] + fn test_output_uses_permission_decision() { + // Claude expects "permissionDecision", NOT "decision" + let response = allow_response("test".into(), None); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + let output = &parsed["hookSpecificOutput"]; + + assert!( + output.get("permissionDecision").is_some(), + "must have 'permissionDecision' field" + ); + assert!( + output.get("decision").is_none(), + "must NOT have Gemini-style 'decision' field" + ); + } + + #[test] + fn test_output_uses_permission_decision_reason() { + let response = deny_response("blocked".into()); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + let output = &parsed["hookSpecificOutput"]; + + assert!( + output.get("permissionDecisionReason").is_some(), + "must have 'permissionDecisionReason'" + ); + } + + #[test] + fn test_output_uses_hook_event_name() { + let response = allow_response("test".into(), None); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed["hookSpecificOutput"]["hookEventName"], "PreToolUse"); + } + + #[test] + fn test_output_uses_updated_input_for_rewrite() { + let input = serde_json::json!({"command": "rtk run -c 'git status'"}); + let response = allow_response("rewrite".into(), Some(input)); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + assert!( + parsed["hookSpecificOutput"].get("updatedInput").is_some(), + "must have 'updatedInput' for rewrites" + ); + } + + #[test] + fn test_allow_omits_updated_input_when_none() { + let response = allow_response("passthrough".into(), None); + let json = serde_json::to_string(&response).unwrap(); + + assert!( + !json.contains("updatedInput"), + "updatedInput must be omitted when None" + ); + } + + #[test] + fn test_rewrite_preserves_other_tool_input_fields() { + let original = serde_json::json!({ + "command": "git status", + "timeout": 30, + "description": "check repo" + }); + + let mut updated = original.clone(); + if let Some(obj) = updated.as_object_mut() { + obj.insert( + "command".into(), + Value::String("rtk run -c 'git status'".into()), + ); + } + + assert_eq!(updated["timeout"], 30); + assert_eq!(updated["description"], "check repo"); + assert_eq!(updated["command"], "rtk run -c 'git status'"); + } + + #[test] + fn test_output_decision_values() { + let allow = allow_response("test".into(), None); + let deny = deny_response("blocked".into()); + + let allow_json: Value = + serde_json::from_str(&serde_json::to_string(&allow).unwrap()).unwrap(); + let deny_json: Value = + serde_json::from_str(&serde_json::to_string(&deny).unwrap()).unwrap(); + + assert_eq!( + allow_json["hookSpecificOutput"]["permissionDecision"], + "allow" + ); + assert_eq!( + deny_json["hookSpecificOutput"]["permissionDecision"], + "deny" + ); + } + + // --- Input: payload parsing --- + + #[test] + fn test_input_extra_fields_ignored() { + // Claude sends session_id, tool_name, transcript_path, etc. + let json = r#"{"tool_input": {"command": "ls"}, "tool_name": "Bash", "session_id": "abc-123", "session_cwd": "/tmp", "transcript_path": "/path/to/transcript.jsonl"}"#; + let payload: ClaudePayload = serde_json::from_str(json).unwrap(); + assert_eq!(extract_command(&payload), Some("ls")); + } + + #[test] + fn test_input_tool_input_is_object() { + let json = r#"{"tool_input": {"command": "git status", "timeout": 30}}"#; + let payload: ClaudePayload = serde_json::from_str(json).unwrap(); + let input = payload.tool_input.unwrap(); + assert_eq!(input["command"].as_str().unwrap(), "git status"); + assert_eq!(input["timeout"].as_i64().unwrap(), 30); + } + + // --- Guard function tests --- + + #[test] + fn test_extract_command_basic() { + let payload: ClaudePayload = + serde_json::from_str(r#"{"tool_input": {"command": "git status"}}"#).unwrap(); + assert_eq!(extract_command(&payload), Some("git status")); + } + + #[test] + fn test_extract_command_missing_tool_input() { + let payload: ClaudePayload = serde_json::from_str(r#"{}"#).unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_extract_command_missing_command_field() { + let payload: ClaudePayload = + serde_json::from_str(r#"{"tool_input": {"cwd": "/tmp"}}"#).unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_extract_command_empty_string() { + let payload: ClaudePayload = + serde_json::from_str(r#"{"tool_input": {"command": ""}}"#).unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_shared_should_passthrough_rtk_prefix() { + assert!(should_passthrough("rtk run -c 'ls'")); + assert!(should_passthrough("rtk cargo test")); + assert!(should_passthrough("/usr/local/bin/rtk run -c 'ls'")); + } + + #[test] + fn test_shared_should_passthrough_heredoc() { + assert!(should_passthrough("cat <(input); + } + } + + // --- Fail-open behavior --- + + #[test] + fn test_run_inner_returns_no_opinion_for_empty_payload() { + // "{}" has no tool_input → no command → NoOpinion + let payload: ClaudePayload = serde_json::from_str("{}").unwrap(); + assert_eq!(extract_command(&payload), None); + } + + #[test] + fn test_shared_is_hook_disabled_hook_enabled_zero() { + std::env::set_var("RTK_HOOK_ENABLED", "0"); + assert!(is_hook_disabled()); + std::env::remove_var("RTK_HOOK_ENABLED"); + } + + #[test] + fn test_shared_is_hook_disabled_rtk_active() { + std::env::set_var("RTK_ACTIVE", "1"); + assert!(is_hook_disabled()); + std::env::remove_var("RTK_ACTIVE"); + } + + // --- Integration: Bug #4669 workaround verification --- + + #[test] + fn test_deny_response_includes_reason_for_stderr() { + // Bug #4669 workaround: deny must provide plain text reason + // that can be output to stderr alongside the JSON stdout. + // The msg is cloned for both paths in run_inner(). + let msg = "RTK: cat is blocked (use rtk read instead)"; + let response = deny_response(msg.to_string()); + let json = serde_json::to_string(&response).unwrap(); + let parsed: Value = serde_json::from_str(&json).unwrap(); + + // JSON stdout path + assert_eq!(parsed["hookSpecificOutput"]["permissionDecision"], "deny"); + assert_eq!( + parsed["hookSpecificOutput"]["permissionDecisionReason"], + msg + ); + // The same msg string is used for stderr in run() via HookResponse::Deny + } + + // Note: Integration tests for check_for_hook() safety decisions are in + // src/cmd/hook.rs (test_safe_commands_rewrite, test_blocked_commands, etc.) + // to avoid duplication. This module focuses on Claude Code wire format. +} diff --git a/src/cmd/exec.rs b/src/cmd/exec.rs new file mode 100644 index 0000000..5eabd63 --- /dev/null +++ b/src/cmd/exec.rs @@ -0,0 +1,426 @@ +//! Command executor: runs simple chains natively, delegates complex shell to /bin/sh. + +use anyhow::{Context, Result}; +use std::process::{Command, Stdio}; + +use super::{analysis, builtins, filters, lexer}; +use crate::tracking; + +/// Check if RTK is already active (recursion guard) +fn is_rtk_active() -> bool { + std::env::var("RTK_ACTIVE").is_ok() +} + +/// RAII guard: sets RTK_ACTIVE on creation, removes on drop (even on panic). +struct RtkActiveGuard; + +impl RtkActiveGuard { + fn new() -> Self { + std::env::set_var("RTK_ACTIVE", "1"); + RtkActiveGuard + } +} + +impl Drop for RtkActiveGuard { + fn drop(&mut self) { + std::env::remove_var("RTK_ACTIVE"); + } +} + +/// Execute a raw command string +pub fn execute(raw: &str, verbose: u8) -> Result { + // Recursion guard + if is_rtk_active() { + if verbose > 0 { + eprintln!("rtk: Recursion detected, passing through"); + } + return run_passthrough(raw, verbose); + } + + // Handle empty input + if raw.trim().is_empty() { + return Ok(true); + } + + let _guard = RtkActiveGuard::new(); + execute_inner(raw, verbose) +} + +fn execute_inner(raw: &str, verbose: u8) -> Result { + // PR 2 adds: crate::config::rules::try_remap() alias expansion + + let tokens = lexer::tokenize(raw); + + // === STEP 1: Decide Native vs Passthrough === + if analysis::needs_shell(&tokens) { + // PR 2 adds: safety::check_raw(raw) before passthrough + return run_passthrough(raw, verbose); + } + + // === STEP 2: Parse into native command chain === + let commands = + analysis::parse_chain(tokens).map_err(|e| anyhow::anyhow!("Parse error: {}", e))?; + + // === STEP 3: Execute native chain === + run_native(&commands, verbose) +} + +/// Run commands in native mode (iterate, check safety, filter output) +fn run_native(commands: &[analysis::NativeCommand], verbose: u8) -> Result { + let mut last_success = true; + let mut prev_operator: Option<&str> = None; + + for cmd in commands { + // === SHORT-CIRCUIT LOGIC === + // Check if we should run based on PREVIOUS operator and result + // The operator stored in cmd is the one AFTER it, so we use prev_operator + if !analysis::should_run(prev_operator, last_success) { + // For && with failure or || with success, skip this command + prev_operator = cmd.operator.as_deref(); + continue; + } + + // === RECURSION PREVENTION === + // Handle "rtk run" or "rtk" binary specially + if cmd.binary == "rtk" && cmd.args.first().map(|s| s.as_str()) == Some("run") { + // Flatten: execute the inner command directly + // rtk run -c "git status" → args = ["run", "-c", "git status"] + let inner = if cmd.args.get(1).map(|s| s.as_str()) == Some("-c") { + cmd.args.get(2).cloned().unwrap_or_default() + } else { + cmd.args.get(1).cloned().unwrap_or_default() + }; + if verbose > 0 { + eprintln!("rtk: Flattening nested rtk run"); + } + return execute(&inner, verbose); + } + // Other rtk commands: spawn as external (they have their own filters) + + // PR 2 adds: safety::check() dispatch block + + // === BUILTINS === + if builtins::is_builtin(&cmd.binary) { + last_success = builtins::execute(&cmd.binary, &cmd.args)?; + prev_operator = cmd.operator.as_deref(); + continue; + } + + // === EXTERNAL COMMAND WITH FILTERING === + last_success = spawn_with_filter(&cmd.binary, &cmd.args, verbose)?; + prev_operator = cmd.operator.as_deref(); + } + + Ok(last_success) +} + +/// Spawn external command and apply appropriate filter +fn spawn_with_filter(binary: &str, args: &[String], _verbose: u8) -> Result { + let timer = tracking::TimedExecution::start(); + + // Try to find the binary in PATH + let binary_path = match which::which(binary) { + Ok(path) => path, + Err(_) => { + // Binary not found + eprintln!("rtk: {}: command not found", binary); + return Ok(false); + } + }; + + // Use wait_with_output() to avoid deadlock when child output exceeds + // pipe buffer (~64KB Linux, ~16KB macOS). This reads stdout/stderr in + // separate threads internally before calling wait(). + let output = Command::new(&binary_path) + .args(args) + .stdin(Stdio::inherit()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .with_context(|| format!("Failed to execute: {}", binary))?; + + let raw_out = String::from_utf8_lossy(&output.stdout); + let raw_err = String::from_utf8_lossy(&output.stderr); + + // Determine filter type and apply + let filter_type = filters::get_filter_type(binary); + let filtered_out = filters::apply_to_string(filter_type, &raw_out); + let filtered_err = crate::utils::strip_ansi(&raw_err); + + // Print filtered output + print!("{}", filtered_out); + eprint!("{}", filtered_err); + + // Track usage with raw vs filtered for accurate savings + let raw_output = format!("{}{}", raw_out, raw_err); + let filtered_output = format!("{}{}", filtered_out, filtered_err); + timer.track( + &format!("{} {}", binary, args.join(" ")), + &format!("rtk run {} {}", binary, args.join(" ")), + &raw_output, + &filtered_output, + ); + + Ok(output.status.success()) +} + +/// Run command via system shell (passthrough mode) +pub fn run_passthrough(raw: &str, verbose: u8) -> Result { + if verbose > 0 { + eprintln!("rtk: Passthrough mode for complex command"); + } + + let timer = tracking::TimedExecution::start(); + + let shell = if cfg!(windows) { "cmd" } else { "sh" }; + let flag = if cfg!(windows) { "/C" } else { "-c" }; + + let output = Command::new(shell) + .arg(flag) + .arg(raw) + .stdin(Stdio::inherit()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .context("Failed to execute passthrough")?; + + let raw_out = String::from_utf8_lossy(&output.stdout); + let raw_err = String::from_utf8_lossy(&output.stderr); + + // Basic filtering even in passthrough (strip ANSI) + let filtered_out = crate::utils::strip_ansi(&raw_out); + let filtered_err = crate::utils::strip_ansi(&raw_err); + print!("{}", filtered_out); + eprint!("{}", filtered_err); + + let raw_output = format!("{}{}", raw_out, raw_err); + let filtered_output = format!("{}{}", filtered_out, filtered_err); + timer.track( + raw, + &format!("rtk passthrough {}", raw), + &raw_output, + &filtered_output, + ); + + Ok(output.status.success()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cmd::test_helpers::EnvGuard; + + // === RAII GUARD TESTS === + + #[test] + fn test_is_rtk_active_default() { + let _env = EnvGuard::new(); + assert!(!is_rtk_active()); + } + + #[test] + fn test_raii_guard_sets_and_clears() { + let _env = EnvGuard::new(); + { + let _guard = RtkActiveGuard::new(); + assert!(is_rtk_active()); + } + assert!( + !is_rtk_active(), + "RTK_ACTIVE must be cleared when guard drops" + ); + } + + #[test] + fn test_raii_guard_clears_on_panic() { + let _env = EnvGuard::new(); + let result = std::panic::catch_unwind(|| { + let _guard = RtkActiveGuard::new(); + assert!(is_rtk_active()); + panic!("simulated panic"); + }); + assert!(result.is_err()); + assert!( + !is_rtk_active(), + "RTK_ACTIVE must be cleared even after panic" + ); + } + + // === EXECUTE TESTS === + + #[test] + fn test_execute_empty() { + let result = execute("", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_whitespace_only() { + let result = execute(" ", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_simple_command() { + let result = execute("echo hello", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_builtin_cd() { + let original = std::env::current_dir().unwrap(); + let result = execute("cd /tmp", 0).unwrap(); + assert!(result); + // On macOS, /tmp might be a symlink to /private/tmp + // Just verify the command succeeded (the cd happened) + let _ = std::env::set_current_dir(&original); + } + + #[test] + fn test_execute_builtin_pwd() { + let result = execute("pwd", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_builtin_true() { + let result = execute("true", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_builtin_false() { + let result = execute("false", 0).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_chain_and_success() { + let result = execute("true && echo success", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_chain_and_failure() { + let result = execute("false && echo should_not_run", 0).unwrap(); + // Chain stops at false, so result is false + assert!(!result); + } + + #[test] + fn test_execute_chain_or_success() { + let result = execute("true || echo should_not_run", 0).unwrap(); + // true succeeds, || doesn't run second command + assert!(result); + } + + #[test] + fn test_execute_chain_or_failure() { + let result = execute("false || echo fallback", 0).unwrap(); + // false fails, || runs fallback + assert!(result); + } + + #[test] + fn test_execute_chain_semicolon() { + let result = execute("true ; false", 0).unwrap(); + // Both run, last result is false + assert!(!result); + } + + #[test] + fn test_execute_passthrough_for_glob() { + let result = execute("echo *", 0).unwrap(); + // Should work via passthrough + assert!(result); + } + + #[test] + fn test_execute_passthrough_for_pipe() { + let result = execute("echo hello | cat", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_quoted_operator() { + let result = execute(r#"echo "hello && world""#, 0).unwrap(); + assert!(result); + } + + #[test] + fn test_execute_binary_not_found() { + let result = execute("nonexistent_command_xyz_123", 0).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_chain_and_three_commands() { + // 3-command chain: true succeeds, false fails, stops before third + let result = execute("true && false && true", 0).unwrap(); + assert!(!result); + } + + #[test] + fn test_execute_chain_semicolon_last_wins() { + // Semicolon runs all; last result (true) determines outcome + let result = execute("false ; true", 0).unwrap(); + assert!(result); + } + + // === INTEGRATION TESTS (moved from edge_cases.rs) === + + #[test] + fn test_chain_mixed_operators() { + // false -> || runs true -> true && runs echo + let result = execute("false || true && echo works", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_passthrough_redirect() { + let result = execute("echo test > /dev/null", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_integration_cd_tilde() { + let original = std::env::current_dir().unwrap(); + let result = execute("cd ~", 0).unwrap(); + assert!(result); + let _ = std::env::set_current_dir(&original); + } + + #[test] + fn test_integration_export() { + let result = execute("export TEST_VAR=value", 0).unwrap(); + assert!(result); + std::env::remove_var("TEST_VAR"); + } + + #[test] + fn test_integration_env_prefix() { + let result = execute("TEST=1 echo hello", 0); + assert!(result.is_ok()); + } + + #[test] + fn test_integration_dash_args() { + let result = execute("echo --help -v --version", 0).unwrap(); + assert!(result); + } + + #[test] + fn test_integration_quoted_empty() { + let result = execute(r#"echo """#, 0).unwrap(); + assert!(result); + } + + // === RECURRENCE PREVENTION TESTS === + + #[test] + fn test_execute_rtk_recursion() { + // This should flatten, not infinitely recurse + let result = execute("rtk run \"echo hello\"", 0); + assert!(result.is_ok()); + } +} diff --git a/src/cmd/filters.rs b/src/cmd/filters.rs new file mode 100644 index 0000000..0bd9dea --- /dev/null +++ b/src/cmd/filters.rs @@ -0,0 +1,212 @@ +//! Filter Registry — basic token reduction for `rtk run` native execution. +//! +//! This module provides **basic filtering (20-40% savings)** for commands +//! executed through rtk run. It is a **fallback** for commands +//! without dedicated RTK implementations. +//! +//! For **specialized filtering (60-90% savings)**, use dedicated modules: +//! - `src/git.rs` — git commands (diff, log, status, etc.) +//! - `src/runner.rs` — test commands (cargo test, pytest, etc.) +//! - `src/grep_cmd.rs` — code search (grep, ripgrep) +//! - `src/pnpm_cmd.rs` — package managers + +use crate::utils; + +/// Filter types for different command categories +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum FilterType { + Git, + Cargo, + Test, + Pnpm, + Npm, + Generic, + None, +} + +/// Determine which filter to apply based on binary name +pub fn get_filter_type(binary: &str) -> FilterType { + match binary { + "git" => FilterType::Git, + "cargo" => FilterType::Cargo, + "npm" | "npx" => FilterType::Npm, + "pnpm" => FilterType::Pnpm, + "pytest" | "go" | "vitest" | "jest" | "mocha" => FilterType::Test, + "ls" | "find" | "grep" | "rg" | "fd" => FilterType::Generic, + _ => FilterType::None, + } +} + +/// Apply filter to already-captured string output +pub fn apply_to_string(filter: FilterType, output: &str) -> String { + match filter { + FilterType::Git => utils::strip_ansi(output), + FilterType::Cargo => filter_cargo_output(output), + FilterType::Test => filter_test_output(output), + FilterType::Generic => truncate_lines(output, 100), + FilterType::Npm | FilterType::Pnpm => utils::strip_ansi(output), + FilterType::None => output.to_string(), + } +} + +/// Filter cargo output: remove verbose "Compiling" lines +fn filter_cargo_output(output: &str) -> String { + output + .lines() + .filter(|line| { + let line = line.trim(); + !line.starts_with("Compiling ") || line.contains("error") || line.contains("warning") + }) + .collect::>() + .join("\n") +} + +/// Filter test output: remove passing tests, keep failures +fn filter_test_output(output: &str) -> String { + output + .lines() + .filter(|line| { + let line = line.trim(); + line.contains("FAILED") + || line.contains("error") + || line.contains("Error") + || line.contains("failed") + || line.contains("test result:") + || line.starts_with("----") + }) + .collect::>() + .join("\n") +} + +/// Truncate output to max lines +fn truncate_lines(output: &str, max_lines: usize) -> String { + let lines: Vec<&str> = output.lines().collect(); + if lines.len() <= max_lines { + output.to_string() + } else { + let truncated: Vec<&str> = lines.iter().take(max_lines).copied().collect(); + format!( + "{}\n... ({} more lines)", + truncated.join("\n"), + lines.len() - max_lines + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // === GET_FILTER_TYPE TESTS === + + #[test] + fn test_filter_type_git() { + assert_eq!(get_filter_type("git"), FilterType::Git); + } + + #[test] + fn test_filter_type_cargo() { + assert_eq!(get_filter_type("cargo"), FilterType::Cargo); + } + + #[test] + fn test_filter_type_npm() { + assert_eq!(get_filter_type("npm"), FilterType::Npm); + assert_eq!(get_filter_type("npx"), FilterType::Npm); + } + + #[test] + fn test_filter_type_generic() { + assert_eq!(get_filter_type("ls"), FilterType::Generic); + assert_eq!(get_filter_type("grep"), FilterType::Generic); + } + + #[test] + fn test_filter_type_none() { + assert_eq!(get_filter_type("unknown_command"), FilterType::None); + } + + // === STRIP_ANSI TESTS (now testing utils::strip_ansi) === + + #[test] + fn test_strip_ansi_no_codes() { + assert_eq!(utils::strip_ansi("hello world"), "hello world"); + } + + #[test] + fn test_strip_ansi_color() { + assert_eq!(utils::strip_ansi("\x1b[32mgreen\x1b[0m"), "green"); + } + + #[test] + fn test_strip_ansi_bold() { + assert_eq!(utils::strip_ansi("\x1b[1mbold\x1b[0m"), "bold"); + } + + #[test] + fn test_strip_ansi_multiple() { + assert_eq!( + utils::strip_ansi("\x1b[31mred\x1b[0m \x1b[32mgreen\x1b[0m"), + "red green" + ); + } + + #[test] + fn test_strip_ansi_complex() { + assert_eq!( + utils::strip_ansi("\x1b[1;31;42mbold red on green\x1b[0m"), + "bold red on green" + ); + } + + // === FILTER_CARGO_OUTPUT TESTS === + + #[test] + fn test_filter_cargo_keeps_errors() { + let input = "Compiling dep1\nerror: something wrong\nCompiling dep2"; + let output = filter_cargo_output(input); + assert!(output.contains("error")); + assert!(!output.contains("Compiling dep1")); + } + + #[test] + fn test_filter_cargo_keeps_warnings() { + let input = "Compiling dep1\nwarning: unused variable\nCompiling dep2"; + let output = filter_cargo_output(input); + assert!(output.contains("warning")); + } + + // === TRUNCATE_LINES TESTS === + + #[test] + fn test_truncate_short() { + let input = "line1\nline2\nline3"; + let output = truncate_lines(input, 10); + assert_eq!(output, input); + } + + #[test] + fn test_truncate_long() { + let input = "line1\nline2\nline3\nline4\nline5"; + let output = truncate_lines(input, 3); + assert!(output.contains("line3")); + assert!(!output.contains("line4")); + assert!(output.contains("2 more lines")); + } + + // === APPLY_TO_STRING TESTS === + + #[test] + fn test_apply_to_string_none() { + let input = "hello world"; + let output = apply_to_string(FilterType::None, input); + assert_eq!(output, input); + } + + #[test] + fn test_apply_to_string_git() { + let input = "\x1b[32mgreen\x1b[0m"; + let output = apply_to_string(FilterType::Git, input); + assert_eq!(output, "green"); + } +} diff --git a/src/cmd/hook.rs b/src/cmd/hook.rs new file mode 100644 index 0000000..2d8c0f8 --- /dev/null +++ b/src/cmd/hook.rs @@ -0,0 +1,569 @@ +//! Hook protocol for Claude Code and Gemini support. +//! +//! This module provides **shared decision logic** for both Claude Code and Gemini CLI hooks. +//! Protocol-specific I/O handling lives in `claude_hook.rs` and `gemini_hook.rs`. +//! +//! ## Architecture: Separation of Concerns +//! +//! ```text +//! main.rs (CAN use println! - normal RTK behavior) +//! ↓ +//! Commands::Hook match +//! ├─→ HookCommands::Check → hook::check_for_hook() (THIS MODULE - CAN use println!) +//! ├─→ HookCommands::Claude → claude_hook::run() [DENY ENFORCED - see claude_hook.rs:52] +//! └─→ HookCommands::Gemini → gemini_hook::run() [DENY ENFORCED - see gemini_hook.rs:42] +//! ``` +//! +//! **I/O Policy Scope:** +//! - **This module (hook.rs)**: CAN use `println!`/`eprintln!` (used by `rtk hook check` text protocol) +//! - **main.rs and all command modules**: CAN use `println!`/`eprintln!` (normal RTK behavior) +//! - **claude_hook.rs, gemini_hook.rs ONLY**: CANNOT use `println!`/`eprintln!` (JSON protocols) +//! +//! The `#![deny(clippy::print_stdout, clippy::print_stderr)]` attribute is applied +//! at the **module boundary** (earliest possible stage) — when control enters +//! `claude_hook::run()` or `gemini_hook::run()`, the deny is enforced. +//! +//! ## Protocol Differences +//! +//! **Claude Code** (`rtk hook check` text protocol): +//! - Success: rewritten command on stdout, exit 0 +//! - Blocked: error message on stderr, exit 2 (blocking error) +//! - Other exit codes: non-blocking errors +//! +//! **Claude Code** (JSON protocol via `claude_hook.rs`): +//! - See `claude_hook.rs` module documentation +//! +//! **Gemini CLI** (JSON protocol via `gemini_hook.rs`): +//! - See `gemini_hook.rs` module documentation + +use super::{analysis, lexer}; +// PR 2 adds: use super::safety; + +/// Hook check result +#[derive(Debug, Clone)] +pub enum HookResult { + /// Command is safe, rewrite to this + Rewrite(String), + /// Command is blocked with this message + Blocked(String), +} + +/// Maximum rewrite depth to prevent infinite recursion from cyclic safety rules. +const MAX_REWRITE_DEPTH: usize = 3; + +/// Check a command for the hook protocol. +/// Returns the rewritten command or an error message. +/// +/// The `_agent` parameter is reserved for future per-agent behavior. +pub fn check_for_hook(raw: &str, _agent: &str) -> HookResult { + check_for_hook_inner(raw, 0) +} + +fn check_for_hook_inner(raw: &str, depth: usize) -> HookResult { + if depth >= MAX_REWRITE_DEPTH { + return HookResult::Blocked("Rewrite loop detected (max depth exceeded)".to_string()); + } + if raw.trim().is_empty() { + return HookResult::Rewrite(raw.to_string()); + } + // PR 2 adds: crate::config::rules::try_remap() alias expansion + // PR 2 adds: safety::check_raw() and safety::check() dispatch + HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))) +} + +// --- Shared guard logic (used by both claude_hook.rs and gemini_hook.rs) --- + +/// Check if hook processing is disabled by environment. +/// +/// Returns true if: +/// - `RTK_HOOK_ENABLED=0` (master toggle off) +/// - `RTK_ACTIVE` is set (recursion prevention — rtk sets this when running commands) +pub fn is_hook_disabled() -> bool { + std::env::var("RTK_HOOK_ENABLED").as_deref() == Ok("0") || std::env::var("RTK_ACTIVE").is_ok() +} + +/// Check if this command should bypass hook processing entirely. +/// +/// Returns true for commands that should not be rewritten: +/// - Already routed through rtk (`rtk ...` or `/path/to/rtk ...`) +/// - Contains heredoc (`<<`) which needs raw shell processing +pub fn should_passthrough(cmd: &str) -> bool { + cmd.starts_with("rtk ") || cmd.contains("/rtk ") || cmd.contains("<<") +} + +/// Replace the command field in a tool_input object, preserving other fields. +/// +/// Used by both claude_hook.rs and gemini_hook.rs when rewriting commands. +/// If tool_input is None or not an object, creates a new object with just the command. +/// +/// # Arguments +/// * `tool_input` - The original tool_input from the hook payload (may be None) +/// * `new_cmd` - The rewritten command string to replace with +/// +/// # Returns +/// A Value with the command field updated, all other fields preserved. +pub fn update_command_in_tool_input( + tool_input: Option, + new_cmd: String, +) -> serde_json::Value { + use serde_json::Value; + let mut updated = tool_input.unwrap_or_else(|| Value::Object(Default::default())); + if let Some(obj) = updated.as_object_mut() { + obj.insert("command".into(), Value::String(new_cmd)); + } + updated +} + +/// Hook output for protocol handlers (claude_hook.rs, gemini_hook.rs). +/// +/// This enum separates decision logic from I/O: `run_inner()` returns a +/// `HookResponse`, and `run()` is the single place that writes to stdout/stderr. +/// Combined with `#[deny(clippy::print_stdout, clippy::print_stderr)]` on the +/// hook modules, this prevents any stray output from corrupting the JSON protocol. +#[derive(Debug, Clone, PartialEq)] +pub enum HookResponse { + /// No opinion — exit 0, no output. Host proceeds normally. + NoOpinion, + /// Allow/rewrite — exit 0, JSON to stdout. + Allow(String), + /// Deny — exit 2, JSON to stdout + reason to stderr. + /// Fields: (stdout_json, stderr_reason) + Deny(String, String), +} + +/// Escape single quotes for shell +fn escape_quotes(s: &str) -> String { + s.replace("'", "'\\''") +} + +/// Format hook result for Claude (text output) +/// +/// Exit codes: +/// - 0: Success, command rewritten/allowed +/// - 2: Blocking error, command should be denied +pub fn format_for_claude(result: HookResult) -> (String, bool, i32) { + match result { + HookResult::Rewrite(cmd) => (cmd, true, 0), + HookResult::Blocked(msg) => (msg, false, 2), // Exit 2 = blocking error per Claude Code spec + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // === TEST HELPERS === + + fn assert_rewrite(input: &str, contains: &str) { + match check_for_hook(input, "claude") { + HookResult::Rewrite(cmd) => assert!( + cmd.contains(contains), + "'{}' rewrite should contain '{}', got '{}'", + input, + contains, + cmd + ), + other => panic!("Expected Rewrite for '{}', got {:?}", input, other), + } + } + + fn assert_blocked(input: &str, contains: &str) { + match check_for_hook(input, "claude") { + HookResult::Blocked(msg) => assert!( + msg.contains(contains), + "'{}' block msg should contain '{}', got '{}'", + input, + contains, + msg + ), + other => panic!("Expected Blocked for '{}', got {:?}", input, other), + } + } + + // === ESCAPE_QUOTES === + + #[test] + fn test_escape_quotes() { + assert_eq!(escape_quotes("hello"), "hello"); + assert_eq!(escape_quotes("it's"), "it'\\''s"); + assert_eq!(escape_quotes("it's a test's"), "it'\\''s a test'\\''s"); + } + + // === EMPTY / WHITESPACE === + + #[test] + fn test_check_empty_and_whitespace() { + match check_for_hook("", "claude") { + HookResult::Rewrite(cmd) => assert!(cmd.is_empty()), + _ => panic!("Expected Rewrite for empty"), + } + match check_for_hook(" ", "claude") { + HookResult::Rewrite(cmd) => assert!(cmd.trim().is_empty()), + _ => panic!("Expected Rewrite for whitespace"), + } + } + + // === COMMANDS THAT SHOULD REWRITE (table-driven) === + + #[test] + fn test_safe_commands_rewrite() { + let cases = [ + ("git status", "rtk run"), + ("ls *.rs", "rtk run"), // shellism passthrough + (r#"git commit -m "Fix && Bug""#, "rtk run"), // quoted operator + ("FOO=bar echo hello", "rtk run"), // env prefix + ("echo `date`", "rtk run"), // backticks + ("echo $(date)", "rtk run"), // subshell + ("echo {a,b}.txt", "rtk run"), // brace expansion + ("echo 'hello!@#$%^&*()'", "rtk run"), // special chars + ("echo '日本語 🎉'", "rtk run"), // unicode + ("cd /tmp && git status", "rtk run"), // chain rewrite + ]; + for (input, expected) in cases { + assert_rewrite(input, expected); + } + // Chain rewrite preserves operator structure + match check_for_hook("cd /tmp && git status", "claude") { + HookResult::Rewrite(cmd) => assert!( + cmd.contains("&&"), + "Chain rewrite must preserve '&&', got '{}'", + cmd + ), + other => panic!("Expected Rewrite for chain, got {:?}", other), + } + // Very long command + assert_rewrite(&format!("echo {}", "a".repeat(1000)), "rtk run"); + } + + // === ENV VAR PREFIX PRESERVATION === + // Ported from old hooks/test-rtk-rewrite.sh Section 2. + // Commands prefixed with KEY=VALUE env vars must not be blocked. + + #[test] + fn test_env_var_prefix_preserved() { + let cases = [ + "GIT_PAGER=cat git status", + "GIT_PAGER=cat git log --oneline -10", + "NODE_ENV=test CI=1 npx vitest run", + "LANG=C ls -la", + "NODE_ENV=test npm run test:e2e", + "COMPOSE_PROJECT_NAME=test docker compose up -d", + "TEST_SESSION_ID=2 npx playwright test --config=foo", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === GLOBAL OPTIONS (PR #99 parity) === + // Commands with global options before subcommands must not be blocked. + // Ported from upstream hooks/rtk-rewrite.sh global option stripping. + + #[test] + fn test_global_options_not_blocked() { + let cases = [ + // Git global options + "git --no-pager status", + "git -C /path/to/project status", + "git -C /path --no-pager log --oneline", + "git --no-optional-locks diff HEAD", + "git --bare log", + // Cargo toolchain prefix + "cargo +nightly test", + "cargo +stable build --release", + // Docker global options + "docker --context prod ps", + "docker -H tcp://host:2375 images", + // Kubectl global options + "kubectl -n kube-system get pods", + "kubectl --context prod describe pod foo", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === SPECIFIC COMMANDS NOT BLOCKED === + // Ported from old hooks/test-rtk-rewrite.sh Sections 1 & 3. + // These commands must pass through (not be blocked by safety rules). + + #[test] + fn test_specific_commands_not_blocked() { + let cases = [ + // Git variants + "git log --oneline -10", + "git diff HEAD", + "git show abc123", + "git add .", + // GitHub CLI + "gh pr list", + "gh api repos/owner/repo", + "gh release list", + // Package managers + "npm run test:e2e", + "npm run build", + "npm test", + // Docker + "docker compose up -d", + "docker compose logs postgrest", + "docker compose down", + "docker run --rm postgres", + "docker exec -it db psql", + // Kubernetes + "kubectl describe pod foo", + "kubectl apply -f deploy.yaml", + // Test runners + "npx playwright test", + "npx prisma migrate", + "cargo test", + // Vitest variants (dedup is internal to rtk run, not hook level) + "vitest", + "vitest run", + "vitest run --reporter=verbose", + "npx vitest run", + "pnpm vitest run --coverage", + // TypeScript + "vue-tsc -b", + "npx vue-tsc --noEmit", + // Utilities + "curl -s https://example.com", + "ls -la", + "grep -rn pattern src/", + "rg pattern src/", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === COMMANDS THAT PASS THROUGH (builtins/unknown) === + // Ported from old hooks/test-rtk-rewrite.sh Section 5. + // These are not blocked — they get wrapped in rtk run -c. + + #[test] + fn test_builtins_not_blocked() { + let cases = [ + "echo hello world", + "cd /tmp", + "mkdir -p foo/bar", + "python3 script.py", + "node -e 'console.log(1)'", + "find . -name '*.ts'", + "tree src/", + "wget https://example.com/file", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === COMPOUND COMMANDS (chained with &&, ||, ;) === + // Shell script only matched FIRST command in a chain. + // Rust hook parses each command independently (#112). + + #[test] + fn test_compound_commands_rewrite() { + let cases = [ + // Basic chains — each command rewritten independently + ("cd /tmp && git status", "&&"), + ("cd dir && git status && git diff", "&&"), + ("git add . && git commit -m msg", "&&"), + // Semicolon chains + ("echo start ; git status ; echo done", ";"), + // Or-chains + ("git pull || echo failed", "||"), + ]; + for (input, operator) in cases { + match check_for_hook(input, "claude") { + HookResult::Rewrite(cmd) => { + assert!(cmd.contains("rtk run"), "'{input}' should rewrite"); + assert!( + cmd.contains(operator), + "'{input}' must preserve '{operator}', got '{cmd}'" + ); + } + other => panic!("Expected Rewrite for '{input}', got {other:?}"), + } + } + } + + // PR 2 adds: test_compound_blocked_in_chain (safety-dependent test) + + #[test] + fn test_compound_quoted_operators_not_split() { + // && inside quotes must NOT split the command + let input = r#"git commit -m "Fix && Bug""#; + match check_for_hook(input, "claude") { + HookResult::Rewrite(cmd) => { + assert!(cmd.contains("rtk run"), "Should rewrite, got '{cmd}'"); + } + other => panic!("Expected Rewrite for quoted &&, got {other:?}"), + } + } + + // PR 2 adds: test_blocked_commands (safety-dependent test) + + // === SHELLISM PASSTHROUGH: cat/sed/head allowed with pipe/redirect === + + #[test] + fn test_token_waste_allowed_in_pipelines() { + let cases = [ + "cat file.txt | grep pattern", + "cat file.txt > output.txt", + "sed 's/old/new/' file.txt > output.txt", + "head -n 10 file.txt | grep pattern", + "for f in *.txt; do cat \"$f\" | grep x; done", + ]; + for input in cases { + assert_rewrite(input, "rtk run"); + } + } + + // === MULTI-AGENT === + + #[test] + fn test_different_agents_same_result() { + for agent in ["claude", "gemini"] { + match check_for_hook("git status", agent) { + HookResult::Rewrite(cmd) => assert!(cmd.contains("rtk run")), + _ => panic!("Expected Rewrite for agent '{}'", agent), + } + } + } + + // === FORMAT_FOR_CLAUDE === + + #[test] + fn test_format_for_claude() { + let (output, success, code) = + format_for_claude(HookResult::Rewrite("rtk run -c 'git status'".to_string())); + assert_eq!(output, "rtk run -c 'git status'"); + assert!(success); + assert_eq!(code, 0); + + let (output, success, code) = + format_for_claude(HookResult::Blocked("Error message".to_string())); + assert_eq!(output, "Error message"); + assert!(!success); + assert_eq!(code, 2); // Exit 2 = blocking error per Claude Code spec + } + + // === RECURSION DEPTH LIMIT === + + #[test] + fn test_rewrite_depth_limit() { + // At max depth → blocked + match check_for_hook_inner("echo hello", MAX_REWRITE_DEPTH) { + HookResult::Blocked(msg) => assert!(msg.contains("loop"), "msg: {}", msg), + _ => panic!("Expected Blocked at max depth"), + } + // At depth 0 → normal rewrite + match check_for_hook_inner("echo hello", 0) { + HookResult::Rewrite(cmd) => assert!(cmd.contains("rtk run")), + _ => panic!("Expected Rewrite at depth 0"), + } + } + + // ========================================================================= + // CLAUDE CODE WIRE FORMAT CONFORMANCE + // https://docs.anthropic.com/en/docs/claude-code/hooks + // + // Claude Code hook protocol: + // - Rewrite: command on stdout, exit code 0 + // - Block: message on stderr, exit code 2 + // - Other exit codes are non-blocking errors + // + // format_for_claude() is the boundary between HookResult and the wire. + // These tests verify it produces the exact contract Claude Code expects. + // ========================================================================= + + #[test] + fn test_claude_rewrite_exit_code_is_zero() { + let (_, _, code) = format_for_claude(HookResult::Rewrite("rtk run -c 'ls'".into())); + assert_eq!(code, 0, "Rewrite must exit 0 (success)"); + } + + #[test] + fn test_claude_block_exit_code_is_two() { + let (_, _, code) = format_for_claude(HookResult::Blocked("denied".into())); + assert_eq!( + code, 2, + "Block must exit 2 (blocking error per Claude Code spec)" + ); + } + + #[test] + fn test_claude_rewrite_output_is_command_text() { + // Claude Code reads stdout as the rewritten command — must be plain text, not JSON + let (output, success, _) = + format_for_claude(HookResult::Rewrite("rtk run -c 'git status'".into())); + assert_eq!(output, "rtk run -c 'git status'"); + assert!(success); + // Must NOT be JSON + assert!( + !output.starts_with('{'), + "Rewrite output must be plain text, not JSON" + ); + } + + #[test] + fn test_claude_block_output_is_human_message() { + // Claude Code reads stderr for the block reason + let (output, success, _) = + format_for_claude(HookResult::Blocked("Use Read tool instead".into())); + assert_eq!(output, "Use Read tool instead"); + assert!(!success); + // Must NOT be JSON + assert!( + !output.starts_with('{'), + "Block output must be plain text, not JSON" + ); + } + + #[test] + fn test_claude_rewrite_success_flag_true() { + let (_, success, _) = format_for_claude(HookResult::Rewrite("cmd".into())); + assert!(success, "Rewrite must set success=true"); + } + + #[test] + fn test_claude_block_success_flag_false() { + let (_, success, _) = format_for_claude(HookResult::Blocked("msg".into())); + assert!(!success, "Block must set success=false"); + } + + #[test] + fn test_claude_exit_codes_not_one() { + // Exit code 1 means non-blocking error in Claude Code — we must never use it + let (_, _, rewrite_code) = format_for_claude(HookResult::Rewrite("cmd".into())); + let (_, _, block_code) = format_for_claude(HookResult::Blocked("msg".into())); + assert_ne!( + rewrite_code, 1, + "Exit code 1 is non-blocking error, not valid for rewrite" + ); + assert_ne!( + block_code, 1, + "Exit code 1 is non-blocking error, not valid for block" + ); + } + + // === CROSS-PROTOCOL: Same decision for both agents === + + #[test] + fn test_cross_protocol_safe_command_allowed_by_both() { + // Both Claude and Gemini must allow the same safe commands + for cmd in ["git status", "cargo test", "ls -la", "echo hello"] { + let claude = check_for_hook(cmd, "claude"); + let gemini = check_for_hook(cmd, "gemini"); + match (&claude, &gemini) { + (HookResult::Rewrite(_), HookResult::Rewrite(_)) => {} + _ => panic!( + "'{}': Claude={:?}, Gemini={:?} — both should Rewrite", + cmd, claude, gemini + ), + } + } + } + + // PR 2 adds: test_cross_protocol_blocked_command_denied_by_both (safety-dependent test) +} diff --git a/src/cmd/lexer.rs b/src/cmd/lexer.rs new file mode 100644 index 0000000..5f820bc --- /dev/null +++ b/src/cmd/lexer.rs @@ -0,0 +1,474 @@ +//! State-machine lexer that respects quotes and escapes. +//! Critical: `git commit -m "Fix && Bug"` must NOT split on && + +#[derive(Debug, PartialEq, Clone)] +pub enum TokenKind { + Arg, // Regular argument + Operator, // &&, ||, ; + Pipe, // | + Redirect, // >, >>, <, 2> + Shellism, // *, $, `, (, ), {, } - forces passthrough +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ParsedToken { + pub kind: TokenKind, + pub value: String, // The actual string value +} + +/// Tokenize input with quote awareness. +/// Returns Vec of parsed tokens. +pub fn tokenize(input: &str) -> Vec { + let mut tokens = Vec::new(); + let mut current = String::new(); + let mut chars = input.chars().peekable(); + + let mut quote: Option = None; // None, Some('\''), Some('"') + let mut escaped = false; + + while let Some(c) = chars.next() { + // Handle escape sequences (but NOT inside single quotes) + if escaped { + current.push(c); + escaped = false; + continue; + } + if c == '\\' && quote != Some('\'') { + escaped = true; + current.push(c); + continue; + } + + // Handle quotes + if let Some(q) = quote { + if c == q { + quote = None; // Close quote + } + current.push(c); + continue; + } + if c == '\'' || c == '"' { + quote = Some(c); + current.push(c); + continue; + } + + // Outside quotes - handle operators and shellisms + match c { + // Shellisms force passthrough (includes ! for history expansion/negation) + '*' | '?' | '$' | '`' | '(' | ')' | '{' | '}' | '!' => { + flush_arg(&mut tokens, &mut current); + tokens.push(ParsedToken { + kind: TokenKind::Shellism, + value: c.to_string(), + }); + } + // Operators + '&' | '|' | ';' | '>' | '<' => { + flush_arg(&mut tokens, &mut current); + + let mut op = c.to_string(); + // Lookahead for double-char operators + if let Some(&next) = chars.peek() { + if (next == c && c != ';' && c != '<') || (c == '>' && next == '>') { + op.push(chars.next().unwrap()); + } + } + + let kind = match op.as_str() { + "&&" | "||" | ";" => TokenKind::Operator, + "|" => TokenKind::Pipe, + "&" => TokenKind::Shellism, // Background job needs real shell + _ => TokenKind::Redirect, + }; + tokens.push(ParsedToken { kind, value: op }); + } + // Whitespace delimits arguments + c if c.is_whitespace() => { + flush_arg(&mut tokens, &mut current); + } + // Regular character + _ => current.push(c), + } + } + + // Handle unclosed quote (treat remaining as arg, don't panic) + flush_arg(&mut tokens, &mut current); + tokens +} + +fn flush_arg(tokens: &mut Vec, current: &mut String) { + let trimmed = current.trim(); + if !trimmed.is_empty() { + tokens.push(ParsedToken { + kind: TokenKind::Arg, + value: trimmed.to_string(), + }); + } + current.clear(); +} + +/// Strip quotes from a token value +pub fn strip_quotes(s: &str) -> String { + let chars: Vec = s.chars().collect(); + if chars.len() >= 2 + && ((chars[0] == '"' && chars[chars.len() - 1] == '"') + || (chars[0] == '\'' && chars[chars.len() - 1] == '\'')) + { + return chars[1..chars.len() - 1].iter().collect(); + } + s.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + // === BASIC FUNCTIONALITY TESTS === + + #[test] + fn test_simple_command() { + let tokens = tokenize("git status"); + assert_eq!(tokens.len(), 2); + assert_eq!(tokens[0].kind, TokenKind::Arg); + assert_eq!(tokens[0].value, "git"); + assert_eq!(tokens[1].value, "status"); + } + + #[test] + fn test_command_with_args() { + let tokens = tokenize("git commit -m message"); + assert_eq!(tokens.len(), 4); + assert_eq!(tokens[0].value, "git"); + assert_eq!(tokens[1].value, "commit"); + assert_eq!(tokens[2].value, "-m"); + assert_eq!(tokens[3].value, "message"); + } + + // === QUOTE HANDLING TESTS === + + #[test] + fn test_quoted_operator_not_split() { + let tokens = tokenize(r#"git commit -m "Fix && Bug""#); + // && inside quotes should NOT be an Operator token + assert!(!tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == "&&")); + assert!(tokens.iter().any(|t| t.value.contains("Fix && Bug"))); + } + + #[test] + fn test_single_quoted_string() { + let tokens = tokenize("echo 'hello world'"); + assert!(tokens.iter().any(|t| t.value == "'hello world'")); + } + + #[test] + fn test_double_quoted_string() { + let tokens = tokenize("echo \"hello world\""); + assert!(tokens.iter().any(|t| t.value == "\"hello world\"")); + } + + #[test] + fn test_empty_quoted_string() { + let tokens = tokenize("echo \"\""); + // Should have echo and "" + assert!(tokens.iter().any(|t| t.value == "\"\"")); + } + + #[test] + fn test_nested_quotes() { + let tokens = tokenize(r#"echo "outer 'inner' outer""#); + assert!(tokens.iter().any(|t| t.value.contains("'inner'"))); + } + + #[test] + fn test_strip_quotes_double() { + assert_eq!(strip_quotes("\"hello\""), "hello"); + } + + #[test] + fn test_strip_quotes_single() { + assert_eq!(strip_quotes("'hello'"), "hello"); + } + + #[test] + fn test_strip_quotes_none() { + assert_eq!(strip_quotes("hello"), "hello"); + } + + #[test] + fn test_strip_quotes_mismatched() { + assert_eq!(strip_quotes("\"hello'"), "\"hello'"); + } + + // === ESCAPE HANDLING TESTS === + + #[test] + fn test_escaped_space() { + let tokens = tokenize("echo hello\\ world"); + // Escaped space should be part of the arg + assert!(tokens.iter().any(|t| t.value.contains("hello"))); + } + + #[test] + fn test_backslash_in_single_quotes() { + // In single quotes, backslash is literal + let tokens = tokenize(r#"echo 'hello\nworld'"#); + assert!(tokens.iter().any(|t| t.value.contains(r#"\n"#))); + } + + #[test] + fn test_escaped_quote_in_double() { + let tokens = tokenize(r#"echo "hello\"world""#); + assert!(tokens.iter().any(|t| t.value.contains("hello"))); + } + + // === EDGE CASE TESTS === + + #[test] + fn test_empty_input() { + let tokens = tokenize(""); + assert!(tokens.is_empty()); + } + + #[test] + fn test_whitespace_only() { + let tokens = tokenize(" "); + assert!(tokens.is_empty()); + } + + #[test] + fn test_unclosed_single_quote() { + // Should not panic, treat remaining as part of arg + let tokens = tokenize("'unclosed"); + assert!(!tokens.is_empty()); + } + + #[test] + fn test_unclosed_double_quote() { + // Should not panic, treat remaining as part of arg + let tokens = tokenize("\"unclosed"); + assert!(!tokens.is_empty()); + } + + #[test] + fn test_unicode_preservation() { + let tokens = tokenize("echo \"héllo wörld\""); + assert!(tokens.iter().any(|t| t.value.contains("héllo"))); + } + + #[test] + fn test_multiple_spaces() { + let tokens = tokenize("git status"); + assert_eq!(tokens.len(), 2); + } + + #[test] + fn test_leading_trailing_spaces() { + let tokens = tokenize(" git status "); + assert_eq!(tokens.len(), 2); + } + + // === OPERATOR TESTS === + + #[test] + fn test_and_operator() { + let tokens = tokenize("cmd1 && cmd2"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == "&&")); + } + + #[test] + fn test_or_operator() { + let tokens = tokenize("cmd1 || cmd2"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == "||")); + } + + #[test] + fn test_semicolon() { + let tokens = tokenize("cmd1 ; cmd2"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Operator) && t.value == ";")); + } + + #[test] + fn test_multiple_and() { + let tokens = tokenize("a && b && c"); + let ops: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Operator)) + .collect(); + assert_eq!(ops.len(), 2); + } + + #[test] + fn test_mixed_operators() { + let tokens = tokenize("a && b || c"); + let ops: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Operator)) + .collect(); + assert_eq!(ops.len(), 2); + } + + #[test] + fn test_operator_at_start() { + let tokens = tokenize("&& cmd"); + // Should still parse, just with operator first + assert!(tokens.iter().any(|t| t.value == "&&")); + } + + #[test] + fn test_operator_at_end() { + let tokens = tokenize("cmd &&"); + assert!(tokens.iter().any(|t| t.value == "&&")); + } + + // === PIPE TESTS === + + #[test] + fn test_pipe_detection() { + let tokens = tokenize("cat file | grep pattern"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Pipe))); + } + + #[test] + fn test_quoted_pipe_not_pipe() { + let tokens = tokenize("\"a|b\""); + // Pipe inside quotes is not a Pipe token + assert!(!tokens.iter().any(|t| matches!(t.kind, TokenKind::Pipe))); + } + + #[test] + fn test_multiple_pipes() { + let tokens = tokenize("a | b | c"); + let pipes: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Pipe)) + .collect(); + assert_eq!(pipes.len(), 2); + } + + // === SHELLISM TESTS === + + #[test] + fn test_glob_detection() { + let tokens = tokenize("ls *.rs"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_quoted_glob_not_shellism() { + let tokens = tokenize("echo \"*.txt\""); + // Glob inside quotes is not a Shellism token + assert!(!tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_variable_detection() { + let tokens = tokenize("echo $HOME"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_quoted_variable_not_shellism() { + let tokens = tokenize("echo \"$HOME\""); + // $ inside double quotes is NOT detected as a Shellism token + // because the lexer respects quotes + // This is correct - the variable can't be expanded by us anyway + // so the whole command will need to passthrough to shell + // But at the tokenization level, it's not a Shellism + assert!(!tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_backtick_substitution() { + let tokens = tokenize("echo `date`"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_subshell_detection() { + let tokens = tokenize("echo $(date)"); + // Both $ and ( should be shellisms + let shellisms: Vec<_> = tokens + .iter() + .filter(|t| matches!(t.kind, TokenKind::Shellism)) + .collect(); + assert!(!shellisms.is_empty()); + } + + #[test] + fn test_brace_expansion() { + let tokens = tokenize("echo {a,b}.txt"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Shellism))); + } + + #[test] + fn test_escaped_glob() { + let tokens = tokenize("echo \\*.txt"); + // Escaped glob should not be a shellism + assert!(!tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Shellism) && t.value == "*")); + } + + // === REDIRECT TESTS === + + #[test] + fn test_redirect_out() { + let tokens = tokenize("cmd > file"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Redirect))); + } + + #[test] + fn test_redirect_append() { + let tokens = tokenize("cmd >> file"); + assert!(tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Redirect) && t.value == ">>")); + } + + #[test] + fn test_redirect_in() { + let tokens = tokenize("cmd < file"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Redirect))); + } + + #[test] + fn test_redirect_stderr() { + let tokens = tokenize("cmd 2> file"); + assert!(tokens.iter().any(|t| matches!(t.kind, TokenKind::Redirect))); + } + + // === EXCLAMATION / NEGATION TESTS === + + #[test] + fn test_exclamation_is_shellism() { + let tokens = tokenize("if ! grep -q pattern file; then echo missing; fi"); + assert!( + tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Shellism) && t.value == "!"), + "! (negation) must be Shellism" + ); + } + + // === BACKGROUND JOB TESTS === + + #[test] + fn test_background_job_is_shellism() { + let tokens = tokenize("sleep 10 &"); + assert!( + tokens + .iter() + .any(|t| matches!(t.kind, TokenKind::Shellism) && t.value == "&"), + "Single & (background job) must be Shellism, not Redirect" + ); + } +} diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs new file mode 100644 index 0000000..e371cc5 --- /dev/null +++ b/src/cmd/mod.rs @@ -0,0 +1,33 @@ +//! Command execution subsystem for RTK hook integration. +//! +//! This module provides the core hook engine that powers `rtk hook claude`. +//! It handles chained command rewriting, native command execution, and output filtering. + +// Analysis and lexing (no external deps) +pub(crate) mod analysis; +pub(crate) mod lexer; + +// Predicates and utilities (no external deps) +pub(crate) mod predicates; + +// Builtins (depends on predicates) +pub(crate) mod builtins; + +// Filters (depends on crate::utils) +pub(crate) mod filters; + +// Exec (depends on analysis, builtins, filters, lexer) +pub mod exec; + +// Hook logic (depends on analysis, lexer) +pub mod hook; + +// Claude hook protocol (depends on hook) +pub mod claude_hook; + +#[cfg(test)] +pub(crate) mod test_helpers; + +// Public exports +pub use exec::execute; +pub use hook::check_for_hook; diff --git a/src/cmd/predicates.rs b/src/cmd/predicates.rs new file mode 100644 index 0000000..9bd5a6a --- /dev/null +++ b/src/cmd/predicates.rs @@ -0,0 +1,94 @@ +//! Context-aware predicates for conditional safety rules. +//! These give RTK "situational awareness" - checking git state, file existence, etc. + +use std::process::Command; + +/// Check if there are unstaged changes in the current git repo +pub(crate) fn has_unstaged_changes() -> bool { + Command::new("git") + .args(["diff", "--quiet"]) + .status() + .map(|s| !s.success()) // git diff --quiet returns 1 if changes exist + .unwrap_or(false) +} + +/// Critical for token reduction: detect if output goes to human or agent +pub(crate) fn is_interactive() -> bool { + use std::io::IsTerminal; + std::io::stderr().is_terminal() +} + +/// Expand ~ to $HOME, with fallback +pub(crate) fn expand_tilde(path: &str) -> String { + if path.starts_with("~") { + // Try HOME first, then USERPROFILE (Windows) + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "/".to_string()); + path.replacen("~", &home, 1) + } else { + path.to_string() + } +} + +/// Get HOME directory with fallback +pub(crate) fn get_home() -> String { + std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "/".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + // === PATH EXPANSION TESTS === + + #[test] + fn test_expand_tilde_simple() { + let home = env::var("HOME").unwrap_or("/".to_string()); + assert_eq!(expand_tilde("~/src"), format!("{}/src", home)); + } + + #[test] + fn test_expand_tilde_no_tilde() { + assert_eq!(expand_tilde("/absolute/path"), "/absolute/path"); + } + + #[test] + fn test_expand_tilde_only_tilde() { + let home = env::var("HOME").unwrap_or("/".to_string()); + assert_eq!(expand_tilde("~"), home); + } + + #[test] + fn test_expand_tilde_relative() { + assert_eq!(expand_tilde("relative/path"), "relative/path"); + } + + // === HOME DIRECTORY TESTS === + + #[test] + fn test_get_home_returns_something() { + let home = get_home(); + assert!(!home.is_empty()); + } + + // === INTERACTIVE TESTS === + + #[test] + fn test_is_interactive() { + // This will be false when running tests + // Just ensure it doesn't panic + let _ = is_interactive(); + } + + // === GIT PREDICATE TESTS === + + #[test] + fn test_has_unstaged_changes() { + // Just ensure it doesn't panic + let _ = has_unstaged_changes(); + } +} diff --git a/src/cmd/test_helpers.rs b/src/cmd/test_helpers.rs new file mode 100644 index 0000000..06f929a --- /dev/null +++ b/src/cmd/test_helpers.rs @@ -0,0 +1,35 @@ +//! Shared test utilities for the cmd module. + +use std::sync::{Mutex, MutexGuard, OnceLock}; + +static ENV_LOCK: OnceLock> = OnceLock::new(); + +/// RAII guard that serializes env-var-mutating tests and auto-cleans on drop. +/// Prevents race conditions between parallel test threads and ensures cleanup +/// even if a test panics. +pub struct EnvGuard { + _lock: MutexGuard<'static, ()>, +} + +impl EnvGuard { + pub fn new() -> Self { + let lock = ENV_LOCK + .get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|e| e.into_inner()); + Self::cleanup(); + Self { _lock: lock } + } + + fn cleanup() { + std::env::remove_var("RTK_SAFE_COMMANDS"); + std::env::remove_var("RTK_BLOCK_TOKEN_WASTE"); + std::env::remove_var("RTK_ACTIVE"); + } +} + +impl Drop for EnvGuard { + fn drop(&mut self) { + Self::cleanup(); + } +} diff --git a/src/init.rs b/src/init.rs index 961e4ac..3255374 100644 --- a/src/init.rs +++ b/src/init.rs @@ -471,9 +471,10 @@ pub fn uninstall(global: bool, verbose: u8) -> Result<()> { fn patch_settings_json(hook_path: &Path, mode: PatchMode, verbose: u8) -> Result { let claude_dir = resolve_claude_dir()?; let settings_path = claude_dir.join("settings.json"); - let hook_command = hook_path - .to_str() - .context("Hook path contains invalid UTF-8")?; + // Use binary command instead of .sh file path for PR 1 v2 + // The rtk hook claude command is a compiled Rust binary + let hook_command = "rtk hook claude"; + let _ = hook_path; // Suppress unused parameter warning (still passed for API compatibility) // Read or create settings.json let mut root = if settings_path.exists() { @@ -1134,15 +1135,10 @@ mod tests { #[test] fn test_hook_has_guards() { - assert!(REWRITE_HOOK.contains("command -v rtk")); - assert!(REWRITE_HOOK.contains("command -v jq")); - // Guards must be BEFORE set -euo pipefail - let guard_pos = REWRITE_HOOK.find("command -v rtk").unwrap(); - let set_pos = REWRITE_HOOK.find("set -euo pipefail").unwrap(); - assert!( - guard_pos < set_pos, - "Guards must come before set -euo pipefail" - ); + // PR 1 v2: Replaced bash hook with binary shim + // Old test checked for "command -v rtk" and "command -v jq" guards + // New shim just execs rtk hook claude binary (no bash guards needed) + assert!(REWRITE_HOOK.contains("exec rtk hook claude")); } #[test] diff --git a/src/main.rs b/src/main.rs index 5bec4da..da7affa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod cargo_cmd; mod cc_economics; mod ccusage; +mod cmd; mod config; mod container; mod curl_cmd; @@ -526,6 +527,34 @@ enum Commands { #[arg(short, long, default_value = "7")] since: u64, }, + + /// Run command with safety checks and token-optimized output + Run { + /// Command string to execute + #[arg(short = 'c', long)] + command: String, + }, + + /// Hook protocol for Claude Code integration + Hook { + #[command(subcommand)] + command: HookCommands, + }, +} + +#[derive(Subcommand)] +enum HookCommands { + /// Check command for safety and rewrite (text protocol for debugging) + Check { + /// Agent type: claude or gemini + #[arg(long, default_value = "claude")] + agent: String, + /// Command to check + #[arg(trailing_var_arg = true, allow_hyphen_values = true)] + command: Vec, + }, + /// Claude Code JSON protocol handler (reads stdin, writes stdout) + Claude, } #[derive(Subcommand)] @@ -1372,6 +1401,30 @@ fn main() -> Result<()> { hook_audit_cmd::run(since, cli.verbose)?; } + Commands::Run { command } => { + let success = cmd::execute(&command, cli.verbose)?; + if !success { + std::process::exit(1); + } + } + + Commands::Hook { command } => match command { + HookCommands::Check { agent, command } => { + let cmd_str = command.join(" "); + let result = cmd::check_for_hook(&cmd_str, &agent); + let (output, _success, code) = cmd::hook::format_for_claude(result); + if code == 0 { + println!("{}", output); + } else { + eprintln!("{}", output); + std::process::exit(code); + } + } + HookCommands::Claude => { + cmd::claude_hook::run()?; + } + }, + Commands::Proxy { args } => { use std::process::Command; From 1f769f0962bde50271d4c7ba82d7551f639589d1 Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Wed, 18 Feb 2026 23:17:44 -0500 Subject: [PATCH 2/5] fix(hook): add lexer integration and route single commands to rtk subcommands Replaces stub check_for_hook_inner with full tokenize+native-path dispatch. Adds route_native_command() with replace_first_word/route_pnpm/route_npx helpers to route single parsed commands to optimized RTK subcommands. Chains (&&/||/;) and shellisms still use rtk run -c. No safety integration (PR #157 adds that). Mirrors ~/.claude/hooks/rtk-rewrite.sh routing table. Corrects shell script vitest double-run bug for pnpm vitest run flags. --- src/cmd/hook.rs | 329 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 315 insertions(+), 14 deletions(-) diff --git a/src/cmd/hook.rs b/src/cmd/hook.rs index 2d8c0f8..53bc973 100644 --- a/src/cmd/hook.rs +++ b/src/cmd/hook.rs @@ -68,7 +68,25 @@ fn check_for_hook_inner(raw: &str, depth: usize) -> HookResult { } // PR 2 adds: crate::config::rules::try_remap() alias expansion // PR 2 adds: safety::check_raw() and safety::check() dispatch - HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))) + + let tokens = lexer::tokenize(raw); + + if analysis::needs_shell(&tokens) { + return HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))); + } + + match analysis::parse_chain(tokens) { + Ok(commands) => { + // Single command: route to optimized RTK subcommand. + // Chained commands (&&, ||, ;): wrap entire chain in rtk run -c. + if commands.len() == 1 { + HookResult::Rewrite(route_native_command(&commands[0], raw)) + } else { + HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))) + } + } + Err(_) => HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))), + } } // --- Shared guard logic (used by both claude_hook.rs and gemini_hook.rs) --- @@ -136,6 +154,175 @@ fn escape_quotes(s: &str) -> String { s.replace("'", "'\\''") } +/// Replace the first occurrence of `old_prefix` in `raw` with `new_prefix`. +/// +/// Preserves everything after the prefix (including original quoting). +/// Falls back to `rtk run -c ''` if prefix not found (safe degradation). +/// +/// # Examples +/// - `replace_first_word("grep -r p src/", "grep", "rtk grep")` → `"rtk grep -r p src/"` +/// - `replace_first_word("rg pattern", "rg", "rtk grep")` → `"rtk grep pattern"` +fn replace_first_word(raw: &str, old_prefix: &str, new_prefix: &str) -> String { + raw.strip_prefix(old_prefix) + .map(|rest| format!("{new_prefix}{rest}")) + .unwrap_or_else(|| format!("rtk run -c '{}'", escape_quotes(raw))) +} + +/// Route pnpm subcommands to RTK equivalents. +/// +/// Uses `cmd.args` (parsed, quote-stripped) for routing decisions. +/// Uses `raw` or reconstructed args for output to preserve original quoting. +fn route_pnpm(cmd: &analysis::NativeCommand, raw: &str) -> String { + let sub = cmd.args.first().map(String::as_str).unwrap_or(""); + match sub { + "list" | "ls" | "outdated" | "install" => format!("rtk {raw}"), + + // pnpm vitest [run] [flags] → rtk vitest run [flags] + // Shell script sed bug: 's/^(pnpm )?vitest/rtk vitest run/' on + // "pnpm vitest run --coverage" produces "rtk vitest run run --coverage". + // Binary hook corrects this by stripping the leading "run" from parsed args. + "vitest" => { + let after_vitest: Vec<&str> = cmd.args[1..] + .iter() + .map(String::as_str) + .skip_while(|&a| a == "run") + .collect(); + if after_vitest.is_empty() { + "rtk vitest run".to_string() + } else { + format!("rtk vitest run {}", after_vitest.join(" ")) + } + } + + // pnpm test [flags] → rtk vitest run [flags] + "test" => { + let after_test: Vec<&str> = cmd.args[1..].iter().map(String::as_str).collect(); + if after_test.is_empty() { + "rtk vitest run".to_string() + } else { + format!("rtk vitest run {}", after_test.join(" ")) + } + } + + "tsc" => replace_first_word(raw, "pnpm tsc", "rtk tsc"), + "lint" => replace_first_word(raw, "pnpm lint", "rtk lint"), + "playwright" => replace_first_word(raw, "pnpm playwright", "rtk playwright"), + + _ => format!("rtk run -c '{}'", escape_quotes(raw)), + } +} + +/// Route npx subcommands to RTK equivalents. +fn route_npx(cmd: &analysis::NativeCommand, raw: &str) -> String { + let sub = cmd.args.first().map(String::as_str).unwrap_or(""); + match sub { + "tsc" | "typescript" => replace_first_word(raw, &format!("npx {sub}"), "rtk tsc"), + "eslint" => replace_first_word(raw, "npx eslint", "rtk lint"), + "prettier" => replace_first_word(raw, "npx prettier", "rtk prettier"), + "playwright" => replace_first_word(raw, "npx playwright", "rtk playwright"), + "prisma" => replace_first_word(raw, "npx prisma", "rtk prisma"), + _ => format!("rtk run -c '{}'", escape_quotes(raw)), + } +} + +/// Route a single parsed native command to its optimized RTK subcommand. +/// +/// ## Design +/// - Uses `cmd.binary`/`cmd.args` (lexer→parse_chain output) for routing DECISIONS. +/// - Uses `raw: &str` with `replace_first_word` for string REPLACEMENT (preserves quoting). +/// - `format!("rtk {raw}")` works when the binary name equals the RTK subcommand. +/// - `replace_first_word` handles renames: `rg → rtk grep`, `cat → rtk read`. +/// +/// ## Fallback +/// Unknown binaries or unrecognized subcommands → `rtk run -c ''` (safe passthrough). +/// +/// ## Mirrors +/// `~/.claude/hooks/rtk-rewrite.sh` routing table. Corrects the shell script's +/// `vitest run` double-"run" bug by using parsed args rather than regex substitution. +/// +/// ## Safety interaction +/// PR 2 adds safety::check before this function. The `cat` arm is defensive for +/// when `RTK_BLOCK_TOKEN_WASTE=0`. +fn route_native_command(cmd: &analysis::NativeCommand, raw: &str) -> String { + let sub = cmd.args.first().map(String::as_str).unwrap_or(""); + let sub2 = cmd.args.get(1).map(String::as_str).unwrap_or(""); + + match cmd.binary.as_str() { + // Git: known subcommands (global options like --no-pager fall through to fallback) + "git" + if matches!( + sub, + "status" + | "diff" + | "log" + | "add" + | "commit" + | "push" + | "pull" + | "branch" + | "fetch" + | "stash" + | "show" + ) => + { + format!("rtk {raw}") + } + + // GitHub CLI + "gh" if matches!(sub, "pr" | "issue" | "run") => format!("rtk {raw}"), + + // Cargo: test/build/clippy/check have rtk equivalents + "cargo" if matches!(sub, "test" | "build" | "clippy" | "check") => format!("rtk {raw}"), + + // File ops — renames (rg/grep → rtk grep, cat → rtk read) + // NOTE: PR 2 adds safety rules that block cat/head/sed before reaching here. + // These arms are defensive for if RTK_BLOCK_TOKEN_WASTE=0. + "cat" => replace_first_word(raw, "cat", "rtk read"), + "grep" | "rg" => replace_first_word(raw, cmd.binary.as_str(), "rtk grep"), + "eslint" => replace_first_word(raw, "eslint", "rtk lint"), + + // Direct prepend: rtk subcommand name = binary name + "ls" | "tsc" | "prettier" | "playwright" | "prisma" | "curl" | "pytest" + | "golangci-lint" => format!("rtk {raw}"), + + // tail: may be blocked by safety (PR 2); defensive routing if allowed + "tail" => format!("rtk {raw}"), + + // vitest: bare vitest → rtk vitest run (not rtk vitest) + "vitest" if sub.is_empty() => "rtk vitest run".to_string(), + "vitest" => format!("rtk {raw}"), + + // Containers: info-read subcommands only + "docker" if matches!(sub, "ps" | "images" | "logs") => format!("rtk {raw}"), + "kubectl" if matches!(sub, "get" | "logs") => format!("rtk {raw}"), + + // Go + "go" if matches!(sub, "test" | "build" | "vet") => format!("rtk {raw}"), + + // Ruff: check/format only + "ruff" if matches!(sub, "check" | "format") => format!("rtk {raw}"), + + // pip/uv: list/outdated/install/show only + "pip" if matches!(sub, "list" | "outdated" | "install" | "show") => format!("rtk {raw}"), + "uv" if sub == "pip" && matches!(sub2, "list" | "outdated" | "install" | "show") => { + replace_first_word(raw, "uv pip", "rtk pip") + } + + // python/python3 -m pytest + "python" | "python3" if sub == "-m" && sub2 == "pytest" => { + let prefix = format!("{} -m pytest", cmd.binary); + replace_first_word(raw, &prefix, "rtk pytest") + } + + // pnpm / npx: delegated to helpers (complex sub-routing) + "pnpm" => route_pnpm(cmd, raw), + "npx" => route_npx(cmd, raw), + + // Fallback: unknown binary or unrecognized subcommand + _ => format!("rtk run -c '{}'", escape_quotes(raw)), + } +} + /// Format hook result for Claude (text output) /// /// Exit codes: @@ -208,15 +395,15 @@ mod tests { #[test] fn test_safe_commands_rewrite() { let cases = [ - ("git status", "rtk run"), - ("ls *.rs", "rtk run"), // shellism passthrough - (r#"git commit -m "Fix && Bug""#, "rtk run"), // quoted operator - ("FOO=bar echo hello", "rtk run"), // env prefix - ("echo `date`", "rtk run"), // backticks - ("echo $(date)", "rtk run"), // subshell - ("echo {a,b}.txt", "rtk run"), // brace expansion + ("git status", "rtk git status"), // now routes to optimized subcommand + ("ls *.rs", "rtk run"), // shellism passthrough (glob) + (r#"git commit -m "Fix && Bug""#, "rtk git commit"), // quoted &&: single cmd, routes + ("FOO=bar echo hello", "rtk run"), // env prefix → shellism + ("echo `date`", "rtk run"), // backticks + ("echo $(date)", "rtk run"), // subshell + ("echo {a,b}.txt", "rtk run"), // brace expansion ("echo 'hello!@#$%^&*()'", "rtk run"), // special chars - ("echo '日本語 🎉'", "rtk run"), // unicode + ("echo '日本語 🎉'", "rtk run"), // unicode ("cd /tmp && git status", "rtk run"), // chain rewrite ]; for (input, expected) in cases { @@ -332,7 +519,13 @@ mod tests { "rg pattern src/", ]; for input in cases { - assert_rewrite(input, "rtk run"); + // Test name intent: commands must Rewrite (not Blocked), regardless of routing target. + // Specific routing targets are verified in test_routing_native_commands. + assert!( + matches!(check_for_hook(input, "claude"), HookResult::Rewrite(_)), + "'{}' should Rewrite (not Blocked)", + input + ); } } @@ -391,11 +584,16 @@ mod tests { #[test] fn test_compound_quoted_operators_not_split() { - // && inside quotes must NOT split the command + // && inside quotes must NOT split the command into a chain. + // parse_chain sees one command: git commit with args ["-m", "Fix && Bug"]. + // That single command routes to rtk git commit (not rtk run -c). let input = r#"git commit -m "Fix && Bug""#; match check_for_hook(input, "claude") { HookResult::Rewrite(cmd) => { - assert!(cmd.contains("rtk run"), "Should rewrite, got '{cmd}'"); + assert!( + cmd.contains("rtk git commit"), + "Quoted && must not split; should route to rtk git commit, got '{cmd}'" + ); } other => panic!("Expected Rewrite for quoted &&, got {other:?}"), } @@ -423,10 +621,12 @@ mod tests { #[test] fn test_different_agents_same_result() { + // Both agents must Rewrite (not Block) safe commands. + // Specific routing targets verified in test_cross_agent_routing_identical. for agent in ["claude", "gemini"] { match check_for_hook("git status", agent) { - HookResult::Rewrite(cmd) => assert!(cmd.contains("rtk run")), - _ => panic!("Expected Rewrite for agent '{}'", agent), + HookResult::Rewrite(_) => {} + other => panic!("Expected Rewrite for agent '{}', got {:?}", agent, other), } } } @@ -566,4 +766,105 @@ mod tests { } // PR 2 adds: test_cross_protocol_blocked_command_denied_by_both (safety-dependent test) + + // ===================================================================== + // ROUTING TESTS — verify route_native_command dispatch + // ===================================================================== + + #[test] + fn test_routing_native_commands() { + // Table-driven: commands that route to optimized rtk subcommands. + // Each (input, expected_substr) must appear in the rewritten output. + let cases = [ + // Git: known subcommands + ("git status", "rtk git status"), + ("git log --oneline -10", "rtk git log --oneline -10"), + ("git diff HEAD", "rtk git diff HEAD"), + ("git add .", "rtk git add ."), + ("git commit -m msg", "rtk git commit"), + // GitHub CLI + ("gh pr view 156", "rtk gh pr view 156"), + // Cargo + ("cargo test", "rtk cargo test"), + ( + "cargo clippy --all-targets", + "rtk cargo clippy --all-targets", + ), + // File ops (rg → rtk grep rename) + // NOTE: PR 2 adds safety that blocks cat before reaching router; arm is defensive. + ("grep -r pattern src/", "rtk grep -r pattern src/"), + ("rg pattern src/", "rtk grep pattern src/"), + ("ls -la", "rtk ls -la"), + ("tail -n 20 file.txt", "rtk tail -n 20 file.txt"), + // JS/TS tooling + ("vitest", "rtk vitest run"), // bare → rtk vitest run + ("vitest run", "rtk vitest run"), // explicit run preserved + ("vitest run --coverage", "rtk vitest run --coverage"), + ("pnpm test", "rtk vitest run"), + ("pnpm vitest", "rtk vitest run"), + ("pnpm lint", "rtk lint"), + ("npx tsc --noEmit", "rtk tsc --noEmit"), + // Python + ("python -m pytest tests/", "rtk pytest tests/"), + ("uv pip list", "rtk pip list"), + // Go + ("go test ./...", "rtk go test ./..."), + ]; + for (input, expected) in cases { + assert_rewrite(input, expected); + } + } + + #[test] + fn test_routing_vitest_no_double_run() { + // Shell script sed bug: 's/^(pnpm )?vitest/rtk vitest run/' on + // "pnpm vitest run --coverage" produces "rtk vitest run run --coverage". + // Binary hook corrects this by using parsed args instead of regex substitution. + let result = match check_for_hook("pnpm vitest run --coverage", "claude") { + HookResult::Rewrite(cmd) => cmd, + other => panic!("Expected Rewrite, got {:?}", other), + }; + assert_rewrite("pnpm vitest run --coverage", "rtk vitest run --coverage"); + assert!( + !result.contains("run run"), + "Must not double 'run' in output: '{}'", + result + ); + } + + #[test] + fn test_routing_fallbacks_to_rtk_run() { + // Unknown subcommand, chains (2+ cmds), and pipes fall back to rtk run -c. + let cases = [ + "git checkout main", // unknown git subcommand + "git add . && git commit -m msg", // chain → 2 commands → rtk run -c + "git log | grep fix", // pipe → needs_shell → rtk run -c + ]; + for input in cases { + assert_rewrite(input, "rtk run -c"); + } + } + + #[test] + fn test_cross_agent_routing_identical() { + // Both claude and gemini must route the same commands to the same output. + for cmd in ["git status", "cargo test", "ls -la"] { + let claude_result = check_for_hook(cmd, "claude"); + let gemini_result = check_for_hook(cmd, "gemini"); + match (&claude_result, &gemini_result) { + (HookResult::Rewrite(c), HookResult::Rewrite(g)) => { + assert_eq!(c, g, "claude and gemini must route '{}' identically", cmd); + assert!( + !c.contains("rtk run -c"), + "'{}' should not fall back to rtk run -c", + cmd + ); + } + _ => panic!( + "'{}' should Rewrite for both agents: claude={:?} gemini={:?}", + cmd, claude_result, gemini_result + ), + } + } + } } From a74be59f278ceb1ba7a77502341b8d82f4851470 Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Thu, 19 Feb 2026 05:07:33 -0500 Subject: [PATCH 3/5] registry.rs,hook.rs: remove no-op route, fix fallback tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit rtk has no `tail` subcommand — routing to "rtk tail" was silently broken (rtk would error "unrecognized subcommand"). Remove the Route entry so the command falls through to `rtk run -c '...'` correctly. Move the log-tailing test cases from test_routing_native_commands (which asserted the broken path) into test_routing_fallbacks_to_rtk_run where they correctly verify the rtk-run-c fallback behavior. --- src/cmd/hook.rs | 74 ++++--------- src/discover/registry.rs | 234 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 255 insertions(+), 53 deletions(-) diff --git a/src/cmd/hook.rs b/src/cmd/hook.rs index 53bc973..10aaed0 100644 --- a/src/cmd/hook.rs +++ b/src/cmd/hook.rs @@ -247,68 +247,36 @@ fn route_native_command(cmd: &analysis::NativeCommand, raw: &str) -> String { let sub = cmd.args.first().map(String::as_str).unwrap_or(""); let sub2 = cmd.args.get(1).map(String::as_str).unwrap_or(""); - match cmd.binary.as_str() { - // Git: known subcommands (global options like --no-pager fall through to fallback) - "git" - if matches!( - sub, - "status" - | "diff" - | "log" - | "add" - | "commit" - | "push" - | "pull" - | "branch" - | "fetch" - | "stash" - | "show" - ) => - { + // 1. Static routing table: O(1) lookup via HashMap (built once at startup). + // Covers all simple cases: direct routes and renames (rg→grep, eslint→lint). + if let Some(route) = crate::discover::registry::lookup(&cmd.binary, sub) { + return if route.rtk_cmd == cmd.binary.as_str() { + // Direct route (binary name == rtk subcommand): prepend "rtk " format!("rtk {raw}") - } - - // GitHub CLI - "gh" if matches!(sub, "pr" | "issue" | "run") => format!("rtk {raw}"), - - // Cargo: test/build/clippy/check have rtk equivalents - "cargo" if matches!(sub, "test" | "build" | "clippy" | "check") => format!("rtk {raw}"), - - // File ops — renames (rg/grep → rtk grep, cat → rtk read) - // NOTE: PR 2 adds safety rules that block cat/head/sed before reaching here. - // These arms are defensive for if RTK_BLOCK_TOKEN_WASTE=0. - "cat" => replace_first_word(raw, "cat", "rtk read"), - "grep" | "rg" => replace_first_word(raw, cmd.binary.as_str(), "rtk grep"), - "eslint" => replace_first_word(raw, "eslint", "rtk lint"), + } else { + // Rename route (rg → grep, eslint → lint): replace binary prefix + replace_first_word(raw, &cmd.binary, &format!("rtk {}", route.rtk_cmd)) + }; + } - // Direct prepend: rtk subcommand name = binary name - "ls" | "tsc" | "prettier" | "playwright" | "prisma" | "curl" | "pytest" - | "golangci-lint" => format!("rtk {raw}"), + // 2. Complex cases that require Rust logic and cannot be expressed as table entries. - // tail: may be blocked by safety (PR 2); defensive routing if allowed - "tail" => format!("rtk {raw}"), + // cat: blocked by safety rules before reaching here; defensive for RTK_BLOCK_TOKEN_WASTE=0 + if cmd.binary == "cat" { + return replace_first_word(raw, "cat", "rtk read"); + } - // vitest: bare vitest → rtk vitest run (not rtk vitest) + match cmd.binary.as_str() { + // vitest: bare invocation → rtk vitest run (not rtk vitest) "vitest" if sub.is_empty() => "rtk vitest run".to_string(), "vitest" => format!("rtk {raw}"), - // Containers: info-read subcommands only - "docker" if matches!(sub, "ps" | "images" | "logs") => format!("rtk {raw}"), - "kubectl" if matches!(sub, "get" | "logs") => format!("rtk {raw}"), - - // Go - "go" if matches!(sub, "test" | "build" | "vet") => format!("rtk {raw}"), - - // Ruff: check/format only - "ruff" if matches!(sub, "check" | "format") => format!("rtk {raw}"), - - // pip/uv: list/outdated/install/show only - "pip" if matches!(sub, "list" | "outdated" | "install" | "show") => format!("rtk {raw}"), + // uv pip: two-word prefix replacement "uv" if sub == "pip" && matches!(sub2, "list" | "outdated" | "install" | "show") => { replace_first_word(raw, "uv pip", "rtk pip") } - // python/python3 -m pytest + // python/python3 -m pytest: two-arg prefix replacement "python" | "python3" if sub == "-m" && sub2 == "pytest" => { let prefix = format!("{} -m pytest", cmd.binary); replace_first_word(raw, &prefix, "rtk pytest") @@ -322,7 +290,6 @@ fn route_native_command(cmd: &analysis::NativeCommand, raw: &str) -> String { _ => format!("rtk run -c '{}'", escape_quotes(raw)), } } - /// Format hook result for Claude (text output) /// /// Exit codes: @@ -795,7 +762,6 @@ mod tests { ("grep -r pattern src/", "rtk grep -r pattern src/"), ("rg pattern src/", "rtk grep pattern src/"), ("ls -la", "rtk ls -la"), - ("tail -n 20 file.txt", "rtk tail -n 20 file.txt"), // JS/TS tooling ("vitest", "rtk vitest run"), // bare → rtk vitest run ("vitest run", "rtk vitest run"), // explicit run preserved @@ -839,6 +805,8 @@ mod tests { "git checkout main", // unknown git subcommand "git add . && git commit -m msg", // chain → 2 commands → rtk run -c "git log | grep fix", // pipe → needs_shell → rtk run -c + "tail -n 20 file.txt", // no rtk tail subcommand + "tail -f server.log", // no rtk tail subcommand ]; for input in cases { assert_rewrite(input, "rtk run -c"); diff --git a/src/discover/registry.rs b/src/discover/registry.rs index 7ef375c..dd45cc2 100644 --- a/src/discover/registry.rs +++ b/src/discover/registry.rs @@ -1,5 +1,197 @@ use lazy_static::lazy_static; use regex::{Regex, RegexSet}; +use std::collections::HashMap; +use std::sync::OnceLock; + +// --------------------------------------------------------------------------- +// Hook routing table — used by `cmd::hook` for O(1) command rewriting. +// This is the single source of truth for which external binaries route through +// RTK and exactly which subcommands are covered. +// +// # Adding a new command +// 1. Add one `Route` entry to `ROUTES`. +// 2. Add a discover entry (PATTERNS + RULES) below if needed. +// 3. Done — hook routing is automatic. +// --------------------------------------------------------------------------- + +/// Subcommand filter for a route entry. +#[derive(Debug, Clone, Copy)] +pub enum Subcmds { + /// Route ALL subcommands of this binary (e.g., ls, curl, prettier). + Any, + /// Route ONLY these specific subcommands; others fall through to `rtk run -c`. + Only(&'static [&'static str]), +} + +/// One row in the static routing table. +/// +/// - `binaries`: one or more external binary names mapping to the same RTK subcommand. +/// - `subcmds`: subcommand filter — `Any` matches everything, `Only` restricts to a list. +/// - `rtk_cmd`: the RTK subcommand name (e.g., `"grep"`, `"lint"`, `"git"`). +/// +/// For direct routes where `binary == rtk_cmd`, the hook uses `format!("rtk {raw}")`. +/// For renames (`rg` → `grep`, `eslint` → `lint`), it uses `replace_first_word`. +#[derive(Debug, Clone, Copy)] +pub struct Route { + pub binaries: &'static [&'static str], + pub subcmds: Subcmds, + pub rtk_cmd: &'static str, +} + +/// Static routing table. Single source of truth for hook routing. +/// +/// Order does not matter — lookups use a HashMap built once at startup (O(1) per call). +/// +/// Complex cases (vitest bare invocation, `uv pip`, `python -m pytest`, pnpm, npx) +/// require Rust logic and stay as match arms in `cmd::hook::route_native_command`. +pub const ROUTES: &[Route] = &[ + // Version control + Route { + binaries: &["git"], + subcmds: Subcmds::Only(&[ + "status", "diff", "log", "add", "commit", "push", "pull", "branch", "fetch", "stash", + "show", + ]), + rtk_cmd: "git", + }, + // GitHub CLI + Route { + binaries: &["gh"], + subcmds: Subcmds::Only(&["pr", "issue", "run"]), + rtk_cmd: "gh", + }, + // Rust build tools + Route { + binaries: &["cargo"], + subcmds: Subcmds::Only(&["test", "build", "clippy", "check"]), + rtk_cmd: "cargo", + }, + // Search — two binaries, one RTK subcommand (rename) + Route { + binaries: &["rg", "grep"], + subcmds: Subcmds::Any, + rtk_cmd: "grep", + }, + // JavaScript linting — rename + Route { + binaries: &["eslint"], + subcmds: Subcmds::Any, + rtk_cmd: "lint", + }, + // File system + Route { + binaries: &["ls"], + subcmds: Subcmds::Any, + rtk_cmd: "ls", + }, + // TypeScript compiler + Route { + binaries: &["tsc"], + subcmds: Subcmds::Any, + rtk_cmd: "tsc", + }, + // JavaScript formatting + Route { + binaries: &["prettier"], + subcmds: Subcmds::Any, + rtk_cmd: "prettier", + }, + // E2E testing + Route { + binaries: &["playwright"], + subcmds: Subcmds::Any, + rtk_cmd: "playwright", + }, + // Database ORM + Route { + binaries: &["prisma"], + subcmds: Subcmds::Any, + rtk_cmd: "prisma", + }, + // Network + Route { + binaries: &["curl"], + subcmds: Subcmds::Any, + rtk_cmd: "curl", + }, + // Python testing + Route { + binaries: &["pytest"], + subcmds: Subcmds::Any, + rtk_cmd: "pytest", + }, + // Go linting + Route { + binaries: &["golangci-lint"], + subcmds: Subcmds::Any, + rtk_cmd: "golangci-lint", + }, + // Containers — read-only subcommands only + Route { + binaries: &["docker"], + subcmds: Subcmds::Only(&["ps", "images", "logs"]), + rtk_cmd: "docker", + }, + // Kubernetes — read-only subcommands only + Route { + binaries: &["kubectl"], + subcmds: Subcmds::Only(&["get", "logs"]), + rtk_cmd: "kubectl", + }, + // Go build tools + Route { + binaries: &["go"], + subcmds: Subcmds::Only(&["test", "build", "vet"]), + rtk_cmd: "go", + }, + // Python linting/formatting + Route { + binaries: &["ruff"], + subcmds: Subcmds::Only(&["check", "format"]), + rtk_cmd: "ruff", + }, + // Python package management + Route { + binaries: &["pip"], + subcmds: Subcmds::Only(&["list", "outdated", "install", "show"]), + rtk_cmd: "pip", + }, +]; + +/// Look up the routing entry for a binary + subcommand. +/// +/// Returns `Some(route)` if the binary is in the table AND the subcommand matches +/// the entry's filter. Returns `None` if unrecognised or subcommand not in `Only` list. +/// +/// The HashMap is built once per process (OnceLock). Each binary maps to the index of +/// its `Route` in `ROUTES`. Multiple binaries from the same entry (e.g., `rg`/`grep`) +/// both point to the same index. +pub fn lookup(binary: &str, sub: &str) -> Option<&'static Route> { + static MAP: OnceLock> = OnceLock::new(); + let map = MAP.get_or_init(|| { + let mut m = HashMap::new(); + for (i, route) in ROUTES.iter().enumerate() { + for &bin in route.binaries { + m.entry(bin).or_insert(i); + } + } + m + }); + + let idx = *map.get(binary)?; + let route = &ROUTES[idx]; + + let matches = match route.subcmds { + Subcmds::Any => true, + Subcmds::Only(subs) => subs.contains(&sub), + }; + + if matches { + Some(route) + } else { + None + } +} /// A rule mapping a shell command pattern to its RTK equivalent. struct RtkRule { @@ -70,6 +262,12 @@ const PATTERNS: &[&str] = &[ r"^kubectl\s+(get|logs)", r"^curl\s+", r"^wget\s+", + // Python/Go tooling (added with Python & Go support) + r"^pytest(\s|$)", + r"^go\s+(test|build|vet)(\s|$)", + r"^ruff\s+(check|format)(\s|$)", + r"^(pip|pip3)\s+(list|outdated|install|show)(\s|$)", + r"^golangci-lint(\s|$)", ]; const RULES: &[RtkRule] = &[ @@ -225,6 +423,42 @@ const RULES: &[RtkRule] = &[ subcmd_savings: &[], subcmd_status: &[], }, + // Python/Go tooling (added with Python & Go support) + RtkRule { + rtk_cmd: "rtk pytest", + category: "Tests", + savings_pct: 90.0, + subcmd_savings: &[], + subcmd_status: &[], + }, + RtkRule { + rtk_cmd: "rtk go", + category: "Build", + savings_pct: 85.0, + subcmd_savings: &[("test", 90.0)], + subcmd_status: &[], + }, + RtkRule { + rtk_cmd: "rtk ruff", + category: "Build", + savings_pct: 80.0, + subcmd_savings: &[], + subcmd_status: &[], + }, + RtkRule { + rtk_cmd: "rtk pip", + category: "PackageManager", + savings_pct: 75.0, + subcmd_savings: &[], + subcmd_status: &[], + }, + RtkRule { + rtk_cmd: "rtk golangci-lint", + category: "Build", + savings_pct: 85.0, + subcmd_savings: &[], + subcmd_status: &[], + }, ]; /// Commands to ignore (shell builtins, trivial, already rtk). From f66f7e7351065abf12e41fbe6ac9acfeb8ef7977 Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Thu, 19 Feb 2026 05:13:29 -0500 Subject: [PATCH 4/5] registry.rs,hook.rs: add missing routing test coverage from main branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Port tests added during the ROUTES table integration that were missing from the v2 worktree: registry.rs: - 12 classify tests for Python/Go commands (pytest, go×4, ruff×2, pip×3, golangci-lint) that verify PATTERNS/RULES and ROUTES alignment - 11 lookup tests (test_lookup_*, test_no_duplicate_binaries_in_routes, test_lookup_is_o1_consistent) that verify O(1) HashMap routing hook.rs: - Extend test_routing_native_commands from 20 to 47 cases covering all ROUTES entries: docker, kubectl, curl, eslint, tsc, prettier, playwright, prisma, pytest, golangci-lint, ruff, pip, gh variants - Add test_routing_subcommand_filter_fallback (14 cases) verifying that Only[] subcommand filters correctly reject unmatched subcommands Total: 545 → 569 tests (+24) --- src/cmd/hook.rs | 55 +++++++++++ src/discover/registry.rs | 208 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) diff --git a/src/cmd/hook.rs b/src/cmd/hook.rs index 10aaed0..d4eb5b1 100644 --- a/src/cmd/hook.rs +++ b/src/cmd/hook.rs @@ -775,12 +775,67 @@ mod tests { ("uv pip list", "rtk pip list"), // Go ("go test ./...", "rtk go test ./..."), + ("go build ./...", "rtk go build ./..."), + ("go vet ./...", "rtk go vet ./..."), + // All ROUTES entries not yet covered above + ("eslint src/", "rtk lint src/"), // rename: eslint → lint + ("tsc --noEmit", "rtk tsc --noEmit"), // bare tsc (not npx tsc) + ("prettier src/", "rtk prettier src/"), + ("playwright test", "rtk playwright test"), + ("prisma migrate dev", "rtk prisma migrate dev"), + ( + "curl https://api.example.com", + "rtk curl https://api.example.com", + ), + ("pytest tests/", "rtk pytest tests/"), // bare pytest (not python -m pytest) + ("pytest -x tests/unit", "rtk pytest -x tests/unit"), + ("golangci-lint run ./...", "rtk golangci-lint run ./..."), + ("docker ps", "rtk docker ps"), + ("docker images", "rtk docker images"), + ("docker logs mycontainer", "rtk docker logs mycontainer"), + ("kubectl get pods", "rtk kubectl get pods"), + ("kubectl logs mypod", "rtk kubectl logs mypod"), + ("ruff check src/", "rtk ruff check src/"), + ("ruff format src/", "rtk ruff format src/"), + ("pip list", "rtk pip list"), + ("pip install requests", "rtk pip install requests"), + ("pip outdated", "rtk pip outdated"), + ("pip show requests", "rtk pip show requests"), + ("gh issue list", "rtk gh issue list"), + ("gh run view 123", "rtk gh run view 123"), + ("git stash pop", "rtk git stash pop"), + ("git fetch origin", "rtk git fetch origin"), ]; for (input, expected) in cases { assert_rewrite(input, expected); } } + #[test] + fn test_routing_subcommand_filter_fallback() { + // Commands where binary is in ROUTES but subcommand is NOT in the Only list + // must fall through to `rtk run -c '...'`. + let cases = [ + "docker build .", // docker Only: ps, images, logs + "docker run -it nginx", // docker Only: ps, images, logs + "kubectl apply -f dep.yaml", // kubectl Only: get, logs + "kubectl delete pod mypod", // kubectl Only: get, logs + "go mod tidy", // go Only: test, build, vet + "go generate ./...", // go Only: test, build, vet + "ruff lint src/", // ruff Only: check, format + "pip freeze", // pip Only: list, outdated, install, show + "pip uninstall requests", // pip Only: list, outdated, install, show + "cargo publish", // cargo Only: test, build, clippy, check + "cargo run", // cargo Only: test, build, clippy, check + "git rebase -i HEAD~3", // git Only list (rebase not included) + "git cherry-pick abc123", // git Only list + "gh repo clone foo/bar", // gh Only: pr, issue, run + ]; + for input in cases { + assert_rewrite(input, "rtk run -c"); + } + } + #[test] fn test_routing_vitest_no_double_run() { // Shell script sed bug: 's/^(pnpm )?vitest/rtk vitest run/' on diff --git a/src/discover/registry.rs b/src/discover/registry.rs index dd45cc2..1c22c62 100644 --- a/src/discover/registry.rs +++ b/src/discover/registry.rs @@ -896,6 +896,129 @@ mod tests { ); } + // --- Tests for commands added in Python/Go support (must be in both ROUTES and PATTERNS) --- + + #[test] + fn test_classify_pytest_bare() { + match classify_command("pytest tests/") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk pytest") + } + other => panic!("pytest should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_pytest_flags() { + match classify_command("pytest -x tests/unit") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk pytest") + } + other => panic!("pytest -x should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_go_test() { + match classify_command("go test ./...") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk go") + } + other => panic!("go test should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_go_build() { + match classify_command("go build ./...") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk go") + } + other => panic!("go build should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_go_vet() { + match classify_command("go vet ./...") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk go") + } + other => panic!("go vet should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_go_unsupported_subcommand_not_matched() { + // go mod tidy is not in the Only list; should not be classified as rtk go + match classify_command("go mod tidy") { + Classification::Unsupported { .. } | Classification::Ignored => {} + Classification::Supported { rtk_equivalent, .. } => { + panic!("go mod should not match, but got rtk_equivalent={rtk_equivalent}") + } + } + } + + #[test] + fn test_classify_ruff_check() { + match classify_command("ruff check src/") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk ruff") + } + other => panic!("ruff check should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_ruff_format() { + match classify_command("ruff format src/") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk ruff") + } + other => panic!("ruff format should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_pip_list() { + match classify_command("pip list") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk pip") + } + other => panic!("pip list should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_pip_install() { + match classify_command("pip install requests") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk pip") + } + other => panic!("pip install should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_pip3_list() { + match classify_command("pip3 list") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk pip") + } + other => panic!("pip3 list should be Supported, got {other:?}"), + } + } + + #[test] + fn test_classify_golangci_lint() { + match classify_command("golangci-lint run ./...") { + Classification::Supported { rtk_equivalent, .. } => { + assert_eq!(rtk_equivalent, "rtk golangci-lint") + } + other => panic!("golangci-lint should be Supported, got {other:?}"), + } + } + #[test] fn test_patterns_rules_length_match() { assert_eq!( @@ -966,4 +1089,89 @@ mod tests { let cmd = "cat <<'EOF'\nhello && world\nEOF"; assert_eq!(split_command_chain(cmd), vec![cmd]); } + + // --- Route lookup tests --- + + #[test] + fn test_lookup_direct_route() { + let r = lookup("git", "status").unwrap(); + assert_eq!(r.rtk_cmd, "git"); + } + + #[test] + fn test_lookup_git_unknown_subcommand_returns_none() { + assert!(lookup("git", "rebase").is_none()); + assert!(lookup("git", "bisect").is_none()); + } + + #[test] + fn test_lookup_rename_rg_to_grep() { + let r = lookup("rg", "").unwrap(); + assert_eq!(r.rtk_cmd, "grep"); + } + + #[test] + fn test_lookup_rename_grep_to_grep() { + let r = lookup("grep", "-r").unwrap(); + assert_eq!(r.rtk_cmd, "grep"); + } + + #[test] + fn test_lookup_rename_eslint_to_lint() { + let r = lookup("eslint", "src/").unwrap(); + assert_eq!(r.rtk_cmd, "lint"); + } + + #[test] + fn test_lookup_any_subcommand() { + let r = lookup("ls", "-la").unwrap(); + assert_eq!(r.rtk_cmd, "ls"); + let r2 = lookup("ls", "").unwrap(); + assert_eq!(r2.rtk_cmd, "ls"); + } + + #[test] + fn test_lookup_unknown_binary_returns_none() { + assert!(lookup("unknownbinary99", "").is_none()); + // These stay as complex Rust match arms, not in ROUTES + assert!(lookup("vitest", "").is_none()); + assert!(lookup("pnpm", "list").is_none()); + assert!(lookup("npx", "tsc").is_none()); + assert!(lookup("uv", "pip").is_none()); + } + + #[test] + fn test_lookup_docker_subcommand_filter() { + assert!(lookup("docker", "ps").is_some()); + assert!(lookup("docker", "images").is_some()); + assert!(lookup("docker", "build").is_none()); + assert!(lookup("docker", "run").is_none()); + } + + #[test] + fn test_lookup_cargo_subcommand_filter() { + assert!(lookup("cargo", "test").is_some()); + assert!(lookup("cargo", "clippy").is_some()); + assert!(lookup("cargo", "publish").is_none()); + } + + #[test] + fn test_no_duplicate_binaries_in_routes() { + let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new(); + for route in ROUTES { + for &bin in route.binaries { + assert!( + seen.insert(bin), + "Binary '{bin}' appears in multiple ROUTES entries" + ); + } + } + } + + #[test] + fn test_lookup_is_o1_consistent() { + let r1 = lookup("git", "status"); + let r2 = lookup("git", "status"); + assert_eq!(r1.map(|r| r.rtk_cmd), r2.map(|r| r.rtk_cmd)); + } } From c1b768548de4cb628066f193bd7ad0f58b3e70cc Mon Sep 17 00:00:00 2001 From: Andrew Hundt Date: Mon, 16 Feb 2026 19:17:11 -0500 Subject: [PATCH 5/5] feat: add data safety rules and extensible rule system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add 11 safety rules intercepting destructive commands: - rm → trash (cross-platform via trash crate) - git reset --hard/clean/checkout → stash first - cat/head/sed → blocked, suggest Read tool or Edit tool Extensible rtk.*.md rule files with YAML frontmatter loaded from .rtk/, .claude/, .gemini/ directories. Config system expanded from single config.rs to config/ module with rules.rs and discovery.rs. Env var opt-out: RTK_SAFE_COMMANDS=0 disables safety rules, RTK_BLOCK_TOKEN_WASTE=0 allows cat/head/sed. Config CLI expanded: rtk config get/set/list/unset/create/export-rules. New dependencies: trash = "5", serde_yaml = "0.9" Closes #115 Tests: 673 total (35 safety, 4 trash, 3 restored hook safety tests) --- Cargo.lock | 149 +- Cargo.toml | 2 + src/cmd/exec.rs | 41 +- src/cmd/hook.rs | 92 +- src/cmd/mod.rs | 6 + src/cmd/safety.rs | 538 ++++++++ src/cmd/trash_cmd.rs | 76 ++ src/config.rs | 127 -- src/config/discovery.rs | 330 +++++ src/config/mod.rs | 1210 +++++++++++++++++ src/config/rules.rs | 881 ++++++++++++ src/init.rs | 2 +- src/main.rs | 108 +- src/rules/rtk.safety.block-cat.md | 11 + src/rules/rtk.safety.block-head.md | 11 + src/rules/rtk.safety.block-sed.md | 11 + src/rules/rtk.safety.git-checkout-dashdash.md | 10 + src/rules/rtk.safety.git-checkout-dot.md | 10 + src/rules/rtk.safety.git-clean-df.md | 9 + src/rules/rtk.safety.git-clean-f.md | 9 + src/rules/rtk.safety.git-clean-fd.md | 9 + src/rules/rtk.safety.git-reset-hard.md | 10 + src/rules/rtk.safety.git-stash-drop.md | 9 + src/rules/rtk.safety.rm-to-trash.md | 9 + 24 files changed, 3522 insertions(+), 148 deletions(-) create mode 100644 src/cmd/safety.rs create mode 100644 src/cmd/trash_cmd.rs delete mode 100644 src/config.rs create mode 100644 src/config/discovery.rs create mode 100644 src/config/mod.rs create mode 100644 src/config/rules.rs create mode 100644 src/rules/rtk.safety.block-cat.md create mode 100644 src/rules/rtk.safety.block-head.md create mode 100644 src/rules/rtk.safety.block-sed.md create mode 100644 src/rules/rtk.safety.git-checkout-dashdash.md create mode 100644 src/rules/rtk.safety.git-checkout-dot.md create mode 100644 src/rules/rtk.safety.git-clean-df.md create mode 100644 src/rules/rtk.safety.git-clean-f.md create mode 100644 src/rules/rtk.safety.git-clean-fd.md create mode 100644 src/rules/rtk.safety.git-reset-hard.md create mode 100644 src/rules/rtk.safety.git-stash-drop.md create mode 100644 src/rules/rtk.safety.rm-to-trash.md diff --git a/Cargo.lock b/Cargo.lock index f9baf4d..a12787f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,7 +383,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core", + "windows-core 0.62.2", ] [[package]] @@ -503,6 +503,31 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c2599ce0ec54857b29ce62166b0ed9b4f6f1a70ccc9a71165b6154caca8c05" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" +dependencies = [ + "bitflags", + "objc2", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -521,6 +546,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + [[package]] name = "pkg-config" version = "0.3.32" @@ -606,9 +637,11 @@ dependencies = [ "rusqlite", "serde", "serde_json", + "serde_yaml", "tempfile", "thiserror", "toml", + "trash", "walkdir", "which", ] @@ -646,6 +679,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + [[package]] name = "same-file" version = "1.0.6" @@ -655,6 +694,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "serde" version = "1.0.228" @@ -708,6 +753,19 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + [[package]] name = "shlex" version = "1.3.0" @@ -811,12 +869,42 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "trash" +version = "5.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9b93a14fcf658568eb11b3ac4cb406822e916e2c55cdebc421beeb0bd7c94d8" +dependencies = [ + "chrono", + "libc", + "log", + "objc2", + "objc2-foundation", + "once_cell", + "percent-encoding", + "scopeguard", + "urlencoding", + "windows", +] + [[package]] name = "unicode-ident" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8parse" version = "0.2.2" @@ -926,19 +1014,52 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "windows" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de69df01bdf1ead2f4ac895dc77c9351aefff65b2f3db429a343f9cbf05e132" +dependencies = [ + "windows-core 0.56.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4698e52ed2d08f8658ab0c39512a7c00ee5fe2688c65f8c0a4f06750d729f2a6" +dependencies = [ + "windows-implement 0.56.0", + "windows-interface 0.56.0", + "windows-result 0.1.2", + "windows-targets 0.52.6", +] + [[package]] name = "windows-core" version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ - "windows-implement", - "windows-interface", + "windows-implement 0.60.2", + "windows-interface 0.59.3", "windows-link", - "windows-result", + "windows-result 0.4.1", "windows-strings", ] +[[package]] +name = "windows-implement" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6fc35f58ecd95a9b71c4f2329b911016e6bec66b3f2e6a4aad86bd2e99e2f9b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-implement" version = "0.60.2" @@ -950,6 +1071,17 @@ dependencies = [ "syn", ] +[[package]] +name = "windows-interface" +version = "0.56.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08990546bf4edef8f431fa6326e032865f27138718c587dc21bc0265bbcb57cc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-interface" version = "0.59.3" @@ -967,6 +1099,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-result" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-result" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index 4f6809f..c8ccea2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,9 @@ regex = "1" lazy_static = "1.4" serde = { version = "1", features = ["derive"] } serde_json = { version = "1", features = ["preserve_order"] } +serde_yaml = "0.9" colored = "2" +trash = "5" dirs = "5" rusqlite = { version = "0.31", features = ["bundled"] } toml = "0.8" diff --git a/src/cmd/exec.rs b/src/cmd/exec.rs index 5eabd63..59be99b 100644 --- a/src/cmd/exec.rs +++ b/src/cmd/exec.rs @@ -3,7 +3,7 @@ use anyhow::{Context, Result}; use std::process::{Command, Stdio}; -use super::{analysis, builtins, filters, lexer}; +use super::{analysis, builtins, filters, lexer, safety, trash_cmd}; use crate::tracking; /// Check if RTK is already active (recursion guard) @@ -47,13 +47,27 @@ pub fn execute(raw: &str, verbose: u8) -> Result { } fn execute_inner(raw: &str, verbose: u8) -> Result { - // PR 2 adds: crate::config::rules::try_remap() alias expansion + // === STEP 0: Remap expansion (aliases like "t" → "cargo test") === + if let Some(expanded) = crate::config::rules::try_remap(raw) { + if verbose > 0 { + eprintln!( + "rtk remap: {} → {}", + raw.split_whitespace().next().unwrap_or(raw), + expanded + ); + } + return execute_inner(&expanded, verbose); + } let tokens = lexer::tokenize(raw); // === STEP 1: Decide Native vs Passthrough === if analysis::needs_shell(&tokens) { - // PR 2 adds: safety::check_raw(raw) before passthrough + // Even in passthrough, check safety on raw string + if let safety::SafetyResult::Blocked(msg) = safety::check_raw(raw) { + eprintln!("{}", msg); + return Ok(false); + } return run_passthrough(raw, verbose); } @@ -97,7 +111,26 @@ fn run_native(commands: &[analysis::NativeCommand], verbose: u8) -> Result } // Other rtk commands: spawn as external (they have their own filters) - // PR 2 adds: safety::check() dispatch block + // === SAFETY CHECK === + match safety::check(&cmd.binary, &cmd.args) { + safety::SafetyResult::Blocked(msg) => { + eprintln!("{}", msg); + return Ok(false); + } + safety::SafetyResult::Rewritten(new_cmd) => { + // Re-execute the rewritten command + if verbose > 0 { + eprintln!("rtk safety: Rewrote command"); + } + return execute(&new_cmd, verbose); + } + safety::SafetyResult::TrashRequested(paths) => { + last_success = trash_cmd::execute(&paths)?; + prev_operator = cmd.operator.as_deref(); + continue; + } + safety::SafetyResult::Safe => {} + } // === BUILTINS === if builtins::is_builtin(&cmd.binary) { diff --git a/src/cmd/hook.rs b/src/cmd/hook.rs index d4eb5b1..2376aba 100644 --- a/src/cmd/hook.rs +++ b/src/cmd/hook.rs @@ -36,8 +36,7 @@ //! **Gemini CLI** (JSON protocol via `gemini_hook.rs`): //! - See `gemini_hook.rs` module documentation -use super::{analysis, lexer}; -// PR 2 adds: use super::safety; +use super::{analysis, lexer, safety}; /// Hook check result #[derive(Debug, Clone)] @@ -61,22 +60,56 @@ pub fn check_for_hook(raw: &str, _agent: &str) -> HookResult { fn check_for_hook_inner(raw: &str, depth: usize) -> HookResult { if depth >= MAX_REWRITE_DEPTH { - return HookResult::Blocked("Rewrite loop detected (max depth exceeded)".to_string()); + return HookResult::Blocked( + "Safety rewrite loop detected (max depth exceeded)".to_string(), + ); } + + // Handle empty if raw.trim().is_empty() { return HookResult::Rewrite(raw.to_string()); } - // PR 2 adds: crate::config::rules::try_remap() alias expansion - // PR 2 adds: safety::check_raw() and safety::check() dispatch + + // Remap expansion (aliases like "t" → "cargo test") + if let Some(expanded) = crate::config::rules::try_remap(raw) { + return check_for_hook_inner(&expanded, depth + 1); + } let tokens = lexer::tokenize(raw); + // Check for shellisms - if present, pass through + // but still check safety if analysis::needs_shell(&tokens) { + match safety::check_raw(raw) { + safety::SafetyResult::Blocked(msg) => return HookResult::Blocked(msg), + safety::SafetyResult::Safe => {} + // check_raw currently only returns Safe/Blocked; defensive no-op + safety::SafetyResult::Rewritten(_) | safety::SafetyResult::TrashRequested(_) => {} + } + // Passthrough: just return as-is wrapped in rtk run return HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))); } + // Native mode: parse and check each command match analysis::parse_chain(tokens) { Ok(commands) => { + // Check safety on each command + for cmd in &commands { + match safety::check(&cmd.binary, &cmd.args) { + safety::SafetyResult::Blocked(msg) => { + return HookResult::Blocked(msg); + } + safety::SafetyResult::Rewritten(new_cmd) => { + return check_for_hook_inner(&new_cmd, depth + 1); + } + safety::SafetyResult::TrashRequested(_) => { + // Redirect to rtk run which handles trash + return HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))); + } + safety::SafetyResult::Safe => {} + } + } + // Single command: route to optimized RTK subcommand. // Chained commands (&&, ||, ;): wrap entire chain in rtk run -c. if commands.len() == 1 { @@ -85,7 +118,10 @@ fn check_for_hook_inner(raw: &str, depth: usize) -> HookResult { HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))) } } - Err(_) => HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))), + Err(_) => { + // Parse error - passthrough with wrapping + HookResult::Rewrite(format!("rtk run -c '{}'", escape_quotes(raw))) + } } } @@ -547,7 +583,18 @@ mod tests { } } - // PR 2 adds: test_compound_blocked_in_chain (safety-dependent test) + #[test] + fn test_compound_blocked_in_chain() { + // Safety rules catch dangerous commands even mid-chain + let cases = [ + ("cd /tmp && cat file.txt", "file-reading"), + ("echo start && sed -i 's/x/y/' f", "file-editing"), + ("git add . && head -5 f.txt", "file-reading"), + ]; + for (input, expected_msg) in cases { + assert_blocked(input, expected_msg); + } + } #[test] fn test_compound_quoted_operators_not_split() { @@ -566,7 +613,20 @@ mod tests { } } - // PR 2 adds: test_blocked_commands (safety-dependent test) + // === COMMANDS THAT SHOULD BLOCK (table-driven) === + + #[test] + fn test_blocked_commands() { + let cases = [ + ("cat file.txt", "file-reading"), + ("sed -i 's/old/new/' file.txt", "file-editing"), + ("head -n 10 file.txt", "file-reading"), + ("cd /tmp && cat file.txt", "file-reading"), // cat in chain + ]; + for (input, expected_msg) in cases { + assert_blocked(input, expected_msg); + } + } // === SHELLISM PASSTHROUGH: cat/sed/head allowed with pipe/redirect === @@ -732,7 +792,21 @@ mod tests { } } - // PR 2 adds: test_cross_protocol_blocked_command_denied_by_both (safety-dependent test) + #[test] + fn test_cross_protocol_blocked_command_denied_by_both() { + // Both Claude and Gemini must block the same unsafe commands + for cmd in ["cat file.txt", "head -n 10 file.txt"] { + let claude = check_for_hook(cmd, "claude"); + let gemini = check_for_hook(cmd, "gemini"); + match (&claude, &gemini) { + (HookResult::Blocked(_), HookResult::Blocked(_)) => {} + _ => panic!( + "'{}': Claude={:?}, Gemini={:?} — both should Block", + cmd, claude, gemini + ), + } + } + } // ===================================================================== // ROUTING TESTS — verify route_native_command dispatch diff --git a/src/cmd/mod.rs b/src/cmd/mod.rs index e371cc5..60ace48 100644 --- a/src/cmd/mod.rs +++ b/src/cmd/mod.rs @@ -7,6 +7,12 @@ pub(crate) mod analysis; pub(crate) mod lexer; +// Safety engine (depends on config::rules) +pub(crate) mod safety; + +// Trash command (depends on trash crate) +pub(crate) mod trash_cmd; + // Predicates and utilities (no external deps) pub(crate) mod predicates; diff --git a/src/cmd/safety.rs b/src/cmd/safety.rs new file mode 100644 index 0000000..21ea981 --- /dev/null +++ b/src/cmd/safety.rs @@ -0,0 +1,538 @@ +//! Safety Policy Engine — unified rule-based implementation. +//! +//! All safety rules, remaps, and blocking rules are loaded from the unified +//! Rule system (`config::rules`). Rules are MD files with YAML frontmatter, +//! loaded from built-in defaults and user directories. + +use crate::config::rules::{self, Rule}; + +use super::predicates; + +/// Result of safety check +#[derive(Clone, Debug, PartialEq)] +pub enum SafetyResult { + /// Command is safe to execute as-is + Safe, + /// Command is blocked with error message + Blocked(String), + /// Command was rewritten to a new command string + Rewritten(String), + /// Request to move files to trash (built-in) + TrashRequested(Vec), +} + +/// Dispatch a matched rule into a SafetyResult. +fn dispatch(rule: &Rule, args: &str) -> SafetyResult { + match rule.action.as_str() { + "trash" => { + let paths: Vec = args + .split_whitespace() + .filter(|a| !a.starts_with('-')) + .map(String::from) + .collect(); + SafetyResult::TrashRequested(paths) + } + "rewrite" => { + let redirect = rule.redirect.as_deref().unwrap_or(args); + SafetyResult::Rewritten(redirect.replace("{args}", args)) + } + "suggest_tool" | "block" => { + // Use interactive-aware message (human vs agent) + let msg = if predicates::is_interactive() { + // For suggest_tool, human message references the tool name + if rule.action == "suggest_tool" { + // First line of message is typically the human-friendly version + rule.message + .lines() + .next() + .unwrap_or(&rule.message) + .to_string() + } else { + rule.message.clone() + } + } else { + // Agent: use the full message (contains BLOCK: prefix) + rule.message.clone() + }; + SafetyResult::Blocked(msg) + } + "warn" => { + eprintln!("{}", rule.message); + SafetyResult::Safe + } + _ => SafetyResult::Safe, + } +} + +/// Check a parsed command against all safety rules. +pub fn check(binary: &str, args: &[String]) -> SafetyResult { + let full_cmd = if args.is_empty() { + binary.to_string() + } else { + format!("{} {}", binary, args.join(" ")) + }; + + for rule in rules::load_all() { + if !rules::matches_rule(rule, Some(binary), &full_cmd) { + continue; + } + if !rule.should_apply() { + continue; + } + return dispatch(rule, &args.join(" ")); + } + SafetyResult::Safe +} + +/// Check raw command string (for passthrough mode). +/// Catches dangerous patterns even when we can't parse the command. +pub fn check_raw(raw: &str) -> SafetyResult { + for rule in rules::load_all() { + if !rules::matches_rule(rule, None, raw) { + continue; + } + if !rule.should_apply() { + continue; + } + // In passthrough, suggest_tool rules don't apply (cat in pipelines is valid) + if rule.action == "suggest_tool" { + continue; + } + // In passthrough, trash becomes block (can't extract paths reliably) + if rule.action == "trash" { + return SafetyResult::Blocked(format!( + "Passthrough blocked: '{}' detected. Use native mode for safe trash.", + rule.patterns.first().map(|s| s.as_str()).unwrap_or("rm") + )); + } + return dispatch(rule, raw); + } + SafetyResult::Safe +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cmd::test_helpers::EnvGuard; + use std::env; + + // === BASIC CHECK TESTS === + + #[test] + fn test_check_safe_command() { + let _guard = EnvGuard::new(); + let result = check("ls", &["-la".to_string()]); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_git_status() { + let _guard = EnvGuard::new(); + let result = check("git", &["status".to_string()]); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_empty_args() { + let _guard = EnvGuard::new(); + let result = check("pwd", &[]); + assert_eq!(result, SafetyResult::Safe); + } + + // === RM SAFETY TESTS (RTK_SAFE_COMMANDS) === + + #[test] + fn test_check_rm_blocked_when_env_set() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("rm", &["file.txt".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + assert_eq!(paths, vec!["file.txt"]); + } + _ => panic!("Expected TrashRequested, got {:?}", result), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_blocked_by_default() { + let _guard = EnvGuard::new(); + // rm should be redirected to trash by default now + let result = check("rm", &["file.txt".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + assert_eq!(paths, vec!["file.txt"]); + } + _ => panic!("Expected TrashRequested by default, got {:?}", result), + } + } + + #[test] + fn test_check_rm_passes_when_disabled() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "0"); + let result = check("rm", &["file.txt".to_string()]); + assert_eq!(result, SafetyResult::Safe); + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_with_flags() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("rm", &["-rf".to_string(), "dir".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + // Flags should be filtered out + assert_eq!(paths, vec!["dir"]); + } + _ => panic!("Expected TrashRequested"), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_multiple_files() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check( + "rm", + &[ + "a.txt".to_string(), + "b.txt".to_string(), + "c.txt".to_string(), + ], + ); + match result { + SafetyResult::TrashRequested(paths) => { + assert_eq!(paths, vec!["a.txt", "b.txt", "c.txt"]); + } + _ => panic!("Expected TrashRequested"), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_rm_no_files() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("rm", &["-rf".to_string()]); + match result { + SafetyResult::TrashRequested(paths) => { + assert!(paths.is_empty()); + } + _ => panic!("Expected TrashRequested, got {:?}", result), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + // === CAT/SED/HEAD TESTS (blocked by default, opt-out with RTK_BLOCK_TOKEN_WASTE=0) === + + #[test] + fn test_check_cat_blocked() { + let _guard = EnvGuard::new(); + let result = check("cat", &["file.txt".to_string()]); + match result { + SafetyResult::Blocked(msg) => { + assert!(msg.contains("file-reading"), "msg: {}", msg); + } + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_cat_passes_when_disabled() { + let _guard = EnvGuard::new(); + env::set_var("RTK_BLOCK_TOKEN_WASTE", "0"); + let result = check("cat", &["file.txt".to_string()]); + env::remove_var("RTK_BLOCK_TOKEN_WASTE"); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_sed_blocked() { + let _guard = EnvGuard::new(); + let result = check("sed", &["-i".to_string(), "s/old/new/g".to_string()]); + match result { + SafetyResult::Blocked(msg) => { + assert!(msg.contains("file-editing"), "msg: {}", msg); + } + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_head_blocked() { + let _guard = EnvGuard::new(); + let result = check( + "head", + &["-n".to_string(), "10".to_string(), "file.txt".to_string()], + ); + match result { + SafetyResult::Blocked(msg) => { + assert!(msg.contains("file-reading"), "msg: {}", msg); + } + _ => panic!("Expected Blocked"), + } + } + + // === GIT SAFETY TESTS (RTK_SAFE_COMMANDS) === + + #[test] + fn test_check_git_reset_hard_blocked_when_env_set() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + // This test may or may not trigger depending on git state + // Just ensure it doesn't panic + let _ = check("git", &["reset".to_string(), "--hard".to_string()]); + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_git_clean_fd_rewritten() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check("git", &["clean".to_string(), "-fd".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash -u")); + assert!(cmd.contains("clean")); + } + _ => panic!("Expected Rewritten, got {:?}", result), + } + env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_git_clean_rewritten_by_default() { + let _guard = EnvGuard::new(); + // git clean should be rewritten with stash by default + let result = check("git", &["clean".to_string(), "-fd".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash -u")); + } + _ => panic!("Expected Rewritten by default, got {:?}", result), + } + } + + #[test] + fn test_check_git_clean_passes_when_disabled() { + let _guard = EnvGuard::new(); + env::set_var("RTK_SAFE_COMMANDS", "0"); + let result = check("git", &["clean".to_string(), "-fd".to_string()]); + assert_eq!(result, SafetyResult::Safe); + env::remove_var("RTK_SAFE_COMMANDS"); + } + + // === CHECK_RAW TESTS === + + #[test] + fn test_check_raw_rm_detected() { + let _guard = EnvGuard::new(); + // RTK_SAFE_COMMANDS is enabled by default, so rm should be blocked + let result = check_raw("rm file.txt"); + match result { + SafetyResult::Blocked(_) => {} + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_raw_sudo_rm_detected() { + let _guard = EnvGuard::new(); + // RTK_SAFE_COMMANDS is enabled by default, so sudo rm should be blocked + let result = check_raw("sudo rm file.txt"); + match result { + SafetyResult::Blocked(_) => {} + _ => panic!("Expected Blocked"), + } + } + + #[test] + fn test_check_raw_sudo_flags_rm_detected() { + let _guard = EnvGuard::new(); + let result = check_raw("sudo -u root rm file.txt"); + match result { + SafetyResult::Blocked(_) => {} + _ => panic!("Expected Blocked for sudo -u root rm"), + } + } + + #[test] + fn test_check_raw_safe_command() { + let _guard = EnvGuard::new(); + let result = check_raw("ls -la"); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_check_raw_rm_in_quoted_string() { + let _guard = EnvGuard::new(); + let result = check_raw("echo \"rm file\""); + // This will be blocked because we can't distinguish quoted rm + // That's intentional - better safe than sorry + match result { + SafetyResult::Blocked(_) => {} + SafetyResult::Safe => {} // Either is acceptable + SafetyResult::Rewritten(_) => {} + SafetyResult::TrashRequested(_) => {} + } + } + + // === NEW GIT SAFETY TESTS === + + #[test] + fn test_git_checkout_dot_stash_prepended() { + let _guard = EnvGuard::new(); + let result = check("git", &["checkout".to_string(), ".".to_string()]); + // May or may not trigger based on predicate, just ensure no panic + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash")); + assert!(cmd.contains("checkout")); + } + SafetyResult::Safe => {} // Predicate returned false (no changes) + _ => {} + } + } + + #[test] + fn test_git_checkout_dashdash_stash_prepended() { + let _guard = EnvGuard::new(); + let result = check( + "git", + &[ + "checkout".to_string(), + "--".to_string(), + "file.txt".to_string(), + ], + ); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash")); + assert!(cmd.contains("checkout")); + } + SafetyResult::Safe => {} + _ => {} + } + } + + #[test] + fn test_git_stash_drop_rewritten_to_pop() { + let _guard = EnvGuard::new(); + let result = check("git", &["stash".to_string(), "drop".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash pop")); + } + _ => panic!("Expected Rewritten to stash pop"), + } + } + + #[test] + fn test_git_clean_f_rewritten() { + let _guard = EnvGuard::new(); + let result = check("git", &["clean".to_string(), "-f".to_string()]); + match result { + SafetyResult::Rewritten(cmd) => { + assert!(cmd.contains("stash -u")); + assert!(cmd.contains("clean")); + } + _ => panic!("Expected Rewritten with stash -u"), + } + } + + #[test] + fn test_git_branch_checkout_safe() { + // git checkout should be safe (not matched by checkout . or checkout --) + let _guard = EnvGuard::new(); + let result = check("git", &["checkout".to_string(), "main".to_string()]); + assert_eq!(result, SafetyResult::Safe); + } + + #[test] + fn test_git_checkout_new_branch_safe() { + let _guard = EnvGuard::new(); + let result = check( + "git", + &[ + "checkout".to_string(), + "-b".to_string(), + "feature".to_string(), + ], + ); + assert_eq!(result, SafetyResult::Safe); + } + + // === PATTERN MATCHING FALSE POSITIVE TESTS === + + #[test] + fn test_no_false_positive_catalog() { + let _guard = EnvGuard::new(); + let result = check("catalog", &["show".to_string()]); + assert_eq!( + result, + SafetyResult::Safe, + "catalog must not match cat rule" + ); + } + + #[test] + fn test_no_false_positive_sedan() { + let _guard = EnvGuard::new(); + let result = check("sedan", &[]); + assert_eq!(result, SafetyResult::Safe, "sedan must not match sed rule"); + } + + #[test] + fn test_no_false_positive_headless() { + let _guard = EnvGuard::new(); + let result = check("headless", &["chrome".to_string()]); + assert_eq!( + result, + SafetyResult::Safe, + "headless must not match head rule" + ); + } + + #[test] + fn test_no_false_positive_rmdir() { + let _guard = EnvGuard::new(); + let result = check("rmdir", &["empty_dir".to_string()]); + assert_eq!(result, SafetyResult::Safe, "rmdir must not match rm rule"); + } + + // === CHECK_RAW WORD BOUNDARY TESTS === + + #[test] + fn test_check_raw_no_false_positive_trim() { + let _guard = EnvGuard::new(); + std::env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check_raw("trim file.txt"); + assert_eq!(result, SafetyResult::Safe, "trim must not match rm pattern"); + std::env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_raw_no_false_positive_farm() { + let _guard = EnvGuard::new(); + std::env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check_raw("farm --harvest"); + assert_eq!(result, SafetyResult::Safe, "farm must not match rm pattern"); + std::env::remove_var("RTK_SAFE_COMMANDS"); + } + + #[test] + fn test_check_raw_catches_standalone_rm() { + let _guard = EnvGuard::new(); + std::env::set_var("RTK_SAFE_COMMANDS", "1"); + let result = check_raw("rm file.txt"); + assert!( + matches!(result, SafetyResult::Blocked(_)), + "standalone rm must be caught" + ); + std::env::remove_var("RTK_SAFE_COMMANDS"); + } +} diff --git a/src/cmd/trash_cmd.rs b/src/cmd/trash_cmd.rs new file mode 100644 index 0000000..70ac9ae --- /dev/null +++ b/src/cmd/trash_cmd.rs @@ -0,0 +1,76 @@ +//! Built-in trash - mirrors rm behavior: silent on success, error on failure. + +use anyhow::Result; +use std::path::Path; + +pub fn execute(paths: &[String]) -> Result { + let expanded: Vec = paths + .iter() + .filter(|p| !p.is_empty()) + .map(|p| super::predicates::expand_tilde(p)) + .collect(); + + if expanded.is_empty() { + eprintln!("trash: no paths specified"); + return Ok(false); + } + + let (existing, missing): (Vec<_>, Vec<_>) = + expanded.iter().partition(|p| Path::new(p).exists()); + + // Report missing like rm does + for p in &missing { + eprintln!("trash: cannot remove '{}': No such path", p); + } + + if existing.is_empty() { + return Ok(false); + } + + let refs: Vec<&str> = existing.iter().map(|s| s.as_str()).collect(); + match trash::delete_all(&refs) { + Ok(_) => Ok(true), + Err(e) => { + eprintln!("trash: {}", e); + Ok(false) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::path::PathBuf; + + fn tmp(name: &str) -> PathBuf { + let p = std::env::temp_dir().join(format!("rtk_{}", name)); + fs::write(&p, "x").unwrap(); + p + } + fn rm(p: &PathBuf) { + let _ = fs::remove_file(p); + } + + #[test] + fn t_empty() { + assert!(!execute(&[]).unwrap()); + } + #[test] + fn t_missing() { + assert!(!execute(&["/nope".into()]).unwrap()); + } + #[test] + fn t_single() { + let p = tmp("s"); + assert!(execute(&[p.to_string_lossy().into()]).unwrap()); + rm(&p); + } + #[test] + fn t_multi() { + let (a, b) = (tmp("a"), tmp("b")); + assert!(execute(&[a.to_string_lossy().into(), b.to_string_lossy().into()]).unwrap()); + rm(&a); + rm(&b); + } +} diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index 1015012..0000000 --- a/src/config.rs +++ /dev/null @@ -1,127 +0,0 @@ -use anyhow::Result; -use serde::{Deserialize, Serialize}; -use std::path::PathBuf; - -#[derive(Debug, Serialize, Deserialize, Default)] -pub struct Config { - #[serde(default)] - pub tracking: TrackingConfig, - #[serde(default)] - pub display: DisplayConfig, - #[serde(default)] - pub filters: FilterConfig, - #[serde(default)] - pub tee: crate::tee::TeeConfig, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TrackingConfig { - pub enabled: bool, - pub history_days: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub database_path: Option, -} - -impl Default for TrackingConfig { - fn default() -> Self { - Self { - enabled: true, - history_days: 90, - database_path: None, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct DisplayConfig { - pub colors: bool, - pub emoji: bool, - pub max_width: usize, -} - -impl Default for DisplayConfig { - fn default() -> Self { - Self { - colors: true, - emoji: true, - max_width: 120, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct FilterConfig { - pub ignore_dirs: Vec, - pub ignore_files: Vec, -} - -impl Default for FilterConfig { - fn default() -> Self { - Self { - ignore_dirs: vec![ - ".git".into(), - "node_modules".into(), - "target".into(), - "__pycache__".into(), - ".venv".into(), - "vendor".into(), - ], - ignore_files: vec!["*.lock".into(), "*.min.js".into(), "*.min.css".into()], - } - } -} - -impl Config { - pub fn load() -> Result { - let path = get_config_path()?; - - if path.exists() { - let content = std::fs::read_to_string(&path)?; - let config: Config = toml::from_str(&content)?; - Ok(config) - } else { - Ok(Config::default()) - } - } - - pub fn save(&self) -> Result<()> { - let path = get_config_path()?; - - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent)?; - } - - let content = toml::to_string_pretty(self)?; - std::fs::write(&path, content)?; - Ok(()) - } - - pub fn create_default() -> Result { - let config = Config::default(); - config.save()?; - get_config_path() - } -} - -fn get_config_path() -> Result { - let config_dir = dirs::config_dir().unwrap_or_else(|| PathBuf::from(".")); - Ok(config_dir.join("rtk").join("config.toml")) -} - -pub fn show_config() -> Result<()> { - let path = get_config_path()?; - println!("Config: {}", path.display()); - println!(); - - if path.exists() { - let config = Config::load()?; - println!("{}", toml::to_string_pretty(&config)?); - } else { - println!("(default config, file not created)"); - println!(); - let config = Config::default(); - println!("{}", toml::to_string_pretty(&config)?); - } - - Ok(()) -} diff --git a/src/config/discovery.rs b/src/config/discovery.rs new file mode 100644 index 0000000..edd5112 --- /dev/null +++ b/src/config/discovery.rs @@ -0,0 +1,330 @@ +//! Directory walk-up discovery for `rtk.*.md` rule files. +//! +//! Walks from cwd to home, scanning configurable dirs in each ancestor. +//! Search dirs, global dirs, and extra rules_dirs are read from config. +//! Results cached via `OnceLock` — zero cost after first call. + +use std::collections::HashSet; +use std::path::{Path, PathBuf}; +use std::sync::OnceLock; + +static DISCOVERED: OnceLock> = OnceLock::new(); + +/// Return all `rtk.*.md` files ordered lowest→highest priority. +/// +/// Precedence (highest wins): +/// 0 (lowest). Compiled `include_str!()` defaults (handled in rules.rs, not here) +/// 1. Platform config dir + `~/.config/rtk/` (global RTK config) +/// 2. Config `discovery.rules_dirs` (explicit extra dirs) +/// 3. Config `discovery.global_dirs` under $HOME (default: `.claude/`, `.gemini/`) +/// 4. Walk up from cwd using config `discovery.search_dirs` +/// (default: `.claude/`, `.gemini/`, `.rtk/` — furthest from cwd first, cwd last) +/// 5. CLI `--rules-add` paths (highest file priority) +/// +/// If `--rules-path` is set, ONLY those paths are searched (skips all discovery). +/// All dirs configurable via `[discovery]` section in config.toml or env vars. +pub fn discover_rtk_files() -> &'static [PathBuf] { + DISCOVERED.get_or_init(discover_impl) +} + +fn discover_impl() -> Vec { + let mut seen = HashSet::new(); + let mut files = Vec::new(); + let overrides = super::cli_overrides(); + + // If --rules-path is set, use ONLY those paths (exclusive mode) + if let Some(ref exclusive_paths) = overrides.rules_path { + for dir in exclusive_paths { + collect_from_dir(dir, &mut files, &mut seen); + } + return files; + } + + let config = super::get_merged(); + + // Normal discovery + let home = match dirs::home_dir() { + Some(h) => h, + None => return files, + }; + + // 1. Platform-specific config dir (macOS: ~/Library/Application Support/rtk/) + if let Some(config_dir) = dirs::config_dir() { + let platform_rtk = config_dir.join("rtk"); + collect_from_dir(&platform_rtk, &mut files, &mut seen); + } + + // 2. Canonical RTK config dir: ~/.config/rtk/ + let canonical_rtk = home.join(".config").join("rtk"); + collect_from_dir(&canonical_rtk, &mut files, &mut seen); + + // 3. Config discovery.rules_dirs (explicit extra directories) + for dir in &config.discovery.rules_dirs { + collect_from_dir(dir, &mut files, &mut seen); + } + + // 4. Global dirs under $HOME (from config discovery.global_dirs) + for name in &config.discovery.global_dirs { + collect_from_dir(&home.join(name), &mut files, &mut seen); + } + + // 5. Walk up from cwd to home using config discovery.search_dirs + let cwd = match std::env::current_dir() { + Ok(c) => c, + Err(_) => return files, + }; + + let mut ancestors: Vec = Vec::new(); + let mut current = cwd.as_path(); + loop { + ancestors.push(current.to_path_buf()); + if current == home { + break; + } + match current.parent() { + Some(p) if p != current => current = p, + _ => break, + } + } + // Reverse: furthest ancestor first (lowest priority), cwd last (highest) + ancestors.reverse(); + + for ancestor in &ancestors { + for search_dir in &config.discovery.search_dirs { + let dir = ancestor.join(search_dir); + collect_from_dir(&dir, &mut files, &mut seen); + } + } + + // 6. --rules-add paths (highest file priority, after all discovery) + for dir in &overrides.rules_add { + collect_from_dir(dir, &mut files, &mut seen); + } + + files +} + +/// Collect `rtk.*.md` files from a directory, deduplicating by canonical path. +fn collect_from_dir(dir: &Path, files: &mut Vec, seen: &mut HashSet) { + let entries = match std::fs::read_dir(dir) { + Ok(e) => e, + Err(_) => return, // Silently skip unreadable dirs + }; + + let mut dir_files: Vec = Vec::new(); + for entry in entries.flatten() { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + if is_rtk_rule_file(&name_str) { + let path = entry.path(); + // Canonicalize for dedup: detects symlink loops and duplicate real paths + let canon = match path.canonicalize() { + Ok(c) => c, + Err(_) => continue, // Broken symlink or unreadable + }; + if seen.insert(canon) { + dir_files.push(path); + } + } + } + // Sort within directory for deterministic ordering + dir_files.sort(); + files.extend(dir_files); +} + +/// Match `rtk.*.md` pattern: starts with "rtk.", ends with ".md", has content between. +fn is_rtk_rule_file(name: &str) -> bool { + name.starts_with("rtk.") && name.ends_with(".md") && name.len() > 7 +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + + #[test] + fn test_is_rtk_rule_file_valid() { + assert!(is_rtk_rule_file("rtk.safety.rm-to-trash.md")); + assert!(is_rtk_rule_file("rtk.remap.t.md")); + assert!(is_rtk_rule_file("rtk.x.md")); // minimal valid: 8 chars + } + + #[test] + fn test_is_rtk_rule_file_invalid() { + assert!(!is_rtk_rule_file("rtk.md")); // too short (7 chars, not > 7) + assert!(!is_rtk_rule_file("foo.md")); + assert!(!is_rtk_rule_file("rtk.safety.txt")); + assert!(!is_rtk_rule_file("")); + } + + #[test] + fn test_collect_from_empty_dir() { + let tmp = tempfile::tempdir().unwrap(); + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + assert!(files.is_empty()); + } + + #[test] + fn test_collect_from_dir_with_rules() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join("rtk.test.md"), "---\nname: test\n---\n").unwrap(); + fs::write(tmp.path().join("not-a-rule.md"), "ignored").unwrap(); + fs::write(tmp.path().join("rtk.md"), "too short name").unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + assert_eq!(files.len(), 1); + assert!(files[0].file_name().unwrap().to_str().unwrap() == "rtk.test.md"); + } + + #[test] + fn test_collect_deduplicates_symlinks() { + let tmp = tempfile::tempdir().unwrap(); + let real = tmp.path().join("rtk.test.md"); + fs::write(&real, "---\nname: test\n---\n").unwrap(); + + // Create a subdirectory with a symlink to the same file + let subdir = tmp.path().join("sub"); + fs::create_dir(&subdir).unwrap(); + #[cfg(unix)] + std::os::unix::fs::symlink(&real, subdir.join("rtk.test.md")).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + collect_from_dir(&subdir, &mut files, &mut seen); + + #[cfg(unix)] + assert_eq!(files.len(), 1, "Symlink should be deduplicated"); + } + + #[test] + fn test_collect_skips_unreadable_dir() { + let mut files = Vec::new(); + let mut seen = HashSet::new(); + // Non-existent directory should be silently skipped + collect_from_dir(Path::new("/nonexistent/path"), &mut files, &mut seen); + assert!(files.is_empty()); + } + + #[test] + fn test_collect_skips_file_as_dir() { + // If a file is passed instead of a directory, read_dir will fail — should be skipped + let tmp = tempfile::tempdir().unwrap(); + let file_path = tmp.path().join("not_a_dir"); + fs::write(&file_path, "i am a file").unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(&file_path, &mut files, &mut seen); + assert!(files.is_empty()); // Not a dir, silently skipped + } + + #[test] + fn test_collect_skips_broken_symlinks() { + let tmp = tempfile::tempdir().unwrap(); + + #[cfg(unix)] + { + // Create a broken symlink (target doesn't exist) + let broken_link = tmp.path().join("rtk.broken.md"); + std::os::unix::fs::symlink("/nonexistent/target", &broken_link).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + // Broken symlink: canonicalize fails → continue (skipped) + assert!(files.is_empty()); + } + } + + #[test] + fn test_collect_handles_non_utf8_filenames() { + // Files with non-UTF8 names should be handled via to_string_lossy + let tmp = tempfile::tempdir().unwrap(); + // Create a normal rtk rule file alongside a non-matching file + fs::write(tmp.path().join("rtk.valid.md"), "---\nname: v\n---\n").unwrap(); + fs::write(tmp.path().join("other.txt"), "not a rule").unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(tmp.path(), &mut files, &mut seen); + assert_eq!(files.len(), 1); + } + + #[test] + fn test_collect_multiple_dirs_deduplicates() { + let tmp = tempfile::tempdir().unwrap(); + let dir_a = tmp.path().join("a"); + let dir_b = tmp.path().join("b"); + fs::create_dir_all(&dir_a).unwrap(); + fs::create_dir_all(&dir_b).unwrap(); + + let real_file = dir_a.join("rtk.test.md"); + fs::write(&real_file, "---\nname: test\n---\n").unwrap(); + + #[cfg(unix)] + { + // Symlink from dir_b to same real file + std::os::unix::fs::symlink(&real_file, dir_b.join("rtk.test.md")).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(&dir_a, &mut files, &mut seen); + collect_from_dir(&dir_b, &mut files, &mut seen); + assert_eq!( + files.len(), + 1, + "Same file via symlink should be deduplicated" + ); + } + } + + #[cfg(unix)] + #[test] + fn test_collect_permission_denied_dir_skipped() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::tempdir().unwrap(); + let restricted = tmp.path().join("restricted"); + fs::create_dir(&restricted).unwrap(); + fs::write(restricted.join("rtk.test.md"), "---\nname: t\n---\n").unwrap(); + + // Remove read permission + fs::set_permissions(&restricted, fs::Permissions::from_mode(0o000)).unwrap(); + + let mut files = Vec::new(); + let mut seen = HashSet::new(); + collect_from_dir(&restricted, &mut files, &mut seen); + // Permission denied → silently skipped + assert!(files.is_empty()); + + // Restore permissions for cleanup + fs::set_permissions(&restricted, fs::Permissions::from_mode(0o755)).unwrap(); + } + + #[test] + fn test_default_search_dirs_match_expected() { + // Verify defaults match the previously hardcoded values + let config = crate::config::DiscoveryConfig::default(); + assert_eq!(config.search_dirs, vec![".claude", ".gemini", ".rtk"]); + } + + #[test] + fn test_default_global_dirs_match_expected() { + let config = crate::config::DiscoveryConfig::default(); + assert_eq!(config.global_dirs, vec![".claude", ".gemini"]); + } + + #[test] + fn test_default_rules_dirs_empty() { + let config = crate::config::DiscoveryConfig::default(); + assert!( + config.rules_dirs.is_empty(), + "Default rules_dirs should be empty (uses ~/.config/rtk/ implicitly)" + ); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..1d5235a --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,1210 @@ +//! Configuration system: scalar config (TOML) + unified rules (MD with YAML frontmatter). +//! +//! Two config layers: +//! 1. Scalar config (`config.toml`): tracking, display, filters +//! 2. Rules (`rtk.*.md`): safety, remaps, warnings — via `rules` submodule + +pub mod discovery; +pub mod rules; + +use anyhow::{anyhow, Result}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::OnceLock; + +/// CLI overrides for config paths. Set from main.rs before any config loading. +#[derive(Debug, Default)] +pub struct CliConfigOverrides { + /// Exclusive config paths — replaces all discovery. Multiple files merged in order. + pub config_path: Option>, + /// Additional config paths — loaded with highest priority (after env vars). + pub config_add: Vec, + /// Exclusive rule discovery paths — replaces walk-up discovery. + pub rules_path: Option>, + /// Additional rule discovery paths — loaded with highest priority. + pub rules_add: Vec, +} + +static CLI_OVERRIDES: OnceLock = OnceLock::new(); + +/// Set CLI config overrides. Must be called before any config loading. +pub fn set_cli_overrides(overrides: CliConfigOverrides) { + let _ = CLI_OVERRIDES.set(overrides); +} + +/// Get CLI config overrides (or defaults if never set). +pub fn cli_overrides() -> &'static CliConfigOverrides { + CLI_OVERRIDES.get_or_init(CliConfigOverrides::default) +} + +#[derive(Debug, Serialize, Deserialize, Default, Clone)] +pub struct Config { + #[serde(default)] + pub tracking: TrackingConfig, + #[serde(default)] + pub display: DisplayConfig, + #[serde(default)] + pub filters: FilterConfig, + #[serde(default)] + pub discovery: DiscoveryConfig, + #[serde(default)] + pub tee: crate::tee::TeeConfig, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct TrackingConfig { + pub enabled: bool, + pub history_days: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub database_path: Option, +} + +impl Default for TrackingConfig { + fn default() -> Self { + Self { + enabled: true, + history_days: 90, + database_path: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct DisplayConfig { + pub colors: bool, + pub emoji: bool, + pub max_width: usize, +} + +impl Default for DisplayConfig { + fn default() -> Self { + Self { + colors: true, + emoji: true, + max_width: 120, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct FilterConfig { + pub ignore_dirs: Vec, + pub ignore_files: Vec, +} + +impl Default for FilterConfig { + fn default() -> Self { + Self { + ignore_dirs: vec![ + ".git".into(), + "node_modules".into(), + "target".into(), + "__pycache__".into(), + ".venv".into(), + "vendor".into(), + ], + ignore_files: vec!["*.lock".into(), "*.min.js".into(), "*.min.css".into()], + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct DiscoveryConfig { + /// Dirs to search in each ancestor during walk-up (e.g. [".claude", ".gemini", ".rtk"]). + pub search_dirs: Vec, + /// Global dirs under $HOME to check before walk-up (e.g. [".claude", ".gemini"]). + pub global_dirs: Vec, + /// Additional rule directories to search. First entry is also the export/write target. + /// Default: [] (uses ~/.config/rtk/ as the implicit primary). + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub rules_dirs: Vec, +} + +impl Default for DiscoveryConfig { + fn default() -> Self { + Self { + search_dirs: vec![".claude".into(), ".gemini".into(), ".rtk".into()], + global_dirs: vec![".claude".into(), ".gemini".into()], + rules_dirs: vec![], + } + } +} + +impl Config { + /// Load global config from `~/.config/rtk/config.toml`. + /// Falls back to defaults if file is missing or unreadable. + pub fn load() -> Result { + let path = match get_config_path() { + Ok(p) => p, + Err(_) => return Ok(Config::default()), + }; + + if path.exists() { + match std::fs::read_to_string(&path) { + Ok(content) => match toml::from_str(&content) { + Ok(config) => Ok(config), + Err(_) => Ok(Config::default()), // Malformed config → defaults + }, + Err(_) => Ok(Config::default()), // Unreadable → defaults + } + } else { + Ok(Config::default()) + } + } + + /// Load merged config with full precedence chain. + /// + /// Precedence (highest wins): + /// 0. CLI params: `--config-path` (exclusive) or `--config-add` (additive) + /// 1. Environment variables (RTK_*) + /// 2. Project-local `.rtk/config.toml` (nearest ancestor) + /// 3. Global `~/.config/rtk/config.toml` (or platform config dir) + /// 4. Compiled defaults + pub fn load_merged() -> Result { + let overrides = cli_overrides(); + + // If --config-path is set, use ONLY those files (skip global + walk-up) + let mut config = if let Some(ref exclusive_paths) = overrides.config_path { + let mut cfg = Config::default(); + for path in exclusive_paths { + if path.exists() { + if let Ok(content) = std::fs::read_to_string(path) { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut cfg); + } + } + } + } + cfg + } else { + // Normal: start with global config + let mut cfg = Self::load()?; + + // Layer 3: Walk up from cwd looking for .rtk/config.toml + if let Ok(cwd) = std::env::current_dir() { + let mut current = cwd.as_path(); + loop { + let project_config = current.join(".rtk").join("config.toml"); + if project_config.exists() { + match std::fs::read_to_string(&project_config) { + Ok(content) => { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut cfg); + } + } + Err(_) => {} // Silently skip unreadable project config + } + break; + } + match current.parent() { + Some(p) if p != current => current = p, + _ => break, + } + } + } + cfg + }; + + // Layer 1.5: --config-add paths (higher than project-local, lower than env vars) + for add_path in &overrides.config_add { + if add_path.exists() { + if let Ok(content) = std::fs::read_to_string(add_path) { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut config); + } + } + } + } + + // Layer 1 (highest priority): Environment variable overrides + if let Ok(val) = std::env::var("RTK_TRACKING_ENABLED") { + if let Ok(b) = val.parse::() { + config.tracking.enabled = b; + } else if val == "0" { + config.tracking.enabled = false; + } else if val == "1" { + config.tracking.enabled = true; + } + } + if let Ok(val) = std::env::var("RTK_HISTORY_DAYS") { + if let Ok(days) = val.parse::() { + config.tracking.history_days = days; + } + } + if let Ok(path) = std::env::var("RTK_DB_PATH") { + config.tracking.database_path = Some(PathBuf::from(path)); + } + if let Ok(val) = std::env::var("RTK_DISPLAY_COLORS") { + if let Ok(b) = val.parse::() { + config.display.colors = b; + } + } + if let Ok(val) = std::env::var("RTK_DISPLAY_EMOJI") { + if let Ok(b) = val.parse::() { + config.display.emoji = b; + } + } + if let Ok(val) = std::env::var("RTK_MAX_WIDTH") { + if let Ok(w) = val.parse::() { + config.display.max_width = w; + } + } + if let Ok(val) = std::env::var("RTK_SEARCH_DIRS") { + config.discovery.search_dirs = val.split(',').map(|s| s.trim().to_string()).collect(); + } + if let Ok(val) = std::env::var("RTK_RULES_DIRS") { + config.discovery.rules_dirs = val.split(',').map(|s| PathBuf::from(s.trim())).collect(); + } + + Ok(config) + } + + pub fn save(&self) -> Result<()> { + let path = get_config_path()?; + + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + + let content = toml::to_string_pretty(self)?; + std::fs::write(&path, content)?; + Ok(()) + } + + /// Save to a specific path (for --local support). + pub fn save_to(&self, path: &std::path::Path) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let content = toml::to_string_pretty(self)?; + std::fs::write(path, content)?; + Ok(()) + } + + pub fn create_default() -> Result { + let config = Config::default(); + config.save()?; + get_config_path() + } +} + +/// Overlay config for merging project config onto global config. +/// All fields are Option — only present fields override. +#[derive(Debug, Deserialize, Default)] +pub struct ConfigOverlay { + pub tracking: Option, + pub display: Option, + pub filters: Option, + pub discovery: Option, +} + +#[derive(Debug, Deserialize)] +pub struct TrackingOverlay { + pub enabled: Option, + pub history_days: Option, + pub database_path: Option, +} + +#[derive(Debug, Deserialize)] +pub struct DisplayOverlay { + pub colors: Option, + pub emoji: Option, + pub max_width: Option, +} + +#[derive(Debug, Deserialize)] +pub struct FilterOverlay { + pub ignore_dirs: Option>, + pub ignore_files: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct DiscoveryOverlay { + pub search_dirs: Option>, + pub global_dirs: Option>, + pub rules_dirs: Option>, +} + +impl ConfigOverlay { + fn apply(&self, config: &mut Config) { + if let Some(ref t) = self.tracking { + if let Some(v) = t.enabled { + config.tracking.enabled = v; + } + if let Some(v) = t.history_days { + config.tracking.history_days = v; + } + if let Some(ref v) = t.database_path { + config.tracking.database_path = Some(v.clone()); + } + } + if let Some(ref d) = self.display { + if let Some(v) = d.colors { + config.display.colors = v; + } + if let Some(v) = d.emoji { + config.display.emoji = v; + } + if let Some(v) = d.max_width { + config.display.max_width = v; + } + } + if let Some(ref f) = self.filters { + if let Some(ref v) = f.ignore_dirs { + config.filters.ignore_dirs = v.clone(); + } + if let Some(ref v) = f.ignore_files { + config.filters.ignore_files = v.clone(); + } + } + if let Some(ref d) = self.discovery { + if let Some(ref v) = d.search_dirs { + config.discovery.search_dirs = v.clone(); + } + if let Some(ref v) = d.global_dirs { + config.discovery.global_dirs = v.clone(); + } + if let Some(ref v) = d.rules_dirs { + config.discovery.rules_dirs = v.clone(); + } + } + } +} + +/// Global config path: `~/.config/rtk/config.toml` +pub fn get_config_path() -> Result { + let config_dir = dirs::config_dir().unwrap_or_else(|| PathBuf::from(".")); + Ok(config_dir.join("rtk").join("config.toml")) +} + +/// Canonical RTK rules directory: `~/.config/rtk/` +/// +/// This is distinct from `dirs::config_dir()` which on macOS returns +/// `~/Library/Application Support/` — not appropriate for a CLI tool's +/// user-facing rule files. We use `~/.config/rtk/` on all platforms. +/// Primary rules directory (for writes/exports). First entry of rules_dirs, or ~/.config/rtk/. +pub fn get_rules_dir() -> Result { + let config = get_merged(); + if let Some(first) = config.discovery.rules_dirs.first() { + return Ok(first.clone()); + } + let home = dirs::home_dir().ok_or_else(|| anyhow!("Cannot determine home directory"))?; + Ok(home.join(".config").join("rtk")) +} + +/// Project-local config path: `.rtk/config.toml` in cwd +pub fn get_local_config_path() -> Result { + let cwd = std::env::current_dir()?; + Ok(cwd.join(".rtk").join("config.toml")) +} + +/// Cached merged config (loaded once per process). +static MERGED_CONFIG: OnceLock = OnceLock::new(); + +/// Get the merged config (cached). For use by tracking, display, etc. +pub fn get_merged() -> &'static Config { + MERGED_CONFIG.get_or_init(|| Config::load_merged().unwrap_or_default()) +} + +pub fn show_config() -> Result<()> { + let path = get_config_path()?; + if path.exists() { + println!("# {}", path.display()); + let config = Config::load()?; + println!("{}", toml::to_string_pretty(&config)?); + } else { + println!("# (defaults, no config file)"); + println!("{}", toml::to_string_pretty(&Config::default())?); + } + Ok(()) +} + +// === Config CRUD === + +/// Get a config value by dotted key (e.g., "tracking.enabled"). +pub fn get_value(key: &str) -> Result { + let config = Config::load_merged()?; + let toml_val = toml::Value::try_from(&config)?; + + let parts: Vec<&str> = key.split('.').collect(); + let mut current = &toml_val; + for part in &parts { + current = current + .get(part) + .ok_or_else(|| anyhow!("Unknown config key: {key}"))?; + } + + match current { + toml::Value::String(s) => Ok(s.clone()), + toml::Value::Boolean(b) => Ok(b.to_string()), + toml::Value::Integer(i) => Ok(i.to_string()), + toml::Value::Float(f) => Ok(f.to_string()), + toml::Value::Array(a) => Ok(format!("{:?}", a)), + other => Ok(other.to_string()), + } +} + +/// Set a config value by dotted key. +pub fn set_value(key: &str, value: &str, local: bool) -> Result<()> { + let path = if local { + get_local_config_path()? + } else { + get_config_path()? + }; + + let mut config = if path.exists() { + let content = std::fs::read_to_string(&path)?; + toml::from_str(&content)? + } else { + Config::default() + }; + + apply_value(&mut config, key, value)?; + + if local { + config.save_to(&path)?; + } else { + config.save()?; + } + Ok(()) +} + +/// Unset a config value (reset to default). +pub fn unset_value(key: &str, local: bool) -> Result<()> { + let path = if local { + get_local_config_path()? + } else { + get_config_path()? + }; + + if !path.exists() { + return Err(anyhow!("Config file not found: {}", path.display())); + } + + let content = std::fs::read_to_string(&path)?; + let mut toml_val: toml::Value = toml::from_str(&content)?; + + let parts: Vec<&str> = key.split('.').collect(); + if parts.len() == 2 { + if let Some(table) = toml_val.get_mut(parts[0]).and_then(|v| v.as_table_mut()) { + table.remove(parts[1]); + } + } else { + return Err(anyhow!("Invalid key format: {key}. Use section.field")); + } + + let content = toml::to_string_pretty(&toml_val)?; + std::fs::write(&path, content)?; + Ok(()) +} + +/// List all config values with optional origin info. +pub fn list_values(origin: bool) -> Result<()> { + let config = Config::load_merged()?; + let toml_str = toml::to_string_pretty(&config)?; + + if origin { + let global_path = get_config_path()?; + let has_global = global_path.exists(); + + // Check for project config + let mut has_project = false; + if let Ok(cwd) = std::env::current_dir() { + let mut current = cwd.as_path(); + loop { + if current.join(".rtk").join("config.toml").exists() { + has_project = true; + break; + } + match current.parent() { + Some(p) if p != current => current = p, + _ => break, + } + } + } + + println!("# Sources:"); + if has_global { + println!("# global: {}", global_path.display()); + } + if has_project { + println!("# project: .rtk/config.toml"); + } + if !has_global && !has_project { + println!("# (all defaults)"); + } + println!(); + } + + println!("{toml_str}"); + + // Show rules summary only with --origin flag + if origin { + let rules = rules::load_all(); + if !rules.is_empty() { + println!("# Rules ({} loaded):", rules.len()); + for rule in rules { + println!("# {} [{}] — {}", rule.name, rule.action, rule.source); + } + } + } + + Ok(()) +} + +/// Apply a string value to a config struct by dotted key. +fn apply_value(config: &mut Config, key: &str, value: &str) -> Result<()> { + match key { + "tracking.enabled" => config.tracking.enabled = value.parse()?, + "tracking.history_days" => config.tracking.history_days = value.parse()?, + "tracking.database_path" => { + config.tracking.database_path = Some(PathBuf::from(value)); + } + "display.colors" => config.display.colors = value.parse()?, + "display.emoji" => config.display.emoji = value.parse()?, + "display.max_width" => config.display.max_width = value.parse()?, + "discovery.search_dirs" => { + config.discovery.search_dirs = value.split(',').map(|s| s.trim().to_string()).collect(); + } + "discovery.global_dirs" => { + config.discovery.global_dirs = value.split(',').map(|s| s.trim().to_string()).collect(); + } + "discovery.rules_dirs" => { + config.discovery.rules_dirs = + value.split(',').map(|s| PathBuf::from(s.trim())).collect(); + } + _ => return Err(anyhow!("Unknown config key: {key}")), + } + Ok(()) +} + +/// Create or update a rule MD file. +pub fn set_rule( + name: &str, + pattern: Option<&str>, + action: Option<&str>, + redirect: Option<&str>, + local: bool, +) -> Result<()> { + let dir = if local { + let cwd = std::env::current_dir()?; + cwd.join(".rtk") + } else { + get_rules_dir()? + }; + std::fs::create_dir_all(&dir)?; + + let action_str = action.unwrap_or("rewrite"); + let filename = format!("rtk.{name}.md"); + let path = dir.join(&filename); + + let mut content = String::from("---\n"); + content.push_str(&format!("name: {name}\n")); + if let Some(pat) = pattern { + // Single pattern without quotes for simple, quoted for multi-word + if pat.contains(' ') { + content.push_str(&format!("patterns: [\"{pat}\"]\n")); + } else { + content.push_str(&format!("patterns: [{pat}]\n")); + } + } + content.push_str(&format!("action: {action_str}\n")); + if let Some(redir) = redirect { + content.push_str(&format!("redirect: \"{redir}\"\n")); + } + content.push_str("---\n\nUser-defined rule.\n"); + + std::fs::write(&path, &content)?; + println!("Created rule: {}", path.display()); + Ok(()) +} + +/// Delete a rule MD file. +pub fn unset_rule(name: &str, local: bool) -> Result<()> { + let dir = if local { + let cwd = std::env::current_dir()?; + cwd.join(".rtk") + } else { + get_rules_dir()? + }; + + let filename = format!("rtk.{name}.md"); + let path = dir.join(&filename); + + if path.exists() { + std::fs::remove_file(&path)?; + println!("Removed rule: {}", path.display()); + } else { + // If it's a built-in rule, create a disabled override + let is_builtin = rules::DEFAULT_RULES.iter().any(|content| { + rules::parse_rule(content, "builtin") + .map(|r| r.name == name) + .unwrap_or(false) + }); + if is_builtin { + std::fs::create_dir_all(&dir)?; + let content = format!("---\nname: {name}\nenabled: false\n---\n\nDisabled by user.\n"); + std::fs::write(&path, content)?; + println!("Disabled built-in rule: {}", path.display()); + } else { + return Err(anyhow!("Rule file not found: {}", path.display())); + } + } + Ok(()) +} + +/// Export built-in rules to a directory. +pub fn export_rules(claude: bool) -> Result<()> { + let dir = if claude { + crate::init::resolve_claude_dir()? + } else { + get_rules_dir()? + }; + std::fs::create_dir_all(&dir)?; + + let mut count = 0; + for content in rules::DEFAULT_RULES { + let rule = rules::parse_rule(content, "builtin")?; + let filename = format!("rtk.{}.md", rule.name); + let path = dir.join(&filename); + // Skip if content unchanged; tolerate unreadable existing files + if path.exists() { + if let Ok(existing) = std::fs::read_to_string(&path) { + if existing.trim() == content.trim() { + continue; + } + } + // If unreadable, overwrite anyway + } + std::fs::write(&path, content)?; + count += 1; + } + + println!("Exported {} rules to {}", count, dir.display()); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_default() { + let config = Config::default(); + assert!(config.tracking.enabled); + assert_eq!(config.tracking.history_days, 90); + assert!(config.display.colors); + assert_eq!(config.display.max_width, 120); + } + + #[test] + fn test_config_overlay_none_fields_dont_override() { + let mut config = Config::default(); + config.tracking.history_days = 30; + config.display.max_width = 80; + + let overlay = ConfigOverlay::default(); + overlay.apply(&mut config); + + // None fields should not override + assert_eq!(config.tracking.history_days, 30); + assert_eq!(config.display.max_width, 80); + } + + #[test] + fn test_config_overlay_applies() { + let mut config = Config::default(); + + let overlay_toml = r#" +[tracking] +history_days = 30 + +[display] +max_width = 80 +"#; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.tracking.history_days, 30); + assert_eq!(config.display.max_width, 80); + // Unmentioned fields unchanged + assert!(config.tracking.enabled); + assert!(config.display.colors); + } + + #[test] + fn test_apply_value_tracking() { + let mut config = Config::default(); + apply_value(&mut config, "tracking.enabled", "false").unwrap(); + assert!(!config.tracking.enabled); + + apply_value(&mut config, "tracking.history_days", "30").unwrap(); + assert_eq!(config.tracking.history_days, 30); + } + + #[test] + fn test_apply_value_display() { + let mut config = Config::default(); + apply_value(&mut config, "display.max_width", "80").unwrap(); + assert_eq!(config.display.max_width, 80); + + apply_value(&mut config, "display.colors", "false").unwrap(); + assert!(!config.display.colors); + } + + #[test] + fn test_apply_value_unknown_key() { + let mut config = Config::default(); + assert!(apply_value(&mut config, "unknown.key", "value").is_err()); + } + + #[test] + fn test_get_value_existing() { + // This uses load_merged which reads from disk, so just test the happy path + let result = get_value("tracking.enabled"); + assert!(result.is_ok()); + let val = result.unwrap(); + assert!(val == "true" || val == "false"); + } + + #[test] + fn test_get_value_unknown() { + let result = get_value("nonexistent.key"); + assert!(result.is_err()); + } + + #[test] + fn test_load_merged_env_override() { + std::env::set_var("RTK_DB_PATH", "/tmp/test.db"); + let config = Config::load_merged().unwrap(); + assert_eq!( + config.tracking.database_path, + Some(PathBuf::from("/tmp/test.db")) + ); + std::env::remove_var("RTK_DB_PATH"); + } + + #[test] + fn test_env_overrides_all_fields() { + // Single test to avoid parallel env var interference. + // Tests all RTK_* env var overrides sequentially. + + // tracking.enabled: "false" overrides default true + std::env::set_var("RTK_TRACKING_ENABLED", "false"); + let config = Config::load_merged().unwrap(); + assert!(!config.tracking.enabled); + std::env::remove_var("RTK_TRACKING_ENABLED"); + + // tracking.enabled: "0" also disables + std::env::set_var("RTK_TRACKING_ENABLED", "0"); + let config = Config::load_merged().unwrap(); + assert!(!config.tracking.enabled); + std::env::remove_var("RTK_TRACKING_ENABLED"); + + // tracking.enabled: "1" enables + std::env::set_var("RTK_TRACKING_ENABLED", "1"); + let config = Config::load_merged().unwrap(); + assert!(config.tracking.enabled); + std::env::remove_var("RTK_TRACKING_ENABLED"); + + // tracking.history_days + std::env::set_var("RTK_HISTORY_DAYS", "7"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.tracking.history_days, 7); + std::env::remove_var("RTK_HISTORY_DAYS"); + + // display.colors + std::env::set_var("RTK_DISPLAY_COLORS", "false"); + let config = Config::load_merged().unwrap(); + assert!(!config.display.colors); + std::env::remove_var("RTK_DISPLAY_COLORS"); + + // display.emoji + std::env::set_var("RTK_DISPLAY_EMOJI", "false"); + let config = Config::load_merged().unwrap(); + assert!(!config.display.emoji); + std::env::remove_var("RTK_DISPLAY_EMOJI"); + + // display.max_width + std::env::set_var("RTK_MAX_WIDTH", "200"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.display.max_width, 200); + std::env::remove_var("RTK_MAX_WIDTH"); + } + + #[test] + fn test_project_local_overlay_overrides_global() { + let tmp = tempfile::tempdir().unwrap(); + let rtk_dir = tmp.path().join(".rtk"); + std::fs::create_dir_all(&rtk_dir).unwrap(); + std::fs::write( + rtk_dir.join("config.toml"), + "[tracking]\nhistory_days = 14\n", + ) + .unwrap(); + + // Simulate being in a project with .rtk/config.toml + let mut config = Config::default(); + assert_eq!(config.tracking.history_days, 90); // default + + let overlay_toml = "[tracking]\nhistory_days = 14\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + assert_eq!(config.tracking.history_days, 14); // project-local overrides + } + + #[test] + fn test_env_overrides_project_local_overlay() { + // Env vars have highest priority — even over project-local config. + // Tests overlay application directly (no env var race). + let mut config = Config::default(); + let overlay_toml = "[tracking]\nhistory_days = 14\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + assert_eq!(config.tracking.history_days, 14); // overlay applied + + // In load_merged, env vars are applied AFTER project overlay, + // so env vars always win. Tested via test_env_overrides_all_fields. + } + + #[test] + fn test_load_robust_to_missing_config() { + // Config::load() should fall back to defaults when config doesn't exist + let config = Config::load().unwrap(); + // Should have defaults — no crash + assert!(config.tracking.enabled); + assert_eq!(config.tracking.history_days, 90); + } + + #[test] + fn test_overlay_partial_sections() { + // Only display section in overlay — tracking should be untouched + let mut config = Config::default(); + config.tracking.history_days = 45; + + let overlay_toml = "[display]\nmax_width = 60\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.display.max_width, 60); // overridden + assert_eq!(config.tracking.history_days, 45); // untouched + } + + #[test] + fn test_overlay_partial_fields_within_section() { + // Only one field in tracking overlay — others untouched + let mut config = Config::default(); + config.tracking.enabled = false; + + let overlay_toml = "[tracking]\nhistory_days = 7\n"; + let overlay: ConfigOverlay = toml::from_str(overlay_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.tracking.history_days, 7); // overridden + assert!(!config.tracking.enabled); // untouched (was false) + } + + #[test] + fn test_get_rules_dir_returns_dot_config_rtk() { + let dir = get_rules_dir().unwrap(); + let home = dirs::home_dir().unwrap(); + assert_eq!(dir, home.join(".config").join("rtk")); + } + + #[test] + fn test_env_override_invalid_value_ignored() { + // Invalid env values should be silently ignored, keeping the default + std::env::set_var("RTK_HISTORY_DAYS", "not_a_number"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.tracking.history_days, 90); // default kept + std::env::remove_var("RTK_HISTORY_DAYS"); + + std::env::set_var("RTK_MAX_WIDTH", "abc"); + let config = Config::load_merged().unwrap(); + assert_eq!(config.display.max_width, 120); // default kept + std::env::remove_var("RTK_MAX_WIDTH"); + } + + #[test] + fn test_cli_overrides_default() { + // Default CLI overrides should not change behavior + let overrides = CliConfigOverrides::default(); + assert!(overrides.config_path.is_none()); // None = use normal discovery + assert!(overrides.config_add.is_empty()); + assert!(overrides.rules_path.is_none()); // None = use normal discovery + assert!(overrides.rules_add.is_empty()); + } + + #[test] + fn test_cli_config_path_multiple_files_merged() { + // --config-path a.toml --config-path b.toml merges both in order + let tmp = tempfile::tempdir().unwrap(); + let file_a = tmp.path().join("a.toml"); + let file_b = tmp.path().join("b.toml"); + std::fs::write(&file_a, "[tracking]\nhistory_days = 5\n").unwrap(); + std::fs::write(&file_b, "[display]\nmax_width = 60\n").unwrap(); + + // Simulate load_merged with exclusive paths + let mut cfg = Config::default(); + for path in &[&file_a, &file_b] { + if let Ok(content) = std::fs::read_to_string(path) { + if let Ok(overlay) = toml::from_str::(&content) { + overlay.apply(&mut cfg); + } + } + } + + assert_eq!(cfg.tracking.history_days, 5); // from a.toml + assert_eq!(cfg.display.max_width, 60); // from b.toml + assert!(cfg.tracking.enabled); // default (not in either file) + } + + #[test] + fn test_cli_config_path_exclusive() { + // --config-path loads ONLY from that file + let tmp = tempfile::tempdir().unwrap(); + let config_file = tmp.path().join("custom.toml"); + std::fs::write( + &config_file, + "[tracking]\nhistory_days = 5\nenabled = false\n", + ) + .unwrap(); + + // Simulate what load_merged does with exclusive path + let path = &config_file; + let config: Config = if path.exists() { + let content = std::fs::read_to_string(path).unwrap(); + toml::from_str(&content).unwrap() + } else { + Config::default() + }; + + assert_eq!(config.tracking.history_days, 5); + assert!(!config.tracking.enabled); + // Other fields get defaults since only tracking was specified + assert!(config.display.colors); + } + + #[test] + fn test_cli_config_add_overlay() { + // --config-add applies as high-priority overlay + let mut config = Config::default(); + assert_eq!(config.display.max_width, 120); + + let add_toml = "[display]\nmax_width = 60\n"; + let overlay: ConfigOverlay = toml::from_str(add_toml).unwrap(); + overlay.apply(&mut config); + + assert_eq!(config.display.max_width, 60); // overridden by --config-add + assert!(config.tracking.enabled); // untouched + } + + // === Error Robustness Tests === + + #[test] + fn test_load_robust_to_malformed_toml() { + let tmp = tempfile::tempdir().unwrap(); + let bad_config = tmp.path().join("config.toml"); + std::fs::write(&bad_config, "this is not valid toml {{{{").unwrap(); + + // Malformed TOML should parse to default (not crash) + let result: Result = toml::from_str("this is not valid toml {{{{"); + assert!(result.is_err()); + + // Config::load falls back to defaults for malformed content + let config = Config::load().unwrap(); + assert!(config.tracking.enabled); // defaults + } + + #[test] + fn test_load_robust_to_empty_config_file() { + // Empty string is valid TOML (all defaults) + let config: Config = toml::from_str("").unwrap(); + assert!(config.tracking.enabled); + assert_eq!(config.tracking.history_days, 90); + assert_eq!(config.display.max_width, 120); + } + + #[test] + fn test_load_robust_to_binary_garbage_config() { + let garbage = "\x00\x01\x02 binary garbage"; + let result: Result = toml::from_str(garbage); + assert!(result.is_err()); // Should error, not panic + } + + #[test] + fn test_overlay_robust_to_malformed_toml() { + let result: Result = toml::from_str("not valid {{{"); + assert!(result.is_err()); // Should error, not panic + } + + #[test] + fn test_overlay_from_empty_string() { + // Empty overlay should be all-None (no overrides) + let overlay: ConfigOverlay = toml::from_str("").unwrap(); + assert!(overlay.tracking.is_none()); + assert!(overlay.display.is_none()); + assert!(overlay.filters.is_none()); + } + + #[test] + fn test_config_path_exclusive_nonexistent_falls_back() { + // If --config-path points to non-existent file, use defaults + let path = PathBuf::from("/nonexistent/config.toml"); + assert!(!path.exists()); + // Simulates load_merged logic: non-existent → Config::default() + let config = Config::default(); + assert!(config.tracking.enabled); + } + + #[test] + fn test_config_add_nonexistent_path_skipped() { + // --config-add with non-existent path should be silently skipped + let path = PathBuf::from("/nonexistent/overlay.toml"); + assert!(!path.exists()); + // The load_merged code does `if add_path.exists()` — non-existent skipped + let mut config = Config::default(); + config.tracking.history_days = 42; + // Config unchanged because path doesn't exist + assert_eq!(config.tracking.history_days, 42); + } + + #[test] + fn test_config_add_malformed_file_skipped() { + let tmp = tempfile::tempdir().unwrap(); + let bad_file = tmp.path().join("bad.toml"); + std::fs::write(&bad_file, "not valid {{{{ toml").unwrap(); + + // Simulates load_merged: if let Ok(overlay) = toml::from_str(...) + let content = std::fs::read_to_string(&bad_file).unwrap(); + let result = toml::from_str::(&content); + assert!(result.is_err()); // Bad TOML → no overlay applied + + // Config should remain at defaults + let config = Config::default(); + assert!(config.tracking.enabled); + } + + #[test] + fn test_set_value_creates_parent_dirs() { + let tmp = tempfile::tempdir().unwrap(); + let config_path = tmp.path().join("nested").join("deep").join("config.toml"); + + // save_to should create parent dirs + let config = Config::default(); + let result = config.save_to(&config_path); + assert!(result.is_ok()); + assert!(config_path.exists()); + } + + // === DiscoveryConfig tests === + + #[test] + fn test_default_discovery_config() { + let config = DiscoveryConfig::default(); + assert_eq!(config.search_dirs, vec![".claude", ".gemini", ".rtk"]); + assert_eq!(config.global_dirs, vec![".claude", ".gemini"]); + assert!(config.rules_dirs.is_empty()); + } + + #[test] + fn test_discovery_config_roundtrip_toml() { + let config = Config::default(); + let toml_str = toml::to_string_pretty(&config).unwrap(); + let parsed: Config = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.discovery.search_dirs, config.discovery.search_dirs); + assert_eq!(parsed.discovery.global_dirs, config.discovery.global_dirs); + assert_eq!(parsed.discovery.rules_dirs, config.discovery.rules_dirs); + } + + #[test] + fn test_discovery_config_from_toml_custom() { + let toml_str = r#" +[discovery] +search_dirs = [".rtk", ".custom"] +global_dirs = [".mytools"] +rules_dirs = ["/opt/rtk/rules", "/home/user/rules"] +"#; + let config: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(config.discovery.search_dirs, vec![".rtk", ".custom"]); + assert_eq!(config.discovery.global_dirs, vec![".mytools"]); + assert_eq!( + config.discovery.rules_dirs, + vec![ + PathBuf::from("/opt/rtk/rules"), + PathBuf::from("/home/user/rules") + ] + ); + } + + #[test] + fn test_discovery_overlay_applies() { + let mut config = Config::default(); + let overlay: ConfigOverlay = toml::from_str( + r#" +[discovery] +search_dirs = [".only-rtk"] +rules_dirs = ["/custom/rules"] +"#, + ) + .unwrap(); + overlay.apply(&mut config); + assert_eq!(config.discovery.search_dirs, vec![".only-rtk"]); + // global_dirs unchanged (not in overlay) + assert_eq!(config.discovery.global_dirs, vec![".claude", ".gemini"]); + assert_eq!( + config.discovery.rules_dirs, + vec![PathBuf::from("/custom/rules")] + ); + } + + #[test] + fn test_apply_value_discovery_search_dirs() { + let mut config = Config::default(); + apply_value(&mut config, "discovery.search_dirs", ".rtk,.custom").unwrap(); + assert_eq!(config.discovery.search_dirs, vec![".rtk", ".custom"]); + } + + #[test] + fn test_apply_value_discovery_global_dirs() { + let mut config = Config::default(); + apply_value(&mut config, "discovery.global_dirs", ".claude").unwrap(); + assert_eq!(config.discovery.global_dirs, vec![".claude"]); + } + + #[test] + fn test_apply_value_discovery_rules_dirs() { + let mut config = Config::default(); + apply_value(&mut config, "discovery.rules_dirs", "/a,/b,/c").unwrap(); + assert_eq!( + config.discovery.rules_dirs, + vec![ + PathBuf::from("/a"), + PathBuf::from("/b"), + PathBuf::from("/c") + ] + ); + } + + #[test] + fn test_get_rules_dir_default() { + // Without any config override, get_rules_dir returns ~/.config/rtk/ + let dir = get_rules_dir().unwrap(); + assert!( + dir.to_string_lossy().contains("rtk"), + "Default rules dir should contain 'rtk': {}", + dir.display() + ); + } + + #[test] + fn test_discovery_config_empty_rules_dirs_not_serialized() { + // Empty rules_dirs should be omitted from TOML output (skip_serializing_if) + let config = Config::default(); + let toml_str = toml::to_string_pretty(&config).unwrap(); + assert!( + !toml_str.contains("rules_dirs"), + "Empty rules_dirs should be omitted from serialization" + ); + } +} diff --git a/src/config/rules.rs b/src/config/rules.rs new file mode 100644 index 0000000..1363474 --- /dev/null +++ b/src/config/rules.rs @@ -0,0 +1,881 @@ +//! Unified Rule system: safety rules, remaps, and warnings as data-driven MD files. +//! +//! Replaces `SafetyAction`, `SafetyRule`, `rule!()` macro, and `get_rules()` from safety.rs. +//! Rules are MD files with YAML frontmatter, loaded from built-in defaults and user directories. + +use anyhow::{anyhow, Result}; +use std::collections::{BTreeMap, HashMap}; +use std::sync::OnceLock; + +/// A unified rule: safety, remap, warning, or block. +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Rule { + pub name: String, + #[serde(default)] + pub patterns: Vec, + #[serde(default = "default_block")] + pub action: String, + #[serde(default)] + pub redirect: Option, + #[serde(default = "default_always")] + pub when: String, + #[serde(default)] + pub env_var: Option, + #[serde(default = "default_true")] + pub enabled: bool, + #[serde(skip)] + pub message: String, + #[serde(skip)] + pub source: String, +} + +fn default_block() -> String { + "block".into() +} +fn default_always() -> String { + "always".into() +} +fn default_true() -> bool { + true +} + +impl Rule { + /// Check if rule should apply given current env + predicates. + pub fn should_apply(&self) -> bool { + // Env var opt-out check + if let Some(ref env) = self.env_var { + if let Ok(val) = std::env::var(env) { + if val == "0" || val == "false" { + return false; + } + } + } + // When predicate + check_when(&self.when) + } +} + +// === Predicate Registry === + +type PredicateFn = fn() -> bool; + +fn predicate_registry() -> &'static HashMap<&'static str, PredicateFn> { + static REGISTRY: OnceLock> = OnceLock::new(); + REGISTRY.get_or_init(|| { + let mut m = HashMap::new(); + m.insert("always", (|| true) as PredicateFn); + m.insert( + "has_unstaged_changes", + crate::cmd::predicates::has_unstaged_changes as PredicateFn, + ); + m + }) +} + +pub fn check_when(when: &str) -> bool { + if when == "always" || when.is_empty() { + return true; + } + if let Some(func) = predicate_registry().get(when) { + return func(); + } + // Bash fallback (matches clautorun behavior) + std::process::Command::new("sh") + .args(["-c", when]) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +// === Parse & Load === + +/// Parse a rule from MD content with YAML frontmatter. +pub fn parse_rule(content: &str, source: &str) -> Result { + let trimmed = content.trim(); + let rest = trimmed + .strip_prefix("---") + .ok_or_else(|| anyhow!("No frontmatter: missing opening ---"))?; + let end = rest + .find("\n---") + .ok_or_else(|| anyhow!("Unclosed frontmatter: missing closing ---"))?; + let yaml = &rest[..end]; + let body = rest[end + 4..].trim(); + let mut rule: Rule = serde_yaml::from_str(yaml)?; + rule.message = body.to_string(); + rule.source = source.to_string(); + Ok(rule) +} + +/// Embedded default rules (compiled into binary). +pub const DEFAULT_RULES: &[&str] = &[ + include_str!("../rules/rtk.safety.rm-to-trash.md"), + include_str!("../rules/rtk.safety.git-reset-hard.md"), + include_str!("../rules/rtk.safety.git-checkout-dashdash.md"), + include_str!("../rules/rtk.safety.git-checkout-dot.md"), + include_str!("../rules/rtk.safety.git-stash-drop.md"), + include_str!("../rules/rtk.safety.git-clean-fd.md"), + include_str!("../rules/rtk.safety.git-clean-df.md"), + include_str!("../rules/rtk.safety.git-clean-f.md"), + include_str!("../rules/rtk.safety.block-cat.md"), + include_str!("../rules/rtk.safety.block-sed.md"), + include_str!("../rules/rtk.safety.block-head.md"), +]; + +static RULES_CACHE: OnceLock> = OnceLock::new(); + +/// Load all rules: embedded defaults + user overrides. Cached via OnceLock. +pub fn load_all() -> &'static [Rule] { + RULES_CACHE.get_or_init(|| { + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // 1. Embedded defaults (lowest priority) + for content in DEFAULT_RULES { + match parse_rule(content, "builtin") { + Ok(rule) if rule.enabled => { + rules_by_name.insert(rule.name.clone(), rule); + } + Ok(rule) => { + rules_by_name.remove(&rule.name); + } + Err(e) => eprintln!("rtk: bad builtin rule: {e}"), + } + } + + // 2. User files (higher priority overrides by name) + for path in super::discovery::discover_rtk_files() { + let content = match std::fs::read_to_string(path) { + Ok(c) => c, + Err(_) => continue, + }; + match parse_rule(&content, &path.display().to_string()) { + Ok(rule) if rule.enabled => { + rules_by_name.insert(rule.name.clone(), rule); + } + Ok(rule) => { + rules_by_name.remove(&rule.name); + } + Err(_) => continue, + } + } + + rules_by_name.into_values().collect() + }) +} + +// === Global Option Stripping === + +/// Strip global options that appear between a command and its subcommand. +/// +/// Tools like git, cargo, docker, and kubectl accept global options before +/// the subcommand (e.g., `git -C /path --no-pager status`). These must be +/// stripped before pattern matching so that safety rules like `"git reset --hard"` +/// still match `git --no-pager reset --hard`. +/// +/// Based on the patterns from upstream PR #99 (hooks/rtk-rewrite.sh). +fn strip_global_options(full_cmd: &str) -> String { + let words: Vec<&str> = full_cmd.split_whitespace().collect(); + if words.is_empty() { + return full_cmd.to_string(); + } + + let binary = words[0]; + let rest = &words[1..]; + + match binary { + "git" => { + // Strip: -C , -c , --no-pager, --no-optional-locks, + // --bare, --literal-pathspecs, --key=value + let mut result = vec!["git"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if (w == "-C" || w == "-c") && i + 1 < rest.len() { + i += 2; // skip flag + argument + } else if w.starts_with("--") + && w.contains('=') + && !w.starts_with("--hard") + && !w.starts_with("--force") + { + i += 1; // skip --key=value global options + } else if matches!( + w, + "--no-pager" + | "--no-optional-locks" + | "--bare" + | "--literal-pathspecs" + | "--paginate" + | "--git-dir" + ) { + i += 1; // skip standalone boolean global options + } else { + // First non-global-option word is the subcommand; keep everything from here + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + "cargo" => { + // Strip: +toolchain (e.g., cargo +nightly test) + let mut result = vec!["cargo"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if w.starts_with('+') { + i += 1; // skip +toolchain + } else { + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + "docker" => { + // Strip: -H , --context , --config , --key=value + let mut result = vec!["docker"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if matches!(w, "-H" | "--context" | "--config") && i + 1 < rest.len() { + i += 2; // skip flag + argument + } else if w.starts_with("--") && w.contains('=') { + i += 1; // skip --key=value + } else { + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + "kubectl" => { + // Strip: --context , --kubeconfig , --namespace , -n , --key=value + let mut result = vec!["kubectl"]; + let mut i = 0; + while i < rest.len() { + let w = rest[i]; + if matches!(w, "--context" | "--kubeconfig" | "--namespace" | "-n") + && i + 1 < rest.len() + { + i += 2; // skip flag + argument + } else if w.starts_with("--") && w.contains('=') { + i += 1; // skip --key=value + } else { + result.extend_from_slice(&rest[i..]); + break; + } + } + result.join(" ") + } + _ => full_cmd.to_string(), + } +} + +// === Pattern Matching === + +/// Check if a rule matches a command. +/// +/// - Single-word pattern: exact binary match (avoids "cat" matching "catalog") +/// - Multi-word pattern: prefix match on full command string (with global option stripping) +/// - Raw mode (binary=None): word-boundary search (handles "sudo rm") +pub fn matches_rule(rule: &Rule, binary: Option<&str>, full_cmd: &str) -> bool { + rule.patterns.iter().any(|pat| { + if pat.contains(' ') { + // Multi-word: prefix match, also try with global options stripped + let normalized = strip_global_options(full_cmd); + full_cmd.starts_with(pat.as_str()) || normalized.starts_with(pat.as_str()) + } else if let Some(bin) = binary { + // Parsed mode: exact binary + bin == pat + } else { + // Raw mode: word-boundary (handles "sudo rm", "/usr/bin/rm") + full_cmd + .split_whitespace() + .any(|w| w == pat || w.ends_with(&format!("/{pat}"))) + } + }) +} + +// === Remap Helper === + +/// Try to expand a single-word remap alias (e.g., "t --lib" → "cargo test --lib"). +/// +/// Only matches single-word patterns with `action: "rewrite"`. Multi-word rewrites +/// are safety rules handled by `check()`. Order: remap → safety → execute. +pub fn try_remap(raw: &str) -> Option { + let first_word = raw.split_whitespace().next()?; + for rule in load_all() { + if rule.action != "rewrite" { + continue; + } + // Only remap single-word pattern matches (aliases like "t" → "cargo test") + if !rule + .patterns + .iter() + .any(|p| !p.contains(' ') && p == first_word) + { + continue; + } + if !rule.should_apply() { + continue; + } + if let Some(ref redirect) = rule.redirect { + let rest = raw[first_word.len()..].trim(); + return Some(redirect.replace("{args}", rest)); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_rule_valid() { + let content = "---\nname: test-rule\npatterns: [rm]\naction: trash\n---\nSafety message."; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.name, "test-rule"); + assert_eq!(rule.patterns, vec!["rm"]); + assert_eq!(rule.action, "trash"); + assert_eq!(rule.message, "Safety message."); + assert_eq!(rule.source, "test"); + } + + #[test] + fn test_parse_rule_no_frontmatter() { + let content = "No frontmatter here"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_unclosed_frontmatter() { + let content = "---\nname: broken\n"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_message_body() { + let content = "---\nname: test\n---\n\nLine 1\n\nLine 2"; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.message, "Line 1\n\nLine 2"); + } + + #[test] + fn test_parse_rule_defaults() { + let content = "---\nname: minimal\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.action, "block"); // default + assert_eq!(rule.when, "always"); // default + assert!(rule.enabled); // default true + assert!(rule.patterns.is_empty()); // default empty + } + + #[test] + fn test_parse_rule_all_fields() { + let content = r#"--- +name: full +patterns: ["git reset --hard"] +action: rewrite +redirect: "git stash && git reset --hard {args}" +when: has_unstaged_changes +env_var: RTK_SAFE_COMMANDS +enabled: true +--- +Full message."#; + let rule = parse_rule(content, "builtin").unwrap(); + assert_eq!(rule.name, "full"); + assert_eq!(rule.patterns, vec!["git reset --hard"]); + assert_eq!(rule.action, "rewrite"); + assert_eq!( + rule.redirect.as_deref(), + Some("git stash && git reset --hard {args}") + ); + assert_eq!(rule.when, "has_unstaged_changes"); + assert_eq!(rule.env_var.as_deref(), Some("RTK_SAFE_COMMANDS")); + assert!(rule.enabled); + assert_eq!(rule.message, "Full message."); + } + + #[test] + fn test_matches_rule_single_word_binary() { + let content = "---\nname: test\npatterns: [rm]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(matches_rule(&rule, Some("rm"), "rm file.txt")); + assert!(!matches_rule(&rule, Some("rmdir"), "rmdir empty")); + } + + #[test] + fn test_matches_rule_multiple_patterns_in_one_rule() { + let content = + "---\nname: test\npatterns: [\"chmod -R 777\", \"chmod 777\"]\naction: warn\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert_eq!(rule.patterns.len(), 2); + assert!(matches_rule(&rule, Some("chmod"), "chmod -R 777 /tmp")); + assert!(matches_rule(&rule, Some("chmod"), "chmod 777 /tmp")); + assert!(!matches_rule(&rule, Some("chmod"), "chmod 755 /tmp")); + } + + #[test] + fn test_matches_rule_multi_word_prefix() { + let content = "---\nname: test\npatterns: [\"git reset --hard\"]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(matches_rule(&rule, Some("git"), "git reset --hard HEAD~1")); + assert!(!matches_rule(&rule, Some("git"), "git reset --soft HEAD")); + } + + #[test] + fn test_matches_rule_raw_mode_word_boundary() { + let content = "---\nname: test\npatterns: [rm]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + // Raw mode: None for binary + assert!(matches_rule(&rule, None, "rm file.txt")); + assert!(matches_rule(&rule, None, "sudo rm file.txt")); + assert!(matches_rule(&rule, None, "/usr/bin/rm file.txt")); + // Should NOT match substrings + assert!(!matches_rule(&rule, None, "trim file.txt")); + assert!(!matches_rule(&rule, None, "farm --harvest")); + } + + #[test] + fn test_should_apply_env_var_opt_out() { + let content = "---\nname: test\npatterns: [rm]\nenv_var: RTK_TEST_VAR\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + + // No env var set → applies (opt-out model) + assert!(rule.should_apply()); + + // Set to "0" → disabled + std::env::set_var("RTK_TEST_VAR", "0"); + assert!(!rule.should_apply()); + + // Set to "false" → disabled + std::env::set_var("RTK_TEST_VAR", "false"); + assert!(!rule.should_apply()); + + // Set to "1" → enabled + std::env::set_var("RTK_TEST_VAR", "1"); + assert!(rule.should_apply()); + + std::env::remove_var("RTK_TEST_VAR"); + } + + #[test] + fn test_should_apply_when_always() { + let content = "---\nname: test\nwhen: always\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(rule.should_apply()); + } + + #[test] + fn test_load_all_includes_builtins() { + let rules = load_all(); + assert!( + rules.len() >= 11, + "Should have at least 11 built-in rules, got {}", + rules.len() + ); + // Check specific built-in names + let names: Vec<&str> = rules.iter().map(|r| r.name.as_str()).collect(); + assert!(names.contains(&"rm-to-trash")); + assert!(names.contains(&"block-cat")); + assert!(names.contains(&"git-reset-hard")); + } + + #[test] + fn test_check_when_always() { + assert!(check_when("always")); + assert!(check_when("")); + } + + #[test] + fn test_check_when_builtin_predicate() { + // has_unstaged_changes is registered - should not panic + let _ = check_when("has_unstaged_changes"); + } + + #[test] + fn test_check_when_bash_fallback() { + assert!(check_when("true")); + assert!(!check_when("false")); + } + + #[test] + fn test_try_remap_no_match() { + // "ls" is not a registered remap alias + assert!(try_remap("ls -la").is_none()); + } + + // Note: try_remap with a match requires user-defined rules in discovery dirs, + // which is tested in E2E tests rather than unit tests. + + #[test] + fn test_rule_override_by_name() { + // Simulate: builtin rule overridden by user rule with same name + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + let builtin = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: trash\n---\nBuiltin message.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // User override: same name, different action + let user_rule = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: block\n---\nUser blocked rm.", + "~/.config/rtk/rtk.safety.rm-to-trash.md", + ) + .unwrap(); + rules_by_name.insert(user_rule.name.clone(), user_rule); + + let rules: Vec = rules_by_name.into_values().collect(); + assert_eq!(rules.len(), 1); // Overridden, not duplicated + assert_eq!(rules[0].action, "block"); // User's action wins + assert_eq!(rules[0].message, "User blocked rm."); // User's message wins + } + + #[test] + fn test_rule_disabled_override_removes() { + // Simulate: user disables a builtin rule + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + let builtin = parse_rule( + "---\nname: block-cat\npatterns: [cat]\naction: suggest_tool\n---\nUse Read.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // User disables it + let disabled = parse_rule( + "---\nname: block-cat\nenabled: false\n---\nDisabled by user.", + "~/.config/rtk/rtk.safety.block-cat.md", + ) + .unwrap(); + assert!(!disabled.enabled); + + // The load_all logic: enabled=false removes from map + if !disabled.enabled { + rules_by_name.remove(&disabled.name); + } + + assert!(rules_by_name.is_empty()); // Rule removed + } + + #[test] + fn test_all_builtin_rules_parse_successfully() { + for (i, content) in DEFAULT_RULES.iter().enumerate() { + let result = parse_rule(content, "builtin"); + assert!( + result.is_ok(), + "Built-in rule #{} failed to parse: {:?}", + i, + result.err() + ); + let rule = result.unwrap(); + assert!(!rule.name.is_empty(), "Rule #{} has empty name", i); + assert!( + rule.enabled, + "Rule #{} ({}) should be enabled", + i, rule.name + ); + } + } + + #[test] + fn test_all_builtin_rules_have_patterns() { + for content in DEFAULT_RULES { + let rule = parse_rule(content, "builtin").unwrap(); + assert!( + !rule.patterns.is_empty(), + "Rule '{}' has no patterns", + rule.name + ); + } + } + + // === Error Robustness Tests === + + #[test] + fn test_parse_rule_empty_string() { + assert!(parse_rule("", "test").is_err()); + } + + #[test] + fn test_parse_rule_binary_garbage() { + assert!(parse_rule("\x00\x01\x02 garbage", "test").is_err()); + } + + #[test] + fn test_parse_rule_valid_frontmatter_invalid_yaml() { + let content = "---\n: : : not valid yaml\n---\nbody"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_missing_name_field() { + // YAML without required 'name' field + let content = "---\npatterns: [rm]\n---\nbody"; + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_only_frontmatter_delimiters() { + let content = "---\n---\n"; + // Empty YAML → missing name → error + assert!(parse_rule(content, "test").is_err()); + } + + #[test] + fn test_parse_rule_extra_fields_ignored() { + // Unknown fields in YAML should be silently ignored (serde default) + let content = "---\nname: test\nunknown_field: 42\nextra: true\n---\nbody"; + let rule = parse_rule(content, "test"); + assert!( + rule.is_ok(), + "Unknown fields should be ignored, got: {:?}", + rule.err() + ); + assert_eq!(rule.unwrap().name, "test"); + } + + #[test] + fn test_check_when_nonexistent_command() { + // A nonsense bash command should return false (not panic) + assert!(!check_when("totally_nonexistent_command_xyz_12345")); + } + + #[test] + fn test_try_remap_empty_string() { + assert!(try_remap("").is_none()); + } + + #[test] + fn test_try_remap_whitespace_only() { + assert!(try_remap(" ").is_none()); + } + + #[test] + fn test_matches_rule_empty_patterns() { + let content = "---\nname: no-patterns\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + assert!(!matches_rule(&rule, Some("rm"), "rm file")); + assert!(!matches_rule(&rule, None, "rm file")); + } + + // === Precedence Chain Tests === + + #[test] + fn test_full_precedence_chain_builtin_global_project() { + // Simulates the full load_all() precedence: builtin → global → project + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // 1. Builtin (lowest priority): action=trash + let builtin = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: trash\n---\nBuiltin.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // 2. Global user file (~/.config/rtk/): action=warn (user edited the exported file) + let global = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: warn\n---\nGlobal user override.", + "~/.config/rtk/rtk.safety.rm-to-trash.md", + ) + .unwrap(); + rules_by_name.insert(global.name.clone(), global); + + // 3. Project-local (.rtk/): action=block (project-specific) + let project = parse_rule( + "---\nname: rm-to-trash\npatterns: [rm]\naction: block\n---\nProject override.", + "/project/.rtk/rtk.safety.rm-to-trash.md", + ) + .unwrap(); + rules_by_name.insert(project.name.clone(), project); + + let rules: Vec = rules_by_name.into_values().collect(); + assert_eq!(rules.len(), 1, "Should be 1 rule after all overrides"); + assert_eq!(rules[0].action, "block", "Project-local should win"); + assert_eq!(rules[0].source, "/project/.rtk/rtk.safety.rm-to-trash.md"); + } + + #[test] + fn test_user_edited_export_overrides_builtin() { + // User exports builtins then edits one: edited file should override compiled builtin + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // Compiled builtin + let builtin = parse_rule( + "---\nname: block-cat\npatterns: [cat]\naction: suggest_tool\nredirect: Read\n---\nBuiltin.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // User-edited export: changed redirect + let edited = parse_rule( + "---\nname: block-cat\npatterns: [cat]\naction: suggest_tool\nredirect: \"Read (with limit=50)\"\n---\nUser customized.", + "~/.config/rtk/rtk.safety.block-cat.md", + ) + .unwrap(); + rules_by_name.insert(edited.name.clone(), edited); + + let rules: Vec = rules_by_name.into_values().collect(); + assert_eq!(rules.len(), 1); + assert_eq!( + rules[0].redirect.as_deref(), + Some("Read (with limit=50)"), + "User-edited redirect should win" + ); + assert!(rules[0].source.contains(".config/rtk/")); + } + + #[test] + fn test_project_local_disable_overrides_global_and_builtin() { + // Project disables a rule that exists both in builtins and global + let mut rules_by_name: BTreeMap = BTreeMap::new(); + + // Builtin + let builtin = parse_rule( + "---\nname: block-sed\npatterns: [sed]\naction: suggest_tool\n---\nBuiltin.", + "builtin", + ) + .unwrap(); + rules_by_name.insert(builtin.name.clone(), builtin); + + // Global user file (same as builtin, maybe exported) + let global = parse_rule( + "---\nname: block-sed\npatterns: [sed]\naction: suggest_tool\n---\nGlobal.", + "~/.config/rtk/rtk.safety.block-sed.md", + ) + .unwrap(); + rules_by_name.insert(global.name.clone(), global); + + // Project-local disables it + let disabled = parse_rule( + "---\nname: block-sed\nenabled: false\n---\nDisabled for this project.", + "/project/.rtk/rtk.safety.block-sed.md", + ) + .unwrap(); + if !disabled.enabled { + rules_by_name.remove(&disabled.name); + } + + assert!( + rules_by_name.is_empty(), + "Project-local disable should remove rule entirely" + ); + } + + // === Global Option Stripping (PR #99 parity) === + // Table-driven: (input, expected_output) pairs covering git, cargo, docker, kubectl. + + #[test] + fn test_strip_global_options() { + let cases: &[(&str, &str)] = &[ + // Git: single flags + ("git --no-pager status", "git status"), + ("git -C /path/to/project status", "git status"), + ("git -c core.autocrlf=true diff", "git diff"), + ("git --git-dir=/path/.git status", "git status"), + ("git --no-optional-locks status", "git status"), + ("git --bare log --oneline", "git log --oneline"), + ("git --literal-pathspecs add .", "git add ."), + // Git: multiple globals stacked + ( + "git -C /path --no-pager --no-optional-locks reset --hard", + "git reset --hard", + ), + // Git: subcommand flags preserved (not stripped) + ("git reset --hard HEAD~1", "git reset --hard HEAD~1"), + ("git checkout --force main", "git checkout --force main"), + // Git: no globals (identity) + ("git status", "git status"), + ("git log --oneline -10", "git log --oneline -10"), + // Cargo: toolchain prefix + ("cargo +nightly test", "cargo test"), + ("cargo +stable build --release", "cargo build --release"), + ("cargo test", "cargo test"), // no prefix (identity) + // Docker: global flags + ("docker --context prod ps", "docker ps"), + ("docker -H tcp://host:2375 images", "docker images"), + ("docker --config /tmp/.docker run hello", "docker run hello"), + ("docker ps", "docker ps"), // no globals (identity) + // Kubectl: global flags + ("kubectl -n kube-system get pods", "kubectl get pods"), + ( + "kubectl --context prod --namespace default describe pod foo", + "kubectl describe pod foo", + ), + ("kubectl --kubeconfig=/path get svc", "kubectl get svc"), + ("kubectl get pods", "kubectl get pods"), // no globals (identity) + // Non-matching commands (identity) + ("rm -rf /tmp/foo", "rm -rf /tmp/foo"), + ("cat file.txt", "cat file.txt"), + ("echo hello", "echo hello"), + ]; + for (input, expected) in cases { + assert_eq!( + strip_global_options(input), + *expected, + "strip_global_options({input:?})" + ); + } + } + + // === Rule Matching with Global Options (PR #99 parity) === + // Multi-word safety patterns must match even with global options inserted. + + #[test] + fn test_matches_rule_with_global_options() { + let cases: &[(&str, &str, bool)] = &[ + // (pattern, full_cmd, expected_match) + ("git reset --hard", "git --no-pager reset --hard HEAD", true), + ("git reset --hard", "git -C /path reset --hard", true), + ( + "git reset --hard", + "git -C /p --no-pager --no-optional-locks reset --hard", + true, + ), + ("git checkout .", "git -C /project checkout .", true), + ( + "git checkout --", + "git --no-pager checkout -- file.txt", + true, + ), + ( + "git clean -fd", + "git -C /path --no-pager --no-optional-locks clean -fd", + true, + ), + ("git stash drop", "git --no-pager stash drop", true), + // No globals: direct match still works + ("git reset --hard", "git reset --hard HEAD~1", true), + ("git checkout .", "git checkout .", true), + // Non-matching + ("git reset --hard", "git reset --soft HEAD", false), + ("git checkout .", "git checkout main", false), + ]; + for (pattern, full_cmd, expected) in cases { + let yaml = format!("---\nname: test\npatterns: [\"{pattern}\"]\n---\n"); + let rule = parse_rule(&yaml, "test").unwrap(); + let binary = full_cmd.split_whitespace().next(); + assert_eq!( + matches_rule(&rule, binary, full_cmd), + *expected, + "matches_rule(pat={pattern:?}, cmd={full_cmd:?})" + ); + } + } + + #[test] + fn test_matches_rule_empty_command() { + let content = "---\nname: test\npatterns: [rm]\n---\n"; + let rule = parse_rule(content, "test").unwrap(); + // Parsed mode: binary match is independent of full_cmd + assert!(matches_rule(&rule, Some("rm"), "")); + // Raw mode: empty string has no words → no match + assert!(!matches_rule(&rule, None, "")); + } +} diff --git a/src/init.rs b/src/init.rs index 3255374..cdc1a82 100644 --- a/src/init.rs +++ b/src/init.rs @@ -975,7 +975,7 @@ fn remove_rtk_block(content: &str) -> (String, bool) { } /// Resolve ~/.claude directory with proper home expansion -fn resolve_claude_dir() -> Result { +pub(crate) fn resolve_claude_dir() -> Result { dirs::home_dir() .map(|h| h.join(".claude")) .context("Cannot determine home directory. Is $HOME set?") diff --git a/src/main.rs b/src/main.rs index da7affa..2cdce4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -346,9 +346,11 @@ enum Commands { format: String, }, - /// Show or create configuration file + /// Show, create, or modify configuration Config { - /// Create default config file + #[command(subcommand)] + action: Option, + /// Create default config file (backward compat) #[arg(long)] create: bool, }, @@ -557,6 +559,53 @@ enum HookCommands { Claude, } +#[derive(Subcommand)] +enum ConfigCommands { + /// Get a config value by key + Get { + /// Dotted key (e.g., "tracking.enabled") + key: String, + }, + /// Set a config value or create a rule + Set { + /// Dotted key (e.g., "display.max_width" or "rules.my-alias") + key: String, + /// Value to set (for scalar config) or redirect template (for rules) + value: Option, + /// Pattern for rule (e.g., "t" or "git reset --hard") + #[arg(long)] + pattern: Option, + /// Action for rule: block, warn, rewrite, trash, suggest_tool + #[arg(long)] + action: Option, + /// Write to project-local .rtk/config.toml + #[arg(long)] + local: bool, + }, + /// List all config values + List { + /// Show where each value comes from + #[arg(long)] + origin: bool, + }, + /// Remove a config key (reset to default) + Unset { + /// Dotted key to remove + key: String, + /// Remove from project-local .rtk/config.toml + #[arg(long)] + local: bool, + }, + /// Create default config file + Create, + /// Export built-in rules as editable MD files + ExportRules { + /// Export to ~/.claude/ instead of ~/.config/rtk/ + #[arg(long)] + claude: bool, + }, +} + #[derive(Subcommand)] enum GitCommands { /// Condensed diff output @@ -1149,11 +1198,64 @@ fn main() -> Result<()> { cc_economics::run(daily, weekly, monthly, all, &format, cli.verbose)?; } - Commands::Config { create } => { + Commands::Config { action, create } => { + // Backward compat: --create flag if create { let path = config::Config::create_default()?; println!("Created: {}", path.display()); + } else if let Some(action) = action { + match action { + ConfigCommands::Get { key } => match config::get_value(&key) { + Ok(val) => println!("{val}"), + Err(e) => { + eprintln!("Error: {e}"); + std::process::exit(1); + } + }, + ConfigCommands::Set { + key, + value, + pattern, + action, + local, + } => { + if key.starts_with("rules.") { + let rule_name = key.strip_prefix("rules.").unwrap(); + config::set_rule( + rule_name, + pattern.as_deref(), + action.as_deref(), + value.as_deref(), + local, + )?; + } else { + let val = value.ok_or_else(|| { + anyhow::anyhow!("Value required for scalar config key: {key}") + })?; + config::set_value(&key, &val, local)?; + } + } + ConfigCommands::List { origin } => { + config::list_values(origin)?; + } + ConfigCommands::Unset { key, local } => { + if key.starts_with("rules.") { + let rule_name = key.strip_prefix("rules.").unwrap(); + config::unset_rule(rule_name, local)?; + } else { + config::unset_value(&key, local)?; + } + } + ConfigCommands::Create => { + let path = config::Config::create_default()?; + println!("Created: {}", path.display()); + } + ConfigCommands::ExportRules { claude } => { + config::export_rules(claude)?; + } + } } else { + // No subcommand: show config (backward compat) config::show_config()?; } } diff --git a/src/rules/rtk.safety.block-cat.md b/src/rules/rtk.safety.block-cat.md new file mode 100644 index 0000000..7b55927 --- /dev/null +++ b/src/rules/rtk.safety.block-cat.md @@ -0,0 +1,11 @@ +--- +name: block-cat +patterns: [cat] +action: suggest_tool +redirect: Read +env_var: RTK_BLOCK_TOKEN_WASTE +--- + +Use the **Read tool** for large files. + +BLOCK: cat wastes tokens. Use your file-reading tool instead. diff --git a/src/rules/rtk.safety.block-head.md b/src/rules/rtk.safety.block-head.md new file mode 100644 index 0000000..01f188b --- /dev/null +++ b/src/rules/rtk.safety.block-head.md @@ -0,0 +1,11 @@ +--- +name: block-head +patterns: [head] +action: suggest_tool +redirect: "Read (with limit)" +env_var: RTK_BLOCK_TOKEN_WASTE +--- + +Use **Read tool with limit parameter** instead of head. + +BLOCK: head wastes tokens. Use your file-reading tool with a line limit instead. diff --git a/src/rules/rtk.safety.block-sed.md b/src/rules/rtk.safety.block-sed.md new file mode 100644 index 0000000..a30923e --- /dev/null +++ b/src/rules/rtk.safety.block-sed.md @@ -0,0 +1,11 @@ +--- +name: block-sed +patterns: [sed] +action: suggest_tool +redirect: Edit +env_var: RTK_BLOCK_TOKEN_WASTE +--- + +Use the **Edit tool** for validated file modifications. + +BLOCK: sed unsafe. Use your file-editing tool instead. diff --git a/src/rules/rtk.safety.git-checkout-dashdash.md b/src/rules/rtk.safety.git-checkout-dashdash.md new file mode 100644 index 0000000..173bb35 --- /dev/null +++ b/src/rules/rtk.safety.git-checkout-dashdash.md @@ -0,0 +1,10 @@ +--- +name: git-checkout-dashdash +patterns: ["git checkout --"] +action: rewrite +redirect: "git stash push -m 'RTK: checkout backup' && git checkout -- {args}" +when: has_unstaged_changes +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Stashing before checkout. diff --git a/src/rules/rtk.safety.git-checkout-dot.md b/src/rules/rtk.safety.git-checkout-dot.md new file mode 100644 index 0000000..007b347 --- /dev/null +++ b/src/rules/rtk.safety.git-checkout-dot.md @@ -0,0 +1,10 @@ +--- +name: git-checkout-dot +patterns: ["git checkout ."] +action: rewrite +redirect: "git stash push -m 'RTK: checkout backup' && git checkout . {args}" +when: has_unstaged_changes +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Stashing before checkout. diff --git a/src/rules/rtk.safety.git-clean-df.md b/src/rules/rtk.safety.git-clean-df.md new file mode 100644 index 0000000..ef7b1c8 --- /dev/null +++ b/src/rules/rtk.safety.git-clean-df.md @@ -0,0 +1,9 @@ +--- +name: git-clean-df +patterns: ["git clean -df"] +action: rewrite +redirect: "git stash -u -m 'RTK: clean backup' && git clean -df {args}" +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Stashing untracked before clean. diff --git a/src/rules/rtk.safety.git-clean-f.md b/src/rules/rtk.safety.git-clean-f.md new file mode 100644 index 0000000..5582e91 --- /dev/null +++ b/src/rules/rtk.safety.git-clean-f.md @@ -0,0 +1,9 @@ +--- +name: git-clean-f +patterns: ["git clean -f"] +action: rewrite +redirect: "git stash -u -m 'RTK: clean backup' && git clean -f {args}" +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Stashing untracked before clean. diff --git a/src/rules/rtk.safety.git-clean-fd.md b/src/rules/rtk.safety.git-clean-fd.md new file mode 100644 index 0000000..ceb2e90 --- /dev/null +++ b/src/rules/rtk.safety.git-clean-fd.md @@ -0,0 +1,9 @@ +--- +name: git-clean-fd +patterns: ["git clean -fd"] +action: rewrite +redirect: "git stash -u -m 'RTK: clean backup' && git clean -fd {args}" +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Stashing untracked before clean. diff --git a/src/rules/rtk.safety.git-reset-hard.md b/src/rules/rtk.safety.git-reset-hard.md new file mode 100644 index 0000000..4b3a6d7 --- /dev/null +++ b/src/rules/rtk.safety.git-reset-hard.md @@ -0,0 +1,10 @@ +--- +name: git-reset-hard +patterns: ["git reset --hard"] +action: rewrite +redirect: "git stash push -m 'RTK: reset backup' && git reset --hard {args}" +when: has_unstaged_changes +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Stashing before reset. diff --git a/src/rules/rtk.safety.git-stash-drop.md b/src/rules/rtk.safety.git-stash-drop.md new file mode 100644 index 0000000..6bd38f0 --- /dev/null +++ b/src/rules/rtk.safety.git-stash-drop.md @@ -0,0 +1,9 @@ +--- +name: git-stash-drop +patterns: ["git stash drop"] +action: rewrite +redirect: "git stash pop" +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Using pop instead of drop (recoverable). diff --git a/src/rules/rtk.safety.rm-to-trash.md b/src/rules/rtk.safety.rm-to-trash.md new file mode 100644 index 0000000..49e690a --- /dev/null +++ b/src/rules/rtk.safety.rm-to-trash.md @@ -0,0 +1,9 @@ +--- +name: rm-to-trash +patterns: [rm] +action: trash +redirect: "trash {args}" +env_var: RTK_SAFE_COMMANDS +--- + +Safety: Moving to trash instead of permanent deletion.