Skip to content

Commit cb290d4

Browse files
pieternclaude
andauthored
Use env.UserHomeDir(ctx) in experimental/ssh package (#4658)
## Summary - Accept `context.Context` in SSH config and key path functions to use `env.UserHomeDir(ctx)` instead of `os.UserHomeDir()` - Follow-up to #4654 which made this change for the rest of the codebase ## Test plan - [x] `go test ./experimental/ssh/...` passes 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 50aa8ba commit cb290d4

File tree

6 files changed

+45
-43
lines changed

6 files changed

+45
-43
lines changed

experimental/ssh/internal/client/client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
220220
return fmt.Errorf("failed to get or generate SSH key pair from secrets: %w", err)
221221
}
222222

223-
keyPath, err := keys.GetLocalSSHKeyPath(sessionID, opts.SSHKeysDir)
223+
keyPath, err := keys.GetLocalSSHKeyPath(ctx, sessionID, opts.SSHKeysDir)
224224
if err != nil {
225225
return fmt.Errorf("failed to get local keys folder: %w", err)
226226
}
@@ -323,7 +323,7 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k
323323
databricksUserName := currentUser.UserName
324324

325325
// Ensure SSH config entry exists
326-
configPath, err := sshconfig.GetMainConfigPath()
326+
configPath, err := sshconfig.GetMainConfigPath(ctx)
327327
if err != nil {
328328
return fmt.Errorf("failed to get SSH config path: %w", err)
329329
}
@@ -354,7 +354,7 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k
354354

355355
func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
356356
// Ensure the Include directive exists in the main SSH config
357-
err := sshconfig.EnsureIncludeDirective(configPath)
357+
err := sshconfig.EnsureIncludeDirective(ctx, configPath)
358358
if err != nil {
359359
return err
360360
}

experimental/ssh/internal/keys/keys.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ import (
1010
"os"
1111
"path/filepath"
1212

13+
"github.com/databricks/cli/libs/env"
1314
"github.com/databricks/databricks-sdk-go"
1415
"golang.org/x/crypto/ssh"
1516
)
1617

1718
// We use different client keys for each session as a good practice for better isolation and control.
1819
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
19-
func GetLocalSSHKeyPath(sessionID, keysDir string) (string, error) {
20+
func GetLocalSSHKeyPath(ctx context.Context, sessionID, keysDir string) (string, error) {
2021
if keysDir == "" {
21-
homeDir, err := os.UserHomeDir() //nolint:forbidigo // TODO: thread ctx through public API
22+
homeDir, err := env.UserHomeDir(ctx)
2223
if err != nil {
2324
return "", fmt.Errorf("failed to get home directory: %w", err)
2425
}

experimental/ssh/internal/setup/setup.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie
4343
return nil
4444
}
4545

46-
func generateHostConfig(opts SetupOptions) (string, error) {
47-
identityFilePath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir)
46+
func generateHostConfig(ctx context.Context, opts SetupOptions) (string, error) {
47+
identityFilePath, err := keys.GetLocalSSHKeyPath(ctx, opts.ClusterID, opts.SSHKeysDir)
4848
if err != nil {
4949
return "", fmt.Errorf("failed to get local keys folder: %w", err)
5050
}
@@ -90,22 +90,22 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
9090
return err
9191
}
9292

93-
configPath, err := sshconfig.GetMainConfigPathOrDefault(opts.SSHConfigPath)
93+
configPath, err := sshconfig.GetMainConfigPathOrDefault(ctx, opts.SSHConfigPath)
9494
if err != nil {
9595
return err
9696
}
9797

98-
err = sshconfig.EnsureIncludeDirective(configPath)
98+
err = sshconfig.EnsureIncludeDirective(ctx, configPath)
9999
if err != nil {
100100
return err
101101
}
102102

103-
hostConfig, err := generateHostConfig(opts)
103+
hostConfig, err := generateHostConfig(ctx, opts)
104104
if err != nil {
105105
return err
106106
}
107107

108-
exists, err := sshconfig.HostConfigExists(opts.HostName)
108+
exists, err := sshconfig.HostConfigExists(ctx, opts.HostName)
109109
if err != nil {
110110
return err
111111
}
@@ -129,7 +129,7 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp
129129
return err
130130
}
131131

132-
hostConfigPath, err := sshconfig.GetHostConfigPath(opts.HostName)
132+
hostConfigPath, err := sshconfig.GetHostConfigPath(ctx, opts.HostName)
133133
if err != nil {
134134
return err
135135
}

experimental/ssh/internal/setup/setup_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func TestGenerateHostConfig_Valid(t *testing.T) {
137137
ProxyCommand: proxyCommand,
138138
}
139139

140-
result, err := generateHostConfig(opts)
140+
result, err := generateHostConfig(t.Context(), opts)
141141
assert.NoError(t, err)
142142

143143
assert.Contains(t, result, "Host test-host")
@@ -172,7 +172,7 @@ func TestGenerateHostConfig_WithoutProfile(t *testing.T) {
172172
ProxyCommand: proxyCommand,
173173
}
174174

175-
result, err := generateHostConfig(opts)
175+
result, err := generateHostConfig(t.Context(), opts)
176176
assert.NoError(t, err)
177177

178178
assert.NotContains(t, result, "--profile=")
@@ -193,7 +193,7 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) {
193193
ShutdownDelay: 30 * time.Second,
194194
}
195195

