Skip to content

Commit 520988b

Browse files
authored
SSH: Check for Remote SSH extension in VS Code and Cursor (#4676)
## Changes Check if the required Remote SSH extension is installed and above a minimum version, and if not, offer to install it. ## Why <!-- Why are these changes needed? Provide the context that the reviewer might be missing. For example, were there any decisions behind the change that are not reflected in the code itself? --> ## Tests Existing and manually. Windows manual tests are WIP
1 parent a16cdc1 commit 520988b

File tree

3 files changed

+303
-16
lines changed

3 files changed

+303
-16
lines changed

experimental/ssh/internal/client/client.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
211211
if err := vscode.CheckIDECommand(opts.IDE); err != nil {
212212
return err
213213
}
214+
if err := vscode.CheckIDESSHExtension(ctx, opts.IDE); err != nil {
215+
return err
216+
}
214217
}
215218

216219
// Check and update IDE settings for serverless mode, where we must set up

experimental/ssh/internal/vscode/run.go

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ import (
55
"fmt"
66
"os"
77
"os/exec"
8+
"strings"
89

910
"github.com/databricks/cli/libs/cmdio"
11+
"golang.org/x/mod/semver"
1012
)
1113

1214
// Options as they can be set via --ide flag.
@@ -16,27 +18,38 @@ const (
1618
)
1719

1820
type ideDescriptor struct {
19-
Option string
20-
Command string
21-
Name string
22-
InstallURL string
23-
AppName string
21+
Option string
22+
Command string
23+
Name string
24+
InstallURL string
25+
AppName string
26+
SSHExtensionID string
27+
SSHExtensionName string
28+
MinSSHExtensionVersion string
2429
}
2530

2631
var vsCodeIDE = ideDescriptor{
27-
Option: VSCodeOption,
28-
Command: "code",
29-
Name: "VS Code",
30-
InstallURL: "https://code.visualstudio.com/",
31-
AppName: "Code",
32+
Option: VSCodeOption,
33+
Command: "code",
34+
Name: "VS Code",
35+
InstallURL: "https://code.visualstudio.com/",
36+
AppName: "Code",
37+
SSHExtensionID: "ms-vscode-remote.remote-ssh",
38+
SSHExtensionName: "Remote - SSH",
39+
// Earlier versions might work too, 0.120.0 is a safe not-too-old pick
40+
MinSSHExtensionVersion: "0.120.0",
3241
}
3342

3443
var cursorIDE = ideDescriptor{
35-
Option: CursorOption,
36-
Command: "cursor",
37-
Name: "Cursor",
38-
InstallURL: "https://cursor.com/",
39-
AppName: "Cursor",
44+
Option: CursorOption,
45+
Command: "cursor",
46+
Name: "Cursor",
47+
InstallURL: "https://cursor.com/",
48+
AppName: "Cursor",
49+
SSHExtensionID: "anysphere.remote-ssh",
50+
SSHExtensionName: "Remote - SSH",
51+
// Earlier versions don't support remote.SSH.serverPickPortsFromRange option
52+
MinSSHExtensionVersion: "1.0.32",
4053
}
4154

