Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,12 @@ enum Commands {
Add {
/// Name of the branch (creates if it doesn't exist) or remote ref (e.g., origin/feature).
/// When used with --pr, this becomes the custom local branch name.
#[arg(required_unless_present_any = ["pr", "auto_name"], value_parser = GitBranchParser::new())]
branch_name: Option<String>,
#[arg(
required_unless_present_any = ["pr", "auto_name"],
value_parser = GitBranchParser::new(),
num_args = 1..
)]
branch_name: Vec<String>,

/// Pull request number to checkout
#[arg(long, conflicts_with_all = ["base", "auto_name"])]
Expand Down Expand Up @@ -604,19 +608,22 @@ pub fn run() -> Result<()> {
multi,
wait,
session,
} => command::add::run(
branch_name.as_deref(),
pr,
auto_name,
base.as_deref(),
name,
prompt,
setup,
rescue,
multi,
wait,
session,
),
} => {
let branch_name = normalize_branch_name_input(&branch_name);
command::add::run(
branch_name.as_deref(),
pr,
auto_name,
base.as_deref(),
name,
prompt,
setup,
rescue,
multi,
wait,
session,
)
}
Commands::Open {
name,
run_hooks,
Expand Down Expand Up @@ -767,3 +774,51 @@ fn print_bash_dynamic_completion() {
fn print_fish_dynamic_completion() {
print!("{}", include_str!("scripts/completions/fish_dynamic.fish"));
}

fn normalize_branch_name_input(parts: &[String]) -> Option<String> {
if parts.is_empty() {
return None;
}

let raw = if parts.len() == 1 {
parts[0].clone()
} else {
parts.join(" ")
};

if raw.chars().any(char::is_whitespace) {
return Some(slug::slugify(&raw));
}

Some(raw)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn add_accepts_unquoted_multi_word_branch_name() {
let parsed = Cli::try_parse_from([
"workmux", "add", "this", "branch", "name", "--background",
]);

assert!(parsed.is_ok());
}

#[test]
fn normalize_branch_name_slugifies_multi_word_input() {
let normalized = normalize_branch_name_input(&[
"this".to_string(),
"branch".to_string(),
"name".to_string(),
]);
assert_eq!(normalized.as_deref(), Some("this-branch-name"));
}

#[test]
fn normalize_branch_name_preserves_single_token_refs() {
let normalized = normalize_branch_name_input(&["origin/feature".to_string()]);
assert_eq!(normalized.as_deref(), Some("origin/feature"));
}
}
7 changes: 6 additions & 1 deletion src/command/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ fn generate_branch_name_with_spinner(
.and_then(|c| c.system_prompt.as_deref());

let generated = spinner::with_spinner("Generating branch name", || {
crate::llm::generate_branch_name(prompt_text, model, system_prompt)
crate::llm::generate_branch_name(
prompt_text,
model,
system_prompt,
config.auto_name.as_ref().and_then(|c| c.command.as_deref()),
)
})?;
println!(" Branch: {}", generated);

Expand Down
22 changes: 22 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ pub struct AutoNameConfig {
/// If not set, uses llm's default model.
pub model: Option<String>,

/// Optional command used instead of `llm` for branch name generation.
///
/// The configured command is executed with the composed prompt appended as
/// the final argument. Examples:
/// - "opencode run"
/// - "claude -p"
/// - "gemini -p"
/// - "codex -p"
pub command: Option<String>,

/// Custom system prompt for branch name generation.
/// If not set, uses the default prompt that asks for a kebab-case branch name.
pub system_prompt: Option<String>,
Expand Down Expand Up @@ -1629,6 +1639,7 @@ impl Config {
# LLM-based branch name generation (`workmux add -A`).
# auto_name:
# model: "gpt-4o-mini"
# command: "opencode run"
# system_prompt: "Generate a kebab-case git branch name."
# background: true # Always run in background when using --auto-name

Expand Down Expand Up @@ -3015,6 +3026,17 @@ windows:
assert!(windows[1].name.is_none());
}

#[test]
fn parse_auto_name_command() {
let yaml = r#"
auto_name:
command: "opencode run"
"#;
let config: Config = serde_yaml::from_str(yaml).unwrap();
let auto_name = config.auto_name.unwrap();
assert_eq!(auto_name.command.as_deref(), Some("opencode run"));
}

#[test]
fn validate_windows_config_empty_errors() {
let result = validate_windows_config(&[]);
Expand Down
118 changes: 109 additions & 9 deletions src/llm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::{Context, Result, anyhow};
use anyhow::{anyhow, Context, Result};
use std::io::Write;
use std::process::{Command, Stdio};

Expand All @@ -9,10 +9,65 @@ pub fn generate_branch_name(
prompt: &str,
model: Option<&str>,
system_prompt: Option<&str>,
command: Option<&str>,
) -> Result<String> {
let system = system_prompt.unwrap_or(DEFAULT_SYSTEM_PROMPT);
let full_prompt = format!("{}\n\nUser Input:\n{}", system, prompt);

let raw = run_generator_command(&full_prompt, model, command)?;
let branch_name = sanitize_branch_name(raw.trim());

if branch_name.is_empty() {
return Err(anyhow!("LLM returned empty branch name"));
}

Ok(branch_name)
}

fn run_generator_command(
full_prompt: &str,
model: Option<&str>,
command: Option<&str>,
) -> Result<String> {
let configured_command = command.and_then(|c| {
let trimmed = c.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed)
}
});

if let Some(command_line) = configured_command {
return run_custom_command(full_prompt, command_line);
}

run_llm_command(full_prompt, model)
}

fn run_custom_command(full_prompt: &str, command_line: &str) -> Result<String> {
let (program, rest) = crate::config::split_first_token(command_line)
.ok_or_else(|| anyhow!("auto_name.command cannot be empty"))?;

let mut cmd = Command::new(program);
if !rest.trim().is_empty() {
cmd.args(rest.split_whitespace());
}
cmd.arg(full_prompt);

let output = cmd
.output()
.with_context(|| format!("Failed to run '{}' command. Is it installed?", program))?;

if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(anyhow!("{} command failed: {}", program, stderr));
}

Ok(String::from_utf8(output.stdout)?)
}

fn run_llm_command(full_prompt: &str, model: Option<&str>) -> Result<String> {
let mut cmd = Command::new("llm");
if let Some(m) = model {
cmd.args(["-m", m]);
Expand All @@ -36,14 +91,7 @@ pub fn generate_branch_name(
return Err(anyhow!("llm command failed: {}", stderr));
}

let raw = String::from_utf8(output.stdout)?;
let branch_name = sanitize_branch_name(raw.trim());

if branch_name.is_empty() {
return Err(anyhow!("LLM returned empty branch name"));
}

Ok(branch_name)
Ok(String::from_utf8(output.stdout)?)
}

fn sanitize_branch_name(raw: &str) -> String {
Expand All @@ -63,6 +111,58 @@ fn sanitize_branch_name(raw: &str) -> String {
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
use tempfile::TempDir;

#[cfg(unix)]
fn write_executable_script(path: &std::path::Path, content: &str) {
fs::write(path, content).unwrap();
let mut perms = fs::metadata(path).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(path, perms).unwrap();
}

#[test]
#[cfg(unix)]
fn custom_command_supports_claude_style_prompt_flag() {
let tmp = TempDir::new().unwrap();
let script_path = tmp.path().join("fake-claude");
let received_path = tmp.path().join("received_prompt.txt");
write_executable_script(
&script_path,
&format!(
"#!/bin/sh\nset -e\n[ \"$1\" = \"-p\" ]\nprintf '%s' \"$2\" > \"{}\"\nprintf '%s' 'branch from claude'\n",
received_path.display()
),
);

let command = format!("{} -p", script_path.display());
let generated =
generate_branch_name("Add billing retry logic", None, None, Some(&command)).unwrap();

assert_eq!(generated, "branch-from-claude");
let captured_prompt = fs::read_to_string(received_path).unwrap();
assert!(captured_prompt.contains("User Input:\nAdd billing retry logic"));
}

#[test]
#[cfg(unix)]
fn custom_command_supports_opencode_run_style_invocation() {
let tmp = TempDir::new().unwrap();
let script_path = tmp.path().join("fake-opencode");
write_executable_script(
&script_path,
"#!/bin/sh\nset -e\n[ \"$1\" = \"run\" ]\nprintf '%s' 'opencode-branch'\n",
);

let command = format!("{} run", script_path.display());
let generated =
generate_branch_name("Refactor auth middleware", None, None, Some(&command)).unwrap();

assert_eq!(generated, "opencode-branch");
}

#[test]
fn sanitize_branch_name_simple() {
Expand Down