Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@ type PGConfig struct {
Dir string // Directory for storing database files, removed for non-persistent configs
IsPersistent bool // Whether to make the current configuraton persistent or not
AdditionalArgs []string // Additional arguments to pass to the postgres command
FSync bool // Sets -F flag when false
DbName string
Password string
}

func New() *PGConfig {
return &PGConfig{
BinDir: "",
Dir: "",
IsPersistent: false,
FSync: false, // disable fsync by default - just go fast.
DbName: "test",
Password: "",
}
}

Expand Down Expand Up @@ -40,6 +46,21 @@ func (c *PGConfig) WithAdditionalArgs(args ...string) *PGConfig {
return c
}

func (c *PGConfig) EnableFSync() *PGConfig {
c.FSync = true
return c
}

func (c *PGConfig) SetPassword(password string) *PGConfig {
c.Password = password
return c
}

func (c *PGConfig) SetDbName(dbName string) *PGConfig {
c.DbName = dbName
return c
}

func (c *PGConfig) Start() (*PG, error) {
return start(c)
}
5 changes: 4 additions & 1 deletion config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ import (
func TestPGConfig(t *testing.T) {
assert := assert.New(t)

config := pgtest.New().From("/usr/bin").DataDir("/tmp/data").Persistent().WithAdditionalArgs("-c", "log_statement=all")
config := pgtest.New().From("/usr/bin").DataDir("/tmp/data").Persistent().EnableFSync().SetDbName("mydb").SetPassword("mypassword").WithAdditionalArgs("-c", "log_statement=all")

assert.True(config.IsPersistent)
assert.EqualValues("/tmp/data", config.Dir)
assert.EqualValues("/usr/bin", config.BinDir)
assert.EqualValues([]string{"-c", "log_statement=all"}, config.AdditionalArgs)
assert.EqualValues(true, config.FSync)
assert.EqualValues("mydb", config.DbName)
assert.EqualValues("mypassword", config.Password)
}
129 changes: 107 additions & 22 deletions pgtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,52 @@ func StartPersistent(folder string) (*PG, error) {
return start(New().DataDir(folder).Persistent())
}

// When the password is provided this function writes the password
// to a temp file, and returns its path.
func createTempPasswordFile(password string, pgUID, pgGID int) (string, func(), error) {
// Create a temporary file
pwdTempFile, err := os.CreateTemp("", "pg_pwd_*")
if err != nil {
return "", nil, fmt.Errorf("failed to create temp file: %w", err)
}
tempFilePath := pwdTempFile.Name()

// Set file permissions to 0600 (read/write for owner only)
if err := pwdTempFile.Chmod(0600); err != nil {
pwdTempFile.Close()
os.Remove(tempFilePath)
return "", nil, fmt.Errorf("failed to set file permissions: %w", err)
}

// Write the password to the file
if _, err := pwdTempFile.WriteString(password); err != nil {
pwdTempFile.Close()
os.Remove(tempFilePath)
return "", nil, fmt.Errorf("failed to write password to file: %w", err)
}

// Close the file
if err := pwdTempFile.Close(); err != nil {
os.Remove(tempFilePath)
return "", nil, fmt.Errorf("failed to close temp file: %w", err)
}

// Change the owner
err = os.Chown(tempFilePath, pgUID, pgGID)
if err != nil {
return "", nil, fmt.Errorf("failed to set permissions on the temp file: %w", err)
}

cleanupPwdFile := func() {
err := os.Remove(tempFilePath)
if err != nil {
fmt.Println("error while removing pwdfile: %w", err)
}
}

return tempFilePath, cleanupPwdFile, nil
}

// start Starts a new PostgreSQL database
//
// Will listen on a unix socket and initialize the database in the given
Expand Down Expand Up @@ -142,9 +188,28 @@ func start(config *PGConfig) (*PG, error) {
// Initialize PostgreSQL data directory
_, err = os.Stat(filepath.Join(dataDir, "postgresql.conf"))
if os.IsNotExist(err) {
init := prepareCommand(isRoot, filepath.Join(binPath, "initdb"),
args := []string{
"-D", dataDir,
"--no-sync",
}

// setup password if specified
if config.Password != "" {
pwdFile, cleanupPwdFile, err := createTempPasswordFile(config.Password, pgUID, pgGID)
if err != nil {
return nil, fmt.Errorf("failed to create password file: %w", err)
}
// for more info on how this is being done
// https://www.postgresql.org/docs/current/app-initdb.html#APP-INITDB-OPTION-AUTH
// https://www.postgresql.org/docs/current/app-initdb.html#APP-INITDB-OPTION-PWFILE
args = append(args, "-A", "scram-sha-256", fmt.Sprintf("--pwfile=%s", pwdFile))
// remove the password file, when we return from start()
defer cleanupPwdFile()
}

// execute the command
init := prepareCommand(isRoot, filepath.Join(binPath, "initdb"),
args...,
)
out, err := init.CombinedOutput()
if err != nil {
Expand All @@ -157,8 +222,17 @@ func start(config *PGConfig) (*PG, error) {
"-D", dataDir, // Data directory
"-k", sockDir, // Location for the UNIX socket
"-h", "", // Disable TCP listening
"-F", // No fsync, just go fast
}

// by default config.Fsync is false,
// so fsync is disabled, unless it is set
// true by the user, in which case we
// skip the -F flag
if !config.FSync {
// no fsync, just go fast
args = append(args, "-F")
}

if len(config.AdditionalArgs) > 0 {
args = append(args, config.AdditionalArgs...)
}
Expand All @@ -181,23 +255,24 @@ func start(config *PGConfig) (*PG, error) {
return nil, abort("Failed to start PostgreSQL", cmd, stderr, stdout, err)
}

// Connect to DB
dsn := makeDSN(sockDir, "postgres", isRoot)
dsn := makeDSN(sockDir, "postgres", config.Password)
db, err := sql.Open("postgres", dsn)
if err != nil {
return nil, abort("Failed to connect to DB", cmd, stderr, stdout, err)
}

// Prepare test database
// Prepare database with password
err = retry(func() error {
// Ensure db exists
var exists bool
err = db.QueryRow("SELECT 1 FROM pg_database WHERE datname = 'test'").Scan(&exists)
if exists {
return nil
err = db.QueryRow(fmt.Sprintf("SELECT 1 FROM pg_database WHERE datname = '%s'", config.DbName)).Scan(&exists)
if !exists {
_, err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", config.DbName))
if err != nil {
return err
}
}

_, err := db.Exec("CREATE DATABASE test")
return err
return nil
}, 1000, 10*time.Millisecond)
if err != nil {
return nil, abort("Failed to prepare test DB", cmd, stderr, stdout, err)
Expand All @@ -208,8 +283,8 @@ func start(config *PGConfig) (*PG, error) {
return nil, abort("Failed to disconnect", cmd, stderr, stdout, err)
}

// Connect to it properly
dsn = makeDSN(sockDir, "test", isRoot)
// Connect to db with password
dsn = makeDSN(sockDir, config.DbName, config.Password)
db, err = sql.Open("postgres", dsn)
if err != nil {
return nil, abort("Failed to connect to test DB", cmd, stderr, stdout, err)
Expand All @@ -222,8 +297,8 @@ func start(config *PGConfig) (*PG, error) {
DB: db,

Host: sockDir,
User: pgUser(isRoot),
Name: "test",
User: pgUser(),
Name: config.DbName,

persistent: config.IsPersistent,

Expand Down Expand Up @@ -326,21 +401,31 @@ func findBinPath(binDir string) (string, error) {
return "", fmt.Errorf("Did not find PostgreSQL executables installed")
}

func pgUser(isRoot bool) string {
user := ""
func pgUser() string {
currentUser, err := user.Current()
isRoot := currentUser.Username == "root"
if isRoot {
user = "postgres"
return "postgres"
}
if err != nil {
return "postgres" // fallback to postgres if we can't get the current user
}
return user
return currentUser.Username
}

func makeDSN(sockDir, dbname string, isRoot bool) string {
func makeDSN(sockDir, dbname, password string) string {
dsnUser := ""
user := pgUser(isRoot)
dsnPassword := ""
user := pgUser()
// add user if defined
if user != "" {
dsnUser = fmt.Sprintf("user=%s", user)
}
return fmt.Sprintf("host=%s dbname=%s %s", sockDir, dbname, dsnUser)
// add password if defined
if password != "" {
dsnPassword = fmt.Sprintf("password=%s", password)
}
return fmt.Sprintf("host=%s dbname=%s %s %s", sockDir, dbname, dsnUser, dsnPassword)
}

func retry(fn func() error, attempts int, interval time.Duration) error {
Expand Down
92 changes: 92 additions & 0 deletions pgtest_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package pgtest_test

import (
"database/sql"
"fmt"
"os"
"os/user"
"path/filepath"
"reflect"
"testing"

"github.com/rubenv/pgtest"
Expand Down Expand Up @@ -96,3 +101,90 @@ func TestAdditionalArgs(t *testing.T) {
err = pg.Stop()
assert.NoError(err)
}

func TestWrongDbNameAndPassword(t *testing.T) {
testDbWithUserNameAndPassword(t, "wrongdbName", "wrongPassword", false)
}

func TestWrongDbName(t *testing.T) {
testDbWithUserNameAndPassword(t, "wrongdbName", "correctpassword", false)
}

func TestWrongDbPassword(t *testing.T) {
testDbWithUserNameAndPassword(t, "correctdbname", "wrongpassword", false)
}

func TestCorrectCredentials(t *testing.T) {
testDbWithUserNameAndPassword(t, "correctdbname", "correctpassword", true)
}

// util functions for the dbname/password tests
func testDbWithUserNameAndPassword(t *testing.T, databaseName, password string, assertErrorNil bool) {
t.Parallel()

assert := assert.New(t)

pg, err := pgtest.New().SetDbName("correctdbname").SetPassword("correctpassword").Start()
assert.NoError(err)
assert.NotNil(pg)

// connect using username and password via a different connection
// using the sockDir.
dsn := makeDsn(getSockDir(pg, t), databaseName, password)
// not testing the error returned by Open, because
// sometimes it returns without connecting.
// so we use .Ping to get the actual error.
connection, _ := sql.Open("postgres", dsn)
err = connection.Ping()
if assertErrorNil {
assert.NoError(err)
} else {
assert.Error(err)
}

err = connection.Close()
assert.NoError(err)

err = pg.Stop()
assert.NoError(err)
}

func pgUser() string {
currentUser, err := user.Current()
isRoot := currentUser.Username == "root"
if isRoot {
return "postgres"
}
if err != nil {
return "postgres" // fallback to postgres if we can't get the current user
}
return currentUser.Username
}

func makeDsn(sockDir, dbname, password string) string {
dsnUser := ""
dsnPassword := ""
user := pgUser()
// add user if defined
if user != "" {
dsnUser = fmt.Sprintf("user=%s", user)
}
// add password if defined
if password != "" {
dsnPassword = fmt.Sprintf("password=%s", password)
}
return fmt.Sprintf("host=%s dbname=%s %s %s", sockDir, dbname, dsnUser, dsnPassword)
}

func getSockDir(pg *pgtest.PG, t *testing.T) string {
// Use reflection to access the private 'dir' field
pgValue := reflect.ValueOf(pg).Elem()
dirField := pgValue.FieldByName("dir")
if !dirField.IsValid() {
t.Fatal("Unable to find 'dir' field in PostgreSQL struct")
}

dbRoot := dirField.String()

return filepath.Join(dbRoot, "sock")
}