Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ all: test lint build
build:
@echo "Building $(BINARY_NAME)..."
@mkdir -p $(BUILD_DIR)
@CGO_ENABLED=0 go build -mod=mod -trimpath -ldflags="-s -w -extldflags '-static'" -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd
@CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd

test:
@echo "Running tests..."
Expand Down
50 changes: 31 additions & 19 deletions pkg/autocomplete.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package pkg
import (
"bytes"
"context"
"log"
"os/exec"
"strings"
"time"
)

const maxCompletions = 10
Expand All @@ -19,9 +21,9 @@ func limitCompletions(completions []string) []string {

// customCompleter implements readline.AutoCompleter interface
type customCompleter struct {
hosts []string
noColor bool
connMgr *SSHConnectionManager
firstHost string
noColor bool
connMgr *SSHConnectionManager
}

// Do implements the AutoCompleter interface
Expand All @@ -45,7 +47,7 @@ func (c *customCompleter) Do(line []rune, pos int) ([][]rune, int) {
currentWord := string(line[wordStart:pos])

// Get completions using our logic - pass both line and current word
completions := completerWithWord(lineStr, currentWord, c.hosts, c.connMgr)
completions := completerWithWord(lineStr, currentWord, c.firstHost, c.connMgr)
completions = limitCompletions(completions)

// Convert completions back to rune slices
Expand Down Expand Up @@ -91,14 +93,17 @@ func getLocalFileCompletions(prefix string) []string {

// getSSHCompletions runs completion command on the first host using socket connection
func getSSHCompletions(word string, firstHost string, connMgr *SSHConnectionManager) []string {
if firstHost == "" {
return []string{}
}

socketPath := connMgr.getSocketPath(firstHost)

var args []string
args = append(args, "-S", socketPath, "-o", "BatchMode=yes")
args = append(args, firstHost)
args = append(args,
"-S", socketPath,
"-o", SSHBatchMode,
)

if connMgr.user != "" {
args = append([]string{"-l", connMgr.user}, args...)
}

// Build compgen command to run on remote host
var compgenCmd string
Expand All @@ -107,28 +112,35 @@ func getSSHCompletions(word string, firstHost string, connMgr *SSHConnectionMana
compgenCmd = "compgen -c"
case strings.Contains(word, "/"):
// Handle path completion
compgenCmd = "compgen -d '" + word + "' || compgen -f '" + word + "'"
compgenCmd = "(compgen -d '" + word + "' || compgen -f '" + word + "')"
default:
// Try command completion first, then file completion
compgenCmd = "compgen -c '" + word + "' || compgen -f '" + word + "'"
compgenCmd = "(compgen -c '" + word + "' || compgen -f '" + word + "')"
}

args = append(args, compgenCmd)
args = append(args, firstHost, compgenCmd)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()

cmd := exec.CommandContext(ctx, "ssh", args...)
if Verbose {
log.Printf("SSH completion command: %s\n", cmd.String())
}

cmd := exec.CommandContext(context.Background(), "ssh", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr

err := cmd.Run()
if err != nil {
if Verbose {
log.Printf("SSH completion command: %s\nErr: %s, %s\n", cmd.String(), stderr.String(), err)
}
return []string{}
}

output := strings.TrimSpace(stdout.String())
if output == "" {
return []string{}
}

lines := strings.Split(output, "\n")
var completions []string
Expand All @@ -143,7 +155,7 @@ func getSSHCompletions(word string, firstHost string, connMgr *SSHConnectionMana
}

// completerWithWord handles tab completion with proper word-based logic
func completerWithWord(line string, currentWord string, hosts []string, connMgr *SSHConnectionManager) []string {
func completerWithWord(line string, currentWord string, firstHost string, connMgr *SSHConnectionManager) []string {
if line == "" {
return []string{}
}
Expand Down Expand Up @@ -202,7 +214,7 @@ func completerWithWord(line string, currentWord string, hosts []string, connMgr
}

// For regular commands, use SSH completion on the first host
sshCompletions := getSSHCompletions(currentWord, hosts[0], connMgr)
sshCompletions := getSSHCompletions(currentWord, firstHost, connMgr)

// For all completions (commands and paths), return suffixes as expected by readline
var filteredCompletions []string
Expand Down
141 changes: 39 additions & 102 deletions pkg/commands.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
package pkg

import (
"bufio"
"context"
"fmt"
"os"
"os/exec"
"os/signal"
"strings"
"sync"
"time"
)

// executeCommandStreaming runs a command on all hosts using persistent SSH connections with streaming output and context cancellation
func executeCommandStreaming(ctx context.Context, cm *SSHConnectionManager, hosts []string, command string, noColor bool) {
maxHostLen := maxLen(hosts)
func executeCommandStreaming(ctx context.Context, cm *SSHConnectionManager, command string) {
var wg sync.WaitGroup

for i, host := range hosts {
for _, connection := range cm.connections {
wg.Go(func() {
cm.runSSHStreaming(ctx, host, command, i, maxHostLen, noColor)
cm.runSSHStreaming(ctx, connection, command)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Goroutine Capture and Concurrent Map Iteration

The connection variable in the for...range loop is captured by reference in each goroutine, which may cause commands to run on the wrong host. Also, iterating cm.connections without mutex protection creates a race condition if the map is modified concurrently.

Fix in Cursor Fix in Web

})
}

Expand All @@ -43,44 +42,56 @@ func ExecuteCommand(hosts []string, command, user string, noColor bool) {
cancel()
}()

// Create SSH connection manager for persistent connections
connManager := NewSSHConnectionManager(user)
defer connManager.closeAllConnections() // Ensure cleanup on exit

maxHostLen := maxLen(hosts)
var wg sync.WaitGroup

for i, host := range hosts {
wg.Go(func() {
runSSHStreaming(ctx, host, command, user, i, maxHostLen, noColor)
conn, err := connManager.establishConnection(host, i, maxHostLen, noColor)
if err != nil {
fmt.Printf("❌ Failed to establish connection to %s: %v\n", host, err)
return
}

connManager.runSSHStreaming(ctx, conn, command)
})
}

wg.Wait()
}

// uploadFile uploads a file to all hosts in parallel
func uploadFile(hosts []string, filepath, user string, noColor bool) {
func uploadFile(connManager *SSHConnectionManager, filepath string) {
// Check if local file exists
if _, err := os.Stat(filepath); os.IsNotExist(err) {
fmt.Printf("❌ Error: File '%s' does not exist\n", filepath)
return
}

maxHostLen := maxLen(hosts)
var wg sync.WaitGroup

for i, host := range hosts {
for _, conn := range connManager.connections {
wg.Go(func() {
runSCP(host, filepath, user, i, maxHostLen, noColor)
runSCP(connManager, conn, filepath)
})
}

wg.Wait()
}

// runSCP uploads a file to a single host using scp
func runSCP(host, filepath, user string, idx, maxHostLen int, noColor bool) {
args := []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes"}
// runSCP uploads a file to a single host using scp with direct connection and progress
func runSCP(cm *SSHConnectionManager, conn *SSHConnection, filepath string) {
args := []string{
"-o", SSHConnectTimeout,
"-o", SSHBatchMode,
"-o", "ControlPath=" + conn.socketPath, // Use the control socket for persistent connection
"-v", // Verbose mode for progress output
}

if user != "" {
args = append(args, "-o", "User="+user)
if cm.user != "" {
args = append(args, "-o", "User="+cm.user)
}

// Get just the filename for the destination
Expand All @@ -90,100 +101,26 @@ func runSCP(host, filepath, user string, idx, maxHostLen int, noColor bool) {
filename = parts[len(parts)-1]
}

start := time.Now()

// scp source destination
args = append(args, filepath, host+":"+filename)
args = append(args, filepath, conn.host+":"+filename)
cmd := exec.CommandContext(context.Background(), "scp", args...)

output, err := cmd.CombinedOutput()
prefix := formatHostPrefix(host, idx, maxHostLen, noColor)

if err != nil {
fmt.Printf("%s: ❌ UPLOAD ERROR: %v\n", prefix, err)
if len(output) > 0 {
fmt.Printf("%s: %s\n", prefix, strings.TrimSpace(string(output)))
}
return
}

fmt.Printf("%s: ✅ Upload successful: %s\n", prefix, filename)
}

// runSSHStreaming executes SSH command for a single host with real-time streaming output
func runSSHStreaming(ctx context.Context, host, command, user string, idx, maxHostLen int, noColor bool) {
args := []string{"-o", "ConnectTimeout=5", "-o", "BatchMode=yes"}

if user != "" {
args = append(args, "-l", user)
if Verbose {
fmt.Printf("SCP connand: %s\n", cmd.String())
}

args = append(args, host, command)
cmd := exec.CommandContext(ctx, "ssh", args...)
output, err := cmd.CombinedOutput()
duration := time.Since(start)

// Get stdout and stderr pipes
stdout, err := cmd.StdoutPipe()
if err != nil {
prefix := formatHostPrefix(host, idx, maxHostLen, noColor)
fmt.Printf("%s: ERROR: Failed to get stdout pipe: %v\n", prefix, err)
return
}
stderr, err := cmd.StderrPipe()
if err != nil {
prefix := formatHostPrefix(host, idx, maxHostLen, noColor)
fmt.Printf("%s: ERROR: Failed to get stderr pipe: %v\n", prefix, err)
return
}

prefix := formatHostPrefix(host, idx, maxHostLen, noColor)

// Start goroutines to read and display output in real-time
var wg sync.WaitGroup

// Handle stdout
wg.Go(func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
line := scanner.Text()
fmt.Printf("%s: %s\n", prefix, line)
}
}
if err := scanner.Err(); err != nil && ctx.Err() == nil {
fmt.Printf("%s: ERROR: Failed to read stdout: %v\n", prefix, err)
}
})

// Handle stderr
wg.Go(func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
line := scanner.Text()
fmt.Printf("%s: %s\n", prefix, line)
}
}
if err := scanner.Err(); err != nil && ctx.Err() == nil {
fmt.Printf("%s: ERROR: Failed to read stderr: %v\n", prefix, err)
fmt.Printf("%s: ❌ UPLOAD ERROR (%.2fs): %v\n", conn.prefix, duration.Seconds(), err)
if len(output) > 0 {
fmt.Printf("%s: %s\n", conn.prefix, strings.TrimSpace(string(output)))
}
})

// Start the command
if err := cmd.Start(); err != nil {
fmt.Printf("%s: ERROR: Failed to start command: %v\n", prefix, err)
return
}

// Wait for command to complete and output readers to finish
if err := cmd.Wait(); err != nil {
// Only show error if context wasn't cancelled
if ctx.Err() == nil {
fmt.Printf("%s: ERROR: Command failed: %v\n", prefix, err)
}
}
wg.Wait()
fmt.Printf("%s: ✅ Upload successful: %s (%.1fs)\n", conn.prefix, filename, duration.Seconds())
}
24 changes: 24 additions & 0 deletions pkg/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,27 @@
}
return maxLength
}

// printProgressBar displays a simple text-based progress bar
func printProgressBar(current, total int, width int) {
if total == 0 {
return
}

percentage := float64(current) / float64(total)
filled := int(float64(width) * percentage)

bar := ""
for i := range width {
if i < filled {
bar += "█"

Check failure on line 63 in pkg/format.go

View workflow job for this annotation

GitHub Actions / build (1.25.x, ubuntu-latest, linux/amd64)

concat-loop: string concatenation in a loop (perfsprint)

Check failure on line 63 in pkg/format.go

View workflow job for this annotation

GitHub Actions / build (1.25.x, ubuntu-latest, windows/amd64)

concat-loop: string concatenation in a loop (perfsprint)

Check failure on line 63 in pkg/format.go

View workflow job for this annotation

GitHub Actions / build (1.25.x, ubuntu-latest, darwin/amd64)

concat-loop: string concatenation in a loop (perfsprint)
} else {
bar += "░"
}
}

fmt.Printf("\r[%s] %d/%d (%.1f%%)", bar, current, total, percentage*100)
if current == total {
fmt.Println() // New line when complete
}
}
Loading
Loading