Skip to content

Commit c34f8d9

Browse files
authored
SSH: fix Include directive escaping and refactor backup logic (#4861)
## Summary The backup refactoring is actually secondary, the main fix is quoting paths inside "Include" statement that we add to ssh config. Noticed that we also didn't do backups for ssh configs, so here are backup related changes: - Extracts `BackupFile` from `vscode/settings.go` into a shared `fileutil` package with exported suffix constants (`SuffixOriginalBak`, `SuffixLatestBak`), eliminating hardcoded magic strings across callers and tests - Replaces `strings.Contains` with a line-aware `containsLine` helper in `sshconfig` to avoid false positives when the Include path appears in a comment or as a substring of another line - Adds migration from unquoted `Include` directives (written by older CLI versions) to the quoted form, which handles paths with spaces - A user visible change is that we consider backup errors as hard errors and don't proceed with the flow ## Tests Existing, new, and manually This pull request was AI-assisted by Isaac.
1 parent 518b428 commit c34f8d9

6 files changed

Lines changed: 385 additions & 50 deletions

File tree

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package fileutil
2+
3+
import (
4+
"context"
5+
"os"
6+
"path/filepath"
7+
8+
"github.com/databricks/cli/libs/log"
9+
)
10+
11+
const (
12+
SuffixOriginalBak = ".original.bak"
13+
SuffixLatestBak = ".latest.bak"
14+
)
15+
16+
// BackupFile saves data to path+".original.bak" on the first call, and
17+
// path+".latest.bak" on subsequent calls. Skips if data is empty.
18+
func BackupFile(ctx context.Context, path string, data []byte) error {
19+
if len(data) == 0 {
20+
return nil
21+
}
22+
originalBak := path + SuffixOriginalBak
23+
latestBak := path + SuffixLatestBak
24+
var bakPath string
25+
_, statErr := os.Stat(originalBak)
26+
if statErr != nil && !os.IsNotExist(statErr) {
27+
return statErr
28+
}
29+
if os.IsNotExist(statErr) {
30+
bakPath = originalBak
31+
} else {
32+
bakPath = latestBak
33+
}
34+
if err := os.WriteFile(bakPath, data, 0o600); err != nil {
35+
return err
36+
}
37+
log.Infof(ctx, "Backed up %s to %s", filepath.ToSlash(path), filepath.ToSlash(bakPath))
38+
return nil
39+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package fileutil_test
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"runtime"
7+
"testing"
8+
9+
"github.com/databricks/cli/experimental/ssh/internal/fileutil"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestBackupFile_EmptyData(t *testing.T) {
15+
tmpDir := t.TempDir()
16+
path := filepath.Join(tmpDir, "file.json")
17+
18+
err := fileutil.BackupFile(t.Context(), path, []byte{})
19+
require.NoError(t, err)
20+
21+
_, err = os.Stat(path + fileutil.SuffixOriginalBak)
22+
assert.True(t, os.IsNotExist(err))
23+
}
24+
25+
func TestBackupFile_FirstBackup(t *testing.T) {
26+
tmpDir := t.TempDir()
27+
path := filepath.Join(tmpDir, "file.json")
28+
data := []byte(`{"key": "value"}`)
29+
30+
err := fileutil.BackupFile(t.Context(), path, data)
31+
require.NoError(t, err)
32+
33+
content, err := os.ReadFile(path + fileutil.SuffixOriginalBak)
34+
require.NoError(t, err)
35+
assert.Equal(t, data, content)
36+
37+
_, err = os.Stat(path + fileutil.SuffixLatestBak)
38+
assert.True(t, os.IsNotExist(err))
39+
}
40+
41+
func TestBackupFile_SubsequentBackup(t *testing.T) {
42+
tmpDir := t.TempDir()
43+
path := filepath.Join(tmpDir, "file.json")
44+
original := []byte(`{"key": "value"}`)
45+
updated := []byte(`{"key": "updated"}`)
46+
47+
err := fileutil.BackupFile(t.Context(), path, original)
48+
require.NoError(t, err)
49+
50+
err = fileutil.BackupFile(t.Context(), path, updated)
51+
require.NoError(t, err)
52+
53+
// .original.bak must remain unchanged
54+
content, err := os.ReadFile(path + fileutil.SuffixOriginalBak)
55+
require.NoError(t, err)
56+
assert.Equal(t, original, content)
57+
58+
// .latest.bak should have the updated content
59+
content, err = os.ReadFile(path + fileutil.SuffixLatestBak)
60+
require.NoError(t, err)
61+
assert.Equal(t, updated, content)
62+
}
63+
64+
func TestBackupFile_WriteError(t *testing.T) {
65+
err := fileutil.BackupFile(t.Context(), "/nonexistent/dir/file.json", []byte("data"))
66+
assert.Error(t, err)
67+
}
68+
69+
func TestBackupFile_StatError(t *testing.T) {
70+
if runtime.GOOS == "windows" {
71+
t.Skip("chmod is not supported on windows")
72+
}
73+
74+
tmpDir := t.TempDir()
75+
path := filepath.Join(tmpDir, "file.json")
76+
77+
// Create the .original.bak file so os.Stat would find it — but make the
78+
// parent directory unreadable so Stat returns a permission error instead.
79+
require.NoError(t, os.WriteFile(path+fileutil.SuffixOriginalBak, []byte("existing"), 0o600))
80+
require.NoError(t, os.Chmod(tmpDir, 0o000))
81+
t.Cleanup(func() { _ = os.Chmod(tmpDir, 0o700) })
82+
83+
err := fileutil.BackupFile(t.Context(), path, []byte("data"))
84+
assert.Error(t, err)
85+
}

experimental/ssh/internal/sshconfig/sshconfig.go

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"path/filepath"
88
"strings"
99

10+
"github.com/databricks/cli/experimental/ssh/internal/fileutil"
1011
"github.com/databricks/cli/libs/cmdio"
1112
"github.com/databricks/cli/libs/env"
1213
)
@@ -80,11 +81,24 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error {
8081
// Convert path to forward slashes for SSH config compatibility across platforms
8182
configDirUnix := filepath.ToSlash(configDir)
8283

83-
includeLine := fmt.Sprintf("Include %s/*", configDirUnix)
84-
if strings.Contains(string(content), includeLine) {
84+
// Quoted to handle paths with spaces; OpenSSH still expands globs inside quotes.
85+
includeLine := fmt.Sprintf(`Include "%s/*"`, configDirUnix)
86+
if containsLine(content, includeLine) {
8587
return nil
8688
}
8789

90+
// Migrate unquoted Include written by older versions of the CLI.
91+
oldIncludeLine := fmt.Sprintf("Include %s/*", configDirUnix)
92+
if containsLine(content, oldIncludeLine) {
93+
if err := fileutil.BackupFile(ctx, configPath, content); err != nil {
94+
return fmt.Errorf("failed to backup SSH config before migration: %w", err)
95+
}
96+
return os.WriteFile(configPath, replaceLine(content, oldIncludeLine, includeLine), 0o600)
97+
}
98+
99+
if err := fileutil.BackupFile(ctx, configPath, content); err != nil {
100+
return fmt.Errorf("failed to backup SSH config: %w", err)
101+
}
88102
newContent := includeLine + "\n"
89103
if len(content) > 0 && !strings.HasPrefix(string(content), "\n") {
90104
newContent += "\n"
@@ -99,6 +113,31 @@ func EnsureIncludeDirective(ctx context.Context, configPath string) error {
99113
return nil
100114
}
101115

116+
// containsLine reports whether data contains line as a line match,
117+
// trimming leading whitespace and \r (Windows line endings) before comparing.
118+
func containsLine(data []byte, line string) bool {
119+
for l := range strings.SplitSeq(string(data), "\n") {
120+
if strings.TrimLeft(strings.TrimRight(l, "\r"), " \t") == line {
121+
return true
122+
}
123+
}
124+
return false
125+
}
126+
127+
// replaceLine replaces the first line in data whose trimmed content matches old
128+
// with new. Uses the same trim logic as containsLine. Returns data unchanged if
129+
// no match.
130+
func replaceLine(data []byte, old, new string) []byte {
131+
lines := strings.Split(string(data), "\n")
132+
for i, l := range lines {
133+
if strings.TrimLeft(strings.TrimRight(l, "\r"), " \t") == old {
134+
lines[i] = new
135+
break
136+
}
137+
}
138+
return []byte(strings.Join(lines, "\n"))
139+
}
140+
102141
func GetHostConfigPath(ctx context.Context, hostName string) (string, error) {
103142
configDir, err := GetConfigDir(ctx)
104143
if err != nil {

experimental/ssh/internal/sshconfig/sshconfig_test.go

Lines changed: 176 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) {
8080
configDir, err := GetConfigDir(t.Context())
8181
require.NoError(t, err)
8282

83-
// Use forward slashes as that's what SSH config uses
83+
// Use forward slashes and quotes as that's what SSH config uses
8484
configDirUnix := filepath.ToSlash(configDir)
85-
existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n"
85+
existingContent := `Include "` + configDirUnix + `/*"` + "\n\nHost example\n User test\n"
8686
err = os.MkdirAll(filepath.Dir(configPath), 0o700)
8787
require.NoError(t, err)
8888
err = os.WriteFile(configPath, []byte(existingContent), 0o600)
@@ -96,6 +96,59 @@ func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) {
9696
assert.Equal(t, existingContent, string(content))
9797
}
9898

99+
func TestEnsureIncludeDirective_MigratesOldUnquotedFormat(t *testing.T) {
100+
tmpDir := t.TempDir()
101+
t.Setenv("HOME", tmpDir)
102+
t.Setenv("USERPROFILE", tmpDir)
103+
104+
configPath := filepath.Join(tmpDir, ".ssh", "config")
105+
106+
configDir, err := GetConfigDir(t.Context())
107+
require.NoError(t, err)
108+
109+
configDirUnix := filepath.ToSlash(configDir)
110+
oldContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n"
111+
err = os.MkdirAll(filepath.Dir(configPath), 0o700)
112+
require.NoError(t, err)
113+
err = os.WriteFile(configPath, []byte(oldContent), 0o600)
114+
require.NoError(t, err)
115+
116+
err = EnsureIncludeDirective(t.Context(), configPath)
117+
assert.NoError(t, err)
118+
119+
content, err := os.ReadFile(configPath)
120+
assert.NoError(t, err)
121+
configStr := string(content)
122+
123+
assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`)
124+
assert.NotContains(t, configStr, "Include "+configDirUnix+"/*\n")
125+
assert.Contains(t, configStr, "Host example")
126+
}
127+
128+
func TestEnsureIncludeDirective_NotFooledBySubstring(t *testing.T) {
129+
tmpDir := t.TempDir()
130+
t.Setenv("HOME", tmpDir)
131+
t.Setenv("USERPROFILE", tmpDir)
132+
133+
configPath := filepath.Join(tmpDir, ".ssh", "config")
134+
135+
configDir, err := GetConfigDir(t.Context())
136+
require.NoError(t, err)
137+
138+
configDirUnix := filepath.ToSlash(configDir)
139+
// The include path appears only inside a comment, not as a standalone directive.
140+
existingContent := `# Include "` + configDirUnix + `/*"` + "\nHost example\n User test\n"
141+
require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700))
142+
require.NoError(t, os.WriteFile(configPath, []byte(existingContent), 0o600))
143+
144+
err = EnsureIncludeDirective(t.Context(), configPath)
145+
require.NoError(t, err)
146+
147+
content, err := os.ReadFile(configPath)
148+
require.NoError(t, err)
149+
assert.Contains(t, string(content), `Include "`+configDirUnix+`/*"`)
150+
}
151+
99152
func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) {
100153
tmpDir := t.TempDir()
101154
configPath := filepath.Join(tmpDir, ".ssh", "config")
@@ -127,6 +180,127 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) {
127180
assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content")
128181
}
129182

183+
func TestContainsLine(t *testing.T) {
184+
tests := []struct {
185+
name string
186+
data string
187+
line string
188+
found bool
189+
}{
190+
{"exact match", `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true},
191+
{"not present", "Host example\n", `Include "/path/*"`, false},
192+
{"substring only", `# Include "/path/*"`, `Include "/path/*"`, false},
193+
{"commented line", `# Include "/path/*"` + "\n" + `Include "/path/*"` + "\n", `Include "/path/*"`, true},
194+
{"windows line ending", `Include "/path/*"` + "\r\nHost example\r\n", `Include "/path/*"`, true},
195+
{"empty data", "", `Include "/path/*"`, false},
196+
{"indented with spaces", " " + `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true},
197+
{"indented with tab", "\t" + `Include "/path/*"` + "\nHost example\n", `Include "/path/*"`, true},
198+
}
199+
for _, tc := range tests {
200+
t.Run(tc.name, func(t *testing.T) {
201+
assert.Equal(t, tc.found, containsLine([]byte(tc.data), tc.line))
202+
})
203+
}
204+
}
205+
206+
func TestReplaceLine(t *testing.T) {
207+
tests := []struct {
208+
name string
209+
data string
210+
old string
211+
new string
212+
expected string
213+
}{
214+
{
215+
"exact match",
216+
`Include "/p/*"` + "\nHost x\n",
217+
`Include "/p/*"`, `Include "/p/*" NEW`,
218+
`Include "/p/*" NEW` + "\nHost x\n",
219+
},
220+
{
221+
"indented match",
222+
" " + `Include "/p/*"` + "\nHost x\n",
223+
`Include "/p/*"`, `Include "/p/*" NEW`,
224+
`Include "/p/*" NEW` + "\nHost x\n",
225+
},
226+
{
227+
"no match",
228+
"Host x\n",
229+
`Include "/p/*"`, `Include "/p/*" NEW`,
230+
"Host x\n",
231+
},
232+
{
233+
"substring in comment — must not be replaced",
234+
`# Include "/p/*"` + "\nHost x\n",
235+
`Include "/p/*"`, `Include "/p/*" NEW`,
236+
`# Include "/p/*"` + "\nHost x\n",
237+
},
238+
}
239+
for _, tc := range tests {
240+
t.Run(tc.name, func(t *testing.T) {
241+
got := replaceLine([]byte(tc.data), tc.old, tc.new)
242+
assert.Equal(t, tc.expected, string(got))
243+
})
244+
}
245+
}
246+
247+
func TestEnsureIncludeDirective_MigratesIndentedOldFormat(t *testing.T) {
248+
tmpDir := t.TempDir()
249+
t.Setenv("HOME", tmpDir)
250+
t.Setenv("USERPROFILE", tmpDir)
251+
252+
configPath := filepath.Join(tmpDir, ".ssh", "config")
253+
254+
configDir, err := GetConfigDir(t.Context())
255+
require.NoError(t, err)
256+
257+
configDirUnix := filepath.ToSlash(configDir)
258+
// Old format with leading whitespace — should still be detected and migrated.
259+
oldContent := " Include " + configDirUnix + "/*\n\nHost example\n User test\n"
260+
require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700))
261+
require.NoError(t, os.WriteFile(configPath, []byte(oldContent), 0o600))
262+
263+
err = EnsureIncludeDirective(t.Context(), configPath)
264+
require.NoError(t, err)
265+
266+
content, err := os.ReadFile(configPath)
267+
require.NoError(t, err)
268+
configStr := string(content)
269+
270+
assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`)
271+
assert.NotContains(t, configStr, " Include "+configDirUnix+"/*")
272+
assert.Contains(t, configStr, "Host example")
273+
}
274+
275+
func TestEnsureIncludeDirective_NotFooledByOldFormatSubstring(t *testing.T) {
276+
tmpDir := t.TempDir()
277+
t.Setenv("HOME", tmpDir)
278+
t.Setenv("USERPROFILE", tmpDir)
279+
280+
configPath := filepath.Join(tmpDir, ".ssh", "config")
281+
282+
configDir, err := GetConfigDir(t.Context())
283+
require.NoError(t, err)
284+
285+
configDirUnix := filepath.ToSlash(configDir)
286+
// Old unquoted form appears only inside a comment — must not be migrated.
287+
existingContent := "# Include " + configDirUnix + "/*\nHost example\n User test\n"
288+
require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700))
289+
require.NoError(t, os.WriteFile(configPath, []byte(existingContent), 0o600))
290+
291+
err = EnsureIncludeDirective(t.Context(), configPath)
292+
require.NoError(t, err)
293+
294+
content, err := os.ReadFile(configPath)
295+
require.NoError(t, err)
296+
configStr := string(content)
297+
298+
// New quoted directive should have been prepended (not a migration).
299+
assert.Contains(t, configStr, `Include "`+configDirUnix+`/*"`)
300+
// Comment line must be preserved unchanged.
301+
assert.Contains(t, configStr, "# Include "+configDirUnix+"/*")
302+
}
303+
130304
func TestGetHostConfigPath(t *testing.T) {
131305
path, err := GetHostConfigPath(t.Context(), "test-host")
132306
assert.NoError(t, err)

0 commit comments

Comments
 (0)