Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 85 additions & 24 deletions pkg/cmd/copy/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
)

var (
copyLong = "Copy files and directories between your local machine and remote instance"
copyLong = "Copy files and directories between your local machine and remote instance (uses rsync by default and falls back to scp)"
copyExample = "brev copy instance_name:/path/to/remote/file /path/to/local/file\nbrev copy /path/to/local/file instance_name:/path/to/remote/file\nbrev copy ./local-directory/ instance_name:/remote/path/"
)

Expand Down Expand Up @@ -87,7 +87,7 @@ func runCopyCommand(t *terminal.Terminal, cstore CopyStore, source, dest string,

_ = writeconnectionevent.WriteWCEOnEnv(cstore, workspace.DNS)

err = runSCP(t, sshName, localPath, remotePath, isUpload)
err = runCopyWithFallback(t, sshName, localPath, remotePath, isUpload)
if err != nil {
return breverrors.WrapAndTrace(err)
}
Expand Down Expand Up @@ -202,33 +202,26 @@ func parseWorkspacePath(path string) (workspace, filePath string, err error) {
return parts[0], parts[1], nil
}

func runSCP(t *terminal.Terminal, sshAlias, localPath, remotePath string, isUpload bool) error {
var scpCmd *exec.Cmd
var source, dest string
type commandRunner func(name string, args ...string) ([]byte, error)

startTime := time.Now()

scpArgs := []string{"scp"}

if isUpload {
if isDirectory(localPath) {
scpArgs = append(scpArgs, "-r")
}
scpArgs = append(scpArgs, localPath, fmt.Sprintf("%s:%s", sshAlias, remotePath))
source = localPath
dest = fmt.Sprintf("%s:%s", sshAlias, remotePath)
} else {
scpArgs = append(scpArgs, "-r")
scpArgs = append(scpArgs, fmt.Sprintf("%s:%s", sshAlias, remotePath), localPath)
source = fmt.Sprintf("%s:%s", sshAlias, remotePath)
dest = localPath
func combinedOutputRunner(name string, args ...string) ([]byte, error) {
cmd := exec.Command(name, args...) //nolint:gosec // Command and args come from internal call sites using fixed binaries/flags (rsync/scp).
output, err := cmd.CombinedOutput()
if err != nil {
return output, fmt.Errorf("run %s command: %w", name, err)
}
return output, nil
}

scpCmd = exec.Command(scpArgs[0], scpArgs[1:]...) //nolint:gosec //sshAlias is validated workspace identifier
func runCopyWithFallback(t *terminal.Terminal, sshAlias, localPath, remotePath string, isUpload bool) error {
source, dest := transferEndpoints(sshAlias, localPath, remotePath, isUpload)

output, err := scpCmd.CombinedOutput()
startTime := time.Now()
err := transferWithFallback(sshAlias, localPath, remotePath, isUpload, combinedOutputRunner, func() {
t.Vprint(t.Yellow("rsync failed, falling back to scp...\n"))
})
if err != nil {
return breverrors.WrapAndTrace(fmt.Errorf("scp failed: %s\nOutput: %s", err.Error(), string(output)))
return breverrors.WrapAndTrace(err)
}

duration := time.Since(startTime)
Expand All @@ -238,6 +231,74 @@ func runSCP(t *terminal.Terminal, sshAlias, localPath, remotePath string, isUplo
return nil
}

func transferWithFallback(sshAlias, localPath, remotePath string, isUpload bool, runner commandRunner, onFallback func()) error {
err := runRsyncCommand(sshAlias, localPath, remotePath, isUpload, runner)
if err == nil {
return nil
}

if onFallback != nil {
onFallback()
}

scpErr := runSCPCommand(sshAlias, localPath, remotePath, isUpload, runner)
if scpErr != nil {
return fmt.Errorf("%v\nscp fallback failed: %w", err, scpErr)
}

return nil
}

func runRsyncCommand(sshAlias, localPath, remotePath string, isUpload bool, runner commandRunner) error {
rsyncArgs := buildRsyncArgs(sshAlias, localPath, remotePath, isUpload)
output, err := runner("rsync", rsyncArgs...)
if err != nil {
return fmt.Errorf("rsync failed: %s\nOutput: %s", err.Error(), string(output))
}
return nil
}

func runSCPCommand(sshAlias, localPath, remotePath string, isUpload bool, runner commandRunner) error {
scpArgs := buildSCPArgs(sshAlias, localPath, remotePath, isUpload)
output, err := runner("scp", scpArgs...)
if err != nil {
return fmt.Errorf("scp failed: %s\nOutput: %s", err.Error(), string(output))
}
return nil
}

func buildRsyncArgs(sshAlias, localPath, remotePath string, isUpload bool) []string {
source, dest := transferEndpoints(sshAlias, localPath, remotePath, isUpload)

rsyncArgs := []string{"-z", "-e", "ssh"}
if !isUpload || isDirectory(localPath) {
rsyncArgs = append(rsyncArgs, "-r")
}
rsyncArgs = append(rsyncArgs, source, dest)

return rsyncArgs
}

func buildSCPArgs(sshAlias, localPath, remotePath string, isUpload bool) []string {
source, dest := transferEndpoints(sshAlias, localPath, remotePath, isUpload)

scpArgs := []string{}
if !isUpload || isDirectory(localPath) {
scpArgs = append(scpArgs, "-r")
}
scpArgs = append(scpArgs, source, dest)

return scpArgs
}

func transferEndpoints(sshAlias, localPath, remotePath string, isUpload bool) (source, dest string) {
remoteTarget := fmt.Sprintf("%s:%s", sshAlias, remotePath)
if isUpload {
return localPath, remoteTarget
}
return remoteTarget, localPath
}

func waitForSSHToBeAvailable(sshAlias string, s *spinner.Spinner) error {
counter := 0
s.Suffix = " waiting for SSH connection to be available"
Expand Down
106 changes: 106 additions & 0 deletions pkg/cmd/copy/copy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package copy

import (
"errors"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBuildRsyncArgs(t *testing.T) {
t.Run("upload file", func(t *testing.T) {
args := buildRsyncArgs("ws", "/tmp/local.txt", "/remote/path", true)
assert.Equal(t, []string{"-z", "-e", "ssh", "/tmp/local.txt", "ws:/remote/path"}, args)
})

t.Run("upload directory", func(t *testing.T) {
tmpDir := t.TempDir()
localDir := filepath.Join(tmpDir, "mydir")
err := os.MkdirAll(localDir, 0o755)
assert.NoError(t, err)

args := buildRsyncArgs("ws", localDir, "/remote/path", true)
assert.Equal(t, []string{"-z", "-e", "ssh", "-r", localDir, "ws:/remote/path"}, args)
})

t.Run("download path", func(t *testing.T) {
args := buildRsyncArgs("ws", "/tmp/local.txt", "/remote/path", false)
assert.Equal(t, []string{"-z", "-e", "ssh", "-r", "ws:/remote/path", "/tmp/local.txt"}, args)
})
}

func TestBuildSCPArgs(t *testing.T) {
t.Run("upload file", func(t *testing.T) {
args := buildSCPArgs("ws", "/tmp/local.txt", "/remote/path", true)
assert.Equal(t, []string{"/tmp/local.txt", "ws:/remote/path"}, args)
})

t.Run("upload directory", func(t *testing.T) {
tmpDir := t.TempDir()
localDir := filepath.Join(tmpDir, "mydir")
err := os.MkdirAll(localDir, 0o755)
assert.NoError(t, err)

args := buildSCPArgs("ws", localDir, "/remote/path", true)
assert.Equal(t, []string{"-r", localDir, "ws:/remote/path"}, args)
})

t.Run("download path", func(t *testing.T) {
args := buildSCPArgs("ws", "/tmp/local.txt", "/remote/path", false)
assert.Equal(t, []string{"-r", "ws:/remote/path", "/tmp/local.txt"}, args)
})
}

func TestTransferWithFallback(t *testing.T) {
t.Run("rsync success", func(t *testing.T) {
calls := []string{}
runner := func(name string, args ...string) ([]byte, error) {
calls = append(calls, name)
return []byte("ok"), nil
}

onFallbackCalled := false
err := transferWithFallback("ws", "/tmp/local.txt", "/remote/path", true, runner, func() {
onFallbackCalled = true
})
assert.NoError(t, err)
assert.False(t, onFallbackCalled)
assert.Equal(t, []string{"rsync"}, calls)
})

t.Run("rsync fails and scp succeeds", func(t *testing.T) {
calls := []string{}
runner := func(name string, args ...string) ([]byte, error) {
calls = append(calls, name)
if name == "rsync" {
return []byte("rsync failed"), errors.New("exit status 1")
}
return []byte("scp ok"), nil
}

err := transferWithFallback("ws", "/tmp/local.txt", "/remote/path", true, runner, func() {
calls = append(calls, "fallback")
})
assert.NoError(t, err)
assert.Equal(t, []string{"rsync", "fallback", "scp"}, calls)
})

t.Run("rsync fails and scp fails", func(t *testing.T) {
runner := func(name string, args ...string) ([]byte, error) {
if name == "rsync" {
return []byte("rsync output"), errors.New("exit status 1")
}
return []byte("scp output"), errors.New("exit status 1")
}

err := transferWithFallback("ws", "/tmp/local.txt", "/remote/path", true, runner, func() {})
assert.Error(t, err)
assert.Contains(t, err.Error(), "rsync failed: exit status 1")
assert.Contains(t, err.Error(), "scp fallback failed")
assert.Contains(t, err.Error(), "rsync output")
assert.Contains(t, err.Error(), "scp output")
assert.NotContains(t, err.Error(), "rsync failed: rsync failed:")
})
}