From ffb6246c09dddc6183334bc3e7b68481abe99ffa Mon Sep 17 00:00:00 2001 From: Pablo Compagni Date: Fri, 10 Apr 2020 19:39:55 -0300 Subject: [PATCH 1/2] make sql statements customizable via options --- migrator.go | 64 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/migrator.go b/migrator.go index 955909b..4f2dbd4 100644 --- a/migrator.go +++ b/migrator.go @@ -12,14 +12,22 @@ const defaultTableName = "migrations" // Migrator is the migrator implementation type Migrator struct { - tableName string - logger Logger - migrations []interface{} + tableName string + logger Logger + migrations []interface{} + createSqlFormat string + insertSqlFormat string } // Option sets options such migrations or table name. type Option func(*Migrator) +type insertArgs struct { + sql string + id int + version string +} + // TableName creates an option to allow overriding the default table name func TableName(tableName string) Option { return func(m *Migrator) { @@ -47,6 +55,20 @@ func WithLogger(logger Logger) Option { } } +// WithCreateSqlFormat creates an option to allow overriding the create SQL script format +func WithCreateSqlFormat(createSql string) Option { + return func(m *Migrator) { + m.createSqlFormat = createSql + } +} + +// WithInsertSqlFormat creates an option to allow overriding the insert SQL script format +func WithInsertSqlFormat(insertSql string) Option { + return func(m *Migrator) { + m.insertSqlFormat = insertSql + } +} + // Migrations creates an option with provided migrations func Migrations(migrations ...interface{}) Option { return func(m *Migrator) { @@ -59,6 +81,14 @@ func New(opts ...Option) (*Migrator, error) { m := &Migrator{ logger: log.New(os.Stdout, "migrator: ", 0), tableName: defaultTableName, + createSqlFormat: ` + CREATE TABLE IF NOT EXISTS %s ( + id INT8 NOT NULL, + version VARCHAR(255) NOT NULL, + PRIMARY KEY (id) + ); + `, + insertSqlFormat: "INSERT INTO %s (id, version) VALUES (?, ?)", } for _, opt := range opts { opt(m) @@ -83,13 +113,7 @@ func New(opts ...Option) (*Migrator, error) { // Migrate applies all available migrations func (m *Migrator) Migrate(db *sql.DB) error { // create migrations table if doesn't exist - _, err := db.Exec(fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( - id INT8 NOT NULL, - version VARCHAR(255) NOT NULL, - PRIMARY KEY (id) - ); - `, m.tableName)) + _, err := db.Exec(fmt.Sprintf(m.createSqlFormat, m.tableName)) if err != nil { return err } @@ -106,14 +130,17 @@ func (m *Migrator) Migrate(db *sql.DB) error { // plan migrations for idx, migration := range m.migrations[count:len(m.migrations)] { - insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.tableName, idx+count, migration.(fmt.Stringer).String()) + sqlStmt := fmt.Sprintf(m.insertSqlFormat, m.tableName) + insertId := idx + count + insertVersion := migration.(fmt.Stringer).String() + switch mig := migration.(type) { case *Migration: - if err := migrate(db, m.logger, insertVersion, mig); err != nil { + if err := migrate(db, m.logger, insertArgs{sql: sqlStmt, id: insertId, version: insertVersion,}, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } case *MigrationNoTx: - if err := migrateNoTx(db, m.logger, insertVersion, mig); err != nil { + if err := migrateNoTx(db, m.logger, insertArgs{sql: sqlStmt, id: insertId, version: insertVersion,}, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } } @@ -173,7 +200,7 @@ func (m *MigrationNoTx) String() string { return m.Name } -func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migration) error { +func migrate(db *sql.DB, logger Logger, args insertArgs, migration *Migration) error { tx, err := db.Begin() if err != nil { return err @@ -191,7 +218,8 @@ func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migrati if err = migration.Func(tx); err != nil { return fmt.Errorf("error executing golang migration: %s", err) } - if _, err = tx.Exec(insertVersion); err != nil { + + if _, err = tx.Exec(args.sql, args.id, args.version); err != nil { return fmt.Errorf("error updating migration versions: %s", err) } logger.Printf("applied migration named '%s'", migration.Name) @@ -199,12 +227,14 @@ func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migrati return err } -func migrateNoTx(db *sql.DB, logger Logger, insertVersion string, migration *MigrationNoTx) error { +func migrateNoTx(db *sql.DB, logger Logger, args insertArgs, migration *MigrationNoTx) error { logger.Printf("applying no tx migration named '%s'...", migration.Name) + if err := migration.Func(db); err != nil { return fmt.Errorf("error executing golang migration: %s", err) } - if _, err := db.Exec(insertVersion); err != nil { + + if _, err := db.Exec(args.sql, args.id, args.version); err != nil { return fmt.Errorf("error updating migration versions: %s", err) } logger.Printf("applied no tx migration named '%s'", migration.Name) From 5508ff9ad83a5fc33622bd2729ab08b14be19ae1 Mon Sep 17 00:00:00 2001 From: Pablo Compagni Date: Fri, 10 Apr 2020 21:09:57 -0300 Subject: [PATCH 2/2] renamed fields to comply with linter --- migrator.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/migrator.go b/migrator.go index 4f2dbd4..7d15cd4 100644 --- a/migrator.go +++ b/migrator.go @@ -15,8 +15,8 @@ type Migrator struct { tableName string logger Logger migrations []interface{} - createSqlFormat string - insertSqlFormat string + createSQLFormat string + insertSQLFormat string } // Option sets options such migrations or table name. @@ -55,17 +55,17 @@ func WithLogger(logger Logger) Option { } } -// WithCreateSqlFormat creates an option to allow overriding the create SQL script format -func WithCreateSqlFormat(createSql string) Option { +// WithCreateSQLFormat creates an option to allow overriding the create SQL script format +func WithCreateSQLFormat(createSQL string) Option { return func(m *Migrator) { - m.createSqlFormat = createSql + m.createSQLFormat = createSQL } } -// WithInsertSqlFormat creates an option to allow overriding the insert SQL script format -func WithInsertSqlFormat(insertSql string) Option { +// WithInsertSQLFormat creates an option to allow overriding the insert SQL script format +func WithInsertSQLFormat(insertSQL string) Option { return func(m *Migrator) { - m.insertSqlFormat = insertSql + m.insertSQLFormat = insertSQL } } @@ -81,14 +81,14 @@ func New(opts ...Option) (*Migrator, error) { m := &Migrator{ logger: log.New(os.Stdout, "migrator: ", 0), tableName: defaultTableName, - createSqlFormat: ` + createSQLFormat: ` CREATE TABLE IF NOT EXISTS %s ( id INT8 NOT NULL, version VARCHAR(255) NOT NULL, PRIMARY KEY (id) ); `, - insertSqlFormat: "INSERT INTO %s (id, version) VALUES (?, ?)", + insertSQLFormat: "INSERT INTO %s (id, version) VALUES (?, ?)", } for _, opt := range opts { opt(m) @@ -113,7 +113,7 @@ func New(opts ...Option) (*Migrator, error) { // Migrate applies all available migrations func (m *Migrator) Migrate(db *sql.DB) error { // create migrations table if doesn't exist - _, err := db.Exec(fmt.Sprintf(m.createSqlFormat, m.tableName)) + _, err := db.Exec(fmt.Sprintf(m.createSQLFormat, m.tableName)) if err != nil { return err } @@ -130,17 +130,17 @@ func (m *Migrator) Migrate(db *sql.DB) error { // plan migrations for idx, migration := range m.migrations[count:len(m.migrations)] { - sqlStmt := fmt.Sprintf(m.insertSqlFormat, m.tableName) - insertId := idx + count + sqlStmt := fmt.Sprintf(m.insertSQLFormat, m.tableName) + insertID := idx + count insertVersion := migration.(fmt.Stringer).String() switch mig := migration.(type) { case *Migration: - if err := migrate(db, m.logger, insertArgs{sql: sqlStmt, id: insertId, version: insertVersion,}, mig); err != nil { + if err := migrate(db, m.logger, insertArgs{sql: sqlStmt, id: insertID, version: insertVersion,}, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } case *MigrationNoTx: - if err := migrateNoTx(db, m.logger, insertArgs{sql: sqlStmt, id: insertId, version: insertVersion,}, mig); err != nil { + if err := migrateNoTx(db, m.logger, insertArgs{sql: sqlStmt, id: insertID, version: insertVersion,}, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } }