From fed3f09d6424a3cb4506ecd54533b103139a00c5 Mon Sep 17 00:00:00 2001 From: Winston Zhao Date: Mon, 2 Mar 2026 01:25:36 -0800 Subject: [PATCH] improve add branch naming ergonomics --- src/cli.rs | 85 ++++++++++++++++++++++++++------ src/command/add.rs | 7 ++- src/config.rs | 22 +++++++++ src/llm.rs | 118 +++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 207 insertions(+), 25 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 8d1541e6..0824ceeb 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -179,8 +179,12 @@ enum Commands { Add { /// Name of the branch (creates if it doesn't exist) or remote ref (e.g., origin/feature). /// When used with --pr, this becomes the custom local branch name. - #[arg(required_unless_present_any = ["pr", "auto_name"], value_parser = GitBranchParser::new())] - branch_name: Option, + #[arg( + required_unless_present_any = ["pr", "auto_name"], + value_parser = GitBranchParser::new(), + num_args = 1.. + )] + branch_name: Vec, /// Pull request number to checkout #[arg(long, conflicts_with_all = ["base", "auto_name"])] @@ -604,19 +608,22 @@ pub fn run() -> Result<()> { multi, wait, session, - } => command::add::run( - branch_name.as_deref(), - pr, - auto_name, - base.as_deref(), - name, - prompt, - setup, - rescue, - multi, - wait, - session, - ), + } => { + let branch_name = normalize_branch_name_input(&branch_name); + command::add::run( + branch_name.as_deref(), + pr, + auto_name, + base.as_deref(), + name, + prompt, + setup, + rescue, + multi, + wait, + session, + ) + } Commands::Open { name, run_hooks, @@ -767,3 +774,51 @@ fn print_bash_dynamic_completion() { fn print_fish_dynamic_completion() { print!("{}", include_str!("scripts/completions/fish_dynamic.fish")); } + +fn normalize_branch_name_input(parts: &[String]) -> Option { + if parts.is_empty() { + return None; + } + + let raw = if parts.len() == 1 { + parts[0].clone() + } else { + parts.join(" ") + }; + + if raw.chars().any(char::is_whitespace) { + return Some(slug::slugify(&raw)); + } + + Some(raw) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn add_accepts_unquoted_multi_word_branch_name() { + let parsed = Cli::try_parse_from([ + "workmux", "add", "this", "branch", "name", "--background", + ]); + + assert!(parsed.is_ok()); + } + + #[test] + fn normalize_branch_name_slugifies_multi_word_input() { + let normalized = normalize_branch_name_input(&[ + "this".to_string(), + "branch".to_string(), + "name".to_string(), + ]); + assert_eq!(normalized.as_deref(), Some("this-branch-name")); + } + + #[test] + fn normalize_branch_name_preserves_single_token_refs() { + let normalized = normalize_branch_name_input(&["origin/feature".to_string()]); + assert_eq!(normalized.as_deref(), Some("origin/feature")); + } +} diff --git a/src/command/add.rs b/src/command/add.rs index 24aa5b86..d0fa6973 100644 --- a/src/command/add.rs +++ b/src/command/add.rs @@ -42,7 +42,12 @@ fn generate_branch_name_with_spinner( .and_then(|c| c.system_prompt.as_deref()); let generated = spinner::with_spinner("Generating branch name", || { - crate::llm::generate_branch_name(prompt_text, model, system_prompt) + crate::llm::generate_branch_name( + prompt_text, + model, + system_prompt, + config.auto_name.as_ref().and_then(|c| c.command.as_deref()), + ) })?; println!(" Branch: {}", generated); diff --git a/src/config.rs b/src/config.rs index 2386ec34..6e847701 100644 --- a/src/config.rs +++ b/src/config.rs @@ -56,6 +56,16 @@ pub struct AutoNameConfig { /// If not set, uses llm's default model. pub model: Option, + /// Optional command used instead of `llm` for branch name generation. + /// + /// The configured command is executed with the composed prompt appended as + /// the final argument. Examples: + /// - "opencode run" + /// - "claude -p" + /// - "gemini -p" + /// - "codex -p" + pub command: Option, + /// Custom system prompt for branch name generation. /// If not set, uses the default prompt that asks for a kebab-case branch name. pub system_prompt: Option, @@ -1629,6 +1639,7 @@ impl Config { # LLM-based branch name generation (`workmux add -A`). # auto_name: # model: "gpt-4o-mini" +# command: "opencode run" # system_prompt: "Generate a kebab-case git branch name." # background: true # Always run in background when using --auto-name @@ -3015,6 +3026,17 @@ windows: assert!(windows[1].name.is_none()); } + #[test] + fn parse_auto_name_command() { + let yaml = r#" +auto_name: + command: "opencode run" +"#; + let config: Config = serde_yaml::from_str(yaml).unwrap(); + let auto_name = config.auto_name.unwrap(); + assert_eq!(auto_name.command.as_deref(), Some("opencode run")); + } + #[test] fn validate_windows_config_empty_errors() { let result = validate_windows_config(&[]); diff --git a/src/llm.rs b/src/llm.rs index b4056e51..19367f4f 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -1,4 +1,4 @@ -use anyhow::{Context, Result, anyhow}; +use anyhow::{anyhow, Context, Result}; use std::io::Write; use std::process::{Command, Stdio}; @@ -9,10 +9,65 @@ pub fn generate_branch_name( prompt: &str, model: Option<&str>, system_prompt: Option<&str>, + command: Option<&str>, ) -> Result { let system = system_prompt.unwrap_or(DEFAULT_SYSTEM_PROMPT); let full_prompt = format!("{}\n\nUser Input:\n{}", system, prompt); + let raw = run_generator_command(&full_prompt, model, command)?; + let branch_name = sanitize_branch_name(raw.trim()); + + if branch_name.is_empty() { + return Err(anyhow!("LLM returned empty branch name")); + } + + Ok(branch_name) +} + +fn run_generator_command( + full_prompt: &str, + model: Option<&str>, + command: Option<&str>, +) -> Result { + let configured_command = command.and_then(|c| { + let trimmed = c.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed) + } + }); + + if let Some(command_line) = configured_command { + return run_custom_command(full_prompt, command_line); + } + + run_llm_command(full_prompt, model) +} + +fn run_custom_command(full_prompt: &str, command_line: &str) -> Result { + let (program, rest) = crate::config::split_first_token(command_line) + .ok_or_else(|| anyhow!("auto_name.command cannot be empty"))?; + + let mut cmd = Command::new(program); + if !rest.trim().is_empty() { + cmd.args(rest.split_whitespace()); + } + cmd.arg(full_prompt); + + let output = cmd + .output() + .with_context(|| format!("Failed to run '{}' command. Is it installed?", program))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow!("{} command failed: {}", program, stderr)); + } + + Ok(String::from_utf8(output.stdout)?) +} + +fn run_llm_command(full_prompt: &str, model: Option<&str>) -> Result { let mut cmd = Command::new("llm"); if let Some(m) = model { cmd.args(["-m", m]); @@ -36,14 +91,7 @@ pub fn generate_branch_name( return Err(anyhow!("llm command failed: {}", stderr)); } - let raw = String::from_utf8(output.stdout)?; - let branch_name = sanitize_branch_name(raw.trim()); - - if branch_name.is_empty() { - return Err(anyhow!("LLM returned empty branch name")); - } - - Ok(branch_name) + Ok(String::from_utf8(output.stdout)?) } fn sanitize_branch_name(raw: &str) -> String { @@ -63,6 +111,58 @@ fn sanitize_branch_name(raw: &str) -> String { #[cfg(test)] mod tests { use super::*; + use std::fs; + #[cfg(unix)] + use std::os::unix::fs::PermissionsExt; + use tempfile::TempDir; + + #[cfg(unix)] + fn write_executable_script(path: &std::path::Path, content: &str) { + fs::write(path, content).unwrap(); + let mut perms = fs::metadata(path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(path, perms).unwrap(); + } + + #[test] + #[cfg(unix)] + fn custom_command_supports_claude_style_prompt_flag() { + let tmp = TempDir::new().unwrap(); + let script_path = tmp.path().join("fake-claude"); + let received_path = tmp.path().join("received_prompt.txt"); + write_executable_script( + &script_path, + &format!( + "#!/bin/sh\nset -e\n[ \"$1\" = \"-p\" ]\nprintf '%s' \"$2\" > \"{}\"\nprintf '%s' 'branch from claude'\n", + received_path.display() + ), + ); + + let command = format!("{} -p", script_path.display()); + let generated = + generate_branch_name("Add billing retry logic", None, None, Some(&command)).unwrap(); + + assert_eq!(generated, "branch-from-claude"); + let captured_prompt = fs::read_to_string(received_path).unwrap(); + assert!(captured_prompt.contains("User Input:\nAdd billing retry logic")); + } + + #[test] + #[cfg(unix)] + fn custom_command_supports_opencode_run_style_invocation() { + let tmp = TempDir::new().unwrap(); + let script_path = tmp.path().join("fake-opencode"); + write_executable_script( + &script_path, + "#!/bin/sh\nset -e\n[ \"$1\" = \"run\" ]\nprintf '%s' 'opencode-branch'\n", + ); + + let command = format!("{} run", script_path.display()); + let generated = + generate_branch_name("Refactor auth middleware", None, None, Some(&command)).unwrap(); + + assert_eq!(generated, "opencode-branch"); + } #[test] fn sanitize_branch_name_simple() {