Skip to content

Commit 857404f

Browse files
anton-107claude
andcommitted
Replace --name flag with auto-generated session names and reconnect support
Remove the requirement for --name in serverless SSH connect. Sessions are now auto-generated with human-readable names (e.g. databricks-gpu-a10-20260310-a1b2c3), tracked in ~/.databricks/ssh-tunnel-sessions.json, and offered for reconnection on subsequent runs. Stale sessions are cleaned up automatically. Sessions expire after 24 hours. Also fixes known_hosts key mismatches for serverless by disabling strict host key checking (identity verified via Databricks auth). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c8eceee commit 857404f

File tree

6 files changed

+541
-18
lines changed

6 files changed

+541
-18
lines changed

experimental/ssh/internal/client/client.go

Lines changed: 138 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/databricks/cli/experimental/ssh/internal/keys"
2222
"github.com/databricks/cli/experimental/ssh/internal/proxy"
23+
"github.com/databricks/cli/experimental/ssh/internal/sessions"
2324
"github.com/databricks/cli/experimental/ssh/internal/sshconfig"
2425
"github.com/databricks/cli/experimental/ssh/internal/vscode"
2526
sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
@@ -99,11 +100,11 @@ type ClientOptions struct {
99100
}
100101

101102
func (o *ClientOptions) Validate() error {
102-
if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" {
103-
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
103+
if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" && o.Accelerator == "" {
104+
return errors.New("please provide --cluster or --accelerator flag")
104105
}
105-
if o.Accelerator != "" && o.ConnectionName == "" {
106-
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
106+
if o.Accelerator != "" && o.ClusterID != "" {
107+
return errors.New("--accelerator flag can only be used with serverless compute, not with --cluster")
107108
}
108109
// TODO: Remove when we add support for serverless CPU
109110
if o.ConnectionName != "" && o.Accelerator == "" {
@@ -122,7 +123,7 @@ func (o *ClientOptions) Validate() error {
122123
}
123124

124125
func (o *ClientOptions) IsServerlessMode() bool {
125-
return o.ClusterID == "" && o.ConnectionName != ""
126+
return o.ClusterID == "" && (o.ConnectionName != "" || o.Accelerator != "")
126127
}
127128

128129
// SessionIdentifier returns the unique identifier for the session.
@@ -202,9 +203,17 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
202203
cancel()
203204
}()
204205

206+
// For serverless without explicit --name: auto-generate or reconnect to existing session.
207+
if opts.IsServerlessMode() && opts.ConnectionName == "" && !opts.ProxyMode {
208+
err := resolveServerlessSession(ctx, client, &opts)
209+
if err != nil {
210+
return err
211+
}
212+
}
213+
205214
sessionID := opts.SessionIdentifier()
206215
if sessionID == "" {
207-
return errors.New("either --cluster or --name must be provided")
216+
return errors.New("either --cluster or --accelerator must be provided")
208217
}
209218

210219
if !opts.ProxyMode {
@@ -327,6 +336,20 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
327336
cmdio.LogString(ctx, "Connected!")
328337
}
329338

339+
// Persist the session for future reconnects.
340+
if opts.IsServerlessMode() && !opts.ProxyMode {
341+
err = sessions.Add(ctx, sessions.Session{
342+
Name: opts.ConnectionName,
343+
Accelerator: opts.Accelerator,
344+
WorkspaceHost: client.Config.Host,
345+
CreatedAt: time.Now(),
346+
ClusterID: clusterID,
347+
})
348+
if err != nil {
349+
log.Warnf(ctx, "Failed to save session state: %v", err)
350+
}
351+
}
352+
330353
if opts.ProxyMode {
331354
return runSSHProxy(ctx, client, serverPort, clusterID, opts)
332355
} else if opts.IDE != "" {
@@ -379,7 +402,12 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k
379402
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
380403
}
381404

382-
hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand)
405+
var hostConfig string
406+
if opts.IsServerlessMode() {
407+
hostConfig = sshconfig.GenerateServerlessHostConfig(hostName, userName, keyPath, proxyCommand)
408+
} else {
409+
hostConfig = sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand)
410+
}
383411

384412
_, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true)
385413
if err != nil {
@@ -547,15 +575,22 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server
547575

548576
hostName := opts.SessionIdentifier()
549577

578+
hostKeyChecking := "StrictHostKeyChecking=accept-new"
579+
if opts.IsServerlessMode() {
580+
hostKeyChecking = "StrictHostKeyChecking=no"
581+
}
582+
550583
sshArgs := []string{
551584
"-l", userName,
552585
"-i", privateKeyPath,
553586
"-o", "IdentitiesOnly=yes",
554-
"-o", "StrictHostKeyChecking=accept-new",
587+
"-o", hostKeyChecking,
555588
"-o", "ConnectTimeout=360",
556589
"-o", "ProxyCommand=" + proxyCommand,
557590
}
558-
if opts.UserKnownHostsFile != "" {
591+
if opts.IsServerlessMode() {
592+
sshArgs = append(sshArgs, "-o", "UserKnownHostsFile=/dev/null")
593+
} else if opts.UserKnownHostsFile != "" {
559594
sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile)
560595
}
561596
sshArgs = append(sshArgs, hostName)
@@ -703,3 +738,97 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC
703738

704739
return userName, serverPort, effectiveClusterID, nil
705740
}
741+
742+
// resolveServerlessSession handles auto-generation and reconnection for serverless sessions.
743+
// It checks local state for existing sessions matching the workspace and accelerator,
744+
// probes them to see if they're still alive, and prompts the user to reconnect or create new.
745+
func resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceClient, opts *ClientOptions) error {
746+
version := build.GetInfo().Version
747+
748+
matching, err := sessions.FindMatching(ctx, client.Config.Host, opts.Accelerator)
749+
if err != nil {
750+
log.Warnf(ctx, "Failed to load session state: %v", err)
751+
}
752+
753+
// Probe sessions to find alive ones (limit to 5 most recent to avoid latency).
754+
const maxProbe = 5
755+
if len(matching) > maxProbe {
756+
matching = matching[len(matching)-maxProbe:]
757+
}
758+
759+
var alive []sessions.Session
760+
for _, s := range matching {
761+
_, _, _, probeErr := getServerMetadata(ctx, client, s.Name, s.ClusterID, version, opts.Liteswap)
762+
if probeErr == nil {
763+
alive = append(alive, s)
764+
} else {
765+
cleanupStaleSession(ctx, client, s, version)
766+
}
767+
}
768+
769+
if len(alive) > 0 && cmdio.IsPromptSupported(ctx) {
770+
choices := make([]string, 0, len(alive)+1)
771+
for _, s := range alive {
772+
choices = append(choices, fmt.Sprintf("Reconnect to %s (started %s)", s.Name, s.CreatedAt.Format(time.RFC822)))
773+
}
774+
choices = append(choices, "Create new session")
775+
776+
choice, choiceErr := cmdio.AskSelect(ctx, "Found existing sessions:", choices)
777+
if choiceErr != nil {
778+
return fmt.Errorf("failed to prompt user: %w", choiceErr)
779+
}
780+
781+
for i, s := range alive {
782+
if choice == choices[i] {
783+
opts.ConnectionName = s.Name
784+
cmdio.LogString(ctx, fmt.Sprintf("Reconnecting to session: %s", s.Name))
785+
return nil
786+
}
787+
}
788+
}
789+
790+
// No alive session selected — generate a new name.
791+
opts.ConnectionName = sessions.GenerateSessionName(opts.Accelerator)
792+
cmdio.LogString(ctx, fmt.Sprintf("Creating new session: %s", opts.ConnectionName))
793+
return nil
794+
}
795+
796+
// cleanupStaleSession removes all local and remote artifacts for a stale session.
797+
func cleanupStaleSession(ctx context.Context, client *databricks.WorkspaceClient, s sessions.Session, version string) {
798+
// Remove local SSH keys.
799+
keyPath, err := keys.GetLocalSSHKeyPath(ctx, s.Name, "")
800+
if err == nil {
801+
os.RemoveAll(filepath.Dir(keyPath))
802+
}
803+
804+
// Remove SSH config entry.
805+
if err := sshconfig.RemoveHostConfig(ctx, s.Name); err != nil {
806+
log.Debugf(ctx, "Failed to remove SSH config for %s: %v", s.Name, err)
807+
}
808+
809+
// Delete secret scope (best-effort).
810+
me, err := client.CurrentUser.Me(ctx)
811+
if err == nil {
812+
scopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, s.Name)
813+
deleteErr := client.Secrets.DeleteScope(ctx, workspace.DeleteScope{Scope: scopeName})
814+
if deleteErr != nil {
815+
log.Debugf(ctx, "Failed to delete secret scope %s: %v", scopeName, deleteErr)
816+
}
817+
}
818+
819+
// Remove workspace content directory (best-effort).
820+
contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, s.Name)
821+
if err == nil {
822+
deleteErr := client.Workspace.Delete(ctx, workspace.Delete{Path: contentDir, Recursive: true})
823+
if deleteErr != nil {
824+
log.Debugf(ctx, "Failed to delete workspace content for %s: %v", s.Name, deleteErr)
825+
}
826+
}
827+
828+
// Remove from local state.
829+
if err := sessions.Remove(ctx, s.Name); err != nil {
830+
log.Debugf(ctx, "Failed to remove session %s from state: %v", s.Name, err)
831+
}
832+
833+
log.Infof(ctx, "Cleaned up stale session: %s", s.Name)
834+
}

