Skip to content

Commit e8fa14a

Browse files
committed
Address review comments: identity tracking, safe cleanup, no StrictHostKeyChecking=no
- Add UserName to Session struct and include in FindMatching to prevent cross-identity session mixups when switching profiles - Only run cleanup on definitive errServerMetadata errors; log and skip on transient failures (network, auth) to avoid deleting live sessions - Add workspace host hash to generated session names to avoid SSH known_hosts conflicts across workspaces, removing the need for StrictHostKeyChecking=no and UserKnownHostsFile=/dev/null - Prune expired sessions from disk during FindMatching - Make resolveServerlessSession a method on ClientOptions - Handle rand.Read error explicitly Co-authored-by: Isaac
1 parent 083521b commit e8fa14a

File tree

5 files changed

+105
-63
lines changed

5 files changed

+105
-63
lines changed

experimental/ssh/internal/client/client.go

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
211211

212212
// For serverless without explicit --name: auto-generate or reconnect to existing session.
213213
if opts.IsServerlessMode() && opts.ConnectionName == "" && !opts.ProxyMode {
214-
err := resolveServerlessSession(ctx, client, &opts)
215-
if err != nil {
214+
if err := opts.resolveServerlessSession(ctx, client); err != nil {
216215
return err
217216
}
218217
}
@@ -343,10 +342,16 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
343342

344343
// Persist the session for future reconnects.
345344
if opts.IsServerlessMode() && !opts.ProxyMode {
345+
currentUser, userErr := client.CurrentUser.Me(ctx)
346+
sessionUserName := ""
347+
if userErr == nil {
348+
sessionUserName = currentUser.UserName
349+
}
346350
err = sessions.Add(ctx, sessions.Session{
347351
Name: opts.ConnectionName,
348352
Accelerator: opts.Accelerator,
349353
WorkspaceHost: client.Config.Host,
354+
UserName: sessionUserName,
350355
CreatedAt: time.Now(),
351356
ClusterID: clusterID,
352357
})
@@ -407,12 +412,7 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k
407412
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
408413
}
409414

410-
var hostConfig string
411-
if opts.IsServerlessMode() {
412-
hostConfig = sshconfig.GenerateServerlessHostConfig(hostName, userName, keyPath, proxyCommand)
413-
} else {
414-
hostConfig = sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand)
415-
}
415+
hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand)
416416

417417
_, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true)
418418
if err != nil {
@@ -580,22 +580,15 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server
580580

581581
hostName := opts.SessionIdentifier()
582582

583-
hostKeyChecking := "StrictHostKeyChecking=accept-new"
584-
if opts.IsServerlessMode() {
585-
hostKeyChecking = "StrictHostKeyChecking=no"
586-
}
587-
588583
sshArgs := []string{
589584
"-l", userName,
590585
"-i", privateKeyPath,
591586
"-o", "IdentitiesOnly=yes",
592-
"-o", hostKeyChecking,
587+
"-o", "StrictHostKeyChecking=accept-new",
593588
"-o", "ConnectTimeout=360",
594589
"-o", "ProxyCommand=" + proxyCommand,
595590
}
596-
if opts.IsServerlessMode() {
597-
sshArgs = append(sshArgs, "-o", "UserKnownHostsFile=/dev/null")
598-
} else if opts.UserKnownHostsFile != "" {
591+
if opts.UserKnownHostsFile != "" {
599592
sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile)
600593
}
601594
sshArgs = append(sshArgs, hostName)
@@ -740,12 +733,17 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
740733
}
741734

742735
// resolveServerlessSession handles auto-generation and reconnection for serverless sessions.
743-
// It checks local state for existing sessions matching the workspace and accelerator,
736+
// It checks local state for existing sessions matching the workspace, accelerator, and user,
744737
// probes them to see if they're still alive, and prompts the user to reconnect or create new.
745-
func resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceClient, opts *ClientOptions) error {
738+
func (opts *ClientOptions) resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceClient) error {
746739
version := build.GetInfo().Version
747740

748-
matching, err := sessions.FindMatching(ctx, client.Config.Host, opts.Accelerator)
741+
me, err := client.CurrentUser.Me(ctx)
742+
if err != nil {
743+
return fmt.Errorf("failed to get current user: %w", err)
744+
}
745+
746+
matching, err := sessions.FindMatching(ctx, client.Config.Host, opts.Accelerator, me.UserName)
749747
if err != nil {
750748
log.Warnf(ctx, "Failed to load session state: %v", err)
751749
}
@@ -761,8 +759,12 @@ func resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceC
761759
_, _, _, probeErr := getServerMetadata(ctx, client, s.Name, s.ClusterID, version, opts.Liteswap)
762760
if probeErr == nil {
763761
alive = append(alive, s)
764-
} else {
762+
} else if errors.Is(probeErr, errServerMetadata) {
763+
// Only clean up when the server is definitively gone (metadata endpoint returns not-found).
764+
// Transient errors (network, auth) should not trigger cleanup.
765765
cleanupStaleSession(ctx, client, s, version)
766+
} else {
767+
log.Warnf(ctx, "Transient error probing session %s, skipping: %v", s.Name, probeErr)
766768
}
767769
}
768770

@@ -788,7 +790,7 @@ func resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceC
788790
}
789791

790792
// No alive session selected — generate a new name.
791-
opts.ConnectionName = sessions.GenerateSessionName(opts.Accelerator)
793+
opts.ConnectionName = sessions.GenerateSessionName(opts.Accelerator, client.Config.Host)
792794
cmdio.LogString(ctx, "Creating new session: "+opts.ConnectionName)
793795
return nil
794796
}
Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package sessions
22

33
import (
4+
"crypto/md5"
45
"crypto/rand"
56
"encoding/hex"
7+
"fmt"
68
"strings"
79
"time"
810
)
@@ -13,16 +15,26 @@ var acceleratorPrefixes = map[string]string{
1315
"GPU_8xH100": "gpu-h100",
1416
}
1517

16-
// GenerateSessionName creates a human-readable session name from the accelerator type.
17-
// Format: <prefix>-<random_hex>, e.g. "gpu-a10-f3a2b1c0".
18-
func GenerateSessionName(accelerator string) string {
18+
// GenerateSessionName creates a human-readable session name from the accelerator type
19+
// and workspace host. The workspace host is hashed into the name to avoid SSH known_hosts
20+
// conflicts when connecting to different workspaces.
21+
// Format: databricks-<prefix>-<date>-<workspace_hash><random_hex>.
22+
func GenerateSessionName(accelerator, workspaceHost string) string {
1923
prefix, ok := acceleratorPrefixes[accelerator]
2024
if !ok {
2125
prefix = strings.ToLower(strings.ReplaceAll(accelerator, "_", "-"))
2226
}
2327

2428
date := time.Now().Format("20060102")
29+
30+
// Include a short hash of the workspace host to avoid known_hosts conflicts
31+
// when connecting to different workspaces.
32+
wsHash := md5.Sum([]byte(workspaceHost))
33+
wsHashStr := hex.EncodeToString(wsHash[:])[:4]
34+
2535
b := make([]byte, 3)
26-
_, _ = rand.Read(b)
27-
return "databricks-" + prefix + "-" + date + "-" + hex.EncodeToString(b)
36+
if _, err := rand.Read(b); err != nil {
37+
panic(fmt.Sprintf("crypto/rand.Read failed: %v", err))
38+
}
39+
return "databricks-" + prefix + "-" + date + "-" + wsHashStr + hex.EncodeToString(b)
2840
}

experimental/ssh/internal/sessions/sessions.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Session struct {
2323
Name string `json:"name"`
2424
Accelerator string `json:"accelerator"`
2525
WorkspaceHost string `json:"workspace_host"`
26+
UserName string `json:"user_name,omitempty"`
2627
CreatedAt time.Time `json:"created_at"`
2728
ClusterID string `json:"cluster_id,omitempty"`
2829
}
@@ -129,17 +130,35 @@ func Remove(ctx context.Context, name string) error {
129130
return Save(ctx, store)
130131
}
131132

132-
// FindMatching returns non-expired sessions that match the given workspace host and accelerator.
133-
func FindMatching(ctx context.Context, workspaceHost, accelerator string) ([]Session, error) {
133+
// FindMatching returns non-expired sessions that match the given workspace host, accelerator,
134+
// and user name. Expired sessions are pruned from the store on disk.
135+
func FindMatching(ctx context.Context, workspaceHost, accelerator, userName string) ([]Session, error) {
134136
store, err := Load(ctx)
135137
if err != nil {
136138
return nil, err
137139
}
138140

139141
cutoff := time.Now().Add(-sessionMaxAge)
140-
var result []Session
142+
143+
// Prune expired sessions from the store.
144+
active := store.Sessions[:0]
145+
pruned := false
141146
for _, s := range store.Sessions {
142-
if s.WorkspaceHost == workspaceHost && s.Accelerator == accelerator && s.CreatedAt.After(cutoff) {
147+
if s.CreatedAt.After(cutoff) {
148+
active = append(active, s)
149+
} else {
150+
pruned = true
151+
}
152+
}
153+
if pruned {
154+
store.Sessions = active
155+
// Best-effort save; don't fail the operation if pruning fails.
156+
_ = Save(ctx, store)
157+
}
158+
159+
var result []Session
160+
for _, s := range active {
161+
if s.WorkspaceHost == workspaceHost && s.Accelerator == accelerator && s.UserName == userName {
143162
result = append(result, s)
144163
}
145164
}

experimental/ssh/internal/sessions/sessions_test.go

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,32 +90,41 @@ func TestFindMatching(t *testing.T) {
9090

9191
ctx := t.Context()
9292
host := "https://test.databricks.com"
93+
user := "alice@example.com"
9394

9495
now := time.Now()
9596

96-
err := Add(ctx, Session{Name: "s1", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: now})
97+
err := Add(ctx, Session{Name: "s1", Accelerator: "GPU_1xA10", WorkspaceHost: host, UserName: user, CreatedAt: now})
9798
require.NoError(t, err)
98-
err = Add(ctx, Session{Name: "s2", Accelerator: "GPU_8xH100", WorkspaceHost: host, CreatedAt: now})
99+
err = Add(ctx, Session{Name: "s2", Accelerator: "GPU_8xH100", WorkspaceHost: host, UserName: user, CreatedAt: now})
99100
require.NoError(t, err)
100-
err = Add(ctx, Session{Name: "s3", Accelerator: "GPU_1xA10", WorkspaceHost: "https://other.com", CreatedAt: now})
101+
err = Add(ctx, Session{Name: "s3", Accelerator: "GPU_1xA10", WorkspaceHost: "https://other.com", UserName: user, CreatedAt: now})
101102
require.NoError(t, err)
102-
err = Add(ctx, Session{Name: "s4", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: now})
103+
err = Add(ctx, Session{Name: "s4", Accelerator: "GPU_1xA10", WorkspaceHost: host, UserName: user, CreatedAt: now})
104+
require.NoError(t, err)
105+
err = Add(ctx, Session{Name: "s5", Accelerator: "GPU_1xA10", WorkspaceHost: host, UserName: "bob@example.com", CreatedAt: now})
103106
require.NoError(t, err)
104107

105-
matches, err := FindMatching(ctx, host, "GPU_1xA10")
108+
matches, err := FindMatching(ctx, host, "GPU_1xA10", user)
106109
require.NoError(t, err)
107110
assert.Len(t, matches, 2)
108111
assert.Equal(t, "s1", matches[0].Name)
109112
assert.Equal(t, "s4", matches[1].Name)
110113

111-
matches, err = FindMatching(ctx, host, "GPU_8xH100")
114+
matches, err = FindMatching(ctx, host, "GPU_8xH100", user)
112115
require.NoError(t, err)
113116
assert.Len(t, matches, 1)
114117
assert.Equal(t, "s2", matches[0].Name)
115118

116-
matches, err = FindMatching(ctx, host, "GPU_4xA100")
119+
matches, err = FindMatching(ctx, host, "GPU_4xA100", user)
117120
require.NoError(t, err)
118121
assert.Empty(t, matches)
122+
123+
// Different user should not see alice's sessions
124+
matches, err = FindMatching(ctx, host, "GPU_1xA10", "bob@example.com")
125+
require.NoError(t, err)
126+
assert.Len(t, matches, 1)
127+
assert.Equal(t, "s5", matches[0].Name)
119128
}
120129

121130
func TestFindMatchingExpiresOldSessions(t *testing.T) {
@@ -125,16 +134,22 @@ func TestFindMatchingExpiresOldSessions(t *testing.T) {
125134

126135
ctx := t.Context()
127136
host := "https://test.databricks.com"
137+
user := "alice@example.com"
128138

129-
err := Add(ctx, Session{Name: "old", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: time.Now().Add(-25 * time.Hour)})
139+
err := Add(ctx, Session{Name: "old", Accelerator: "GPU_1xA10", WorkspaceHost: host, UserName: user, CreatedAt: time.Now().Add(-25 * time.Hour)})
130140
require.NoError(t, err)
131-
err = Add(ctx, Session{Name: "recent", Accelerator: "GPU_1xA10", WorkspaceHost: host, CreatedAt: time.Now()})
141+
err = Add(ctx, Session{Name: "recent", Accelerator: "GPU_1xA10", WorkspaceHost: host, UserName: user, CreatedAt: time.Now()})
132142
require.NoError(t, err)
133143

134-
matches, err := FindMatching(ctx, host, "GPU_1xA10")
144+
matches, err := FindMatching(ctx, host, "GPU_1xA10", user)
135145
require.NoError(t, err)
136146
require.Len(t, matches, 1)
137147
assert.Equal(t, "recent", matches[0].Name)
148+
149+
// Verify expired sessions were pruned from disk.
150+
store, err := Load(ctx)
151+
require.NoError(t, err)
152+
assert.Len(t, store.Sessions, 1, "expired sessions should be pruned from disk")
138153
}
139154

140155
func TestStateFilePath(t *testing.T) {
@@ -151,6 +166,7 @@ func TestStateFilePath(t *testing.T) {
151166
var connectionNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
152167

153168
func TestGenerateSessionName(t *testing.T) {
169+
host := "https://test.databricks.com"
154170
tests := []struct {
155171
accelerator string
156172
wantPrefix string
@@ -163,7 +179,7 @@ func TestGenerateSessionName(t *testing.T) {
163179

164180
for _, tt := range tests {
165181
t.Run(tt.accelerator, func(t *testing.T) {
166-
name := GenerateSessionName(tt.accelerator)
182+
name := GenerateSessionName(tt.accelerator, host)
167183
assert.Greater(t, len(name), len(tt.wantPrefix), "name should be longer than prefix")
168184
assert.Equal(t, tt.wantPrefix, name[:len(tt.wantPrefix)])
169185
// Verify date component is present (starts with "20" for 2000s dates).
@@ -173,10 +189,20 @@ func TestGenerateSessionName(t *testing.T) {
173189
}
174190
}
175191

192+
func TestGenerateSessionNameDiffersByWorkspace(t *testing.T) {
193+
name1 := GenerateSessionName("GPU_1xA10", "https://workspace-a.databricks.com")
194+
name2 := GenerateSessionName("GPU_1xA10", "https://workspace-b.databricks.com")
195+
// The workspace hash portion (after the date-) should differ.
196+
// Names have format: databricks-gpu-a10-YYYYMMDD-<wshash><random>
197+
// Extract after the date prefix to compare workspace hash parts.
198+
assert.NotEqual(t, name1, name2, "names for different workspaces should differ")
199+
}
200+
176201
func TestGenerateSessionNameUniqueness(t *testing.T) {
202+
host := "https://test.databricks.com"
177203
seen := make(map[string]bool)
178204
for range 100 {
179-
name := GenerateSessionName("GPU_1xA10")
205+
name := GenerateSessionName("GPU_1xA10", host)
180206
assert.False(t, seen[name], "duplicate name generated: %s", name)
181207
seen[name] = true
182208
}

experimental/ssh/internal/sshconfig/sshconfig.go

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -175,30 +175,13 @@ func RemoveHostConfig(ctx context.Context, hostName string) error {
175175

176176
// GenerateHostConfig generates an SSH host config block.
177177
func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string {
178-
return generateHostConfig(hostName, userName, identityFile, proxyCommand, false)
179-
}
180-
181-
// GenerateServerlessHostConfig generates an SSH host config block for serverless compute.
182-
// It disables strict host key checking since serverless containers generate fresh keys each time,
183-
// and identity is already verified through Databricks authentication and Driver Proxy.
184-
func GenerateServerlessHostConfig(hostName, userName, identityFile, proxyCommand string) string {
185-
return generateHostConfig(hostName, userName, identityFile, proxyCommand, true)
186-
}
187-
188-
func generateHostConfig(hostName, userName, identityFile, proxyCommand string, serverless bool) string {
189-
hostKeyChecking := "StrictHostKeyChecking accept-new"
190-
knownHostsLine := ""
191-
if serverless {
192-
hostKeyChecking = "StrictHostKeyChecking no"
193-
knownHostsLine = " UserKnownHostsFile /dev/null\n"
194-
}
195178
return fmt.Sprintf(`
196179
Host %s
197180
User %s
198181
ConnectTimeout 360
199-
%s
200-
%s IdentitiesOnly yes
182+
StrictHostKeyChecking accept-new
183+
IdentitiesOnly yes
201184
IdentityFile %q
202185
ProxyCommand %s
203-
`, hostName, userName, hostKeyChecking, knownHostsLine, identityFile, proxyCommand)
186+
`, hostName, userName, identityFile, proxyCommand)
204187
}

0 commit comments

Comments
 (0)