diff --git a/.gitignore b/.gitignore index daf913b..54088fd 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ _testmain.go *.exe *.test *.prof + + +.idea \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1b45e77 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/theamazinghari/s3update + +go 1.14 + +require ( + github.com/aws/aws-sdk-go v1.30.7 // indirect + github.com/mitchellh/ioprogress v0.0.0-20180201004757-6a23b12fa88e // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..376b51f --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/aws/aws-sdk-go v1.30.7 h1:IaXfqtioP6p9SFAnNfsqdNczbR5UNbYqvcZUSsCAdTY= +github.com/aws/aws-sdk-go v1.30.7/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= +github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/mitchellh/ioprogress v0.0.0-20180201004757-6a23b12fa88e h1:Qa6dnn8DlasdXRnacluu8HzPts0S1I9zvvUPDbBnXFI= +github.com/mitchellh/ioprogress v0.0.0-20180201004757-6a23b12fa88e/go.mod h1:waEya8ee1Ro/lgxpVhkJI4BVASzkm3UZqkx/cFJiYHM= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/s3update.go b/s3update.go index 1a3e3e8..6d1e876 100644 --- a/s3update.go +++ b/s3update.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "io/ioutil" + "log" "os" "runtime" "strconv" @@ -17,6 +18,8 @@ import ( "github.com/mitchellh/ioprogress" ) +var remoteFileSize int64 + type Updater struct { // CurrentVersion represents the current binary version. // This is generally set at the compilation time with -ldflags "-X main.Version=42" @@ -89,12 +92,11 @@ func runAutoUpdate(u Updater) error { return fmt.Errorf("invalid local version") } - svc := s3.New(session.New(), &aws.Config{Region: aws.String(u.S3Region)}) + svc := s3.New(session.Must(session.NewSession()), &aws.Config{Region: aws.String(u.S3Region)}) resp, err := svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(u.S3Bucket), Key: aws.String(u.S3VersionKey)}) if err != nil { return err } - defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -114,7 +116,7 @@ func runAutoUpdate(u Updater) error { if err != nil { return err } - defer resp.Body.Close() + remoteFileSize = *resp.ContentLength progressR := &ioprogress.Reader{ Reader: resp.Body, Size: *resp.ContentLength, @@ -130,41 +132,84 @@ func runAutoUpdate(u Updater) error { return err } - // Move the old version to a backup path that we can recover from - // in case the upgrade fails destBackup := dest + ".bak" - if _, err := os.Stat(dest); err == nil { - os.Rename(dest, destBackup) + + // Create a temp file + tempFile, err := ioutil.TempFile("", "s3update_tmp_download") + if err != nil { + return err } + tempFilePath := tempFile.Name() + // Download to tempFile // Use the same flags that ioutil.WriteFile uses - f, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) + f, err := os.OpenFile(tempFile.Name(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) if err != nil { - os.Rename(destBackup, dest) + _ = os.Remove(tempFile.Name()) + return err + } + + if err := tempFile.Close(); err != nil { return err } - defer f.Close() - fmt.Printf("s3update: downloading new version to %s\n", dest) if _, err := io.Copy(f, progressR); err != nil { - os.Rename(destBackup, dest) return err } - // The file must be closed already so we can execute it in the next step - f.Close() - // Removing backup - os.Remove(destBackup) + // Close the response stream + if err := resp.Body.Close(); err != nil { + return err + } + + // The file must be closed so we can execute it in the next step + if err := f.Close(); err != nil { + return err + } + + if err := finalizeUpdate(dest, destBackup, tempFilePath); err != nil { + return err + } fmt.Printf("s3update: updated with success to version %d\nRestarting application\n", remoteVersion) // The update completed, we can now restart the application without requiring any user action. if err := syscall.Exec(dest, os.Args, os.Environ()); err != nil { + fmt.Println(err) return err } os.Exit(0) } - return nil } + +func finalizeUpdate(originalFilePath, backupFilePath, tempFilePath string) (err error) { + if downloadSucceeded(tempFilePath) { + // Backup current binary + if _, err = os.Stat(originalFilePath); err == nil { + err = os.Rename(originalFilePath, backupFilePath) + } + + // Replace old binary by downloaded file + if err = os.Rename(tempFilePath, originalFilePath); err != nil { + // revert backup file + err = os.Rename(backupFilePath, originalFilePath) + } + } else { // Do nothing + return fmt.Errorf("inconsistent file size") + } + return +} + +func downloadSucceeded(tempFile string) bool { + return fileSize(tempFile) == remoteFileSize +} + +func fileSize(path string) int64 { + fileInfo, err := os.Stat(path) + if err != nil { + log.Fatal(err) + } + return fileInfo.Size() +}