experimental/ssh/internal/client/client_test.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ func TestValidate(t *testing.T) {
1818
wantErr string
1919
}{
2020
{
21-
name: "no cluster or connection name",
21+
name: "no cluster or connection name or accelerator",
2222
opts: client.ClientOptions{},
23-
wantErr: "please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)",
23+
wantErr: "please provide --cluster or --accelerator flag",
2424
},
2525
{
2626
name: "proxy mode skips cluster/name check",
@@ -31,9 +31,13 @@ func TestValidate(t *testing.T) {
3131
opts: client.ClientOptions{ClusterID: "abc-123"},
3232
},
3333
{
34-
name: "accelerator without connection name",
34+
name: "accelerator with cluster ID",
3535
opts: client.ClientOptions{ClusterID: "abc-123", Accelerator: "GPU_1xA10"},
36-
wantErr: "--accelerator flag can only be used with serverless compute (--name flag)",
36+
wantErr: "--accelerator flag can only be used with serverless compute, not with --cluster",
37+
},
38+
{
39+
name: "accelerator only (auto-generate session name)",
40+
opts: client.ClientOptions{Accelerator: "GPU_1xA10"},
3741
},
3842
{
3943
name: "connection name without accelerator",
@@ -55,8 +59,9 @@ func TestValidate(t *testing.T) {
5559
opts: client.ClientOptions{ConnectionName: "my-conn_1", Accelerator: "GPU_1xA10"},
5660
},
5761
{
58-
name: "both cluster ID and connection name",
59-
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"},
62+
name: "both cluster ID and connection name (no accelerator)",
63+
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn"},
64+
wantErr: "--name flag requires --accelerator to be set (for now we only support serverless GPU compute)",
6065
},
6166
{
6267
name: "proxy mode with invalid connection name",
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package sessions
2+
3+
import (
4+
"crypto/rand"
5+
"encoding/hex"
6+
"strings"
7+
"time"
8+
)
9+
10+
// acceleratorPrefixes maps known accelerator types to short human-readable prefixes.
11+
var acceleratorPrefixes = map[string]string{
12+
"GPU_1xA10": "gpu-a10",
13+
"GPU_8xH100": "gpu-h100",
14+
}
15+
16+
// GenerateSessionName creates a human-readable session name from the accelerator type.
17+
// Format: <prefix>-<random_hex>, e.g. "gpu-a10-f3a2b1c0".
18+
func GenerateSessionName(accelerator string) string {
19+
prefix, ok := acceleratorPrefixes[accelerator]
20+
if !ok {
21+
prefix = strings.ToLower(strings.ReplaceAll(accelerator, "_", "-"))
22+
}
23+
24+
date := time.Now().Format("20060102")
25+
b := make([]byte, 3)
26+
_, _ = rand.Read(b)
27+
return "databricks-" + prefix + "-" + date + "-" + hex.EncodeToString(b)
28+
}

0 commit comments

Comments
 (0)