196-
result, err := generateHostConfig(opts)
196+
result, err := generateHostConfig(t.Context(), opts)
197197
assert.NoError(t, err)
198198

199199
// Check that quotes are properly escaped

experimental/ssh/internal/sshconfig/sshconfig.go

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,35 @@ import (
88
"strings"
99

1010
"github.com/databricks/cli/libs/cmdio"
11+
"github.com/databricks/cli/libs/env"
1112
)
1213

1314
const (
1415
// configDirName is the directory name for Databricks SSH tunnel configs, relative to the user's home directory.
1516
configDirName = ".databricks/ssh-tunnel-configs"
1617
)
1718

18-
func GetConfigDir() (string, error) {
19-
homeDir, err := os.UserHomeDir() //nolint:forbidigo // TODO: thread ctx through public API
19+
func GetConfigDir(ctx context.Context) (string, error) {
20+
homeDir, err := env.UserHomeDir(ctx)
2021
if err != nil {
2122
return "", fmt.Errorf("failed to get home directory: %w", err)
2223
}
2324
return filepath.Join(homeDir, configDirName), nil
2425
}
2526

26-
func GetMainConfigPath() (string, error) {
27-
homeDir, err := os.UserHomeDir() //nolint:forbidigo // TODO: thread ctx through public API
27+
func GetMainConfigPath(ctx context.Context) (string, error) {
28+
homeDir, err := env.UserHomeDir(ctx)
2829
if err != nil {
2930
return "", fmt.Errorf("failed to get home directory: %w", err)
3031
}
3132
return filepath.Join(homeDir, ".ssh", "config"), nil
3233
}
3334

34-
func GetMainConfigPathOrDefault(configPath string) (string, error) {
35+
func GetMainConfigPathOrDefault(ctx context.Context, configPath string) (string, error) {
3536
if configPath != "" {
3637
return configPath, nil
3738
}
38-
return GetMainConfigPath()
39+
return GetMainConfigPath(ctx)
3940
}
4041

4142
func EnsureMainConfigExists(configPath string) error {
@@ -55,8 +56,8 @@ func EnsureMainConfigExists(configPath string) error {
5556
return err
5657
}
5758

58-
func EnsureIncludeDirective(configPath string) error {
59-
configDir, err := GetConfigDir()
59+
func EnsureIncludeDirective(ctx context.Context, configPath string) error {
60+
configDir, err := GetConfigDir(ctx)
6061
if err != nil {
6162
return err
6263
}
@@ -98,16 +99,16 @@ func EnsureIncludeDirective(configPath string) error {
9899
return nil
99100
}
100101

101-
func GetHostConfigPath(hostName string) (string, error) {
102-
configDir, err := GetConfigDir()
102+
func GetHostConfigPath(ctx context.Context, hostName string) (string, error) {
103+
configDir, err := GetConfigDir(ctx)
103104
if err != nil {
104105
return "", err
105106
}
106107
return filepath.Join(configDir, hostName), nil
107108
}
108109

109-
func HostConfigExists(hostName string) (bool, error) {
110-
configPath, err := GetHostConfigPath(hostName)
110+
func HostConfigExists(ctx context.Context, hostName string) (bool, error) {
111+
configPath, err := GetHostConfigPath(ctx, hostName)
111112
if err != nil {
112113
return false, err
113114
}
@@ -123,12 +124,12 @@ func HostConfigExists(hostName string) (bool, error) {
123124

124125
// Returns true if the config was created/updated, false if it was skipped.
125126
func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, recreate bool) (bool, error) {
126-
configPath, err := GetHostConfigPath(hostName)
127+
configPath, err := GetHostConfigPath(ctx, hostName)
127128
if err != nil {
128129
return false, err
129130
}
130131

131-
exists, err := HostConfigExists(hostName)
132+
exists, err := HostConfigExists(ctx, hostName)
132133
if err != nil {
133134
return false, err
134135
}

experimental/ssh/internal/sshconfig/sshconfig_test.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,23 @@ import (
1111
)
1212

1313
func TestGetConfigDir(t *testing.T) {
14-
dir, err := GetConfigDir()
14+
dir, err := GetConfigDir(t.Context())
1515
assert.NoError(t, err)
1616
assert.Contains(t, dir, filepath.Join(".databricks", "ssh-tunnel-configs"))
1717
}
1818

1919
func TestGetMainConfigPath(t *testing.T) {
20-
path, err := GetMainConfigPath()
20+
path, err := GetMainConfigPath(t.Context())
2121
assert.NoError(t, err)
2222
assert.Contains(t, path, filepath.Join(".ssh", "config"))
2323
}
2424

2525
func TestGetMainConfigPathOrDefault(t *testing.T) {
26-
path, err := GetMainConfigPathOrDefault("/custom/path")
26+
path, err := GetMainConfigPathOrDefault(t.Context(), "/custom/path")
2727
assert.NoError(t, err)
2828
assert.Equal(t, "/custom/path", path)
2929

30-
path, err = GetMainConfigPathOrDefault("")
30+
path, err = GetMainConfigPathOrDefault(t.Context(), "")
3131
assert.NoError(t, err)
3232
assert.Contains(t, path, filepath.Join(".ssh", "config"))
3333
}
@@ -58,7 +58,7 @@ func TestEnsureIncludeDirective_NewConfig(t *testing.T) {
5858
t.Setenv("HOME", tmpDir)
5959
t.Setenv("USERPROFILE", tmpDir)
6060

61-
err := EnsureIncludeDirective(configPath)
61+
err := EnsureIncludeDirective(t.Context(), configPath)
6262
assert.NoError(t, err)
6363

6464
content, err := os.ReadFile(configPath)
@@ -77,7 +77,7 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) {
7777

7878
configPath := filepath.Join(tmpDir, ".ssh", "config")
7979

80-
configDir, err := GetConfigDir()
80+
configDir, err := GetConfigDir(t.Context())
8181
require.NoError(t, err)
8282

8383
// Use forward slashes as that's what SSH config uses
@@ -88,7 +88,7 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) {
8888
err = os.WriteFile(configPath, []byte(existingContent), 0o600)
8989
require.NoError(t, err)
9090

91-
err = EnsureIncludeDirective(configPath)
91+
err = EnsureIncludeDirective(t.Context(), configPath)
9292
assert.NoError(t, err)
9393

9494
content, err := os.ReadFile(configPath)
@@ -110,7 +110,7 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) {
110110
err = os.WriteFile(configPath, []byte(existingContent), 0o600)
111111
require.NoError(t, err)
112112

113-
err = EnsureIncludeDirective(configPath)
113+
err = EnsureIncludeDirective(t.Context(), configPath)
114114
assert.NoError(t, err)
115115

116116
content, err := os.ReadFile(configPath)
@@ -128,7 +128,7 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) {
128128
}
129129

130130
func TestGetHostConfigPath(t *testing.T) {
131-
path, err := GetHostConfigPath("test-host")
131+
path, err := GetHostConfigPath(t.Context(), "test-host")
132132
assert.NoError(t, err)
133133
assert.Contains(t, path, filepath.Join(".databricks", "ssh-tunnel-configs", "test-host"))
134134
}
@@ -138,7 +138,7 @@ func TestHostConfigExists(t *testing.T) {
138138
t.Setenv("HOME", tmpDir)
139139
t.Setenv("USERPROFILE", tmpDir)
140140

141-
exists, err := HostConfigExists("nonexistent")
141+
exists, err := HostConfigExists(t.Context(), "nonexistent")
142142
assert.NoError(t, err)
143143
assert.False(t, exists)
144144

@@ -148,7 +148,7 @@ func TestHostConfigExists(t *testing.T) {
148148
err = os.WriteFile(filepath.Join(configDir, "existing-host"), []byte("config"), 0o600)
149149
require.NoError(t, err)
150150

151-
exists, err = HostConfigExists("existing-host")
151+
exists, err = HostConfigExists(t.Context(), "existing-host")
152152
assert.NoError(t, err)
153153
assert.True(t, exists)
154154
}
@@ -164,7 +164,7 @@ func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) {
164164
assert.NoError(t, err)
165165
assert.True(t, created)
166166

167-
configPath, err := GetHostConfigPath("test-host")
167+
configPath, err := GetHostConfigPath(ctx, "test-host")
168168
require.NoError(t, err)
169169
content, err := os.ReadFile(configPath)
170170
assert.NoError(t, err)
@@ -189,7 +189,7 @@ func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) {
189189
assert.NoError(t, err)
190190
assert.False(t, created)
191191

192-
configPath, err := GetHostConfigPath("test-host")
192+
configPath, err := GetHostConfigPath(ctx, "test-host")
193193
require.NoError(t, err)
194194
content, err := os.ReadFile(configPath)
195195
assert.NoError(t, err)
@@ -214,7 +214,7 @@ func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) {
214214
assert.NoError(t, err)
215215
assert.True(t, created)
216216

217-
configPath, err := GetHostConfigPath("test-host")
217+
configPath, err := GetHostConfigPath(ctx, "test-host")
218218
require.NoError(t, err)
219219
content, err := os.ReadFile(configPath)
220220
assert.NoError(t, err)

0 commit comments

Comments
 (0)