From 5707dce1e14770c3adc1d6423509f815aab09f18 Mon Sep 17 00:00:00 2001 From: Hai Vo Date: Tue, 14 Apr 2020 16:24:51 +0200 Subject: [PATCH 1/2] Simple fix --- .gitignore | 3 +++ go.mod | 8 ++++++ go.sum | 18 +++++++++++++ s3update.go | 78 +++++++++++++++++++++++++++++++++++++++++------------ 4 files changed, 90 insertions(+), 17 deletions(-) create mode 100644 go.mod create mode 100644 go.sum diff --git a/.gitignore b/.gitignore index daf913b..54088fd 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ _testmain.go *.exe *.test *.prof + + +.idea \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1b45e77 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/theamazinghari/s3update + +go 1.14 + +require ( + github.com/aws/aws-sdk-go v1.30.7 // indirect + github.com/mitchellh/ioprogress v0.0.0-20180201004757-6a23b12fa88e // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..376b51f --- /dev/null +++ b/go.sum @@ -0,0 +1,18 @@ +github.com/aws/aws-sdk-go v1.30.7 h1:IaXfqtioP6p9SFAnNfsqdNczbR5UNbYqvcZUSsCAdTY= +github.com/aws/aws-sdk-go v1.30.7/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= +github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/mitchellh/ioprogress v0.0.0-20180201004757-6a23b12fa88e h1:Qa6dnn8DlasdXRnacluu8HzPts0S1I9zvvUPDbBnXFI= +github.com/mitchellh/ioprogress v0.0.0-20180201004757-6a23b12fa88e/go.mod h1:waEya8ee1Ro/lgxpVhkJI4BVASzkm3UZqkx/cFJiYHM= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/s3update.go b/s3update.go index 1a3e3e8..fb46be1 100644 --- a/s3update.go +++ b/s3update.go @@ -1,9 +1,11 @@ package s3update import ( + "errors" "fmt" "io" "io/ioutil" + "log" "os" "runtime" "strconv" @@ -17,6 +19,9 @@ import ( "github.com/mitchellh/ioprogress" ) +var downloadSize int64 +var ErrInconsistentFileSize = errors.New("inconsistent file size") + type Updater struct { // CurrentVersion represents the current binary version. // This is generally set at the compilation time with -ldflags "-X main.Version=42" @@ -89,12 +94,11 @@ func runAutoUpdate(u Updater) error { return fmt.Errorf("invalid local version") } - svc := s3.New(session.New(), &aws.Config{Region: aws.String(u.S3Region)}) + svc := s3.New(session.Must(session.NewSession()), &aws.Config{Region: aws.String(u.S3Region)}) resp, err := svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(u.S3Bucket), Key: aws.String(u.S3VersionKey)}) if err != nil { return err } - defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -114,7 +118,7 @@ func runAutoUpdate(u Updater) error { if err != nil { return err } - defer resp.Body.Close() + downloadSize = *resp.ContentLength progressR := &ioprogress.Reader{ Reader: resp.Body, Size: *resp.ContentLength, @@ -130,31 +134,40 @@ func runAutoUpdate(u Updater) error { return err } - // Move the old version to a backup path that we can recover from - // in case the upgrade fails destBackup := dest + ".bak" - if _, err := os.Stat(dest); err == nil { - os.Rename(dest, destBackup) + + // Create a temp file + tempFile, err := ioutil.TempFile("", "tmp_download") + if err != nil { + printError(err) + return err } + tempFilePath := tempFile.Name() + // Download to tempFile // Use the same flags that ioutil.WriteFile uses - f, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) + f, err := os.OpenFile(tempFile.Name(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) if err != nil { - os.Rename(destBackup, dest) + os.Remove(tempFile.Name()) return err } - defer f.Close() + err = tempFile.Close() + printError(err) - fmt.Printf("s3update: downloading new version to %s\n", dest) + fmt.Printf("s3update: downloading new version to %s\n", tempFile.Name()) if _, err := io.Copy(f, progressR); err != nil { - os.Rename(destBackup, dest) + printError(err) return err } - // The file must be closed already so we can execute it in the next step - f.Close() + // Close the response stream + err = resp.Body.Close() + printError(err) + // The file must be closed so we can execute it in the next step + err = f.Close() + printError(err) - // Removing backup - os.Remove(destBackup) + err = finalizeUpdate(dest, destBackup, tempFilePath) + printError(err) fmt.Printf("s3update: updated with success to version %d\nRestarting application\n", remoteVersion) @@ -165,6 +178,37 @@ func runAutoUpdate(u Updater) error { os.Exit(0) } - return nil } + +func printError(err error) { + if err != nil { + fmt.Println("s3Update: ", err) + } +} + +func finalizeUpdate(originalFile, backupFile, tempFile string) (err error) { + if fileSize(tempFile) == downloadSize { // backup old binary then replace it with downloaded file + // Backup current binary + if _, err = os.Stat(originalFile); err == nil { + err = os.Rename(originalFile, backupFile) + } + + // Replace old binary by downloaded file + if err = os.Rename(tempFile, originalFile); err != nil { + // revert backup file + err = os.Rename(backupFile, originalFile) + } + } else { // Do nothing + err = ErrInconsistentFileSize + } + return +} + +func fileSize(path string) int64 { + fileInfo, err := os.Stat(path) + if err != nil { + log.Fatal(err) + } + return fileInfo.Size() +} From c98da0a06718c1ca334d0fd5eac71a5c171e811a Mon Sep 17 00:00:00 2001 From: Hai Vo Date: Wed, 15 Apr 2020 12:05:55 +0200 Subject: [PATCH 2/2] clean up some code --- s3update.go | 61 +++++++++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/s3update.go b/s3update.go index fb46be1..6d1e876 100644 --- a/s3update.go +++ b/s3update.go @@ -1,7 +1,6 @@ package s3update import ( - "errors" "fmt" "io" "io/ioutil" @@ -19,8 +18,7 @@ import ( "github.com/mitchellh/ioprogress" ) -var downloadSize int64 -var ErrInconsistentFileSize = errors.New("inconsistent file size") +var remoteFileSize int64 type Updater struct { // CurrentVersion represents the current binary version. @@ -118,7 +116,7 @@ func runAutoUpdate(u Updater) error { if err != nil { return err } - downloadSize = *resp.ContentLength + remoteFileSize = *resp.ContentLength progressR := &ioprogress.Reader{ Reader: resp.Body, Size: *resp.ContentLength, @@ -137,9 +135,8 @@ func runAutoUpdate(u Updater) error { destBackup := dest + ".bak" // Create a temp file - tempFile, err := ioutil.TempFile("", "tmp_download") + tempFile, err := ioutil.TempFile("", "s3update_tmp_download") if err != nil { - printError(err) return err } tempFilePath := tempFile.Name() @@ -148,31 +145,37 @@ func runAutoUpdate(u Updater) error { // Use the same flags that ioutil.WriteFile uses f, err := os.OpenFile(tempFile.Name(), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) if err != nil { - os.Remove(tempFile.Name()) + _ = os.Remove(tempFile.Name()) + return err + } + + if err := tempFile.Close(); err != nil { return err } - err = tempFile.Close() - printError(err) - fmt.Printf("s3update: downloading new version to %s\n", tempFile.Name()) if _, err := io.Copy(f, progressR); err != nil { - printError(err) return err } + // Close the response stream - err = resp.Body.Close() - printError(err) + if err := resp.Body.Close(); err != nil { + return err + } + // The file must be closed so we can execute it in the next step - err = f.Close() - printError(err) + if err := f.Close(); err != nil { + return err + } - err = finalizeUpdate(dest, destBackup, tempFilePath) - printError(err) + if err := finalizeUpdate(dest, destBackup, tempFilePath); err != nil { + return err + } fmt.Printf("s3update: updated with success to version %d\nRestarting application\n", remoteVersion) // The update completed, we can now restart the application without requiring any user action. if err := syscall.Exec(dest, os.Args, os.Environ()); err != nil { + fmt.Println(err) return err } @@ -181,30 +184,28 @@ func runAutoUpdate(u Updater) error { return nil } -func printError(err error) { - if err != nil { - fmt.Println("s3Update: ", err) - } -} - -func finalizeUpdate(originalFile, backupFile, tempFile string) (err error) { - if fileSize(tempFile) == downloadSize { // backup old binary then replace it with downloaded file +func finalizeUpdate(originalFilePath, backupFilePath, tempFilePath string) (err error) { + if downloadSucceeded(tempFilePath) { // Backup current binary - if _, err = os.Stat(originalFile); err == nil { - err = os.Rename(originalFile, backupFile) + if _, err = os.Stat(originalFilePath); err == nil { + err = os.Rename(originalFilePath, backupFilePath) } // Replace old binary by downloaded file - if err = os.Rename(tempFile, originalFile); err != nil { + if err = os.Rename(tempFilePath, originalFilePath); err != nil { // revert backup file - err = os.Rename(backupFile, originalFile) + err = os.Rename(backupFilePath, originalFilePath) } } else { // Do nothing - err = ErrInconsistentFileSize + return fmt.Errorf("inconsistent file size") } return } +func downloadSucceeded(tempFile string) bool { + return fileSize(tempFile) == remoteFileSize +} + func fileSize(path string) int64 { fileInfo, err := os.Stat(path) if err != nil {