Skip to content

Commit 8ab5487

Browse files
committed
test: cover additional root error branches
1 parent cd3b80d commit 8ab5487

2 files changed

Lines changed: 144 additions & 3 deletions

File tree

cmd/root.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ var cfgFile string
3434
var hostsFile string
3535
var hostGroup string
3636

37+
var loadSSHConfigFunc = sshConn.LoadSSHConfig
38+
39+
var resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) {
40+
return resolver.ResolveHost(spec, fallbackUser)
41+
}
42+
3743
// RootCmd represents the base command when called without any subcommands
3844
var RootCmd = &cobra.Command{
3945
Use: "pretty",
@@ -106,7 +112,7 @@ usage:
106112
if home, err := os.UserHomeDir(); err == nil {
107113
userConfigPath = filepath.Join(home, ".ssh", "config")
108114
}
109-
resolver, err := sshConn.LoadSSHConfig(sshConn.SSHConfigPaths{
115+
resolver, err := loadSSHConfigFunc(sshConn.SSHConfigPaths{
110116
User: userConfigPath,
111117
System: "/etc/ssh/ssh_config",
112118
})
@@ -130,7 +136,7 @@ usage:
130136
resolveSpec.User = globalUser
131137
resolveSpec.UserSet = true
132138
}
133-
resolved, err := resolver.ResolveHost(resolveSpec, "")
139+
resolved, err := resolveHostFunc(resolver, resolveSpec, "")
134140
if err != nil {
135141
return fmt.Errorf("unable to resolve host %q: %w", spec.Host, err)
136142
}
@@ -141,7 +147,7 @@ usage:
141147
jumpSpec.User = globalUser
142148
jumpSpec.UserSet = true
143149
}
144-
jumpResolved, err := resolver.ResolveHost(jumpSpec, "")
150+
jumpResolved, err := resolveHostFunc(resolver, jumpSpec, "")
145151
if err != nil {
146152
return fmt.Errorf("unable to resolve jump host %q: %w", jumpAlias, err)
147153
}

cmd/root_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package cmd
22

33
import (
4+
"errors"
5+
"os"
46
"strings"
57
"testing"
68

9+
"github.com/ncode/pretty/internal/sshConn"
710
"github.com/spf13/viper"
811
)
912

@@ -102,3 +105,135 @@ func TestExecuteReturnsErrorForInvalidGroupSpec(t *testing.T) {
102105
t.Fatalf("unexpected error: %v", err)
103106
}
104107
}
108+
109+
func TestExecuteReturnsErrorForInvalidHostsFileContent(t *testing.T) {
110+
prevHostGroup := hostGroup
111+
prevHostsFile := hostsFile
112+
t.Cleanup(func() {
113+
hostGroup = prevHostGroup
114+
hostsFile = prevHostsFile
115+
RootCmd.SetArgs(nil)
116+
})
117+
118+
f, err := os.CreateTemp(t.TempDir(), "hosts-*.txt")
119+
if err != nil {
120+
t.Fatalf("unexpected temp file error: %v", err)
121+
}
122+
if _, err := f.WriteString("host1:badport\n"); err != nil {
123+
t.Fatalf("unexpected write error: %v", err)
124+
}
125+
if err := f.Close(); err != nil {
126+
t.Fatalf("unexpected close error: %v", err)
127+
}
128+
129+
hostGroup = ""
130+
hostsFile = f.Name()
131+
RootCmd.SetArgs([]string{"host2"})
132+
133+
err = Execute()
134+
if err == nil {
135+
t.Fatalf("expected error")
136+
}
137+
if !strings.Contains(err.Error(), "invalid port") {
138+
t.Fatalf("unexpected error: %v", err)
139+
}
140+
}
141+
142+
func TestExecuteReturnsErrorWhenLoadingSSHConfigFails(t *testing.T) {
143+
prevHostGroup := hostGroup
144+
prevHostsFile := hostsFile
145+
prevLoad := loadSSHConfigFunc
146+
t.Cleanup(func() {
147+
hostGroup = prevHostGroup
148+
hostsFile = prevHostsFile
149+
loadSSHConfigFunc = prevLoad
150+
RootCmd.SetArgs(nil)
151+
})
152+
153+
loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) {
154+
return nil, errors.New("config boom")
155+
}
156+
157+
hostGroup = ""
158+
hostsFile = ""
159+
RootCmd.SetArgs([]string{"host1"})
160+
161+
err := Execute()
162+
if err == nil {
163+
t.Fatalf("expected error")
164+
}
165+
if !strings.Contains(err.Error(), "unable to load ssh config") {
166+
t.Fatalf("unexpected error: %v", err)
167+
}
168+
}
169+
170+
func TestExecuteReturnsErrorWhenResolveHostFails(t *testing.T) {
171+
prevHostGroup := hostGroup
172+
prevHostsFile := hostsFile
173+
prevLoad := loadSSHConfigFunc
174+
prevResolve := resolveHostFunc
175+
t.Cleanup(func() {
176+
hostGroup = prevHostGroup
177+
hostsFile = prevHostsFile
178+
loadSSHConfigFunc = prevLoad
179+
resolveHostFunc = prevResolve
180+
RootCmd.SetArgs(nil)
181+
})
182+
183+
loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) {
184+
return &sshConn.SSHConfigResolver{}, nil
185+
}
186+
resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) {
187+
return sshConn.ResolvedHost{}, errors.New("resolve boom")
188+
}
189+
190+
hostGroup = ""
191+
hostsFile = ""
192+
RootCmd.SetArgs([]string{"host1"})
193+
194+
err := Execute()
195+
if err == nil {
196+
t.Fatalf("expected error")
197+
}
198+
if !strings.Contains(err.Error(), "unable to resolve host") {
199+
t.Fatalf("unexpected error: %v", err)
200+
}
201+
}
202+
203+
func TestExecuteReturnsErrorWhenResolveJumpFails(t *testing.T) {
204+
prevHostGroup := hostGroup
205+
prevHostsFile := hostsFile
206+
prevLoad := loadSSHConfigFunc
207+
prevResolve := resolveHostFunc
208+
t.Cleanup(func() {
209+
hostGroup = prevHostGroup
210+
hostsFile = prevHostsFile
211+
loadSSHConfigFunc = prevLoad
212+
resolveHostFunc = prevResolve
213+
RootCmd.SetArgs(nil)
214+
})
215+
216+
loadSSHConfigFunc = func(paths sshConn.SSHConfigPaths) (*sshConn.SSHConfigResolver, error) {
217+
return &sshConn.SSHConfigResolver{}, nil
218+
}
219+
call := 0
220+
resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) {
221+
call++
222+
if call == 1 {
223+
return sshConn.ResolvedHost{Alias: spec.Alias, Host: spec.Host, Port: 22, ProxyJump: []string{"jump1"}}, nil
224+
}
225+
return sshConn.ResolvedHost{}, errors.New("jump boom")
226+
}
227+
228+
hostGroup = ""
229+
hostsFile = ""
230+
RootCmd.SetArgs([]string{"host1"})
231+
232+
err := Execute()
233+
if err == nil {
234+
t.Fatalf("expected error")
235+
}
236+
if !strings.Contains(err.Error(), "unable to resolve jump host") {
237+
t.Fatalf("unexpected error: %v", err)
238+
}
239+
}

0 commit comments

Comments
 (0)