diff --git a/cmd/sup/main.go b/cmd/sup/main.go index e1f35ee..e4a138b 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -109,6 +109,11 @@ func cmdUsage(conf *sup.Supfile) { func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { var commands []*sup.Command + // In case of the conf.Env needs an initialization + if conf.Env == nil { + conf.Env = make(sup.EnvList, 0) + } + args := flag.Args() if len(args) < 1 { networkUsage(conf) @@ -122,26 +127,14 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { return nil, nil, ErrUnknownNetwork } - // Parse CLI --env flag env vars, override values defined in Network env. - for _, env := range envVars { - if len(env) == 0 { - continue - } - i := strings.Index(env, "=") - if i < 0 { - if len(env) > 0 { - network.Env.Set(env, "") - } - continue - } - network.Env.Set(env[:i], env[i+1:]) - } - - hosts, err := network.ParseInventory() + hosts, err := network.ParseInventory(conf.Env) if err != nil { return nil, nil, err } network.Hosts = append(network.Hosts, hosts...) + if network.Env == nil { + network.Env = make(sup.EnvList, 0) + } // Does the have at least one host? if len(network.Hosts) == 0 { @@ -155,27 +148,9 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { return nil, nil, ErrUsage } - // In case of the network.Env needs an initialization - if network.Env == nil { - network.Env = make(sup.EnvList, 0) - } - // Add default env variable with current network network.Env.Set("SUP_NETWORK", args[0]) - // Add default nonce - network.Env.Set("SUP_TIME", time.Now().UTC().Format(time.RFC3339)) - if os.Getenv("SUP_TIME") != "" { - network.Env.Set("SUP_TIME", os.Getenv("SUP_TIME")) - } - - // Add user - if os.Getenv("SUP_USER") != "" { - network.Env.Set("SUP_USER", os.Getenv("SUP_USER")) - } else { - network.Env.Set("SUP_USER", os.Getenv("USER")) - } - for _, cmd := range args[1:] { // Target? target, isTarget := conf.Targets.Get(cmd) @@ -248,7 +223,14 @@ func main() { os.Exit(1) } } - conf, err := sup.NewSupfile(data) + conf, err := sup.NewSupfile(data, + // SUPFILE_DIR might change as sup invocations are chained. + sup.WithEnv("SUPFILE_DIR", filepath.Dir(supfile)), + // Add default nonce, but inherit from previous invocation. + sup.WithInheritEnv("SUP_TIME", time.Now().UTC().Format(time.RFC3339)), + // Add user, but inherit from previous invocation. + sup.WithInheritEnv("SUP_USER", os.Getenv("USER")), + ) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) @@ -331,15 +313,6 @@ func main() { } } - var vars sup.EnvList - for _, val := range append(conf.Env, network.Env...) { - vars.Set(val.Key, val.Value) - } - if err := vars.ResolveValues(); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) - } - // Parse CLI --env flag env vars, define $SUP_ENV and override values defined in Supfile. var cliVars sup.EnvList for _, env := range envVars { @@ -349,11 +322,10 @@ func main() { i := strings.Index(env, "=") if i < 0 { if len(env) > 0 { - vars.Set(env, "") + cliVars.Set(env, "") } continue } - vars.Set(env[:i], env[i+1:]) cliVars.Set(env[:i], env[i+1:]) } @@ -363,7 +335,7 @@ func main() { for _, v := range cliVars { supEnv += fmt.Sprintf(" -e %v=%q", v.Key, v.Value) } - vars.Set("SUP_ENV", strings.TrimSpace(supEnv)) + cliVars.Set("SUP_ENV", strings.TrimSpace(supEnv)) // Create new Stackup app. app, err := sup.New(conf) @@ -375,7 +347,7 @@ func main() { app.Prefix(!disablePrefix) // Run all the commands in the given network. - err = app.Run(network, vars, commands...) + err = app.Run(network, cliVars, commands...) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) diff --git a/envlist_test.go b/envlist_test.go new file mode 100644 index 0000000..eb4d0ca --- /dev/null +++ b/envlist_test.go @@ -0,0 +1,52 @@ +package sup + +import ( + "reflect" + "testing" + + "gopkg.in/yaml.v2" +) + +func TestEnvListUnmarshalYAML(t *testing.T) { + type holder struct { + Env EnvList `yaml:"env"` + } + + testCases := []struct { + input string + expect holder + }{ + { + + input: ` +env: + MY_KEY: abc123 +`, + expect: holder{ + Env: EnvList{ + &EnvVar{Key: "MY_KEY", Value: "abc123"}, + }, + }, + }, + { + + input: ` +env: + MY_KEY: $(echo abc123) +`, + expect: holder{ + Env: EnvList{ + &EnvVar{Key: "MY_KEY", Value: "abc123"}, + }, + }, + }, + } + + for _, tc := range testCases { + h := holder{} + yaml.Unmarshal([]byte(tc.input), &h) + if !reflect.DeepEqual(h, tc.expect) { + t.Errorf("Unmarshalling yaml did not produce the expected result. Got:\n%#v\nExpected: %#v\n", h, tc.expect) + } + } +} diff --git a/example/Supfile b/example/Supfile index 5140496..9307c7a 100644 --- a/example/Supfile +++ b/example/Supfile @@ -14,6 +14,8 @@ env: networks: # Groups of hosts local: + env: + SUP_LOCAL: yessir hosts: - localhost @@ -53,7 +55,7 @@ commands: upload: - src: ./ dst: /tmp/$IMAGE - script: ./scripts/docker-build.sh + script: $SUPFILE_DIR/scripts/docker-build.sh once: true pull: @@ -126,6 +128,10 @@ commands: curl -X POST --data-urlencode 'payload={"channel": "#_team_", "text": "['$SUP_NETWORK'] '$SUP_USER' deployed '$NAME'"}' \ https://hooks.slack.com/services/X/Y/Z + env: + desc: Print environment + local: env + bash: desc: Interactive shell on all hosts stdin: true diff --git a/localhost.go b/localhost.go index ebdc495..eb58c3b 100644 --- a/localhost.go +++ b/localhost.go @@ -6,11 +6,9 @@ import ( "os" "os/exec" "os/user" - - "github.com/pkg/errors" ) -// Client is a wrapper over the SSH connection/sessions. +// LocalhostClient is a wrapper over the SSH connection/sessions. type LocalhostClient struct { cmd *exec.Cmd user string @@ -105,15 +103,3 @@ func (c *LocalhostClient) WriteClose() error { func (c *LocalhostClient) Signal(sig os.Signal) error { return c.cmd.Process.Signal(sig) } - -func ResolveLocalPath(cwd, path, env string) (string, error) { - // Check if file exists first. Use bash to resolve $ENV_VARs. - cmd := exec.Command("bash", "-c", env+"echo -n "+path) - cmd.Dir = cwd - resolvedFilename, err := cmd.Output() - if err != nil { - return "", errors.Wrap(err, "resolving path failed") - } - - return string(resolvedFilename), nil -} diff --git a/ssh.go b/ssh.go index eb3cefb..7544575 100644 --- a/ssh.go +++ b/ssh.go @@ -15,7 +15,7 @@ import ( "golang.org/x/crypto/ssh/agent" ) -// Client is a wrapper over the SSH connection/sessions. +// SSHClient is a wrapper over the SSH connection/sessions. type SSHClient struct { conn *ssh.Client sess *ssh.Session @@ -219,16 +219,16 @@ func (c *SSHClient) Wait() error { } // DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer. -func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - conn, err := sc.conn.Dial(net, addr) +func (c *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + conn, err := c.conn.Dial(net, addr) if err != nil { return nil, err } - c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + sc, chans, reqs, err := ssh.NewClientConn(conn, addr, config) if err != nil { return nil, err } - return ssh.NewClient(c, chans, reqs), nil + return ssh.NewClient(sc, chans, reqs), nil } diff --git a/sup.go b/sup.go index d815068..c1e55aa 100644 --- a/sup.go +++ b/sup.go @@ -30,12 +30,18 @@ func New(conf *Supfile) (*Stackup, error) { // Run runs set of commands on multiple hosts defined by network sequentially. // TODO: This megamoth method needs a big refactor and should be split // to multiple smaller methods. -func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) error { +func (sup *Stackup) Run(network *Network, cliVars EnvList, commands ...*Command) error { if len(commands) == 0 { return errors.New("no commands to be run") } - env := envVars.AsExport() + // Order is important here. + // Least specific (most general) env vars first, + // then the network specific ones and finally + // the command line vars. + // Semantics are last-write-wins. + env := append(sup.conf.Env, network.Env...) + env = append(env, cliVars...) // Create clients for every host (either SSH or Localhost). var bastion *SSHClient @@ -58,7 +64,7 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) // Localhost client. if host == "localhost" { local := &LocalhostClient{ - env: env + `export SUP_HOST="` + host + `";`, + env: env.AsExport() + `export SUP_HOST="` + host + `";`, } if err := local.Connect(host); err != nil { errCh <- errors.Wrap(err, "connecting to localhost failed") @@ -70,7 +76,7 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) // SSH client. remote := &SSHClient{ - env: env + `export SUP_HOST="` + host + `";`, + env: env.AsExport() + `export SUP_HOST="` + host + `";`, user: network.User, color: Colors[i%len(Colors)], } @@ -112,7 +118,7 @@ func (sup *Stackup) Run(network *Network, envVars EnvList, commands ...*Command) // Run command or run multiple commands defined by target sequentially. for _, cmd := range commands { // Translate command into task(s). - tasks, err := sup.createTasks(cmd, clients, env) + tasks, err := sup.createTasks(cmd, clients) if err != nil { return errors.Wrap(err, "creating task failed") } diff --git a/supfile.go b/supfile.go index 2cf88b5..74605bc 100644 --- a/supfile.go +++ b/supfile.go @@ -192,7 +192,7 @@ func (e *EnvList) UnmarshalYAML(unmarshal func(interface{}) error) error { e.Set(fmt.Sprintf("%v", v.Key), fmt.Sprintf("%v", v.Value)) } - return nil + return e.ResolveValues() } // Set key to be equal value in this list. @@ -262,14 +262,47 @@ func (e ErrUnsupportedSupfileVersion) Error() string { return fmt.Sprintf("%v\n\nCheck your Supfile version (available latest version: v0.5)", e.Msg) } +// SupfileOption modifies a Supfile in some way and maybe returns an error. +type SupfileOption func(*Supfile) error + +// WithInheritEnv exports the environment variable env as value, val, in the context of a supfile. +// env will be read from the process runtime and the values here have precedence. This allows users +// to chain Supfile invocation and only have the top-level value set the SUP_* env vars. +func WithInheritEnv(env, val string) SupfileOption { + if envVal := os.Getenv(env); envVal != "" { + val = envVal + } + return WithEnv(env, val) +} + +// WithEnv forces the environment variable env to val in the context of a Supfile. +func WithEnv(env, val string) SupfileOption { + return func(s *Supfile) error { + for network := range s.Networks.nets { + n := s.Networks.nets[network] + n.Env.Set(env, val) + s.Networks.nets[network] = n + } + s.Env.Set(env, val) + return nil + } +} + // NewSupfile parses configuration file and returns Supfile or error. -func NewSupfile(data []byte) (*Supfile, error) { +func NewSupfile(data []byte, opts ...SupfileOption) (*Supfile, error) { var conf Supfile if err := yaml.Unmarshal(data, &conf); err != nil { return nil, err } + for _, opt := range opts { + err := opt(&conf) + if err != nil { + return nil, err + } + } + // API backward compatibility. Will be deprecated in v1.0. switch conf.Version { case "": @@ -327,16 +360,28 @@ func NewSupfile(data []byte) (*Supfile, error) { return &conf, nil } +func (s *Supfile) ResolveLocalPath(cwd, path string) (string, error) { + // Check if file exists first. Use bash to resolve $ENV_VARs. + cmd := exec.Command("bash", "-c", s.Env.AsExport()+" echo -n "+path) + cmd.Dir = cwd + resolvedFilename, err := cmd.Output() + if err != nil { + return "", errors.Wrap(err, "resolving path failed") + } + + return string(resolvedFilename), nil +} + // ParseInventory runs the inventory command, if provided, and appends // the command's output lines to the manually defined list of hosts. -func (n Network) ParseInventory() ([]string, error) { +func (n Network) ParseInventory(ctx EnvList) ([]string, error) { if n.Inventory == "" { return nil, nil } cmd := exec.Command("/bin/sh", "-c", n.Inventory) cmd.Env = os.Environ() - cmd.Env = append(cmd.Env, n.Env.Slice()...) + cmd.Env = append(cmd.Env, append(ctx.Slice(), n.Env.Slice()...)...) cmd.Stderr = os.Stderr output, err := cmd.Output() if err != nil { diff --git a/task.go b/task.go index eebc3c7..1b1bc29 100644 --- a/task.go +++ b/task.go @@ -17,7 +17,7 @@ type Task struct { TTY bool } -func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]*Task, error) { +func (sup *Stackup) createTasks(cmd *Command, clients []Client) ([]*Task, error) { var tasks []*Task cwd, err := os.Getwd() @@ -27,7 +27,7 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* // Anything to upload? for _, upload := range cmd.Upload { - uploadFile, err := ResolveLocalPath(cwd, upload.Src, env) + uploadFile, err := sup.conf.ResolveLocalPath(cwd, upload.Src) if err != nil { return nil, errors.Wrap(err, "upload: "+upload.Src) } @@ -64,7 +64,11 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* // Script. Read the file as a multiline input command. if cmd.Script != "" { - f, err := os.Open(cmd.Script) + script, err := sup.conf.ResolveLocalPath(cwd, cmd.Script) + if err != nil { + return nil, errors.Wrap(err, "can't resolve script path") + } + f, err := os.Open(script) if err != nil { return nil, errors.Wrap(err, "can't open script") } @@ -106,7 +110,7 @@ func (sup *Stackup) createTasks(cmd *Command, clients []Client, env string) ([]* // Local command. if cmd.Local != "" { local := &LocalhostClient{ - env: env + `export SUP_HOST="localhost";`, + env: sup.conf.Env.AsExport() + `export SUP_HOST="localhost";`, } local.Connect("localhost") task := &Task{