Skip to content

Commit 217deb5

Browse files
authored
SSH: better validation for "ssh connect" options (#4623)
## Changes - Add validation for connection name to ensure it only contains letters, numbers, dashes, and underscores - Add validation for IDE value to ensure it is either "vscode" or "cursor" - Move existing and new validation logic to the ClientOptions level - Add tests for ClientOptions Validate and ToProxyCommand ## Why We didn't have validation for the --name flag, and --ide validation was too late in the process ## Tests New unit tests
1 parent 3b7b63d commit 217deb5

File tree

3 files changed

+178
-21
lines changed

3 files changed

+178
-21
lines changed

experimental/ssh/cmd/connect.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package ssh
22

33
import (
4-
"errors"
54
"time"
65

76
"github.com/databricks/cli/cmd/root"
@@ -82,22 +81,6 @@ the SSH server and handling the connection proxy.
8281
cmd.RunE = func(cmd *cobra.Command, args []string) error {
8382
ctx := cmd.Context()
8483
wsClient := cmdctx.WorkspaceClient(ctx)
85-
86-
if !proxyMode && clusterID == "" && connectionName == "" {
87-
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
88-
}
89-
90-
if accelerator != "" && connectionName == "" {
91-
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
92-
}
93-
94-
// Remove when we add support for serverless CPU
95-
if connectionName != "" && accelerator == "" {
96-
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
97-
}
98-
99-
// TODO: validate connectionName if provided
100-
10184
opts := client.ClientOptions{
10285
Profile: wsClient.Config.Profile,
10386
ClusterID: clusterID,
@@ -120,6 +103,9 @@ the SSH server and handling the connection proxy.
120103
SkipSettingsCheck: skipSettingsCheck,
121104
AdditionalArgs: args,
122105
}
106+
if err := opts.Validate(); err != nil {
107+
return err
108+
}
123109
return client.Run(ctx, wsClient, opts)
124110
}
125111

experimental/ssh/internal/client/client.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"os/exec"
1313
"os/signal"
1414
"path/filepath"
15+
"regexp"
1516
"strconv"
1617
"strings"
1718
"syscall"
@@ -38,6 +39,8 @@ var sshServerBootstrapScript string
3839

3940
var errServerMetadata = errors.New("server metadata error")
4041

42+
var connectionNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
43+
4144
const (
4245
sshServerTaskKey = "start_ssh_server"
4346
serverlessEnvironmentKey = "ssh_tunnel_serverless"
@@ -97,6 +100,26 @@ type ClientOptions struct {
97100
SkipSettingsCheck bool
98101
}
99102

103+
func (o *ClientOptions) Validate() error {
104+
if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" {
105+
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
106+
}
107+
if o.Accelerator != "" && o.ConnectionName == "" {
108+
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
109+
}
110+
// TODO: Remove when we add support for serverless CPU
111+
if o.ConnectionName != "" && o.Accelerator == "" {
112+
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
113+
}
114+
if o.ConnectionName != "" && !connectionNameRegex.MatchString(o.ConnectionName) {
115+
return fmt.Errorf("connection name %q must consist of letters, numbers, dashes, and underscores", o.ConnectionName)
116+
}
117+
if o.IDE != "" && o.IDE != VSCodeOption && o.IDE != CursorOption {
118+
return fmt.Errorf("invalid IDE value: %q, expected %q or %q", o.IDE, VSCodeOption, CursorOption)
119+
}
120+
return nil
121+
}
122+
100123
func (o *ClientOptions) IsServerlessMode() bool {
101124
return o.ClusterID == "" && o.ConnectionName != ""
102125
}
@@ -287,10 +310,6 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
287310
}
288311

289312
func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
290-
if opts.IDE != VSCodeOption && opts.IDE != CursorOption {
291-
return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption)
292-
}
293-
294313
connectionName := opts.SessionIdentifier()
295314
if connectionName == "" {
296315
return errors.New("connection name is required for IDE integration")
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package client_test
2+
3+
import (
4+
"fmt"
5+
"os"
6+
"testing"
7+
"time"
8+
9+
"github.com/databricks/cli/experimental/ssh/internal/client"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestValidate(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
opts client.ClientOptions
18+
wantErr string
19+
}{
20+
{
21+
name: "no cluster or connection name",
22+
opts: client.ClientOptions{},
23+
wantErr: "please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)",
24+
},
25+
{
26+
name: "proxy mode skips cluster/name check",
27+
opts: client.ClientOptions{ProxyMode: true},
28+
},
29+
{
30+
name: "cluster ID only",
31+
opts: client.ClientOptions{ClusterID: "abc-123"},
32+
},
33+
{
34+
name: "accelerator without connection name",
35+
opts: client.ClientOptions{ClusterID: "abc-123", Accelerator: "GPU_1xA10"},
36+
wantErr: "--accelerator flag can only be used with serverless compute (--name flag)",
37+
},
38+
{
39+
name: "connection name without accelerator",
40+
opts: client.ClientOptions{ConnectionName: "my-conn"},
41+
wantErr: "--name flag requires --accelerator to be set (for now we only support serverless GPU compute)",
42+
},
43+
{
44+
name: "invalid connection name characters",
45+
opts: client.ClientOptions{ConnectionName: "my conn!", Accelerator: "GPU_1xA10"},
46+
wantErr: `connection name "my conn!" must consist of letters, numbers, dashes, and underscores`,
47+
},
48+
{
49+
name: "connection name with leading dash",
50+
opts: client.ClientOptions{ConnectionName: "-my-conn", Accelerator: "GPU_1xA10"},
51+
wantErr: `connection name "-my-conn" must consist of letters, numbers, dashes, and underscores`,
52+
},
53+
{
54+
name: "valid connection name with accelerator",
55+
opts: client.ClientOptions{ConnectionName: "my-conn_1", Accelerator: "GPU_1xA10"},
56+
},
57+
{
58+
name: "both cluster ID and connection name",
59+
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"},
60+
},
61+
{
62+
name: "proxy mode with invalid connection name",
63+
opts: client.ClientOptions{ProxyMode: true, ConnectionName: "bad name!", Accelerator: "GPU_1xA10"},
64+
wantErr: `connection name "bad name!" must consist of letters, numbers, dashes, and underscores`,
65+
},
66+
{
67+
name: "invalid IDE value",
68+
opts: client.ClientOptions{ClusterID: "abc-123", IDE: "vim"},
69+
wantErr: `invalid IDE value: "vim", expected "vscode" or "cursor"`,
70+
},
71+
{
72+
name: "valid IDE vscode",
73+
opts: client.ClientOptions{ClusterID: "abc-123", IDE: "vscode"},
74+
},
75+
{
76+
name: "valid IDE cursor",
77+
opts: client.ClientOptions{ClusterID: "abc-123", IDE: "cursor"},
78+
},
79+
}
80+
81+
for _, tt := range tests {
82+
t.Run(tt.name, func(t *testing.T) {
83+
err := tt.opts.Validate()
84+
if tt.wantErr == "" {
85+
assert.NoError(t, err)
86+
} else {
87+
assert.EqualError(t, err, tt.wantErr)
88+
}
89+
})
90+
}
91+
}
92+
93+
func TestToProxyCommand(t *testing.T) {
94+
exe, err := os.Executable()
95+
require.NoError(t, err)
96+
quoted := fmt.Sprintf("%q", exe)
97+
98+
tests := []struct {
99+
name string
100+
opts client.ClientOptions
101+
want string
102+
}{
103+
{
104+
name: "dedicated cluster",
105+
opts: client.ClientOptions{ClusterID: "abc-123", ShutdownDelay: 5 * time.Minute},
106+
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=5m0s",
107+
},
108+
{
109+
name: "dedicated cluster with auto-start",
110+
opts: client.ClientOptions{ClusterID: "abc-123", AutoStartCluster: true, ShutdownDelay: 5 * time.Minute},
111+
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=true --shutdown-delay=5m0s",
112+
},
113+
{
114+
name: "serverless",
115+
opts: client.ClientOptions{ConnectionName: "my-conn", ShutdownDelay: 2 * time.Minute},
116+
want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s",
117+
},
118+
{
119+
name: "serverless with accelerator",
120+
opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", ShutdownDelay: 2 * time.Minute},
121+
want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --accelerator=GPU_1xA10",
122+
},
123+
{
124+
name: "with metadata",
125+
opts: client.ClientOptions{ClusterID: "abc-123", ServerMetadata: "user,2222,abc-123"},
126+
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --metadata=user,2222,abc-123",
127+
},
128+
{
129+
name: "with handover timeout",
130+
opts: client.ClientOptions{ClusterID: "abc-123", HandoverTimeout: 10 * time.Minute},
131+
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --handover-timeout=10m0s",
132+
},
133+
{
134+
name: "with profile",
135+
opts: client.ClientOptions{ClusterID: "abc-123", Profile: "my-profile"},
136+
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --profile=my-profile",
137+
},
138+
{
139+
name: "with liteswap",
140+
opts: client.ClientOptions{ClusterID: "abc-123", Liteswap: "test-env"},
141+
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --liteswap=test-env",
142+
},
143+
}
144+
145+
for _, tt := range tests {
146+
t.Run(tt.name, func(t *testing.T) {
147+
got, err := tt.opts.ToProxyCommand()
148+
require.NoError(t, err)
149+
assert.Equal(t, tt.want, got)
150+
})
151+
}
152+
}

0 commit comments

Comments
 (0)