diff --git a/cmd/mutagen/forward/create.go b/cmd/mutagen/forward/create.go index 7b4b468ab..62c93fada 100644 --- a/cmd/mutagen/forward/create.go +++ b/cmd/mutagen/forward/create.go @@ -101,6 +101,31 @@ func createMain(_ *cobra.Command, arguments []string) error { return fmt.Errorf("unable to parse destination URL: %w", err) } + if source.Protocol == url.Protocol_SSH { + sshConfigPath := os.Getenv("MUTAGEN_SSH_CONFIG_SOURCE") + if sshConfigPath == "" { + sshConfigPath = os.Getenv("MUTAGEN_SSH_CONFIG") + } + if sshConfigPath != "" { + if source.Parameters == nil { + source.Parameters = make(map[string]string) + } + source.Parameters["ssh-config-path"] = sshConfigPath + } + } + if destination.Protocol == url.Protocol_SSH { + sshConfigPath := os.Getenv("MUTAGEN_SSH_CONFIG_DESTINATION") + if sshConfigPath == "" { + sshConfigPath = os.Getenv("MUTAGEN_SSH_CONFIG") + } + if sshConfigPath != "" { + if destination.Parameters == nil { + destination.Parameters = make(map[string]string) + } + destination.Parameters["ssh-config-path"] = sshConfigPath + } + } + // Validate the name. if err := selection.EnsureNameValid(createConfiguration.name); err != nil { return fmt.Errorf("invalid session name: %w", err) diff --git a/cmd/mutagen/forward/list_monitor_common.go b/cmd/mutagen/forward/list_monitor_common.go index e97a00761..49fdf35e5 100644 --- a/cmd/mutagen/forward/list_monitor_common.go +++ b/cmd/mutagen/forward/list_monitor_common.go @@ -30,6 +30,15 @@ func printEndpoint(name string, url *url.URL, configuration *forwarding.Configur // Print the URL. fmt.Println("\tURL:", terminal.NeutralizeControlCharacters(url.Format("\n\t\t"))) + // Print parameters, if any. + if len(url.Parameters) > 0 { + fmt.Println("\tParameters:") + keys := selection.ExtractAndSortLabelKeys(url.Parameters) + for _, key := range keys { + fmt.Printf("\t\t%s: %s\n", key, terminal.NeutralizeControlCharacters(url.Parameters[key])) + } + } + // Print configuration information if desired. if mode == common.SessionDisplayModeListLong || mode == common.SessionDisplayModeMonitorLong { // Print configuration header. diff --git a/cmd/mutagen/sync/create.go b/cmd/mutagen/sync/create.go index 2ac6fb22e..f86c06049 100644 --- a/cmd/mutagen/sync/create.go +++ b/cmd/mutagen/sync/create.go @@ -109,6 +109,31 @@ func createMain(_ *cobra.Command, arguments []string) error { return fmt.Errorf("unable to parse beta URL: %w", err) } + if alpha.Protocol == url.Protocol_SSH { + sshConfigPath := os.Getenv("MUTAGEN_SSH_CONFIG_ALPHA") + if sshConfigPath == "" { + sshConfigPath = os.Getenv("MUTAGEN_SSH_CONFIG") + } + if sshConfigPath != "" { + if alpha.Parameters == nil { + alpha.Parameters = make(map[string]string) + } + alpha.Parameters["ssh-config-path"] = sshConfigPath + } + } + if beta.Protocol == url.Protocol_SSH { + sshConfigPath := os.Getenv("MUTAGEN_SSH_CONFIG_BETA") + if sshConfigPath == "" { + sshConfigPath = os.Getenv("MUTAGEN_SSH_CONFIG") + } + if sshConfigPath != "" { + if beta.Parameters == nil { + beta.Parameters = make(map[string]string) + } + beta.Parameters["ssh-config-path"] = sshConfigPath + } + } + // Validate the name. if err := selection.EnsureNameValid(createConfiguration.name); err != nil { return fmt.Errorf("invalid session name: %w", err) diff --git a/cmd/mutagen/sync/list_monitor_common.go b/cmd/mutagen/sync/list_monitor_common.go index 1f25a0c01..07380cc76 100644 --- a/cmd/mutagen/sync/list_monitor_common.go +++ b/cmd/mutagen/sync/list_monitor_common.go @@ -90,6 +90,15 @@ func printEndpoint(name string, url *urlpkg.URL, configuration *synchronization. // Print the URL. fmt.Println("\tURL:", terminal.NeutralizeControlCharacters(url.Format("\n\t\t"))) + // Print parameters, if any. + if len(url.Parameters) > 0 { + fmt.Println("\tParameters:") + keys := selection.ExtractAndSortLabelKeys(url.Parameters) + for _, key := range keys { + fmt.Printf("\t\t%s: %s\n", key, terminal.NeutralizeControlCharacters(url.Parameters[key])) + } + } + // Print configuration information if desired. if mode == common.SessionDisplayModeListLong || mode == common.SessionDisplayModeMonitorLong { // Print configuration header. diff --git a/pkg/agent/transport/ssh/transport.go b/pkg/agent/transport/ssh/transport.go index b65f9ff65..237090638 100644 --- a/pkg/agent/transport/ssh/transport.go +++ b/pkg/agent/transport/ssh/transport.go @@ -49,15 +49,18 @@ type sshTransport struct { port uint16 // prompter is the prompter identifier to use for prompting. prompter string + // configPath is the path to the SSH config file to use, if specified. + configPath string } // NewTransport creates a new SSH transport using the specified parameters. -func NewTransport(user, host string, port uint16, prompter string) (agent.Transport, error) { +func NewTransport(user, host string, port uint16, prompter, configPath string) (agent.Transport, error) { return &sshTransport{ - user: user, - host: host, - port: port, - prompter: prompter, + user: user, + host: host, + port: port, + prompter: prompter, + configPath: configPath, }, nil } @@ -95,6 +98,7 @@ func (t *sshTransport) Copy(localPath, remoteName string) error { // Set up arguments. var scpArguments []string + scpArguments = append(scpArguments, ssh.ConfigFlags(t.configPath)...) scpArguments = append(scpArguments, ssh.CompressionFlag()) scpArguments = append(scpArguments, ssh.ConnectTimeoutFlag(connectTimeoutSeconds)) scpArguments = append(scpArguments, ssh.ServerAliveFlags(serverAliveIntervalSeconds, serverAliveCountMax)...) @@ -155,6 +159,7 @@ func (t *sshTransport) Command(command string) (*exec.Cmd, error) { // more efficient to compress at that layer, even with the slower Go // implementation. var sshArguments []string + sshArguments = append(sshArguments, ssh.ConfigFlags(t.configPath)...) sshArguments = append(sshArguments, ssh.ConnectTimeoutFlag(connectTimeoutSeconds)) sshArguments = append(sshArguments, ssh.ServerAliveFlags(serverAliveIntervalSeconds, serverAliveCountMax)...) if t.port != 0 { diff --git a/pkg/forwarding/protocols/ssh/protocol.go b/pkg/forwarding/protocols/ssh/protocol.go index 67a32b9ad..c40a035d2 100644 --- a/pkg/forwarding/protocols/ssh/protocol.go +++ b/pkg/forwarding/protocols/ssh/protocol.go @@ -51,8 +51,14 @@ func (p *protocolHandler) Connect( return nil, fmt.Errorf("unable to parse target specification: %w", err) } + // Extract SSH config path from URL parameters if present. + var sshConfigPath string + if url.Parameters != nil { + sshConfigPath = url.Parameters["ssh-config-path"] + } + // Create an SSH agent transport. - transport, err := ssh.NewTransport(url.User, url.Host, uint16(url.Port), prompter) + transport, err := ssh.NewTransport(url.User, url.Host, uint16(url.Port), prompter, sshConfigPath) if err != nil { return nil, fmt.Errorf("unable to create SSH transport: %w", err) } diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index b6c3a38c9..508e58adb 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -49,6 +49,15 @@ func ServerAliveFlags(interval, countMax int) []string { } } +// ConfigFlags returns flags that can be passed to scp or ssh to specify a +// custom SSH config file. Returns an empty slice if configPath is empty. +func ConfigFlags(configPath string) []string { + if configPath == "" { + return nil + } + return []string{"-F", configPath} +} + // sshCommandPath returns the full path to use for invoking ssh. It will use the // MUTAGEN_SSH_PATH environment variable if provided, otherwise falling back to // a platform-specific implementation. diff --git a/pkg/synchronization/protocols/ssh/protocol.go b/pkg/synchronization/protocols/ssh/protocol.go index ad4229281..36e8301f4 100644 --- a/pkg/synchronization/protocols/ssh/protocol.go +++ b/pkg/synchronization/protocols/ssh/protocol.go @@ -44,8 +44,14 @@ func (h *protocolHandler) Connect( panic("non-SSH URL dispatched to SSH protocol handler") } + // Extract SSH config path from URL parameters if present. + var sshConfigPath string + if url.Parameters != nil { + sshConfigPath = url.Parameters["ssh-config-path"] + } + // Create an SSH agent transport. - transport, err := ssh.NewTransport(url.User, url.Host, uint16(url.Port), prompter) + transport, err := ssh.NewTransport(url.User, url.Host, uint16(url.Port), prompter, sshConfigPath) if err != nil { return nil, fmt.Errorf("unable to create SSH transport: %w", err) }