From 7cb4b75a69a89c385c66bdecd40e6df9288c88b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20D=C3=B6tsch?= Date: Thu, 25 Sep 2025 07:44:17 +0200 Subject: [PATCH 1/3] refactor SSH connection management and autocomplete logic for improved clarity and functionality --- Makefile | 2 +- pkg/autocomplete.go | 50 ++++++---- pkg/commands.go | 141 ++++++++------------------ pkg/format.go | 24 +++++ pkg/format_test.go | 234 ++++++++++++++++++++++++++++++++++++++++++++ pkg/interactive.go | 39 ++------ pkg/main_test.go | 172 ++++---------------------------- pkg/ssh.go | 102 ++++++++++--------- 8 files changed, 408 insertions(+), 356 deletions(-) create mode 100644 pkg/format_test.go diff --git a/Makefile b/Makefile index 5e26593..acbeaeb 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ all: test lint build build: @echo "Building $(BINARY_NAME)..." @mkdir -p $(BUILD_DIR) - @CGO_ENABLED=0 go build -mod=mod -trimpath -ldflags="-s -w -extldflags '-static'" -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd + @CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd test: @echo "Running tests..." diff --git a/pkg/autocomplete.go b/pkg/autocomplete.go index 83ccaf4..7c8efd1 100644 --- a/pkg/autocomplete.go +++ b/pkg/autocomplete.go @@ -3,8 +3,10 @@ package pkg import ( "bytes" "context" + "log" "os/exec" "strings" + "time" ) const maxCompletions = 10 @@ -19,9 +21,9 @@ func limitCompletions(completions []string) []string { // customCompleter implements readline.AutoCompleter interface type customCompleter struct { - hosts []string - noColor bool - connMgr *SSHConnectionManager + firstHost string + noColor bool + connMgr *SSHConnectionManager } // Do implements the AutoCompleter interface @@ -45,7 +47,7 @@ func (c *customCompleter) Do(line []rune, pos int) ([][]rune, int) { currentWord := string(line[wordStart:pos]) // Get completions using our logic - pass both line and current word - completions := completerWithWord(lineStr, currentWord, c.hosts, c.connMgr) + completions := completerWithWord(lineStr, currentWord, c.firstHost, c.connMgr) completions = limitCompletions(completions) // Convert completions back to rune slices @@ -91,14 +93,17 @@ func getLocalFileCompletions(prefix string) []string { // getSSHCompletions runs completion command on the first host using socket connection func getSSHCompletions(word string, firstHost string, connMgr *SSHConnectionManager) []string { - if firstHost == "" { - return []string{} - } - socketPath := connMgr.getSocketPath(firstHost) + var args []string - args = append(args, "-S", socketPath, "-o", "BatchMode=yes") - args = append(args, firstHost) + args = append(args, + "-S", socketPath, + "-o", SSHBatchMode, + ) + + if connMgr.user != "" { + args = append([]string{"-l", connMgr.user}, args...) + } // Build compgen command to run on remote host var compgenCmd string @@ -107,28 +112,35 @@ func getSSHCompletions(word string, firstHost string, connMgr *SSHConnectionMana compgenCmd = "compgen -c" case strings.Contains(word, "/"): // Handle path completion - compgenCmd = "compgen -d '" + word + "' || compgen -f '" + word + "'" + compgenCmd = "(compgen -d '" + word + "' || compgen -f '" + word + "')" default: // Try command completion first, then file completion - compgenCmd = "compgen -c '" + word + "' || compgen -f '" + word + "'" + compgenCmd = "(compgen -c '" + word + "' || compgen -f '" + word + "')" } - args = append(args, compgenCmd) + args = append(args, firstHost, compgenCmd) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + + cmd := exec.CommandContext(ctx, "ssh", args...) + if Verbose { + log.Printf("SSH completion command: %s\n", cmd.String()) + } - cmd := exec.CommandContext(context.Background(), "ssh", args...) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr err := cmd.Run() if err != nil { + if Verbose { + log.Printf("SSH completion command: %s\nErr: %s, %s\n", cmd.String(), stderr.String(), err) + } return []string{} } output := strings.TrimSpace(stdout.String()) - if output == "" { - return []string{} - } lines := strings.Split(output, "\n") var completions []string @@ -143,7 +155,7 @@ func getSSHCompletions(word string, firstHost string, connMgr *SSHConnectionMana } // completerWithWord handles tab completion with proper word-based logic -func completerWithWord(line string, currentWord string, hosts []string, connMgr *SSHConnectionManager) []string { +func completerWithWord(line string, currentWord string, firstHost string, connMgr *SSHConnectionManager) []string { if line == "" { return []string{} } @@ -202,7 +214,7 @@ func completerWithWord(line string, currentWord string, hosts []string, connMgr } // For regular commands, use SSH completion on the first host - sshCompletions := getSSHCompletions(currentWord, hosts[0], connMgr) + sshCompletions := getSSHCompletions(currentWord, firstHost, connMgr) // For all completions (commands and paths), return suffixes as expected by readline var filteredCompletions []string diff --git a/pkg/commands.go b/pkg/commands.go index 4d6c34a..5f88177 100644 --- a/pkg/commands.go +++ b/pkg/commands.go @@ -1,7 +1,6 @@ package pkg import ( - "bufio" "context" "fmt" "os" @@ -9,16 +8,16 @@ import ( "os/signal" "strings" "sync" + "time" ) // executeCommandStreaming runs a command on all hosts using persistent SSH connections with streaming output and context cancellation -func executeCommandStreaming(ctx context.Context, cm *SSHConnectionManager, hosts []string, command string, noColor bool) { - maxHostLen := maxLen(hosts) +func executeCommandStreaming(ctx context.Context, cm *SSHConnectionManager, command string) { var wg sync.WaitGroup - for i, host := range hosts { + for _, connection := range cm.connections { wg.Go(func() { - cm.runSSHStreaming(ctx, host, command, i, maxHostLen, noColor) + cm.runSSHStreaming(ctx, connection, command) }) } @@ -43,12 +42,21 @@ func ExecuteCommand(hosts []string, command, user string, noColor bool) { cancel() }() + // Create SSH connection manager for persistent connections + connManager := NewSSHConnectionManager(user) + defer connManager.closeAllConnections() // Ensure cleanup on exit + maxHostLen := maxLen(hosts) var wg sync.WaitGroup - for i, host := range hosts { wg.Go(func() { - runSSHStreaming(ctx, host, command, user, i, maxHostLen, noColor) + err, conn := connManager.establishConnection(host, i, maxHostLen, noColor) + if err != nil { + fmt.Printf("Failed to establish connection to %s: %v\n", host, err) + return + } + + connManager.runSSHStreaming(ctx, conn, command) }) } @@ -56,31 +64,34 @@ func ExecuteCommand(hosts []string, command, user string, noColor bool) { } // uploadFile uploads a file to all hosts in parallel -func uploadFile(hosts []string, filepath, user string, noColor bool) { +func uploadFile(connManager *SSHConnectionManager, filepath string) { // Check if local file exists if _, err := os.Stat(filepath); os.IsNotExist(err) { fmt.Printf("❌ Error: File '%s' does not exist\n", filepath) return } - maxHostLen := maxLen(hosts) var wg sync.WaitGroup - - for i, host := range hosts { + for _, conn := range connManager.connections { wg.Go(func() { - runSCP(host, filepath, user, i, maxHostLen, noColor) + runSCP(connManager, conn, filepath) }) } wg.Wait() } -// runSCP uploads a file to a single host using scp -func runSCP(host, filepath, user string, idx, maxHostLen int, noColor bool) { - args := []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes"} +// runSCP uploads a file to a single host using scp with direct connection and progress +func runSCP(cm *SSHConnectionManager, conn *SSHConnection, filepath string) { + args := []string{ + "-o", SSHConnectTimeout, + "-o", SSHBatchMode, + "-o", "ControlPath=" + conn.socketPath, // Use the control socket for persistent connection + "-v", // Verbose mode for progress output + } - if user != "" { - args = append(args, "-o", "User="+user) + if cm.user != "" { + args = append(args, "-o", "User="+cm.user) } // Get just the filename for the destination @@ -90,100 +101,26 @@ func runSCP(host, filepath, user string, idx, maxHostLen int, noColor bool) { filename = parts[len(parts)-1] } + start := time.Now() + // scp source destination - args = append(args, filepath, host+":"+filename) + args = append(args, filepath, conn.host+":"+filename) cmd := exec.CommandContext(context.Background(), "scp", args...) - output, err := cmd.CombinedOutput() - prefix := formatHostPrefix(host, idx, maxHostLen, noColor) - - if err != nil { - fmt.Printf("%s: ❌ UPLOAD ERROR: %v\n", prefix, err) - if len(output) > 0 { - fmt.Printf("%s: %s\n", prefix, strings.TrimSpace(string(output))) - } - return - } - - fmt.Printf("%s: ✅ Upload successful: %s\n", prefix, filename) -} - -// runSSHStreaming executes SSH command for a single host with real-time streaming output -func runSSHStreaming(ctx context.Context, host, command, user string, idx, maxHostLen int, noColor bool) { - args := []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes"} - - if user != "" { - args = append(args, "-l", user) + if Verbose { + fmt.Printf("SCP connand: %s\n", cmd.String()) } - args = append(args, host, command) - cmd := exec.CommandContext(ctx, "ssh", args...) + output, err := cmd.CombinedOutput() + duration := time.Since(start) - // Get stdout and stderr pipes - stdout, err := cmd.StdoutPipe() - if err != nil { - prefix := formatHostPrefix(host, idx, maxHostLen, noColor) - fmt.Printf("%s: ERROR: Failed to get stdout pipe: %v\n", prefix, err) - return - } - stderr, err := cmd.StderrPipe() if err != nil { - prefix := formatHostPrefix(host, idx, maxHostLen, noColor) - fmt.Printf("%s: ERROR: Failed to get stderr pipe: %v\n", prefix, err) - return - } - - prefix := formatHostPrefix(host, idx, maxHostLen, noColor) - - // Start goroutines to read and display output in real-time - var wg sync.WaitGroup - - // Handle stdout - wg.Go(func() { - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - select { - case <-ctx.Done(): - return - default: - line := scanner.Text() - fmt.Printf("%s: %s\n", prefix, line) - } - } - if err := scanner.Err(); err != nil && ctx.Err() == nil { - fmt.Printf("%s: ERROR: Failed to read stdout: %v\n", prefix, err) - } - }) - - // Handle stderr - wg.Go(func() { - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - select { - case <-ctx.Done(): - return - default: - line := scanner.Text() - fmt.Printf("%s: %s\n", prefix, line) - } - } - if err := scanner.Err(); err != nil && ctx.Err() == nil { - fmt.Printf("%s: ERROR: Failed to read stderr: %v\n", prefix, err) + fmt.Printf("%s: ❌ UPLOAD ERROR (%.2fs): %v\n", conn.prefix, duration.Seconds(), err) + if len(output) > 0 { + fmt.Printf("%s: %s\n", conn.prefix, strings.TrimSpace(string(output))) } - }) - - // Start the command - if err := cmd.Start(); err != nil { - fmt.Printf("%s: ERROR: Failed to start command: %v\n", prefix, err) return } - // Wait for command to complete and output readers to finish - if err := cmd.Wait(); err != nil { - // Only show error if context wasn't cancelled - if ctx.Err() == nil { - fmt.Printf("%s: ERROR: Command failed: %v\n", prefix, err) - } - } - wg.Wait() + fmt.Printf("%s: ✅ Upload successful: %s (%.1fs)\n", conn.prefix, filename, duration.Seconds()) } diff --git a/pkg/format.go b/pkg/format.go index 04bcc3e..fc15042 100644 --- a/pkg/format.go +++ b/pkg/format.go @@ -47,3 +47,27 @@ func maxLen(strings []string) int { } return maxLength } + +// printProgressBar displays a simple text-based progress bar +func printProgressBar(current, total int, width int) { + if total == 0 { + return + } + + percentage := float64(current) / float64(total) + filled := int(float64(width) * percentage) + + bar := "" + for i := range width { + if i < filled { + bar += "█" + } else { + bar += "░" + } + } + + fmt.Printf("\r[%s] %d/%d (%.1f%%)", bar, current, total, percentage*100) + if current == total { + fmt.Println() // New line when complete + } +} diff --git a/pkg/format_test.go b/pkg/format_test.go new file mode 100644 index 0000000..4e11cb6 --- /dev/null +++ b/pkg/format_test.go @@ -0,0 +1,234 @@ +package pkg + +import ( + "io" + "os" + "strings" + "testing" +) + +func TestMaxLen(t *testing.T) { + tests := []struct { + input []string + expected int + }{ + {[]string{"a", "bb", "ccc"}, 3}, + {[]string{"hello", "world"}, 5}, + {[]string{}, 0}, + {[]string{"single"}, 6}, + } + + for _, test := range tests { + result := maxLen(test.input) + if result != test.expected { + t.Errorf("maxLen(%v) = %d, expected %d", test.input, result, test.expected) + } + } +} + +func TestFormatHost(t *testing.T) { + tests := []struct { + name string + host string + idx int + maxLen int + noColor bool + contains string + notContains string + }{ + {"no color padding", "host1", 0, 10, true, "host1 ", "\033["}, + {"short host no color", "short", 1, 10, true, "short ", "\033["}, + {"with color codes", "host1", 0, 10, false, "host1 ", ""}, + {"with color has ANSI", "host1", 0, 10, false, "\033[", ""}, + {"with color has reset", "host1", 0, 10, false, "\033[0m", ""}, + {"different colors", "host2", 1, 10, false, "\033[", ""}, + {"exact length", "exactly10c", 0, 10, true, "exactly10c", "\033["}, + {"longer than max", "verylonghost", 0, 8, true, "verylonghost", "\033["}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := formatHostPrefix(test.host, test.idx, test.maxLen, test.noColor) + + if test.contains != "" && !strings.Contains(result, test.contains) { + t.Errorf("formatHostPrefix(%q, %d, %d, %t) = %q, should contain %q", + test.host, test.idx, test.maxLen, test.noColor, result, test.contains) + } + + if test.notContains != "" && strings.Contains(result, test.notContains) { + t.Errorf("formatHostPrefix(%q, %d, %d, %t) = %q, should not contain %q", + test.host, test.idx, test.maxLen, test.noColor, result, test.notContains) + } + }) + } +} + +func TestFormatHostColorCycling(t *testing.T) { + // Test that different indices produce different colors + host := "test" + maxLen := 10 + noColor := false + + results := make([]string, len(colors)) + for i := range colors { + results[i] = formatHostPrefix(host, i, maxLen, noColor) + } + + // Test that cycling works (index beyond colors length) + cycledResult := formatHostPrefix(host, len(colors), maxLen, noColor) + if cycledResult != results[0] { + t.Errorf("Color cycling failed: index %d should match index 0", len(colors)) + } +} + +func TestFormatHostPrefixEdgeCases(t *testing.T) { + tests := []struct { + name string + host string + idx int + maxLen int + noColor bool + description string + }{ + {"empty host", "", 0, 5, true, "empty hostname should pad correctly"}, + {"very long host", "verylonghostname", 0, 5, true, "long hostname should be truncated by padding"}, + {"negative maxLen", "host", 0, -1, true, "negative maxLen should work with fmt.Sprintf"}, + {"special chars", "host@domain.com", 0, 15, true, "host with special characters"}, + {"unicode host", "héllo", 0, 8, true, "host with unicode characters"}, + {"large index", "host", 100, 8, false, "large index should cycle colors"}, + {"zero maxLen", "host", 0, 0, true, "zero maxLen should work"}, + {"very large maxLen", "host", 0, 100, true, "very large maxLen should work"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := formatHostPrefix(test.host, test.idx, test.maxLen, test.noColor) + + // Basic sanity checks + if test.noColor && strings.Contains(result, "\033[") { + t.Errorf("Expected no color codes when noColor=true, got: %q", result) + } + if !test.noColor && !strings.Contains(result, "\033[") { + t.Errorf("Expected color codes when noColor=false, got: %q", result) + } + if !test.noColor && !strings.Contains(result, "\033[0m") { + t.Errorf("Expected reset code when noColor=false, got: %q", result) + } + + // Length check (padding should work regardless of maxLen value) + if test.maxLen >= 0 && len(result) < len(test.host) && !test.noColor { + // For colored output, length should be at least host length + color codes + minExpectedLen := len(test.host) + len("\033[1;31m") + len("\033[0m") + if len(result) < minExpectedLen { + t.Errorf("Expected result length >= %d for colored output, got %d", minExpectedLen, len(result)) + } + } + }) + } +} + +func TestMaxLenEdgeCases(t *testing.T) { + tests := []struct { + name string + input []string + expected int + }{ + {"nil slice", nil, 0}, + {"empty strings", []string{"", ""}, 0}, + {"mixed empty and non-empty", []string{"", "a", ""}, 1}, + {"unicode strings", []string{"hello", "héllo", "wörld"}, 6}, // "wörld" has 6 bytes + {"equal lengths", []string{"aaa", "bbb", "ccc"}, 3}, + {"single character variations", []string{"a", "b", "c"}, 1}, + {"very long strings", []string{"short", "thisisaverylongstringthatshouldbethelongest"}, 43}, + {"numbers as strings", []string{"1", "22", "333", "4444"}, 4}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := maxLen(test.input) + if result != test.expected { + t.Errorf("maxLen(%v) = %d, expected %d", test.input, result, test.expected) + } + }) + } +} + +func TestPrintProgressBarEdgeCases(t *testing.T) { + tests := []struct { + name string + current int + total int + width int + expected string // partial expected output + }{ + {"zero progress", 0, 10, 5, "[░░░░░] 0/10"}, + {"half progress", 5, 10, 5, "[██░░░] 5/10"}, + {"full progress", 10, 10, 5, "[█████] 10/10"}, + {"zero total", 0, 0, 5, ""}, // Should not print anything + {"single item", 1, 1, 3, "[███] 1/1"}, + {"current exceeds total", 15, 10, 5, "[█████] 15/10"}, // Should show 100% filled + {"negative current", -1, 10, 5, "[░░░░░] -1/10"}, // Should show 0% filled + {"negative total", 5, -1, 5, "[░░░░░] 5/-1"}, // Division by zero should show 0% + {"zero width", 1, 2, 0, ""}, // Should show empty bar + {"large width", 1, 2, 20, "[██████████░░░░░░░░░░] 1/2"}, + {"very large numbers", 1000, 1000, 10, "[██████████] 1000/1000"}, + {"fractional progress", 1, 3, 6, "[██░░░░] 1/3"}, // Should show ~33% filled + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Capture stdout + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + printProgressBar(test.current, test.total, test.width) + + // Restore stdout + w.Close() + os.Stdout = oldStdout + + // Read captured output + output, _ := io.ReadAll(r) + outputStr := string(output) + + if test.total == 0 { + if outputStr != "" { + t.Errorf("Expected no output for zero total, got %q", outputStr) + } + return + } + + // Check that the output contains expected elements + if test.expected != "" && !strings.Contains(outputStr, test.expected) { + t.Errorf("Expected output to contain %q, got %q", test.expected, outputStr) + } + + // Check for progress bar characters (when width > 0) + if test.width > 0 && (!strings.Contains(outputStr, "[") || !strings.Contains(outputStr, "]")) { + t.Errorf("Expected output to contain progress bar brackets, got %q", outputStr) + } + + // For completed progress, should have newline + if test.current == test.total && test.total > 0 { + if !strings.HasSuffix(outputStr, "\n") { + t.Errorf("Expected output to end with newline for completed progress, got %q", outputStr) + } + } + + // Check bar length matches width (count runes, not bytes) + if test.width > 0 && strings.Contains(outputStr, "[") { + // Extract the bar content between brackets + start := strings.Index(outputStr, "[") + end := strings.Index(outputStr, "]") + if start >= 0 && end > start { + barContent := outputStr[start+1 : end] + runeCount := len([]rune(barContent)) + if runeCount != test.width { + t.Errorf("Expected bar width %d, got %d in %q", test.width, runeCount, barContent) + } + } + } + }) + } +} diff --git a/pkg/interactive.go b/pkg/interactive.go index 5b73c46..b9255b7 100644 --- a/pkg/interactive.go +++ b/pkg/interactive.go @@ -41,9 +41,9 @@ func InteractiveMode(hosts []string, user string, noColor bool, verbose bool) { var wg sync.WaitGroup // Start connections in parallel - for _, host := range hosts { + for idx, host := range hosts { wg.Go(func() { - err := connManager.establishConnection(host) + err, _ := connManager.establishConnection(host, idx, maxLen(hosts), noColor) resultChan <- connectionResult{host: host, error: err} }) } @@ -94,10 +94,11 @@ func InteractiveMode(hosts []string, user string, noColor bool, verbose bool) { prompt := fmt.Sprintf("🖥️ [%d]> ", len(connectedHosts)) config := &readline.Config{ Prompt: prompt, + // todo tweak AutoComplete: &customCompleter{ - hosts: connectedHosts, - noColor: noColor, - connMgr: connManager, + firstHost: connectedHosts[0], + noColor: noColor, + connMgr: connManager, }, HistoryFile: os.Getenv("HOME") + "/.gosh_history", } @@ -136,7 +137,7 @@ func InteractiveMode(hosts []string, user string, noColor bool, verbose bool) { fmt.Println("📁 Usage: :upload ") continue } - uploadFile(connectedHosts, filepath, user, noColor) + uploadFile(connManager, filepath) case line == ":verbose": Verbose = !Verbose status := "disabled" @@ -161,7 +162,7 @@ func InteractiveMode(hosts []string, user string, noColor bool, verbose bool) { }() // Execute command with interruptible context - executeCommandStreaming(ctx, connManager, connectedHosts, line, noColor) + executeCommandStreaming(ctx, connManager, line) // Clean up cancel() @@ -186,27 +187,3 @@ func showHelp() { fmt.Println(" ls -la - List files on all connected hosts") fmt.Println(" :upload script.sh - Upload script.sh to all connected hosts") } - -// printProgressBar displays a simple text-based progress bar -func printProgressBar(current, total int, width int) { - if total == 0 { - return - } - - percentage := float64(current) / float64(total) - filled := int(float64(width) * percentage) - - bar := "" - for i := range width { - if i < filled { - bar += "█" - } else { - bar += "░" - } - } - - fmt.Printf("\r[%s] %d/%d (%.1f%%)", bar, current, total, percentage*100) - if current == total { - fmt.Println() // New line when complete - } -} diff --git a/pkg/main_test.go b/pkg/main_test.go index dcf966e..d87387c 100644 --- a/pkg/main_test.go +++ b/pkg/main_test.go @@ -26,80 +26,6 @@ type HostStatus struct { Error error } -func TestMaxLen(t *testing.T) { - tests := []struct { - input []string - expected int - }{ - {[]string{"a", "bb", "ccc"}, 3}, - {[]string{"hello", "world"}, 5}, - {[]string{}, 0}, - {[]string{"single"}, 6}, - } - - for _, test := range tests { - result := maxLen(test.input) - if result != test.expected { - t.Errorf("maxLen(%v) = %d, expected %d", test.input, result, test.expected) - } - } -} - -func TestFormatHost(t *testing.T) { - tests := []struct { - name string - host string - idx int - maxLen int - noColor bool - contains string - notContains string - }{ - {"no color padding", "host1", 0, 10, true, "host1 ", "\033["}, - {"short host no color", "short", 1, 10, true, "short ", "\033["}, - {"with color codes", "host1", 0, 10, false, "host1 ", ""}, - {"with color has ANSI", "host1", 0, 10, false, "\033[", ""}, - {"with color has reset", "host1", 0, 10, false, "\033[0m", ""}, - {"different colors", "host2", 1, 10, false, "\033[", ""}, - {"exact length", "exactly10c", 0, 10, true, "exactly10c", "\033["}, - {"longer than max", "verylonghost", 0, 8, true, "verylonghost", "\033["}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - result := formatHostPrefix(test.host, test.idx, test.maxLen, test.noColor) - - if test.contains != "" && !strings.Contains(result, test.contains) { - t.Errorf("formatHostPrefix(%q, %d, %d, %t) = %q, should contain %q", - test.host, test.idx, test.maxLen, test.noColor, result, test.contains) - } - - if test.notContains != "" && strings.Contains(result, test.notContains) { - t.Errorf("formatHostPrefix(%q, %d, %d, %t) = %q, should not contain %q", - test.host, test.idx, test.maxLen, test.noColor, result, test.notContains) - } - }) - } -} - -func TestFormatHostColorCycling(t *testing.T) { - // Test that different indices produce different colors - host := "test" - maxLen := 10 - noColor := false - - results := make([]string, len(colors)) - for i := range colors { - results[i] = formatHostPrefix(host, i, maxLen, noColor) - } - - // Test that cycling works (index beyond colors length) - cycledResult := formatHostPrefix(host, len(colors), maxLen, noColor) - if cycledResult != results[0] { - t.Errorf("Color cycling failed: index %d should match index 0", len(colors)) - } -} - // startFakeSSHServer starts a fake SSH server for testing purposes func startFakeSSHServer(t *testing.T, addr string, response string) net.Listener { t.Helper() @@ -370,7 +296,14 @@ func TestUploadFile(t *testing.T) { t.Run(test.name, func(_ *testing.T) { // This test verifies the function doesn't panic and handles file existence // Actual SCP execution is tested in integration tests - uploadFile(test.hosts, test.filepath, test.user, test.noColor) + connManager := NewSSHConnectionManager(test.user) + connManager.connections = make(map[string]*SSHConnection) + for _, host := range test.hosts { + connManager.connections[host] = &SSHConnection{ + host: host, + } + } + uploadFile(connManager, test.filepath) }) } } @@ -419,21 +352,21 @@ func TestSSHCommandConstruction(t *testing.T) { "example.com", "echo test", "", - []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes", "example.com", "echo test"}, + []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode, "example.com", "echo test"}, }, { "with user", "example.com", "whoami", "testuser", - []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes", "-l", "testuser", "example.com", "whoami"}, + []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode, "-l", "testuser", "example.com", "whoami"}, }, { "complex command", "192.168.1.1", "ls -la /home", "admin", - []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes", "-l", "admin", "192.168.1.1", "ls -la /home"}, + []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode, "-l", "admin", "192.168.1.1", "ls -la /home"}, }, } @@ -441,7 +374,7 @@ func TestSSHCommandConstruction(t *testing.T) { t.Run(test.name, func(t *testing.T) { // We can't easily test runSSH directly due to exec.Command, // but we can verify the argument construction logic - args := []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes"} + args := []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode} if test.user != "" { args = append(args, "-l", test.user) } @@ -473,28 +406,28 @@ func TestSCPCommandConstruction(t *testing.T) { "example.com", "test.txt", "", - []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes", "test.txt", "example.com:test.txt"}, + []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode, "test.txt", "example.com:test.txt"}, }, { "with user", "example.com", "script.sh", "testuser", - []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes", "-o", "User=testuser", "script.sh", "example.com:script.sh"}, + []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode, "-o", "User=testuser", "script.sh", "example.com:script.sh"}, }, { "path with directory", "192.168.1.1", "/home/user/data.csv", "admin", - []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes", "-o", "User=admin", "/home/user/data.csv", "192.168.1.1:data.csv"}, + []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode, "-o", "User=admin", "/home/user/data.csv", "192.168.1.1:data.csv"}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Test the SCP argument construction logic - args := []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes"} + args := []string{"-o", SSHConnectTimeout, "-o", SSHBatchMode} if test.user != "" { args = append(args, "-o", "User="+test.user) } @@ -759,9 +692,9 @@ func TestTestAllConnections(t *testing.T) { func TestCustomCompleterDo(t *testing.T) { completer := &customCompleter{ - hosts: []string{"host1", "host2"}, - noColor: true, - connMgr: &SSHConnectionManager{}, + firstHost: "host1", + noColor: true, + connMgr: &SSHConnectionManager{}, } tests := []struct { @@ -944,19 +877,15 @@ func TestCompleterWithWord(t *testing.T) { {"help command", ":help", "help", "internal"}, {"regular command", "ls", "ls", "ssh"}, {"path command", "cd /home", "/home", "ssh"}, - {"no hosts", "", "", "none"}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var testHosts []string - if test.expectedType != "none" { - testHosts = hosts - } + testHosts := hosts conMgr := NewSSHConnectionManager(user) conMgr.connections = map[string]*SSHConnection{} - completions := completerWithWord(test.line, test.currentWord, testHosts, conMgr) + completions := completerWithWord(test.line, test.currentWord, testHosts[0], conMgr) // completions should be a valid slice (can be nil or empty) // The function may return nil in some error cases @@ -1042,62 +971,3 @@ func TestGetLocalFileCompletionsExtended(t *testing.T) { }) } } - -func TestPrintProgressBar(t *testing.T) { - tests := []struct { - name string - current int - total int - width int - expected string // partial expected output - }{ - {"zero progress", 0, 10, 5, "[░░░░░] 0/10"}, - {"half progress", 5, 10, 5, "[██░░░] 5/10"}, - {"full progress", 10, 10, 5, "[█████] 10/10"}, - {"zero total", 0, 0, 5, ""}, // Should not print anything - {"single item", 1, 1, 3, "[███] 1/1"}, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Capture stdout - oldStdout := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - printProgressBar(test.current, test.total, test.width) - - // Restore stdout - w.Close() - os.Stdout = oldStdout - - // Read captured output - output, _ := io.ReadAll(r) - outputStr := string(output) - - if test.total == 0 { - if outputStr != "" { - t.Errorf("Expected no output for zero total, got %q", outputStr) - } - return - } - - // Check that the output contains expected elements - if !strings.Contains(outputStr, test.expected) { - t.Errorf("Expected output to contain %q, got %q", test.expected, outputStr) - } - - // Check for progress bar characters - if !strings.Contains(outputStr, "[") || !strings.Contains(outputStr, "]") { - t.Errorf("Expected output to contain progress bar brackets, got %q", outputStr) - } - - // For completed progress, should have newline - if test.current == test.total && test.total > 0 { - if !strings.HasSuffix(outputStr, "\n") { - t.Errorf("Expected output to end with newline for completed progress, got %q", outputStr) - } - } - }) - } -} diff --git a/pkg/ssh.go b/pkg/ssh.go index e65c044..9e60da5 100644 --- a/pkg/ssh.go +++ b/pkg/ssh.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -12,6 +13,15 @@ import ( "sync" ) +const ( + // SSHConnectTimeout is the SSH connection timeout in seconds + SSHConnectTimeout = "ConnectTimeout=5" + // SSHControlPersist is how long SSH control connections should persist + SSHControlPersist = "ControlPersist=10m" + // SSHBatchMode enables non-interactive operation + SSHBatchMode = "BatchMode=yes" +) + // runCmdWithSeparateOutput runs a command and returns stdout, stderr, and error separately func runCmdWithSeparateOutput(cmd *exec.Cmd) (string, string, error) { var stdout, stderr bytes.Buffer @@ -34,6 +44,7 @@ type SSHConnectionManager struct { type SSHConnection struct { host string socketPath string + prefix string } // NewSSHConnectionManager creates a new connection manager @@ -57,16 +68,16 @@ func (cm *SSHConnectionManager) getSocketPath(host string) string { } // establishConnection establishes a persistent SSH connection to a host -func (cm *SSHConnectionManager) establishConnection(host string) error { +func (cm *SSHConnectionManager) establishConnection(host string, idx int, maxLen int, noColor bool) (error, *SSHConnection) { socketPath := cm.getSocketPath(host) // Establish new connection args := []string{ "-M", // Enable ControlMaster "-S", socketPath, // Control socket path - "-o", "ControlPersist=10m", // Keep connection alive for 10 minutes - "-o", "ConnectTimeout=5", - "-o", "BatchMode=yes", + "-o", SSHControlPersist, // Keep connection alive for 10 minutes + "-o", SSHConnectTimeout, + "-o", SSHBatchMode, "-f", // Go to background after establishing connection } @@ -78,85 +89,75 @@ func (cm *SSHConnectionManager) establishConnection(host string) error { cmd := exec.CommandContext(context.Background(), "ssh", args...) if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to establish SSH connection to %s: %w", host, err) + return fmt.Errorf("failed to establish SSH connection to %s: %w", host, err), nil } - // Store connection info - cm.mu.Lock() - cm.connections[host] = &SSHConnection{ + conn := &SSHConnection{ host: host, socketPath: socketPath, + prefix: formatHostPrefix(host, idx, maxLen, noColor), } + // Store connection info, only for successful hosts + cm.mu.Lock() + cm.connections[host] = conn cm.mu.Unlock() - return nil + return nil, conn } // runSSHStreaming executes SSH command using persistent connection with real-time streaming output and context cancellation -func (cm *SSHConnectionManager) runSSHStreaming(ctx context.Context, host, command string, idx, maxHostLen int, noColor bool) { - socketPath := cm.getSocketPath(host) - +func (cm *SSHConnectionManager) runSSHStreaming(ctx context.Context, conn *SSHConnection, command string) { args := []string{ - "-S", socketPath, // Use existing control socket - "-o", "BatchMode=yes", + "-S", conn.socketPath, // Use existing control socket + "-o", SSHBatchMode, } if cm.user != "" { args = append(args, "-l", cm.user) } - args = append(args, host, command) + args = append(args, conn.host, command) cmd := exec.CommandContext(ctx, "ssh", args...) // Get stdout and stderr pipes stdout, err := cmd.StdoutPipe() if err != nil { - prefix := formatHostPrefix(host, idx, maxHostLen, noColor) - fmt.Printf("%s: ERROR: Failed to get stdout pipe: %v\n", prefix, err) + fmt.Printf("%s: ERROR: Failed to get stdout pipe: %v\n", conn.prefix, err) return } stderr, err := cmd.StderrPipe() if err != nil { - prefix := formatHostPrefix(host, idx, maxHostLen, noColor) - fmt.Printf("%s: ERROR: Failed to get stderr pipe: %v\n", prefix, err) + fmt.Printf("%s: ERROR: Failed to get stderr pipe: %v\n", conn.prefix, err) return } - prefix := formatHostPrefix(host, idx, maxHostLen, noColor) - - // Start goroutines to read and display output in real-time - var wg sync.WaitGroup - // Handle stdout - wg.Go(func() { - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - select { - case <-ctx.Done(): - return - default: - line := scanner.Text() - fmt.Printf("%s: %s\n", prefix, line) + // Helper function to handle output streams + handleStream := func(reader io.Reader, output *os.File) func() { + return func() { + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + select { + case <-ctx.Done(): + return + default: + line := scanner.Text() + _, _ = fmt.Fprintf(output, "%s: %s\n", conn.prefix, line) + } } - } - }) - - // Handle stderr - wg.Go(func() { - scanner := bufio.NewScanner(stderr) - for scanner.Scan() { - select { - case <-ctx.Done(): - return - default: - line := scanner.Text() - fmt.Printf("%s: %s\n", prefix, line) + if err := scanner.Err(); err != nil && ctx.Err() == nil { + _, _ = fmt.Fprintf(os.Stderr, "%s: ERROR: Failed to read: %v\n", conn.prefix, err) } } - }) + } + + // Start goroutines to read and display output in real-time + var wg sync.WaitGroup + wg.Go(handleStream(stdout, os.Stdout)) + wg.Go(handleStream(stderr, os.Stderr)) // Start the command if err := cmd.Start(); err != nil { - fmt.Printf("%s: ERROR: Failed to start command: %v\n", prefix, err) + fmt.Printf("%s: ERROR: Failed to start command: %v\n", conn.prefix, err) return } @@ -164,7 +165,7 @@ func (cm *SSHConnectionManager) runSSHStreaming(ctx context.Context, host, comma if err := cmd.Wait(); err != nil { // Only show error if context wasn't cancelled if ctx.Err() == nil { - fmt.Printf("%s: ERROR: Command failed: %v\n", prefix, err) + fmt.Printf("%s: ERROR: Command failed: %v\n", conn.prefix, err) } } wg.Wait() @@ -195,7 +196,4 @@ func (cm *SSHConnectionManager) closeAllConnections() { for host := range cm.connections { cm.closeConnection(host) } - - // Remove socket directory - _ = os.RemoveAll(cm.socketDir) } From 95c3162921772b53cbfdd57b282fdab151d80ea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20D=C3=B6tsch?= Date: Thu, 25 Sep 2025 08:02:46 +0200 Subject: [PATCH 2/3] refactor SSH connection management and autocomplete logic for improved clarity and functionality made linter happy --- pkg/commands.go | 2 +- pkg/interactive.go | 2 +- pkg/ssh.go | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/commands.go b/pkg/commands.go index 5f88177..afc8eac 100644 --- a/pkg/commands.go +++ b/pkg/commands.go @@ -50,7 +50,7 @@ func ExecuteCommand(hosts []string, command, user string, noColor bool) { var wg sync.WaitGroup for i, host := range hosts { wg.Go(func() { - err, conn := connManager.establishConnection(host, i, maxHostLen, noColor) + conn, err := connManager.establishConnection(host, i, maxHostLen, noColor) if err != nil { fmt.Printf("Failed to establish connection to %s: %v\n", host, err) return diff --git a/pkg/interactive.go b/pkg/interactive.go index b9255b7..bef5867 100644 --- a/pkg/interactive.go +++ b/pkg/interactive.go @@ -43,7 +43,7 @@ func InteractiveMode(hosts []string, user string, noColor bool, verbose bool) { // Start connections in parallel for idx, host := range hosts { wg.Go(func() { - err, _ := connManager.establishConnection(host, idx, maxLen(hosts), noColor) + _, err := connManager.establishConnection(host, idx, maxLen(hosts), noColor) resultChan <- connectionResult{host: host, error: err} }) } diff --git a/pkg/ssh.go b/pkg/ssh.go index 9e60da5..dcde90b 100644 --- a/pkg/ssh.go +++ b/pkg/ssh.go @@ -68,7 +68,7 @@ func (cm *SSHConnectionManager) getSocketPath(host string) string { } // establishConnection establishes a persistent SSH connection to a host -func (cm *SSHConnectionManager) establishConnection(host string, idx int, maxLen int, noColor bool) (error, *SSHConnection) { +func (cm *SSHConnectionManager) establishConnection(host string, idx int, maxLen int, noColor bool) (*SSHConnection, error) { socketPath := cm.getSocketPath(host) // Establish new connection @@ -89,7 +89,7 @@ func (cm *SSHConnectionManager) establishConnection(host string, idx int, maxLen cmd := exec.CommandContext(context.Background(), "ssh", args...) if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to establish SSH connection to %s: %w", host, err), nil + return nil, fmt.Errorf("failed to establish SSH connection to %s: %w", host, err) } conn := &SSHConnection{ @@ -102,7 +102,7 @@ func (cm *SSHConnectionManager) establishConnection(host string, idx int, maxLen cm.connections[host] = conn cm.mu.Unlock() - return nil, conn + return conn, nil } // runSSHStreaming executes SSH command using persistent connection with real-time streaming output and context cancellation From d334ed9a6e58bad2fbea3ac33c7ea34a0b921701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20D=C3=B6tsch?= Date: Mon, 12 Jan 2026 09:06:56 +0100 Subject: [PATCH 3/3] cleanups --- pkg/commands.go | 2 +- pkg/ssh.go | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/pkg/commands.go b/pkg/commands.go index afc8eac..d14045d 100644 --- a/pkg/commands.go +++ b/pkg/commands.go @@ -52,7 +52,7 @@ func ExecuteCommand(hosts []string, command, user string, noColor bool) { wg.Go(func() { conn, err := connManager.establishConnection(host, i, maxHostLen, noColor) if err != nil { - fmt.Printf("Failed to establish connection to %s: %v\n", host, err) + fmt.Printf("❌ Failed to establish connection to %s: %v\n", host, err) return } diff --git a/pkg/ssh.go b/pkg/ssh.go index dcde90b..b8f43ed 100644 --- a/pkg/ssh.go +++ b/pkg/ssh.go @@ -88,8 +88,17 @@ func (cm *SSHConnectionManager) establishConnection(host string, idx int, maxLen args = append(args, host, "true") // Simple command to establish connection cmd := exec.CommandContext(context.Background(), "ssh", args...) - if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("failed to establish SSH connection to %s: %w", host, err) + stdout, stderr, err := runCmdWithSeparateOutput(cmd) + if err != nil { + // Extract meaningful error message from stderr if available + errorMsg := err.Error() + if stderr != "" { + // Use the actual SSH error message from stderr + errorMsg = strings.TrimSpace(stderr) + } + // Suppress unused variable warning for stdout + _ = stdout + return nil, fmt.Errorf("failed to establish SSH connection to %s: %s", host, errorMsg) } conn := &SSHConnection{