diff --git a/policies/recipes/installrecipe.go b/policies/recipes/installrecipe.go index 5d9b9c9d6..e6d89cbf4 100644 --- a/policies/recipes/installrecipe.go +++ b/policies/recipes/installrecipe.go @@ -31,7 +31,7 @@ import ( func InstallRecipe(ctx context.Context, recipe *agentendpointpb.SoftwareRecipe) error { ctx = clog.WithLabels(ctx, map[string]string{"recipe_name": recipe.GetName()}) steps := recipe.InstallSteps - recipeDB, err := newRecipeDB() + recipeDB, err := newRecipeDBWithDefaults() if err != nil { return err } diff --git a/policies/recipes/recipedb.go b/policies/recipes/recipedb.go index 739d6e144..9bac1096f 100644 --- a/policies/recipes/recipedb.go +++ b/policies/recipes/recipedb.go @@ -16,10 +16,12 @@ package recipes import ( "encoding/json" + "io" "io/ioutil" "os" "path/filepath" "runtime" + "sort" "time" ) @@ -29,76 +31,108 @@ var ( dbFileName = "osconfig_recipedb" ) +type timeFunc func() time.Time + // RecipeDB represents local state of installed recipes. -type RecipeDB map[string]Recipe +type recipeDB struct { + file string + timeFunc timeFunc -// newRecipeDB instantiates a recipeDB. -func newRecipeDB() (RecipeDB, error) { - db := make(RecipeDB) - f, err := os.Open(filepath.Join(getDbDir(), dbFileName)) + recipes map[string]Recipe +} + +func newRecipeDB(path string) (*recipeDB, error) { + db := &recipeDB{ + file: path, + timeFunc: time.Now, + recipes: make(map[string]Recipe, 0), + } + + f, err := os.Open(path) if err != nil { if os.IsNotExist(err) { return db, nil } + return nil, err } defer f.Close() - bytes, err := ioutil.ReadAll(f) + + raw, err := io.ReadAll(f) if err != nil { return nil, err } - var recipelist []Recipe - if err := json.Unmarshal(bytes, &recipelist); err != nil { + + var recipes []Recipe + if err := json.Unmarshal(raw, &recipes); err != nil { return nil, err } - for _, recipe := range recipelist { - db[recipe.Name] = recipe + + for _, recipe := range recipes { + db.recipes[recipe.Name] = recipe } return db, nil } +// newRecipeDB instantiates a recipeDB. +func newRecipeDBWithDefaults() (*recipeDB, error) { + dir, fileName := getDbDir(), dbFileName + return newRecipeDB(filepath.Join(dir, fileName)) +} + // getRecipe returns the Recipe object for the given recipe name. -func (db RecipeDB) getRecipe(name string) (Recipe, bool) { - r, ok := db[name] +func (db *recipeDB) getRecipe(name string) (Recipe, bool) { + r, ok := db.recipes[name] return r, ok } // addRecipe marks a recipe as installed. -func (db RecipeDB) addRecipe(name, version string, success bool) error { +func (db *recipeDB) addRecipe(name, version string, success bool) error { versionNum, err := convertVersion(version) if err != nil { return err } - db[name] = Recipe{Name: name, Version: versionNum, InstallTime: time.Now().Unix(), Success: success} + db.recipes[name] = Recipe{Name: name, Version: versionNum, InstallTime: db.timeFunc().Unix(), Success: success} - var recipelist []Recipe - for _, recipe := range db { - recipelist = append(recipelist, recipe) + return db.saveToFS() +} + +func (db *recipeDB) saveToFS() error { + var recipes []Recipe + for _, recipe := range db.recipes { + recipes = append(recipes, recipe) } - dbBytes, err := json.Marshal(recipelist) + + sort.Slice(recipes, func(i, j int) bool { + return recipes[i].Name < recipes[j].Name + }) + + raw, err := json.Marshal(recipes) if err != nil { return err } - dbDir := getDbDir() - if err := os.MkdirAll(dbDir, 0755); err != nil { + dir := filepath.Dir(db.file) + if err := os.MkdirAll(dir, 0755); err != nil { return err } - f, err := ioutil.TempFile(dbDir, dbFileName+"_*") + fileName := filepath.Base(db.file) + f, err := ioutil.TempFile(dir, fileName+"_*") if err != nil { return err } - if _, err := f.Write(dbBytes); err != nil { + if _, err := f.Write(raw); err != nil { f.Close() return err } + if err := f.Close(); err != nil { return err } - return os.Rename(f.Name(), filepath.Join(dbDir, dbFileName)) + return os.Rename(f.Name(), db.file) } func getDbDir() string { diff --git a/policies/recipes/recipedb_test.go b/policies/recipes/recipedb_test.go new file mode 100644 index 000000000..fcdb1014a --- /dev/null +++ b/policies/recipes/recipedb_test.go @@ -0,0 +1,380 @@ +package recipes + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/GoogleCloudPlatform/osconfig/util" + "github.com/GoogleCloudPlatform/osconfig/util/utiltest" +) + +func Test_newRecipeDB(t *testing.T) { + tests := []struct { + name string + + file string + setup func(file string) error + + wantRecipes map[string]Recipe + wantErr error + }{ + { + name: "file does not exists", + + file: "/var/file_does_not_exists", + setup: func(_ string) error { + return nil + }, + + wantRecipes: make(map[string]Recipe, 0), + wantErr: nil, + }, + { + name: "file exists, but empty", + file: tempFileMust(os.TempDir(), "recipes", os.ModePerm).Name(), + setup: func(_ string) error { + return nil + }, + + wantRecipes: nil, + wantErr: fmt.Errorf("unexpected end of JSON input"), + }, + { + name: "directory set as filepath", + file: os.TempDir(), + setup: func(_ string) error { + return nil + }, + + wantRecipes: nil, + wantErr: fmt.Errorf("read %s: is a directory", os.TempDir()), + }, + { + name: "file exist with some recipe", + file: tempFileMust(os.TempDir(), "recipes", os.ModePerm).Name(), + setup: func(path string) error { + recipes := []Recipe{ + { + Name: "test", + Version: []int{1, 1}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: true, + }, + { + Name: "test2", + Version: []int{2, 2}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: false, + }, + } + + raw, err := json.Marshal(recipes) + if err != nil { + return err + } + + fd, err := os.OpenFile(path, os.O_RDWR, os.ModePerm) + if err != nil { + return err + } + + if _, err := fd.Write(raw); err != nil { + return err + } + return nil + }, + + wantRecipes: map[string]Recipe{ + "test": Recipe{ + Name: "test", + Version: []int{1, 1}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: true, + }, + "test2": Recipe{ + Name: "test2", + Version: []int{2, 2}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: false, + }, + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.setup(tt.file); err != nil { + t.Errorf("unwanted error in process of setup: %v", err) + } + + got, gotErr := newRecipeDB(tt.file) + var gotRecipes map[string]Recipe + if got != nil { + gotRecipes = got.recipes + } + + utiltest.EnsureError(t, tt.wantErr, gotErr) + utiltest.EnsureResults(t, tt.wantRecipes, gotRecipes) + }) + } +} + +func Test_recipeDB_getRecipe(t *testing.T) { + recipeDB := &recipeDB{ + file: tempFileMust(os.TempDir(), "recipes", os.ModePerm).Name(), + timeFunc: mockTimeFunc, + recipes: map[string]Recipe{ + "test": Recipe{ + Name: "test", + Version: []int{1, 1}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: true, + }, + "test2": Recipe{ + Name: "test2", + Version: []int{2, 2}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: false, + }, + }, + } + + want1 := Recipe{ + Name: "test", + Version: []int{1, 1}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: true, + } + got1, ok1 := recipeDB.getRecipe("test") + utiltest.EnsureResults(t, want1, got1) + utiltest.EnsureResults(t, true, ok1) + + want2 := Recipe{} + got2, ok2 := recipeDB.getRecipe("test5") + utiltest.EnsureResults(t, want2, got2) + utiltest.EnsureResults(t, false, ok2) +} + +func Test_recipeDB_addRecipe(t *testing.T) { + tests := []struct { + name string + + db *recipeDB + operations []struct { + recipe Recipe + wantErr error + } + wantContent string + }{ + { + name: "empty db two operations, expect no errors", + db: &recipeDB{ + file: tempFileMust(os.TempDir(), "recipes", os.ModePerm).Name(), + timeFunc: mockTimeFunc, + recipes: make(map[string]Recipe, 0), + }, + operations: []struct { + recipe Recipe + wantErr error + }{ + { + recipe: Recipe{ + Name: "test", + Version: []int{1, 1}, + Success: true, + }, + wantErr: nil, + }, + { + recipe: Recipe{ + Name: "test2", + Version: []int{2, 2}, + Success: false, + }, + wantErr: nil, + }, + }, + wantContent: `[{"Name":"test","Version":[1,1],"InstallTime":949408200,"Success":true},{"Name":"test2","Version":[2,2],"InstallTime":949408200,"Success":false}]`, + }, + { + name: "db with one entry, second entry added", + db: &recipeDB{ + file: tempFileMust(os.TempDir(), "recipes", os.ModePerm).Name(), + timeFunc: mockTimeFunc, + recipes: map[string]Recipe{ + "test2": Recipe{ + Name: "test2", + Version: []int{2, 2}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: false, + }, + }, + }, + operations: []struct { + recipe Recipe + wantErr error + }{ + { + recipe: Recipe{ + Name: "test3", + Version: []int{3, 3}, + Success: true, + }, + wantErr: nil, + }, + }, + wantContent: `[{"Name":"test2","Version":[2,2],"InstallTime":949408200,"Success":false},{"Name":"test3","Version":[3,3],"InstallTime":949408200,"Success":true}]`, + }, + { + name: "invalid entry skiped", + db: &recipeDB{ + file: tempFileMust(os.TempDir(), "recipes", os.ModePerm).Name(), + timeFunc: mockTimeFunc, + recipes: make(map[string]Recipe, 0), + }, + operations: []struct { + recipe Recipe + wantErr error + }{ + { + recipe: Recipe{ + Name: "test2", + Version: []int{2, 2}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: false, + }, + wantErr: nil, + }, + { + recipe: Recipe{ + Name: "test3", + Version: []int{-3}, + Success: true, + }, + wantErr: fmt.Errorf("invalid Version string"), + }, + }, + wantContent: `[{"Name":"test2","Version":[2,2],"InstallTime":949408200,"Success":false}]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, operation := range tt.operations { + recipe := operation.recipe + gotErr := tt.db.addRecipe(recipe.Name, recipe.Version.String(), recipe.Success) + utiltest.EnsureError(t, operation.wantErr, gotErr) + } + + gotContent, gotContentErr := readFile(tt.db.file) + utiltest.EnsureError(t, nil, gotContentErr) + utiltest.EnsureResults(t, tt.wantContent, string(gotContent)) + }) + } +} + +func Test_recipeDB_saveToFS(t *testing.T) { + tests := []struct { + name string + + db *recipeDB + + wantContent string + wantErr error + }{ + { + name: "database with records, properly stored on the fs", + db: &recipeDB{ + file: tempFileMust(os.TempDir(), "recipes", os.ModePerm).Name(), + timeFunc: mockTimeFunc, + recipes: map[string]Recipe{ + "test": Recipe{ + Name: "test", + Version: []int{1, 1}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: true, + }, + "test2": Recipe{ + Name: "test2", + Version: []int{2, 2}, + InstallTime: time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC).Unix(), + Success: false, + }, + }, + }, + wantContent: `[{"Name":"test","Version":[1,1],"InstallTime":949408200,"Success":true},{"Name":"test2","Version":[2,2],"InstallTime":949408200,"Success":false}]`, + wantErr: nil, + }, + { + name: "path to system dir", + db: &recipeDB{ + file: string(os.PathSeparator), + timeFunc: mockTimeFunc, + recipes: make(map[string]Recipe, 0), + }, + + wantErr: fmt.Errorf("createtemp /_*: pattern contains path separator"), + wantContent: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := tt.db.saveToFS() + utiltest.EnsureError(t, tt.wantErr, gotErr) + + if tt.wantErr == nil { + gotContent, gotContentErr := readFile(tt.db.file) + + utiltest.EnsureError(t, nil, gotContentErr) + utiltest.EnsureResults(t, string(tt.wantContent), string(gotContent)) + } + }) + + } +} + +func Test_newRecipeDBWithDefaults(t *testing.T) { + wantDir := dbDirUnix + if runtime.GOOS == "windows" { + wantDir = dbDirWindows + } + wantDBPath := filepath.Join(wantDir, dbFileName) + + db, gotErr := newRecipeDBWithDefaults() + gotDBPath := db.file + + utiltest.EnsureError(t, nil, gotErr) + utiltest.EnsureResults(t, wantDBPath, gotDBPath) +} + +func tempFileMust(dir, pattern string, mode os.FileMode) *os.File { + fd, err := util.TempFile(dir, pattern, mode) + if err != nil { + panic(err) + } + + defer fd.Close() + + return fd +} + +func readFile(path string) ([]byte, error) { + fd, err := os.Open(path) + if err != nil { + return nil, err + } + + return io.ReadAll(fd) +} + +func mockTimeFunc() time.Time { + return time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC) +} diff --git a/util/utiltest/utiltest.go b/util/utiltest/utiltest.go index 930f8018a..ea52f6385 100644 --- a/util/utiltest/utiltest.go +++ b/util/utiltest/utiltest.go @@ -2,6 +2,7 @@ package utiltest import ( "errors" + "fmt" "os" "testing" @@ -81,3 +82,19 @@ func MatchSnapshot(t testReporter, actual any, snapshotFilepath string) { removeSnapshotDraft(snapshotFilepath) } } + +// EnsureError fails test and print diff if want and got errors not equals. +func EnsureError(t *testing.T, want, got error) { + if fmt.Sprintf("%s", want) != fmt.Sprintf("%s", got) { + t.Errorf("unwanted error, want: %s, got: %s", want, got) + + } +} + +// EnsureResults fails test and print diff if want and got parameters not equals. +func EnsureResults(t *testing.T, want, got any, options ...cmp.Option) { + diff := cmp.Diff(want, got, options...) + if diff != "" { + t.Errorf("unwanted diff:\n%s", diff) + } +}