4255
func getIDE(option string) ideDescriptor {
@@ -62,7 +75,71 @@ func CheckIDECommand(option string) error {
6275
return nil
6376
}
6477

65-
// LaunchIDE launches the IDE with a remote SSH connection.
78+
// parseExtensionVersion finds the version of the given extension in the output
79+
// of "<command> --list-extensions --show-versions" (one "name@version" per line).
80+
func parseExtensionVersion(output, extensionID string) (string, bool) {
81+
for line := range strings.SplitSeq(output, "\n") {
82+
name, version, ok := strings.Cut(strings.TrimSpace(line), "@")
83+
if ok && name == extensionID {
84+
return version, true
85+
}
86+
}
87+
return "", false
88+
}
89+
90+
func isExtensionVersionAtLeast(version, minVersion string) bool {
91+
v := "v" + version
92+
return semver.IsValid(v) && semver.Compare(v, "v"+minVersion) >= 0
93+
}
94+
95+
// CheckIDESSHExtension verifies that the required Remote SSH extension is installed
96+
// with a compatible version, and offers to install/update it if not.
97+
func CheckIDESSHExtension(ctx context.Context, option string) error {
98+
ide := getIDE(option)
99+
100+
out, err := exec.CommandContext(ctx, ide.Command, "--list-extensions", "--show-versions").Output()
101+
if err != nil {
102+
return fmt.Errorf("failed to list %s extensions: %w", ide.Name, err)
103+
}
104+
105+
version, found := parseExtensionVersion(string(out), ide.SSHExtensionID)
106+
if found && isExtensionVersionAtLeast(version, ide.MinSSHExtensionVersion) {
107+
return nil
108+
}
109+
110+
var msg string
111+
if !found {
112+
msg = fmt.Sprintf("Required extension %q is not installed in %s.", ide.SSHExtensionName, ide.Name)
113+
} else {
114+
msg = fmt.Sprintf("Extension %q version %s is installed, but version >= %s is required.",
115+
ide.SSHExtensionName, version, ide.MinSSHExtensionVersion)
116+
}
117+
118+
if !cmdio.IsPromptSupported(ctx) {
119+
return fmt.Errorf("%s Install it with: %s --install-extension %s",
120+
msg, ide.Command, ide.SSHExtensionID)
121+
}
122+
123+
shouldInstall, err := cmdio.AskYesOrNo(ctx, msg+" Would you like to install it?")
124+
if err != nil {
125+
return fmt.Errorf("failed to prompt user: %w", err)
126+
}
127+
if !shouldInstall {
128+
return fmt.Errorf("%s Install it with: %s --install-extension %s",
129+
msg, ide.Command, ide.SSHExtensionID)
130+
}
131+
132+
cmdio.LogString(ctx, fmt.Sprintf("Installing %q...", ide.SSHExtensionName))
133+
installCmd := exec.CommandContext(ctx, ide.Command, "--install-extension", ide.SSHExtensionID, "--force")
134+
installCmd.Stdout = os.Stdout
135+
installCmd.Stderr = os.Stderr
136+
if err := installCmd.Run(); err != nil {
137+
return fmt.Errorf("failed to install extension %q: %w", ide.SSHExtensionName, err)
138+
}
139+
return nil
140+
}
141+
142+
// LaunchIDE launches the IDE with a remote SSH connection using special "ssh-remote" URI format.
66143
func LaunchIDE(ctx context.Context, ideOption, connectionName, userName, databricksUserName string) error {
67144
ide := getIDE(ideOption)
68145

experimental/ssh/internal/vscode/run_test.go

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package vscode
22

33
import (
4+
"fmt"
45
"os"
56
"path/filepath"
67
"runtime"
78
"testing"
89

10+
"github.com/databricks/cli/libs/cmdio"
911
"github.com/stretchr/testify/assert"
1012
"github.com/stretchr/testify/require"
1113
)
@@ -88,3 +90,208 @@ func TestCheckIDECommand_Found(t *testing.T) {
8890
})
8991
}
9092
}
93+
94+
func TestParseExtensionVersion(t *testing.T) {
95+
tests := []struct {
96+
name string
97+
output string
98+
extensionID string
99+
wantVersion string
100+
wantFound bool
101+
minVersion string
102+
wantAtLeast bool
103+
}{
104+
{
105+
name: "found and above minimum",
106+
output: "ms-python.python@2024.1.1\nms-vscode-remote.remote-ssh@0.123.0\n",
107+
extensionID: "ms-vscode-remote.remote-ssh",
108+
wantVersion: "0.123.0",
109+
wantFound: true,
110+
minVersion: "0.120.0",
111+
wantAtLeast: true,
112+
},
113+
{
114+
name: "found but below minimum",
115+
output: "ms-vscode-remote.remote-ssh@0.100.0\n",
116+
extensionID: "ms-vscode-remote.remote-ssh",
117+
wantVersion: "0.100.0",
118+
wantFound: true,
119+
minVersion: "0.120.0",
120+
wantAtLeast: false,
121+
},
122+
{
123+
name: "not found",
124+
output: "ms-python.python@2024.1.1\n",
125+
extensionID: "ms-vscode-remote.remote-ssh",
126+
wantVersion: "",
127+
wantFound: false,
128+
},
129+
{
130+
name: "empty output",
131+
output: "",
132+
extensionID: "ms-vscode-remote.remote-ssh",
133+
wantVersion: "",
134+
wantFound: false,
135+
},
136+
{
137+
name: "multiple extensions",
138+
output: "ext.a@1.0.0\next.b@2.0.0\next.c@3.0.0\n",
139+
extensionID: "ext.b",
140+
wantVersion: "2.0.0",
141+
wantFound: true,
142+
minVersion: "1.0.0",
143+
wantAtLeast: true,
144+
},
145+
{
146+
name: "prerelease is less than release",
147+
output: "ms-vscode-remote.remote-ssh@0.120.0-beta.1\n",
148+
extensionID: "ms-vscode-remote.remote-ssh",
149+
wantVersion: "0.120.0-beta.1",
150+
wantFound: true,
151+
minVersion: "0.120.0",
152+
wantAtLeast: false,
153+
},
154+
{
155+
name: "line with whitespace",
156+
output: " ms-vscode-remote.remote-ssh@0.123.0 \n",
157+
extensionID: "ms-vscode-remote.remote-ssh",
158+
wantVersion: "0.123.0",
159+
wantFound: true,
160+
minVersion: "0.120.0",
161+
wantAtLeast: true,
162+
},
163+
{
164+
name: "windows CRLF line endings",
165+
output: "ms-python.python@2024.1.1\r\nms-vscode-remote.remote-ssh@0.123.0\r\n",
166+
extensionID: "ms-vscode-remote.remote-ssh",
167+
wantVersion: "0.123.0",
168+
wantFound: true,
169+
minVersion: "0.120.0",
170+
wantAtLeast: true,
171+
},
172+
}
173+
174+
for _, tt := range tests {
175+
t.Run(tt.name, func(t *testing.T) {
176+
version, found := parseExtensionVersion(tt.output, tt.extensionID)
177+
assert.Equal(t, tt.wantFound, found)
178+
assert.Equal(t, tt.wantVersion, version)
179+
if found {
180+
assert.Equal(t, tt.wantAtLeast, isExtensionVersionAtLeast(version, tt.minVersion))
181+
}
182+
})
183+
}
184+
}
185+
186+
func TestIsExtensionVersionAtLeast(t *testing.T) {
187+
tests := []struct {
188+
name string
189+
version string
190+
minVersion string
191+
want bool
192+
}{
193+
{name: "above minimum", version: "0.123.0", minVersion: "0.120.0", want: true},
194+
{name: "exact minimum", version: "0.120.0", minVersion: "0.120.0", want: true},
195+
{name: "below minimum", version: "0.100.0", minVersion: "0.120.0", want: false},
196+
{name: "major version ahead", version: "1.0.0", minVersion: "0.120.0", want: true},
197+
{name: "prerelease below release", version: "0.120.0-beta.1", minVersion: "0.120.0", want: false},
198+
{name: "prerelease above prior release", version: "0.121.0-beta.1", minVersion: "0.120.0", want: true},
199+
{name: "two-component version is valid", version: "1.0", minVersion: "0.120.0", want: true},
200+
{name: "empty version", version: "", minVersion: "0.120.0", want: false},
201+
{name: "garbage version", version: "abc", minVersion: "0.120.0", want: false},
202+
{name: "four-component version is invalid", version: "0.120.0.1", minVersion: "0.120.0", want: false},
203+
{name: "cursor exact minimum", version: "1.0.32", minVersion: "1.0.32", want: true},
204+
{name: "cursor above minimum", version: "1.1.0", minVersion: "1.0.32", want: true},
205+
{name: "cursor below minimum", version: "1.0.31", minVersion: "1.0.32", want: false},
206+
}
207+
208+
for _, tt := range tests {
209+
t.Run(tt.name, func(t *testing.T) {
210+
assert.Equal(t, tt.want, isExtensionVersionAtLeast(tt.version, tt.minVersion))
211+
})
212+
}
213+
}
214+
215+
// createFakeIDEExecutable creates a fake IDE command that outputs the given text
216+
// when called with --list-extensions --show-versions.
217+
func createFakeIDEExecutable(t *testing.T, dir, command, output string) {
218+
t.Helper()
219+
if runtime.GOOS == "windows" {
220+
// Write output to a temp file and use "type" to print it, avoiding escaping issues.
221+
payloadPath := filepath.Join(dir, command+"-payload.txt")
222+
err := os.WriteFile(payloadPath, []byte(output), 0o644)
223+
require.NoError(t, err)
224+
script := fmt.Sprintf("@echo off\ntype \"%s\"\n", payloadPath)
225+
err = os.WriteFile(filepath.Join(dir, command+".cmd"), []byte(script), 0o755)
226+
require.NoError(t, err)
227+
} else {
228+
// Use printf (a shell builtin) instead of cat to avoid PATH issues in tests.
229+
script := fmt.Sprintf("#!/bin/sh\nprintf '%%s' '%s'\n", output)
230+
err := os.WriteFile(filepath.Join(dir, command), []byte(script), 0o755)
231+
require.NoError(t, err)
232+
}
233+
}
234+
235+
func TestCheckIDESSHExtension_UpToDate(t *testing.T) {
236+
tmpDir := t.TempDir()
237+
t.Setenv("PATH", tmpDir)
238+
ctx, _ := cmdio.NewTestContextWithStdout(t.Context())
239+
240+
extensionOutput := "ms-python.python@2024.1.1\nms-vscode-remote.remote-ssh@0.123.0\n"
241+
createFakeIDEExecutable(t, tmpDir, "code", extensionOutput)
242+
243+
err := CheckIDESSHExtension(ctx, VSCodeOption)
244+
assert.NoError(t, err)
245+
}
246+
247+
func TestCheckIDESSHExtension_ExactMinVersion(t *testing.T) {
248+
tmpDir := t.TempDir()
249+
t.Setenv("PATH", tmpDir)
250+
ctx, _ := cmdio.NewTestContextWithStdout(t.Context())
251+
252+
extensionOutput := "ms-vscode-remote.remote-ssh@0.120.0\n"
253+
createFakeIDEExecutable(t, tmpDir, "code", extensionOutput)
254+
255+
err := CheckIDESSHExtension(ctx, VSCodeOption)
256+
assert.NoError(t, err)
257+
}
258+
259+
func TestCheckIDESSHExtension_Missing(t *testing.T) {
260+
tmpDir := t.TempDir()
261+
t.Setenv("PATH", tmpDir)
262+
ctx, _ := cmdio.NewTestContextWithStdout(t.Context())
263+
264+
extensionOutput := "ms-python.python@2024.1.1\n"
265+
createFakeIDEExecutable(t, tmpDir, "code", extensionOutput)
266+
267+
err := CheckIDESSHExtension(ctx, VSCodeOption)
268+
require.Error(t, err)
269+
assert.Contains(t, err.Error(), `"Remote - SSH"`)
270+
assert.Contains(t, err.Error(), "not installed")
271+
}
272+
273+
func TestCheckIDESSHExtension_Outdated(t *testing.T) {
274+
tmpDir := t.TempDir()
275+
t.Setenv("PATH", tmpDir)
276+
ctx, _ := cmdio.NewTestContextWithStdout(t.Context())
277+
278+
extensionOutput := "ms-vscode-remote.remote-ssh@0.100.0\n"
279+
createFakeIDEExecutable(t, tmpDir, "code", extensionOutput)
280+
281+
err := CheckIDESSHExtension(ctx, VSCodeOption)
282+
require.Error(t, err)
283+
assert.Contains(t, err.Error(), "0.100.0")
284+
assert.Contains(t, err.Error(), ">= 0.120.0")
285+
}
286+
287+
func TestCheckIDESSHExtension_Cursor(t *testing.T) {
288+
tmpDir := t.TempDir()
289+
t.Setenv("PATH", tmpDir)
290+
ctx, _ := cmdio.NewTestContextWithStdout(t.Context())
291+
292+
extensionOutput := "anysphere.remote-ssh@1.0.32\n"
293+
createFakeIDEExecutable(t, tmpDir, "cursor", extensionOutput)
294+
295+
err := CheckIDESSHExtension(ctx, CursorOption)
296+
assert.NoError(t, err)
297+
}

0 commit comments

Comments
 (0)