diff --git a/config.go b/config.go index 66d7178..a04c5af 100644 --- a/config.go +++ b/config.go @@ -5,6 +5,9 @@ type PGConfig struct { Dir string // Directory for storing database files, removed for non-persistent configs IsPersistent bool // Whether to make the current configuraton persistent or not AdditionalArgs []string // Additional arguments to pass to the postgres command + FSync bool // Sets -F flag when false + DbName string + Password string } func New() *PGConfig { @@ -12,6 +15,9 @@ func New() *PGConfig { BinDir: "", Dir: "", IsPersistent: false, + FSync: false, // disable fsync by default - just go fast. + DbName: "test", + Password: "", } } @@ -40,6 +46,21 @@ func (c *PGConfig) WithAdditionalArgs(args ...string) *PGConfig { return c } +func (c *PGConfig) EnableFSync() *PGConfig { + c.FSync = true + return c +} + +func (c *PGConfig) SetPassword(password string) *PGConfig { + c.Password = password + return c +} + +func (c *PGConfig) SetDbName(dbName string) *PGConfig { + c.DbName = dbName + return c +} + func (c *PGConfig) Start() (*PG, error) { return start(c) } diff --git a/config_test.go b/config_test.go index 6cad64a..5156052 100644 --- a/config_test.go +++ b/config_test.go @@ -10,10 +10,13 @@ import ( func TestPGConfig(t *testing.T) { assert := assert.New(t) - config := pgtest.New().From("/usr/bin").DataDir("/tmp/data").Persistent().WithAdditionalArgs("-c", "log_statement=all") + config := pgtest.New().From("/usr/bin").DataDir("/tmp/data").Persistent().EnableFSync().SetDbName("mydb").SetPassword("mypassword").WithAdditionalArgs("-c", "log_statement=all") assert.True(config.IsPersistent) assert.EqualValues("/tmp/data", config.Dir) assert.EqualValues("/usr/bin", config.BinDir) assert.EqualValues([]string{"-c", "log_statement=all"}, config.AdditionalArgs) + assert.EqualValues(true, config.FSync) + assert.EqualValues("mydb", config.DbName) + assert.EqualValues("mypassword", config.Password) } diff --git a/pgtest.go b/pgtest.go index 903a12c..81426e3 100644 --- a/pgtest.go +++ b/pgtest.go @@ -56,6 +56,52 @@ func StartPersistent(folder string) (*PG, error) { return start(New().DataDir(folder).Persistent()) } +// When the password is provided this function writes the password +// to a temp file, and returns its path. +func createTempPasswordFile(password string, pgUID, pgGID int) (string, func(), error) { + // Create a temporary file + pwdTempFile, err := os.CreateTemp("", "pg_pwd_*") + if err != nil { + return "", nil, fmt.Errorf("failed to create temp file: %w", err) + } + tempFilePath := pwdTempFile.Name() + + // Set file permissions to 0600 (read/write for owner only) + if err := pwdTempFile.Chmod(0600); err != nil { + pwdTempFile.Close() + os.Remove(tempFilePath) + return "", nil, fmt.Errorf("failed to set file permissions: %w", err) + } + + // Write the password to the file + if _, err := pwdTempFile.WriteString(password); err != nil { + pwdTempFile.Close() + os.Remove(tempFilePath) + return "", nil, fmt.Errorf("failed to write password to file: %w", err) + } + + // Close the file + if err := pwdTempFile.Close(); err != nil { + os.Remove(tempFilePath) + return "", nil, fmt.Errorf("failed to close temp file: %w", err) + } + + // Change the owner + err = os.Chown(tempFilePath, pgUID, pgGID) + if err != nil { + return "", nil, fmt.Errorf("failed to set permissions on the temp file: %w", err) + } + + cleanupPwdFile := func() { + err := os.Remove(tempFilePath) + if err != nil { + fmt.Println("error while removing pwdfile: %w", err) + } + } + + return tempFilePath, cleanupPwdFile, nil +} + // start Starts a new PostgreSQL database // // Will listen on a unix socket and initialize the database in the given @@ -142,9 +188,28 @@ func start(config *PGConfig) (*PG, error) { // Initialize PostgreSQL data directory _, err = os.Stat(filepath.Join(dataDir, "postgresql.conf")) if os.IsNotExist(err) { - init := prepareCommand(isRoot, filepath.Join(binPath, "initdb"), + args := []string{ "-D", dataDir, "--no-sync", + } + + // setup password if specified + if config.Password != "" { + pwdFile, cleanupPwdFile, err := createTempPasswordFile(config.Password, pgUID, pgGID) + if err != nil { + return nil, fmt.Errorf("failed to create password file: %w", err) + } + // for more info on how this is being done + // https://www.postgresql.org/docs/current/app-initdb.html#APP-INITDB-OPTION-AUTH + // https://www.postgresql.org/docs/current/app-initdb.html#APP-INITDB-OPTION-PWFILE + args = append(args, "-A", "scram-sha-256", fmt.Sprintf("--pwfile=%s", pwdFile)) + // remove the password file, when we return from start() + defer cleanupPwdFile() + } + + // execute the command + init := prepareCommand(isRoot, filepath.Join(binPath, "initdb"), + args..., ) out, err := init.CombinedOutput() if err != nil { @@ -157,8 +222,17 @@ func start(config *PGConfig) (*PG, error) { "-D", dataDir, // Data directory "-k", sockDir, // Location for the UNIX socket "-h", "", // Disable TCP listening - "-F", // No fsync, just go fast } + + // by default config.Fsync is false, + // so fsync is disabled, unless it is set + // true by the user, in which case we + // skip the -F flag + if !config.FSync { + // no fsync, just go fast + args = append(args, "-F") + } + if len(config.AdditionalArgs) > 0 { args = append(args, config.AdditionalArgs...) } @@ -181,23 +255,24 @@ func start(config *PGConfig) (*PG, error) { return nil, abort("Failed to start PostgreSQL", cmd, stderr, stdout, err) } - // Connect to DB - dsn := makeDSN(sockDir, "postgres", isRoot) + dsn := makeDSN(sockDir, "postgres", config.Password) db, err := sql.Open("postgres", dsn) if err != nil { return nil, abort("Failed to connect to DB", cmd, stderr, stdout, err) } - // Prepare test database + // Prepare database with password err = retry(func() error { + // Ensure db exists var exists bool - err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = 'test'").Scan(&exists) - if exists { - return nil + err = db.QueryRow(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", config.DbName)).Scan(&exists) + if !exists { + _, err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", config.DbName)) + if err != nil { + return err + } } - - _, err := db.Exec("CREATE DATABASE test") - return err + return nil }, 1000, 10*time.Millisecond) if err != nil { return nil, abort("Failed to prepare test DB", cmd, stderr, stdout, err) @@ -208,8 +283,8 @@ func start(config *PGConfig) (*PG, error) { return nil, abort("Failed to disconnect", cmd, stderr, stdout, err) } - // Connect to it properly - dsn = makeDSN(sockDir, "test", isRoot) + // Connect to db with password + dsn = makeDSN(sockDir, config.DbName, config.Password) db, err = sql.Open("postgres", dsn) if err != nil { return nil, abort("Failed to connect to test DB", cmd, stderr, stdout, err) @@ -222,8 +297,8 @@ func start(config *PGConfig) (*PG, error) { DB: db, Host: sockDir, - User: pgUser(isRoot), - Name: "test", + User: pgUser(), + Name: config.DbName, persistent: config.IsPersistent, @@ -326,21 +401,31 @@ func findBinPath(binDir string) (string, error) { return "", fmt.Errorf("Did not find PostgreSQL executables installed") } -func pgUser(isRoot bool) string { - user := "" +func pgUser() string { + currentUser, err := user.Current() + isRoot := currentUser.Username == "root" if isRoot { - user = "postgres" + return "postgres" + } + if err != nil { + return "postgres" // fallback to postgres if we can't get the current user } - return user + return currentUser.Username } -func makeDSN(sockDir, dbname string, isRoot bool) string { +func makeDSN(sockDir, dbname, password string) string { dsnUser := "" - user := pgUser(isRoot) + dsnPassword := "" + user := pgUser() + // add user if defined if user != "" { dsnUser = fmt.Sprintf("user=%s", user) } - return fmt.Sprintf("host=%s dbname=%s %s", sockDir, dbname, dsnUser) + // add password if defined + if password != "" { + dsnPassword = fmt.Sprintf("password=%s", password) + } + return fmt.Sprintf("host=%s dbname=%s %s %s", sockDir, dbname, dsnUser, dsnPassword) } func retry(fn func() error, attempts int, interval time.Duration) error { diff --git a/pgtest_test.go b/pgtest_test.go index 11c89b9..2554661 100644 --- a/pgtest_test.go +++ b/pgtest_test.go @@ -1,7 +1,12 @@ package pgtest_test import ( + "database/sql" + "fmt" "os" + "os/user" + "path/filepath" + "reflect" "testing" "github.com/rubenv/pgtest" @@ -96,3 +101,90 @@ func TestAdditionalArgs(t *testing.T) { err = pg.Stop() assert.NoError(err) } + +func TestWrongDbNameAndPassword(t *testing.T) { + testDbWithUserNameAndPassword(t, "wrongdbName", "wrongPassword", false) +} + +func TestWrongDbName(t *testing.T) { + testDbWithUserNameAndPassword(t, "wrongdbName", "correctpassword", false) +} + +func TestWrongDbPassword(t *testing.T) { + testDbWithUserNameAndPassword(t, "correctdbname", "wrongpassword", false) +} + +func TestCorrectCredentials(t *testing.T) { + testDbWithUserNameAndPassword(t, "correctdbname", "correctpassword", true) +} + +// util functions for the dbname/password tests +func testDbWithUserNameAndPassword(t *testing.T, databaseName, password string, assertErrorNil bool) { + t.Parallel() + + assert := assert.New(t) + + pg, err := pgtest.New().SetDbName("correctdbname").SetPassword("correctpassword").Start() + assert.NoError(err) + assert.NotNil(pg) + + // connect using username and password via a different connection + // using the sockDir. + dsn := makeDsn(getSockDir(pg, t), databaseName, password) + // not testing the error returned by Open, because + // sometimes it returns without connecting. + // so we use .Ping to get the actual error. + connection, _ := sql.Open("postgres", dsn) + err = connection.Ping() + if assertErrorNil { + assert.NoError(err) + } else { + assert.Error(err) + } + + err = connection.Close() + assert.NoError(err) + + err = pg.Stop() + assert.NoError(err) +} + +func pgUser() string { + currentUser, err := user.Current() + isRoot := currentUser.Username == "root" + if isRoot { + return "postgres" + } + if err != nil { + return "postgres" // fallback to postgres if we can't get the current user + } + return currentUser.Username +} + +func makeDsn(sockDir, dbname, password string) string { + dsnUser := "" + dsnPassword := "" + user := pgUser() + // add user if defined + if user != "" { + dsnUser = fmt.Sprintf("user=%s", user) + } + // add password if defined + if password != "" { + dsnPassword = fmt.Sprintf("password=%s", password) + } + return fmt.Sprintf("host=%s dbname=%s %s %s", sockDir, dbname, dsnUser, dsnPassword) +} + +func getSockDir(pg *pgtest.PG, t *testing.T) string { + // Use reflection to access the private 'dir' field + pgValue := reflect.ValueOf(pg).Elem() + dirField := pgValue.FieldByName("dir") + if !dirField.IsValid() { + t.Fatal("Unable to find 'dir' field in PostgreSQL struct") + } + + dbRoot := dirField.String() + + return filepath.Join(dbRoot, "sock") +}