diff --git a/migrator.go b/migrator.go index 955909b..7d15cd4 100644 --- a/migrator.go +++ b/migrator.go @@ -12,14 +12,22 @@ const defaultTableName = "migrations" // Migrator is the migrator implementation type Migrator struct { - tableName string - logger Logger - migrations []interface{} + tableName string + logger Logger + migrations []interface{} + createSQLFormat string + insertSQLFormat string } // Option sets options such migrations or table name. type Option func(*Migrator) +type insertArgs struct { + sql string + id int + version string +} + // TableName creates an option to allow overriding the default table name func TableName(tableName string) Option { return func(m *Migrator) { @@ -47,6 +55,20 @@ func WithLogger(logger Logger) Option { } } +// WithCreateSQLFormat creates an option to allow overriding the create SQL script format +func WithCreateSQLFormat(createSQL string) Option { + return func(m *Migrator) { + m.createSQLFormat = createSQL + } +} + +// WithInsertSQLFormat creates an option to allow overriding the insert SQL script format +func WithInsertSQLFormat(insertSQL string) Option { + return func(m *Migrator) { + m.insertSQLFormat = insertSQL + } +} + // Migrations creates an option with provided migrations func Migrations(migrations ...interface{}) Option { return func(m *Migrator) { @@ -59,6 +81,14 @@ func New(opts ...Option) (*Migrator, error) { m := &Migrator{ logger: log.New(os.Stdout, "migrator: ", 0), tableName: defaultTableName, + createSQLFormat: ` + CREATE TABLE IF NOT EXISTS %s ( + id INT8 NOT NULL, + version VARCHAR(255) NOT NULL, + PRIMARY KEY (id) + ); + `, + insertSQLFormat: "INSERT INTO %s (id, version) VALUES (?, ?)", } for _, opt := range opts { opt(m) @@ -83,13 +113,7 @@ func New(opts ...Option) (*Migrator, error) { // Migrate applies all available migrations func (m *Migrator) Migrate(db *sql.DB) error { // create migrations table if doesn't exist - _, err := db.Exec(fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( - id INT8 NOT NULL, - version VARCHAR(255) NOT NULL, - PRIMARY KEY (id) - ); - `, m.tableName)) + _, err := db.Exec(fmt.Sprintf(m.createSQLFormat, m.tableName)) if err != nil { return err } @@ -106,14 +130,17 @@ func (m *Migrator) Migrate(db *sql.DB) error { // plan migrations for idx, migration := range m.migrations[count:len(m.migrations)] { - insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.tableName, idx+count, migration.(fmt.Stringer).String()) + sqlStmt := fmt.Sprintf(m.insertSQLFormat, m.tableName) + insertID := idx + count + insertVersion := migration.(fmt.Stringer).String() + switch mig := migration.(type) { case *Migration: - if err := migrate(db, m.logger, insertVersion, mig); err != nil { + if err := migrate(db, m.logger, insertArgs{sql: sqlStmt, id: insertID, version: insertVersion,}, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } case *MigrationNoTx: - if err := migrateNoTx(db, m.logger, insertVersion, mig); err != nil { + if err := migrateNoTx(db, m.logger, insertArgs{sql: sqlStmt, id: insertID, version: insertVersion,}, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } } @@ -173,7 +200,7 @@ func (m *MigrationNoTx) String() string { return m.Name } -func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migration) error { +func migrate(db *sql.DB, logger Logger, args insertArgs, migration *Migration) error { tx, err := db.Begin() if err != nil { return err @@ -191,7 +218,8 @@ func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migrati if err = migration.Func(tx); err != nil { return fmt.Errorf("error executing golang migration: %s", err) } - if _, err = tx.Exec(insertVersion); err != nil { + + if _, err = tx.Exec(args.sql, args.id, args.version); err != nil { return fmt.Errorf("error updating migration versions: %s", err) } logger.Printf("applied migration named '%s'", migration.Name) @@ -199,12 +227,14 @@ func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migrati return err } -func migrateNoTx(db *sql.DB, logger Logger, insertVersion string, migration *MigrationNoTx) error { +func migrateNoTx(db *sql.DB, logger Logger, args insertArgs, migration *MigrationNoTx) error { logger.Printf("applying no tx migration named '%s'...", migration.Name) + if err := migration.Func(db); err != nil { return fmt.Errorf("error executing golang migration: %s", err) } - if _, err := db.Exec(insertVersion); err != nil { + + if _, err := db.Exec(args.sql, args.id, args.version); err != nil { return fmt.Errorf("error updating migration versions: %s", err) } logger.Printf("applied no tx migration named '%s'", migration.Name)