Skip to content

Commit e3315c2

Browse files
authored
Merge pull request #8 from ncode/juliano/tests
increase coverage
2 parents f1d6f85 + e9d8da4 commit e3315c2

File tree

11 files changed

+748
-29
lines changed

11 files changed

+748
-29
lines changed

cmd/hosts_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,52 @@ func TestPromptFlagOverridesConfig(t *testing.T) {
232232
t.Fatalf("expected CLI prompt to override config, got %q", got)
233233
}
234234
}
235+
236+
func TestParsePortValue(t *testing.T) {
237+
tests := []struct {
238+
name string
239+
input interface{}
240+
want int
241+
wantErr bool
242+
}{
243+
{name: "int", input: 22, want: 22},
244+
{name: "int64", input: int64(2222), want: 2222},
245+
{name: "float64 integer", input: float64(2200), want: 2200},
246+
{name: "string", input: "2022", want: 2022},
247+
{name: "float64 non integer", input: float64(22.5), wantErr: true},
248+
{name: "string invalid", input: "abc", wantErr: true},
249+
{name: "out of range", input: 70000, wantErr: true},
250+
{name: "invalid type", input: true, wantErr: true},
251+
}
252+
253+
for _, tc := range tests {
254+
t.Run(tc.name, func(t *testing.T) {
255+
got, err := parsePortValue(tc.input)
256+
if tc.wantErr {
257+
if err == nil {
258+
t.Fatalf("expected error")
259+
}
260+
return
261+
}
262+
if err != nil {
263+
t.Fatalf("unexpected error: %v", err)
264+
}
265+
if got != tc.want {
266+
t.Fatalf("expected %d, got %d", tc.want, got)
267+
}
268+
})
269+
}
270+
}
271+
272+
func TestValidatePort(t *testing.T) {
273+
if got, err := validatePort(22); err != nil || got != 22 {
274+
t.Fatalf("expected port 22, got %d err=%v", got, err)
275+
}
276+
277+
if _, err := validatePort(0); err == nil {
278+
t.Fatalf("expected out of range error")
279+
}
280+
if _, err := validatePort(65536); err == nil {
281+
t.Fatalf("expected out of range error")
282+
}
283+
}

cmd/root.go

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ var cfgFile string
3434
var hostsFile string
3535
var hostGroup string
3636

37+
var loadSSHConfigFunc = sshConn.LoadSSHConfig
38+
39+
var resolveHostFunc = func(resolver *sshConn.SSHConfigResolver, spec sshConn.HostSpec, fallbackUser string) (sshConn.ResolvedHost, error) {
40+
return resolver.ResolveHost(spec, fallbackUser)
41+
}
42+
43+
var spawnShellFunc = shell.Spawn
44+
3745
// RootCmd represents the base command when called without any subcommands
3846
var RootCmd = &cobra.Command{
3947
Use: "pretty",
@@ -50,19 +58,17 @@ usage:
5058
}
5159
return nil
5260
},
53-
Run: func(cmd *cobra.Command, args []string) {
61+
RunE: func(cmd *cobra.Command, args []string) error {
5462
argsLen := len(args)
5563
hostSpecs, err := parseArgsHosts(args)
5664
if err != nil {
57-
fmt.Println(err)
58-
os.Exit(1)
65+
return err
5966
}
6067

6168
if hostGroup != "" {
6269
groupSpecs, err := parseGroupSpecs(viper.Get(fmt.Sprintf("groups.%s", hostGroup)), hostGroup)
6370
if err != nil {
64-
fmt.Println(err)
65-
os.Exit(1)
71+
return err
6672
}
6773
if argsLen > 1 {
6874
hostSpecs = append(hostSpecs, groupSpecs...)
@@ -74,13 +80,11 @@ usage:
7480
if hostsFile != "" {
7581
data, err := ioutil.ReadFile(hostsFile)
7682
if err != nil {
77-
fmt.Printf("unable to read hostsFile: %v\n", err)
78-
os.Exit(1)
83+
return fmt.Errorf("unable to read hostsFile: %w", err)
7984
}
8085
fileSpecs, err := parseHostsFile(data)
8186
if err != nil {
82-
fmt.Println(err)
83-
os.Exit(1)
87+
return err
8488
}
8589
hostSpecs = append(hostSpecs, fileSpecs...)
8690
}
@@ -110,13 +114,12 @@ usage:
110114
if home, err := os.UserHomeDir(); err == nil {
111115
userConfigPath = filepath.Join(home, ".ssh", "config")
112116
}
113-
resolver, err := sshConn.LoadSSHConfig(sshConn.SSHConfigPaths{
117+
resolver, err := loadSSHConfigFunc(sshConn.SSHConfigPaths{
114118
User: userConfigPath,
115119
System: "/etc/ssh/ssh_config",
116120
})
117121
if err != nil {
118-
fmt.Printf("unable to load ssh config: %v\n", err)
119-
os.Exit(1)
122+
return fmt.Errorf("unable to load ssh config: %w", err)
120123
}
121124

122125
globalUser := strings.TrimSpace(viper.GetString("username"))
@@ -135,10 +138,9 @@ usage:
135138
resolveSpec.User = globalUser
136139
resolveSpec.UserSet = true
137140
}
138-
resolved, err := resolver.ResolveHost(resolveSpec, "")
141+
resolved, err := resolveHostFunc(resolver, resolveSpec, "")
139142
if err != nil {
140-
fmt.Printf("unable to resolve host %q: %v\n", spec.Host, err)
141-
os.Exit(1)
143+
return fmt.Errorf("unable to resolve host %q: %w", spec.Host, err)
142144
}
143145
jumps := make([]sshConn.ResolvedHost, 0, len(resolved.ProxyJump))
144146
for _, jumpAlias := range resolved.ProxyJump {
@@ -147,10 +149,9 @@ usage:
147149
jumpSpec.User = globalUser
148150
jumpSpec.UserSet = true
149151
}
150-
jumpResolved, err := resolver.ResolveHost(jumpSpec, "")
152+
jumpResolved, err := resolveHostFunc(resolver, jumpSpec, "")
151153
if err != nil {
152-
fmt.Printf("unable to resolve jump host %q: %v\n", jumpAlias, err)
153-
os.Exit(1)
154+
return fmt.Errorf("unable to resolve jump host %q: %w", jumpAlias, err)
154155
}
155156
jumps = append(jumps, jumpResolved)
156157
}
@@ -167,17 +168,15 @@ usage:
167168
}
168169
hostList.AddHost(host)
169170
}
170-
shell.Spawn(hostList)
171+
spawnShellFunc(hostList)
172+
return nil
171173
},
172174
}
173175

174176
// Execute adds all child commands to the root command and sets flags appropriately.
175177
// This is called by main.main(). It only needs to happen once to the rootCmd.
176-
func Execute() {
177-
if err := RootCmd.Execute(); err != nil {
178-
fmt.Println(err)
179-
os.Exit(1)
180-
}
178+
func Execute() error {
179+
return RootCmd.Execute()
181180
}
182181

183182
func init() {

0 commit comments

Comments
 (0)