From f8d18d6e9e99f64feda8bcfa07a02c97e7f2f385 Mon Sep 17 00:00:00 2001 From: Ahmad Faiz Kamaludin Date: Sat, 9 Aug 2025 17:07:15 +0700 Subject: [PATCH 1/7] feat: refactor builder --- blueprint.go | 840 ++++++++++---------------- builder.go | 58 +- column_definition.go | 156 +++-- command.go | 43 ++ dialect.go | 20 +- dialect_test.go | 30 - examples/basic/cmd/migrate/migrate.go | 2 +- foreign_key_definition.go | 114 ++-- grammar.go | 114 ++-- index_definition.go | 12 +- mysql_builder.go | 22 +- mysql_builder_test.go | 81 +-- mysql_grammar.go | 549 +++++++++++------ mysql_grammar_test.go | 195 ++---- options.go | 26 + postgres_builder.go | 13 +- postgres_builder_test.go | 119 ++-- postgres_grammar.go | 571 ++++++++++------- postgres_grammar_test.go | 96 ++- schema.go | 4 +- schema_test.go | 5 +- util.go | 45 +- 22 files changed, 1579 insertions(+), 1536 deletions(-) create mode 100644 command.go delete mode 100644 dialect_test.go create mode 100644 options.go diff --git a/blueprint.go b/blueprint.go index 0a3f7d6..a036772 100644 --- a/blueprint.go +++ b/blueprint.go @@ -1,74 +1,57 @@ package schema import ( + "context" + "database/sql" "fmt" - "slices" + "log" ) -type columnType uint8 - -const ( - columnTypeBoolean columnType = iota - columnTypeChar - columnTypeString - columnTypeLongText - columnTypeMediumText - columnTypeText - columnTypeTinyText - columnTypeBigInteger - columnTypeInteger - columnTypeMediumInteger - columnTypeSmallInteger - columnTypeTinyInteger - columnTypeDecimal - columnTypeDouble - columnTypeFloat - columnTypeDateTime - columnTypeDateTimeTz - columnTypeDate - columnTypeTime - columnTypeTimestamp - columnTypeTimestampTz - columnTypeYear - columnTypeBinary - columnTypeJSON - columnTypeJSONB - columnTypeGeography - columnTypeGeometry - columnTypePoint - columnTypeUUID - columnTypeEnum - columnTypeCustom // Custom type for user-defined types -) - -type indexType int - const ( - indexTypeIndex indexType = iota - indexTypeUnique - indexTypePrimary - indexTypeFullText + columnTypeBoolean string = "boolean" + columnTypeChar string = "char" + columnTypeString string = "string" + columnTypeLongText string = "longText" + columnTypeMediumText string = "mediumText" + columnTypeText string = "text" + columnTypeTinyText string = "tinyText" + columnTypeBigInteger string = "bigInteger" + columnTypeInteger string = "integer" + columnTypeMediumInteger string = "mediumInteger" + columnTypeSmallInteger string = "smallInteger" + columnTypeTinyInteger string = "tinyInteger" + columnTypeDecimal string = "decimal" + columnTypeDouble string = "double" + columnTypeFloat string = "float" + columnTypeDateTime string = "dateTime" + columnTypeDateTimeTz string = "dateTimeTz" + columnTypeDate string = "date" + columnTypeTime string = "time" + columnTypeTimeTz string = "timeTz" + columnTypeTimestamp string = "timestamp" + columnTypeTimestampTz string = "timestampTz" + columnTypeYear string = "year" + columnTypeBinary string = "binary" + columnTypeJSON string = "json" + columnTypeJSONB string = "jsonb" + columnTypeGeography string = "geography" + columnTypeGeometry string = "geometry" + columnTypePoint string = "point" + columnTypeUUID string = "uuid" + columnTypeEnum string = "enum" ) // Blueprint represents a schema blueprint for creating or altering a database table. type Blueprint struct { - commands []string // commands to be executed - name string - newName string - charset string - collation string - engine string - columns []*columnDefinition - indexes []*indexDefinition - foreignKeys []*foreignKeyDefinition - dropColumns []string - renameColumns map[string]string // old column name to new column name - dropIndexes []string // indexes to be dropped - dropForeignKeys []string // foreign keys to be dropped - dropPrimaryKeys []string // primary keys to be dropped - dropUniqueKeys []string // unique keys to be dropped - dropFullText []string // fulltext indexes to be dropped - renameIndexes map[string]string // old index name to new index name + debug bool + dialect dialect + columns []*columnDefinition + commands []*command + grammar grammar + name string + charset string + collation string + engine string } // Charset sets the character set for the table in the blueprint. @@ -88,83 +71,56 @@ func (b *Blueprint) Engine(engine string) { // Column creates a new custom column definition in the blueprint with the specified name and type. func (b *Blueprint) Column(name string, columnType string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeCustom, - customColumnType: columnType, - }) + return b.addColumn(columnType, name) } // Boolean creates a new boolean column definition in the blueprint. func (b *Blueprint) Boolean(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBoolean, - }) + return b.addColumn(columnTypeBoolean, name) } // Char creates a new char column definition in the blueprint. func (b *Blueprint) Char(name string, length ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeChar, - length: optional(0, length...), + return b.addColumn(columnTypeChar, name, &columnDefinition{ + length: optionalPtr(1, length...), }) } // String creates a new string column definition in the blueprint. func (b *Blueprint) String(name string, length ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeString, - length: optional(0, length...), + return b.addColumn(columnTypeString, name, &columnDefinition{ + length: optionalPtr(255, length...), }) } // LongText creates a new long text column definition in the blueprint. func (b *Blueprint) LongText(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeLongText, - }) + return b.addColumn(columnTypeLongText, name) } // Text creates a new text column definition in the blueprint. func (b *Blueprint) Text(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeText, - }) + return b.addColumn(columnTypeText, name) } // MediumText creates a new medium text column definition in the blueprint. func (b *Blueprint) MediumText(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeMediumText, - }) + return b.addColumn(columnTypeMediumText, name) } // TinyText creates a new tiny text column definition in the blueprint. func (b *Blueprint) TinyText(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTinyText, - }) + return b.addColumn(columnTypeTinyText, name) } // BigIncrements creates a new big increments column definition in the blueprint. func (b *Blueprint) BigIncrements(name string) ColumnDefinition { - return b.UnsignedBigInteger(name, true) + return b.UnsignedBigInteger(name).AutoIncrement() } // BigInteger creates a new big integer column definition in the blueprint. -func (b *Blueprint) BigInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBigInteger, - autoIncrement: optional(false, autoIncrement...), - }) +func (b *Blueprint) BigInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeBigInteger, name) } // Decimal creates a new decimal column definition in the blueprint. @@ -177,39 +133,25 @@ func (b *Blueprint) BigInteger(name string, autoIncrement ...bool) ColumnDefinit // // table.Decimal("price") // creates a decimal column with default total 8 and places 2 func (b *Blueprint) Decimal(name string, params ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDecimal, - total: optional(8, params...), - places: optionalAtIndex(1, 2, params...), + defaultPlaces := 2 + if len(params) > 1 { + defaultPlaces = params[1] + } + return b.addColumn(columnTypeDecimal, name, &columnDefinition{ + total: optionalPtr(8, params...), + places: ptrOf(defaultPlaces), }) } // Double creates a new double column definition in the blueprint. -func (b *Blueprint) Double(name string, params ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDouble, - total: optional(0, params...), - places: optionalAtIndex(1, 0, params...), - }) +func (b *Blueprint) Double(name string) ColumnDefinition { + return b.addColumn(columnTypeDouble, name) } // Float creates a new float column definition in the blueprint. -// -// The total and places parameters are optional. -// -// Example: -// -// table.Float("price", 10, 2) // creates a float column with total 10 and places 2 -// -// table.Float("price") // creates a float column with default total 8 and places 2 -func (b *Blueprint) Float(name string, params ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeFloat, - total: optional(8, params...), - places: optionalAtIndex(1, 2, params...), +func (b *Blueprint) Float(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeFloat, name, &columnDefinition{ + precision: optionalPtr(53, precision...), }) } @@ -222,246 +164,180 @@ func (b *Blueprint) ID(name ...string) ColumnDefinition { // Increments create a new increment column definition in the blueprint. func (b *Blueprint) Increments(name string) ColumnDefinition { - return b.UnsignedInteger(name, true) + return b.UnsignedInteger(name).AutoIncrement() } // Integer creates a new integer column definition in the blueprint. -func (b *Blueprint) Integer(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeInteger, - autoIncrement: optional(false, autoIncrement...), - }) +func (b *Blueprint) Integer(name string) ColumnDefinition { + return b.addColumn(columnTypeInteger, name) } // MediumIncrements creates a new medium increments column definition in the blueprint. func (b *Blueprint) MediumIncrements(name string) ColumnDefinition { - return b.UnsignedMediumInteger(name, true) + return b.UnsignedMediumInteger(name).AutoIncrement() } -func (b *Blueprint) MediumInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeMediumInteger, - autoIncrement: optional(false, autoIncrement...), - }) +func (b *Blueprint) MediumInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeMediumInteger, name) } // SmallIncrements creates a new small increments column definition in the blueprint. func (b *Blueprint) SmallIncrements(name string) ColumnDefinition { - return b.UnsignedSmallInteger(name, true) + return b.UnsignedSmallInteger(name).AutoIncrement() } // SmallInteger creates a new small integer column definition in the blueprint. -func (b *Blueprint) SmallInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeSmallInteger, - autoIncrement: optional(false, autoIncrement...), - }) +func (b *Blueprint) SmallInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeSmallInteger, name) } // TinyIncrements creates a new tiny increments column definition in the blueprint. func (b *Blueprint) TinyIncrements(name string) ColumnDefinition { - return b.UnsignedTinyInteger(name, true) + return b.UnsignedTinyInteger(name).AutoIncrement() } // TinyInteger creates a new tiny integer column definition in the blueprint. -func (b *Blueprint) TinyInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTinyInteger, - autoIncrement: optional(false, autoIncrement...), - }) +func (b *Blueprint) TinyInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeTinyInteger, name) } // UnsignedBigInteger creates a new unsigned big integer column definition in the blueprint. -func (b *Blueprint) UnsignedBigInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBigInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) +func (b *Blueprint) UnsignedBigInteger(name string) ColumnDefinition { + return b.BigInteger(name).Unsigned() } // UnsignedInteger creates a new unsigned integer column definition in the blueprint. -func (b *Blueprint) UnsignedInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) +func (b *Blueprint) UnsignedInteger(name string) ColumnDefinition { + return b.Integer(name).Unsigned() } // UnsignedMediumInteger creates a new unsigned medium integer column definition in the blueprint. -func (b *Blueprint) UnsignedMediumInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeMediumInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) +func (b *Blueprint) UnsignedMediumInteger(name string) ColumnDefinition { + return b.MediumInteger(name).Unsigned() } // UnsignedSmallInteger creates a new unsigned small integer column definition in the blueprint. -func (b *Blueprint) UnsignedSmallInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeSmallInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) +func (b *Blueprint) UnsignedSmallInteger(name string) ColumnDefinition { + return b.SmallInteger(name).Unsigned() } // UnsignedTinyInteger creates a new unsigned tiny integer column definition in the blueprint. -func (b *Blueprint) UnsignedTinyInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTinyInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) +func (b *Blueprint) UnsignedTinyInteger(name string) ColumnDefinition { + return b.TinyInteger(name).Unsigned() } // DateTime creates a new date time column definition in the blueprint. func (b *Blueprint) DateTime(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDateTime, - precision: optional(0, precision...), + return b.addColumn(columnTypeDateTime, name, &columnDefinition{ + precision: optionalPtr(0, precision...), }) } // DateTimeTz creates a new date time with a time zone column definition in the blueprint. func (b *Blueprint) DateTimeTz(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDateTimeTz, - precision: optional(0, precision...), + return b.addColumn(columnTypeDateTimeTz, name, &columnDefinition{ + precision: optionalPtr(0, precision...), }) } // Date creates a new date column definition in the blueprint. func (b *Blueprint) Date(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDate, - }) + return b.addColumn(columnTypeDate, name) } // Time creates a new time column definition in the blueprint. -func (b *Blueprint) Time(name string, precission ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTime, - precision: optional(0, precission...), +func (b *Blueprint) Time(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeTime, name, &columnDefinition{ + precision: optionalPtr(0, precision...), + }) +} + +// TimeTz creates a new time with time zone column definition in the blueprint. +func (b *Blueprint) TimeTz(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeTimeTz, name, &columnDefinition{ + precision: optionalPtr(0, precision...), }) } // Timestamp creates a new timestamp column definition in the blueprint. // The precision parameter is optional and defaults to 0 if not provided. func (b *Blueprint) Timestamp(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTimestamp, - precision: optional(0, precision...), + return b.addColumn(columnTypeTimestamp, name, &columnDefinition{ + precision: optionalPtr(0, precision...), }) } // TimestampTz creates a new timestamp with time zone column definition in the blueprint. // The precision parameter is optional and defaults to 0 if not provided. func (b *Blueprint) TimestampTz(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTimestampTz, - precision: optional(0, precision...), + return b.addColumn(columnTypeTimestampTz, name, &columnDefinition{ + precision: optionalPtr(0, precision...), }) } // Timestamps adds created_at and updated_at timestamp columns to the blueprint. -func (b *Blueprint) Timestamps() { - b.Timestamp("created_at").Nullable(false).UseCurrent() - b.Timestamp("updated_at").Nullable(false).UseCurrent().UseCurrentOnUpdate() +func (b *Blueprint) Timestamps(precision ...int) { + b.Timestamp("created_at", precision...).UseCurrent() + b.Timestamp("updated_at", precision...).UseCurrent().UseCurrentOnUpdate() } // TimestampsTz adds created_at and updated_at timestamp with time zone columns to the blueprint. -func (b *Blueprint) TimestampsTz() { - b.TimestampTz("created_at").Nullable(false).UseCurrent() - b.TimestampTz("updated_at").Nullable(false).UseCurrent().UseCurrentOnUpdate() +func (b *Blueprint) TimestampsTz(precision ...int) { + b.TimestampTz("created_at", precision...).UseCurrent() + b.TimestampTz("updated_at", precision...).UseCurrent().UseCurrentOnUpdate() } // Year creates a new year column definition in the blueprint. func (b *Blueprint) Year(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeYear, - }) + return b.addColumn(columnTypeYear, name) } // Binary creates a new binary column definition in the blueprint. -func (b *Blueprint) Binary(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBinary, +func (b *Blueprint) Binary(name string, length ...int) ColumnDefinition { + return b.addColumn(columnTypeBinary, name, &columnDefinition{ + length: optionalNil(length...), }) } // JSON creates a new JSON column definition in the blueprint. func (b *Blueprint) JSON(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeJSON, - }) + return b.addColumn(columnTypeJSON, name) } // JSONB creates a new JSONB column definition in the blueprint. func (b *Blueprint) JSONB(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeJSONB, - }) + return b.addColumn(columnTypeJSONB, name) } // UUID creates a new UUID column definition in the blueprint. func (b *Blueprint) UUID(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeUUID, - }) + return b.addColumn(columnTypeUUID, name) } // Geography creates a new geography column definition in the blueprint. // The subType parameter is optional and can be used to specify the type of geography (e.g., "Point", "LineString", "Polygon"). // The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geography type. -func (b *Blueprint) Geography(name string, subType string, srid ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeGeography, - subType: subType, - srid: optional(4326, srid...), +func (b *Blueprint) Geography(name string, subtype string, srid ...int) ColumnDefinition { + return b.addColumn(columnTypeGeography, name, &columnDefinition{ + subtype: optionalPtr("", subtype), + srid: optionalPtr(4326, srid...), }) } // Geometry creates a new geometry column definition in the blueprint. // The subType parameter is optional and can be used to specify the type of geometry (e.g., "Point", "LineString", "Polygon"). // The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geometry type. -func (b *Blueprint) Geometry(name string, subType string, srid ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeGeometry, - subType: subType, - srid: optional(0, srid...), +func (b *Blueprint) Geometry(name string, subtype string, srid ...int) ColumnDefinition { + return b.addColumn(columnTypeGeometry, name, &columnDefinition{ + subtype: optionalPtr("", subtype), + srid: optionalNil(srid...), }) } // Point creates a new point column definition in the blueprint. func (b *Blueprint) Point(name string, srid ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypePoint, - srid: optional(4326, srid...), + return b.addColumn(columnTypePoint, name, &columnDefinition{ + srid: optionalPtr(4326, srid...), }) } @@ -472,14 +348,22 @@ func (b *Blueprint) Point(name string, srid ...int) ColumnDefinition { // // table.Enum("status", []string{"active", "inactive", "pending"}) // table.Enum("role", []string{"admin", "user", "guest"}).Comment("User role in the system") -func (b *Blueprint) Enum(name string, allowedEnums []string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeEnum, - allowedEnums: allowedEnums, +func (b *Blueprint) Enum(name string, allowed []string) ColumnDefinition { + return b.addColumn(columnTypeEnum, name, &columnDefinition{ + allowed: allowed, }) } +// DropTimestamps removes the created_at and updated_at timestamp columns from the blueprint. +func (b *Blueprint) DropTimestamps() { + b.DropColumn("created_at", "updated_at") +} + +// DropTimestampsTz removes the created_at and updated_at timestamp with time zone columns from the blueprint. +func (b *Blueprint) DropTimestampsTz() { + b.DropTimestamps() +} + // Index creates a new index definition in the blueprint. // // Example: @@ -488,14 +372,7 @@ func (b *Blueprint) Enum(name string, allowedEnums []string) ColumnDefinition { // table.Index("email", "username") // creates a composite index // table.Index("email").Algorithm("btree") // creates a btree index func (b *Blueprint) Index(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypeIndex, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("index") - - return index + return b.indexCommand(commandIndex, append([]string{column}, otherColumns...)...) } // Unique creates a new unique index definition in the blueprint. @@ -505,14 +382,7 @@ func (b *Blueprint) Index(column string, otherColumns ...string) IndexDefinition // table.Unique("email") // table.Unique("email", "username") // creates a composite unique index func (b *Blueprint) Unique(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypeUnique, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("unique") - - return index + return b.indexCommand(commandUnique, append([]string{column}, otherColumns...)...) } // Primary creates a new primary key index definition in the blueprint. @@ -522,25 +392,12 @@ func (b *Blueprint) Unique(column string, otherColumns ...string) IndexDefinitio // table.Primary("id") // table.Primary("id", "email") // creates a composite primary key func (b *Blueprint) Primary(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypePrimary, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("primary") - return index + return b.indexCommand(commandPrimary, append([]string{column}, otherColumns...)...) } // FullText creates a new fulltext index definition in the blueprint. func (b *Blueprint) FullText(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypeFullText, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("fullText") - - return index + return b.indexCommand(commandFullText, append([]string{column}, otherColumns...)...) } // Foreign creates a new foreign key definition in the blueprint. @@ -549,13 +406,10 @@ func (b *Blueprint) FullText(column string, otherColumns ...string) IndexDefinit // // table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") func (b *Blueprint) Foreign(column string) ForeignKeyDefinition { - fk := &foreignKeyDefinition{ - tableName: b.name, - column: column, - } - b.foreignKeys = append(b.foreignKeys, fk) - b.addCommand("foreign") - return fk + command := b.addCommand(commandForeign, &command{ + columns: []string{column}, + }) + return &foreignKeyDefinition{command: command} } // DropColumn adds a column to be dropped from the table. @@ -565,8 +419,9 @@ func (b *Blueprint) Foreign(column string) ForeignKeyDefinition { // table.DropColumn("old_column") // table.DropColumn("old_column", "another_old_column") // drops multiple columns func (b *Blueprint) DropColumn(column string, otherColumns ...string) { - b.dropColumns = append(b.dropColumns, append([]string{column}, otherColumns...)...) - b.addCommand("dropColumn") + b.addCommand(commandDropColumn, &command{ + columns: append([]string{column}, otherColumns...), + }) } // RenameColumn changes the name of the table in the blueprint. @@ -575,40 +430,34 @@ func (b *Blueprint) DropColumn(column string, otherColumns ...string) { // // table.RenameColumn("old_table_name", "new_table_name") func (b *Blueprint) RenameColumn(oldColumn string, newColumn string) { - if b.renameColumns == nil { - b.renameColumns = make(map[string]string) - } - b.renameColumns[oldColumn] = newColumn - b.addCommand("renameColumn") + b.addCommand(commandRenameColumn, &command{ + from: oldColumn, + to: newColumn, + }) } // DropIndex adds an index to be dropped from the table. -func (b *Blueprint) DropIndex(indexName string) { - b.dropIndexes = append(b.dropIndexes, indexName) - b.addCommand("dropIndex") +func (b *Blueprint) DropIndex(index any) { + b.dropIndexCommand(commandDropIndex, commandIndex, index) } // DropForeign adds a foreign key to be dropped from the table. -func (b *Blueprint) DropForeign(foreignKeyName string) { - b.dropForeignKeys = append(b.dropForeignKeys, foreignKeyName) - b.addCommand("dropForeign") +func (b *Blueprint) DropForeign(index any) { + b.dropIndexCommand(commandDropForeign, commandForeign, index) } // DropPrimary adds a primary key to be dropped from the table. -func (b *Blueprint) DropPrimary(primaryKeyName string) { - b.dropPrimaryKeys = append(b.dropPrimaryKeys, primaryKeyName) - b.addCommand("dropPrimary") +func (b *Blueprint) DropPrimary(index any) { + b.dropIndexCommand(commandDropPrimary, commandPrimary, index) } // DropUnique adds a unique key to be dropped from the table. -func (b *Blueprint) DropUnique(uniqueKeyName string) { - b.dropUniqueKeys = append(b.dropUniqueKeys, uniqueKeyName) - b.addCommand("dropUnique") +func (b *Blueprint) DropUnique(index any) { + b.dropIndexCommand(commandDropUnique, commandUnique, index) } -func (b *Blueprint) DropFulltext(indexName string) { - b.dropFullText = append(b.dropFullText, indexName) - b.addCommand("dropFullText") +func (b *Blueprint) DropFulltext(index any) { + b.dropIndexCommand(commandDropFullText, commandFullText, index) } // RenameIndex changes the name of an index in the blueprint. @@ -616,17 +465,16 @@ func (b *Blueprint) DropFulltext(indexName string) { // // table.RenameIndex("old_index_name", "new_index_name") func (b *Blueprint) RenameIndex(oldIndexName string, newIndexName string) { - if b.renameIndexes == nil { - b.renameIndexes = make(map[string]string) - } - b.renameIndexes[oldIndexName] = newIndexName - b.addCommand("renameIndex") + b.addCommand(commandRenameIndex, &command{ + from: oldIndexName, + to: newIndexName, + }) } func (b *Blueprint) getAddedColumns() []*columnDefinition { var addedColumns []*columnDefinition for _, col := range b.columns { - if !col.changed { + if !col.change { addedColumns = append(addedColumns, col) } } @@ -636,7 +484,7 @@ func (b *Blueprint) getAddedColumns() []*columnDefinition { func (b *Blueprint) getChangedColumns() []*columnDefinition { var changedColumns []*columnDefinition for _, col := range b.columns { - if col.changed { + if col.change { changedColumns = append(changedColumns, col) } } @@ -644,12 +492,12 @@ func (b *Blueprint) getChangedColumns() []*columnDefinition { } func (b *Blueprint) create() { - b.addCommand("create") + b.addCommand(commandCreate) } func (b *Blueprint) creating() bool { for _, command := range b.commands { - if command == "create" || command == "createIfNotExists" { + if command.name == commandCreate || command.name == commandCreateIfNotExists { return true } } @@ -657,45 +505,130 @@ func (b *Blueprint) creating() bool { } func (b *Blueprint) createIfNotExists() { - b.addCommand("createIfNotExists") + b.addCommand(commandCreateIfNotExists) } func (b *Blueprint) drop() { - b.addCommand("drop") + b.addCommand(commandDrop) } func (b *Blueprint) dropIfExists() { - b.addCommand("dropIfNotExists") + b.addCommand(commandDropIfExists) } -func (b *Blueprint) rename() { - b.addCommand("rename") +func (b *Blueprint) rename(to string) { + b.addCommand(commandRename, &command{ + to: to, + }) } func (b *Blueprint) addImpliedCommands() { - if len(b.getAddedColumns()) > 0 && !b.creating() { - b.commands = append([]string{"add"}, b.commands...) + b.addFluentIndexes() + + if !b.creating() { + if len(b.getAddedColumns()) > 0 { + b.commands = append([]*command{{name: commandAdd}}, b.commands...) + } + if len(b.getChangedColumns()) > 0 { + for _, col := range b.getChangedColumns() { + b.commands = append([]*command{{name: commandChange, column: col}}, b.commands...) + } + } + } +} + +func (b *Blueprint) addFluentIndexes() { + for _, col := range b.columns { + if col.primary != nil { + if b.dialect == dialectMySQL { + continue + } + if !*col.primary && col.change { + b.DropPrimary([]string{col.name}) + col.primary = nil + } + } + if col.index != nil { + if *col.index { + b.Index(col.name).Name(col.indexName) + col.index = nil + } else if !*col.index && col.change { + b.DropIndex([]string{col.name}) + col.index = nil + } + } + + if col.unique != nil { + if *col.unique { + b.Unique(col.name).Name(col.uniqueName) + col.unique = nil + } else if !*col.unique && col.change { + b.DropUnique([]string{col.name}) + col.unique = nil + } + } + } +} + +func (b *Blueprint) getFluentStatements() []string { + var statements []string + for _, column := range b.columns { + for _, fluentCommand := range b.grammar.getFluentCommands() { + if statement := fluentCommand(b, &command{column: column}); statement != "" { + statements = append(statements, statement) + } + } + } + return statements +} + +func (b *Blueprint) build(ctx context.Context, tx *sql.Tx) error { + statements, err := b.toSql() + if err != nil { + return err } - if len(b.getChangedColumns()) > 0 && !b.creating() { - b.commands = append([]string{"change"}, b.commands...) + for _, statement := range statements { + if b.debug { + log.Print(statement) + } + if _, err := tx.ExecContext(ctx, statement); err != nil { + return err + } } + return nil } -func (b *Blueprint) toSql(grammar grammar) ([]string, error) { +func (b *Blueprint) toSql() ([]string, error) { b.addImpliedCommands() var statements []string - commandMap := map[string]func(*Blueprint) (string, error){ - "create": grammar.compileCreate, - "createIfNotExists": grammar.compileCreateIfNotExists, - "add": grammar.compileAdd, - "drop": grammar.compileDrop, - "dropIfNotExists": grammar.compileDropIfExists, - "rename": grammar.compileRename, + mainCommandMap := map[string]func(*Blueprint) (string, error){ + commandCreate: b.grammar.compileCreate, + commandCreateIfNotExists: b.grammar.compileCreateIfNotExists, + commandAdd: b.grammar.compileAdd, + commandDrop: b.grammar.compileDrop, + commandDropIfExists: b.grammar.compileDropIfExists, } - for _, command := range b.commands { - if compileFunc, exists := commandMap[command]; exists { + secondaryCommandMap := map[string]func(*Blueprint, *command) (string, error){ + commandChange: b.grammar.compileChange, + commandDropColumn: b.grammar.compileDropColumn, + commandDropIndex: b.grammar.compileDropIndex, + commandDropForeign: b.grammar.compileDropForeign, + commandDropFullText: b.grammar.compileDropFulltext, + commandDropPrimary: b.grammar.compileDropPrimary, + commandDropUnique: b.grammar.compileDropUnique, + commandForeign: b.grammar.compileForeign, + commandFullText: b.grammar.compileFullText, + commandIndex: b.grammar.compileIndex, + commandPrimary: b.grammar.compilePrimary, + commandRename: b.grammar.compileRename, + commandRenameColumn: b.grammar.compileRenameColumn, + commandRenameIndex: b.grammar.compileRenameIndex, + commandUnique: b.grammar.compileUnique, + } + for _, cmd := range b.commands { + if compileFunc, exists := mainCommandMap[cmd.name]; exists { sql, err := compileFunc(b) if err != nil { return nil, err @@ -703,185 +636,76 @@ func (b *Blueprint) toSql(grammar grammar) ([]string, error) { if sql != "" { statements = append(statements, sql) } + continue } - switch command { - case "create", "createIfNotExists", "add": - for _, col := range b.getAddedColumns() { - if col.index { - indexDef := &indexDefinition{ - indexType: indexTypeIndex, - columns: []string{col.name}, - name: col.indexName, - } - sql, err := grammar.compileIndex(b, indexDef) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - } - case "change": - changedStatements, err := grammar.compileChange(b) - if err != nil { - return nil, err - } - statements = append(statements, changedStatements...) - case "index": - indexStatements, err := b.getIndexStatements(grammar, indexTypeIndex) - if err != nil { - return nil, err - } - statements = append(statements, indexStatements...) - case "unique": - uniqueStatements, err := b.getIndexStatements(grammar, indexTypeUnique) - if err != nil { - return nil, err - } - statements = append(statements, uniqueStatements...) - case "primary": - primaryStatements, err := b.getIndexStatements(grammar, indexTypePrimary) - if err != nil { - return nil, err - } - statements = append(statements, primaryStatements...) - case "fullText": - fulltextStatements, err := b.getIndexStatements(grammar, indexTypeFullText) - if err != nil { - return nil, err - } - statements = append(statements, fulltextStatements...) - case "foreign": - for _, foreignKey := range b.foreignKeys { - sql, err := grammar.compileForeign(b, foreignKey) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropColumn": - sql, err := grammar.compileDropColumn(b) + if compileFunc, exists := secondaryCommandMap[cmd.name]; exists { + sql, err := compileFunc(b, cmd) if err != nil { return nil, err } if sql != "" { statements = append(statements, sql) } - case "renameColumn": - for oldName, newName := range b.renameColumns { - sql, err := grammar.compileRenameColumn(b, oldName, newName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropIndex": - for _, indexName := range b.dropIndexes { - sql, err := grammar.compileDropIndex(b, indexName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropUnique": - for _, uniqueKeyName := range b.dropUniqueKeys { - sql, err := grammar.compileDropUnique(b, uniqueKeyName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropFullText": - for _, fullTextIndexName := range b.dropFullText { - sql, err := grammar.compileDropFulltext(b, fullTextIndexName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropPrimary": - for _, primaryKeyName := range b.dropPrimaryKeys { - sql, err := grammar.compileDropPrimary(b, primaryKeyName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "renameIndex": - for oldName, newName := range b.renameIndexes { - sql, err := grammar.compileRenameIndex(b, oldName, newName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropForeign": - for _, foreignKeyName := range b.dropForeignKeys { - sql, err := grammar.compileDropForeign(b, foreignKeyName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } + continue } + return nil, fmt.Errorf("unknown command: %s", cmd.name) } + statements = append(statements, b.getFluentStatements()...) + return statements, nil } -func (b *Blueprint) getIndexStatements(grammar grammar, idxType indexType) ([]string, error) { - indexCommandMap := map[indexType]func(*Blueprint, *indexDefinition) (string, error){ - indexTypeIndex: grammar.compileIndex, - indexTypeUnique: grammar.compileUnique, - indexTypePrimary: grammar.compilePrimary, - indexTypeFullText: grammar.compileFullText, - } - var statements []string - for _, index := range b.indexes { - if index.indexType == idxType { - compileFunc, exists := indexCommandMap[idxType] - if !exists { - return nil, fmt.Errorf("unsupported index type: %d", idxType) - } - sql, err := compileFunc(b, index) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } +func (b *Blueprint) addColumn(colType string, name string, columnDefs ...*columnDefinition) *columnDefinition { + var col *columnDefinition + if len(columnDefs) > 0 { + col = columnDefs[0] + } else { + col = &columnDefinition{} } - return statements, nil + col.columnType = colType + col.name = name + + return b.addColumnDefinition(col) } -func (b *Blueprint) addColumn(col *columnDefinition) *columnDefinition { +func (b *Blueprint) addColumnDefinition(col *columnDefinition) *columnDefinition { b.columns = append(b.columns, col) return col } -func (b *Blueprint) addCommand(command string) { - if command == "" { - return +func (b *Blueprint) indexCommand(name string, columns ...string) IndexDefinition { + command := b.addCommand(name, &command{ + columns: columns, + }) + return &indexDefinition{command} +} + +func (b *Blueprint) dropIndexCommand(name string, indexType string, index any) { + switch index := index.(type) { + case string: + b.addCommand(name, &command{ + index: index, + }) + case []string: + indexName := b.grammar.createIndexName(b, indexType, index...) + b.addCommand(name, &command{ + index: indexName, + }) + default: + panic(fmt.Sprintf("unsupported index type: %T", index)) } - if !slices.Contains(b.commands, command) { - b.commands = append(b.commands, command) +} + +func (b *Blueprint) addCommand(name string, parameters ...*command) *command { + var parameter *command + if len(parameters) > 0 { + parameter = parameters[0] + } else { + parameter = &command{} } + parameter.name = name + b.commands = append(b.commands, parameter) + + return parameter } diff --git a/builder.go b/builder.go index e63c1aa..aff8d9a 100644 --- a/builder.go +++ b/builder.go @@ -40,22 +40,27 @@ type Builder interface { // It returns an error if the dialect is not supported. // // Supported dialects are "postgres", "pgx", "mysql", and "mariadb". -func NewBuilder(dialect string) (Builder, error) { +func NewBuilder(dialect string, options ...Option) (Builder, error) { dialectVal := dialectFromString(dialect) switch dialectVal { case dialectMySQL: - return newMysqlBuilder(), nil + return newMysqlBuilder(options...), nil case dialectPostgres: - return newPostgresBuilder(), nil + return newPostgresBuilder(options...), nil default: return nil, errors.New("unsupported dialect: " + dialect) } } type baseBuilder struct { + debug bool grammar grammar } +func (b *baseBuilder) newBlueprint(name string) *Blueprint { + return &Blueprint{name: name, grammar: b.grammar, debug: b.debug} +} + func (b *baseBuilder) validateTxAndName(tx *sql.Tx, name string) error { if name == "" { return errors.New("table name is empty") @@ -84,16 +89,15 @@ func (b *baseBuilder) Create(ctx context.Context, tx *sql.Tx, name string, bluep return err } - bp := &Blueprint{name: name} + bp := b.newBlueprint(name) bp.create() blueprint(bp) - statements, err := bp.toSql(b.grammar) - if err != nil { + if err := bp.build(ctx, tx); err != nil { return err } - return execContext(ctx, tx, statements...) + return nil } func (b *baseBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { @@ -101,16 +105,15 @@ func (b *baseBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name st return err } - bp := &Blueprint{name: name} + bp := b.newBlueprint(name) bp.createIfNotExists() blueprint(bp) - statements, err := bp.toSql(b.grammar) - if err != nil { + if err := bp.build(ctx, tx); err != nil { return err } - return execContext(ctx, tx, statements...) + return nil } func (b *baseBuilder) Drop(ctx context.Context, tx *sql.Tx, name string) error { @@ -118,14 +121,14 @@ func (b *baseBuilder) Drop(ctx context.Context, tx *sql.Tx, name string) error { return err } - bp := &Blueprint{name: name} + bp := b.newBlueprint(name) bp.drop() - statements, err := bp.toSql(b.grammar) - if err != nil { + + if err := bp.build(ctx, tx); err != nil { return err } - return execContext(ctx, tx, statements...) + return nil } func (b *baseBuilder) DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { @@ -133,14 +136,14 @@ func (b *baseBuilder) DropIfExists(ctx context.Context, tx *sql.Tx, name string) return err } - bp := &Blueprint{name: name} + bp := b.newBlueprint(name) bp.dropIfExists() - statements, err := bp.toSql(b.grammar) - if err != nil { + + if err := bp.build(ctx, tx); err != nil { return err } - return execContext(ctx, tx, statements...) + return nil } func (b *baseBuilder) Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error { @@ -150,14 +153,14 @@ func (b *baseBuilder) Rename(ctx context.Context, tx *sql.Tx, oldName string, ne if tx == nil { return errors.New("transaction is nil") } - bp := &Blueprint{name: oldName, newName: newName} - bp.rename() - statements, err := bp.toSql(b.grammar) - if err != nil { + bp := b.newBlueprint(oldName) + bp.rename(newName) + + if err := bp.build(ctx, tx); err != nil { return err } - return execContext(ctx, tx, statements...) + return nil } func (b *baseBuilder) Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { @@ -165,13 +168,12 @@ func (b *baseBuilder) Table(ctx context.Context, tx *sql.Tx, name string, bluepr return err } - bp := &Blueprint{name: name} + bp := b.newBlueprint(name) blueprint(bp) - statements, err := bp.toSql(b.grammar) - if err != nil { + if err := bp.build(ctx, tx); err != nil { return err } - return execContext(ctx, tx, statements...) + return nil } diff --git a/column_definition.go b/column_definition.go index 516bea9..f6d48aa 100644 --- a/column_definition.go +++ b/column_definition.go @@ -1,7 +1,5 @@ package schema -import "slices" - // ColumnDefinition defines the interface for defining a column in a database table. type ColumnDefinition interface { // AutoIncrement sets the column to auto-increment. @@ -9,18 +7,24 @@ type ColumnDefinition interface { AutoIncrement() ColumnDefinition // Change changes the column definition. Change() ColumnDefinition + // Charset sets the character set for the column. + Charset(charset string) ColumnDefinition + // Collation sets the collation for the column. + Collation(collation string) ColumnDefinition // Comment adds a comment to the column definition. Comment(comment string) ColumnDefinition // Default sets a default value for the column. Default(value any) ColumnDefinition // Index adds an index to the column. - Index(indexName ...string) ColumnDefinition + Index(params ...any) ColumnDefinition // Nullable sets the column to be nullable or not. Nullable(value ...bool) ColumnDefinition + // OnUpdate sets the value to be used when the column is updated. + OnUpdate(value any) ColumnDefinition // Primary sets the column as a primary key. - Primary() ColumnDefinition + Primary(value ...bool) ColumnDefinition // Unique sets the column to be unique. - Unique(indexName ...string) ColumnDefinition + Unique(params ...any) ColumnDefinition // Unsigned sets the column to be unsigned (applicable for numeric types). Unsigned() ColumnDefinition // UseCurrent sets the column to use the current timestamp as default. @@ -29,109 +33,133 @@ type ColumnDefinition interface { UseCurrentOnUpdate() ColumnDefinition } +// Expression is a type for expressions that can be used as default values for columns. +// +// Example: +// +// schema.Timestamp("created_at").Default(schema.Expression("CURRENT_TIMESTAMP")) +type Expression string + +func (e Expression) String() string { + return string(e) +} + var _ ColumnDefinition = &columnDefinition{} type columnDefinition struct { - name string - columnType columnType - customColumnType string // for custom column types - commands []string - comment string - defaultValue any - onUpdateValue string - nullable bool - autoIncrement bool - unsigned bool - primary bool - index bool - indexName string - unique bool - uniqueName string - length int - precision int - total int - places int - changed bool - allowedEnums []string // for enum type columns - subType string // for geography and geometry types - srid int // for geography and geometry types -} - -func (c *columnDefinition) addCommand(command string) { - if command == "" { - return - } - if !slices.Contains(c.commands, command) { - c.commands = append(c.commands, command) - } + name string + columnType string + charset *string + collation *string + comment *string + defaultValue any + onUpdateValue any + useCurrent *bool + useCurrentOnUpdate *bool + nullable *bool + autoIncrement *bool + unsigned *bool + primary *bool + index *bool + indexName string + unique *bool + uniqueName string + length *int + precision *int + total *int + places *int + change bool + allowed []string // for enum type columns + subtype *string // for geography and geometry types + srid *int // for geography and geometry types } -func (c *columnDefinition) hasCommand(command string) bool { - return slices.Contains(c.commands, command) +func (c *columnDefinition) AutoIncrement() ColumnDefinition { + c.autoIncrement = ptrOf(true) + return c } -func (c *columnDefinition) AutoIncrement() ColumnDefinition { - c.addCommand("autoIncrement") - c.autoIncrement = true +func (c *columnDefinition) Charset(charset string) ColumnDefinition { + c.charset = &charset + return c +} + +func (c *columnDefinition) Change() ColumnDefinition { + c.change = true + return c +} + +func (c *columnDefinition) Collation(collation string) ColumnDefinition { + c.collation = &collation return c } func (c *columnDefinition) Comment(comment string) ColumnDefinition { - c.addCommand("comment") - c.comment = comment + c.comment = &comment return c } func (c *columnDefinition) Default(value any) ColumnDefinition { - c.addCommand("default") c.defaultValue = value return c } -func (c *columnDefinition) Index(indexName ...string) ColumnDefinition { - c.index = true - c.indexName = optional("", indexName...) - c.addCommand("index") +func (c *columnDefinition) Index(params ...any) ColumnDefinition { + index := true + for _, param := range params { + switch v := param.(type) { + case bool: + index = v + case string: + c.indexName = v + } + } + c.index = &index return c } func (c *columnDefinition) Nullable(value ...bool) ColumnDefinition { - c.addCommand("nullable") - c.nullable = optional(true, value...) + c.nullable = optionalPtr(true, value...) return c } -func (c *columnDefinition) Primary() ColumnDefinition { - c.addCommand("primary") - c.primary = true +func (c *columnDefinition) OnUpdate(value any) ColumnDefinition { + c.onUpdateValue = value return c } -func (c *columnDefinition) Unique(indexName ...string) ColumnDefinition { - c.addCommand("unique") - c.unique = true - c.uniqueName = optional("", indexName...) +func (c *columnDefinition) Primary(value ...bool) ColumnDefinition { + val := optional(true, value...) + c.primary = &val return c } -func (c *columnDefinition) Change() ColumnDefinition { - c.addCommand("change") - c.changed = true +func (c *columnDefinition) Unique(params ...any) ColumnDefinition { + unique := true + for _, param := range params { + switch v := param.(type) { + case bool: + unique = v + case string: + c.uniqueName = v + } + } + c.unique = &unique return c } func (c *columnDefinition) Unsigned() ColumnDefinition { - c.unsigned = true + c.unsigned = ptrOf(true) return c } func (c *columnDefinition) UseCurrent() ColumnDefinition { - c.Default("CURRENT_TIMESTAMP") + c.useCurrent = ptrOf(true) return c } func (c *columnDefinition) UseCurrentOnUpdate() ColumnDefinition { - c.onUpdateValue = "CURRENT_TIMESTAMP" + c.useCurrentOnUpdate = ptrOf(true) return c } diff --git a/command.go b/command.go new file mode 100644 index 0000000..8d59f5c --- /dev/null +++ b/command.go @@ -0,0 +1,43 @@ +package schema + +const ( + commandAdd string = "add" + commandChange string = "change" + commandCreate string = "create" + commandCreateIfNotExists string = "createIfNotExists" + commandComment string = "comment" + commandDefault string = "default" + commandDrop string = "drop" + commandDropIfExists string = "dropIfExists" + commandDropColumn string = "dropColumn" + commandDropForeign string = "dropForeign" + commandDropFullText string = "dropFullText" + commandDropIndex string = "dropIndex" + commandDropPrimary string = "dropPrimary" + commandDropUnique string = "dropUnique" + commandForeign string = "foreign" + commandFullText string = "fullText" + commandIndex string = "index" + commandPrimary string = "primary" + commandRename string = "rename" + commandRenameColumn string = "renameColumn" + commandRenameIndex string = "renameIndex" + commandUnique string = "unique" +) + +type command struct { + column *columnDefinition + deferrable *bool + initiallyImmediate *bool + algorithm string + from string + index string + language string + name string + on string + onDelete string + onUpdate string + to string + columns []string + references []string +} diff --git a/dialect.go b/dialect.go index a355a9d..5ce78f6 100644 --- a/dialect.go +++ b/dialect.go @@ -23,15 +23,18 @@ func (d dialect) String() string { } } -var dialectValue dialect = dialectUnknown -var debug = false +var dialectValue = dialectUnknown +var cfg = &config{ + debug: false, // default value +} -// SetDialect sets the current dialect for the schema package. -func SetDialect(d string) error { - dialectValue = dialectFromString(d) +// Init initializes the schema package with the given dialect and options. +func Init(dialect string, options ...Option) error { + dialectValue = dialectFromString(dialect) if dialectValue == dialectUnknown { - return fmt.Errorf("unknown dialect: %s", d) + return fmt.Errorf("unknown dialect: %s", dialect) } + cfg = applyOptions(options...) return nil } @@ -46,8 +49,3 @@ func dialectFromString(d string) dialect { return dialectUnknown } } - -// SetDebug enables or disables debug mode for the schema package. -func SetDebug(d bool) { - debug = d -} diff --git a/dialect_test.go b/dialect_test.go deleted file mode 100644 index 5889903..0000000 --- a/dialect_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package schema - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestSetDialect(t *testing.T) { - testCases := []struct { - dialect string - expectError bool - }{ - {"postgres", false}, - {"pgx", false}, - {"mysql", false}, - {"mariadb", false}, - {"sqlite", true}, - } - for _, tc := range testCases { - t.Run(tc.dialect, func(t *testing.T) { - err := SetDialect(tc.dialect) - if tc.expectError { - assert.Error(t, err, "Expected error for unsupported dialect %s", tc.dialect) - } else { - assert.NoError(t, err, "Did not expect error for supported dialect %s", tc.dialect) - } - }) - } -} diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index d55c33b..8fb54ae 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -81,7 +81,7 @@ func (m *migrator) init() error { if err := goose.SetDialect(m.dialect); err != nil { return err } - if err := schema.SetDialect(m.dialect); err != nil { + if err := schema.Init(m.dialect); err != nil { return err } return nil diff --git a/foreign_key_definition.go b/foreign_key_definition.go index 4c6828b..d92596e 100644 --- a/foreign_key_definition.go +++ b/foreign_key_definition.go @@ -13,115 +13,97 @@ type ForeignKeyDefinition interface { // Name sets the name of the foreign key constraint. // This is optional and can be used to give a specific name to the foreign key. Name(name string) ForeignKeyDefinition - // NoActionOnDelete sets the foreign key to do nothing on delete. + // NoActionOnDelete set the foreign key to do nothing on delete. NoActionOnDelete() ForeignKeyDefinition - // NoActionOnUpdate sets the foreign key to do nothing on update. + // NoActionOnUpdate set the foreign key to do nothing on the update. NoActionOnUpdate() ForeignKeyDefinition - // NullOnDelete sets the foreign key to set the column to NULL on delete. + // NullOnDelete set the foreign key to set the column to NULL on delete. NullOnDelete() ForeignKeyDefinition - // NullOnUpdate sets the foreign key to set the column to NULL on update. + // NullOnUpdate set the foreign key to set the column to NULL on update. NullOnUpdate() ForeignKeyDefinition - // On sets the table that this foreign key references. + // On sets the table that these foreign key references. On(table string) ForeignKeyDefinition - // OnDelete sets the action to take when the referenced row is deleted. + // OnDelete set the action to take when the referenced row is deleted. OnDelete(action string) ForeignKeyDefinition - // OnUpdate sets the action to take when the referenced row is updated. + // OnUpdate set the action to take when the referenced row is updated. OnUpdate(action string) ForeignKeyDefinition - // References sets the column that this foreign key references in the other table. + // References set the column that this foreign key references in the other table. References(column string) ForeignKeyDefinition - // RestrictOnDelete sets the foreign key to restrict deletion of the referenced row. + // RestrictOnDelete set the foreign key to restrict deletion of the referenced row. RestrictOnDelete() ForeignKeyDefinition - // RestrictOnUpdate sets the foreign key to restrict updating of the referenced row. + // RestrictOnUpdate set the foreign key to restrict updating of the referenced row. RestrictOnUpdate() ForeignKeyDefinition } -var _ ForeignKeyDefinition = &foreignKeyDefinition{} - type foreignKeyDefinition struct { - tableName string - column string - constaintName string // name of the foreign key constraint - references string - on string - onDelete string - onUpdate string - deferrable *bool - initiallyImmediate *bool + *command } -func (fk *foreignKeyDefinition) CascadeOnDelete() ForeignKeyDefinition { - fk.onDelete = "CASCADE" - return fk +func (fd *foreignKeyDefinition) CascadeOnDelete() ForeignKeyDefinition { + return fd.OnDelete("CASCADE") } -func (fk *foreignKeyDefinition) CascadeOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "CASCADE" - return fk +func (fd *foreignKeyDefinition) CascadeOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("CASCADE") } -func (fk *foreignKeyDefinition) Deferrable(value ...bool) ForeignKeyDefinition { +func (fd *foreignKeyDefinition) Deferrable(value ...bool) ForeignKeyDefinition { val := optional(true, value...) - fk.deferrable = &val - return fk + fd.deferrable = &val + return fd } -func (fk *foreignKeyDefinition) InitiallyImmediate(value ...bool) ForeignKeyDefinition { +func (fd *foreignKeyDefinition) InitiallyImmediate(value ...bool) ForeignKeyDefinition { val := optional(true, value...) - fk.initiallyImmediate = &val - return fk + fd.initiallyImmediate = &val + return fd } -func (fk *foreignKeyDefinition) Name(name string) ForeignKeyDefinition { - fk.constaintName = name - return fk +func (fd *foreignKeyDefinition) Name(name string) ForeignKeyDefinition { + fd.index = name + return fd } -func (fk *foreignKeyDefinition) NoActionOnDelete() ForeignKeyDefinition { - fk.onDelete = "NO ACTION" - return fk +func (fd *foreignKeyDefinition) NoActionOnDelete() ForeignKeyDefinition { + return fd.OnDelete("NO ACTION") } -func (fk *foreignKeyDefinition) NoActionOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "NO ACTION" - return fk +func (fd *foreignKeyDefinition) NoActionOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("NO ACTION") } -func (fk *foreignKeyDefinition) NullOnDelete() ForeignKeyDefinition { - fk.onDelete = "SET NULL" - return fk +func (fd *foreignKeyDefinition) NullOnDelete() ForeignKeyDefinition { + return fd.OnDelete("SET NULL") } -func (fk *foreignKeyDefinition) NullOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "SET NULL" - return fk +func (fd *foreignKeyDefinition) NullOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("SET NULL") } -func (fk *foreignKeyDefinition) On(table string) ForeignKeyDefinition { - fk.on = table - return fk +func (fd *foreignKeyDefinition) On(table string) ForeignKeyDefinition { + fd.on = table + return fd } -func (fk *foreignKeyDefinition) OnDelete(action string) ForeignKeyDefinition { - fk.onDelete = action - return fk +func (fd *foreignKeyDefinition) OnDelete(action string) ForeignKeyDefinition { + fd.onDelete = action + return fd } -func (fk *foreignKeyDefinition) OnUpdate(action string) ForeignKeyDefinition { - fk.onUpdate = action - return fk +func (fd *foreignKeyDefinition) OnUpdate(action string) ForeignKeyDefinition { + fd.onUpdate = action + return fd } -func (fk *foreignKeyDefinition) References(column string) ForeignKeyDefinition { - fk.references = column - return fk +func (fd *foreignKeyDefinition) References(columns string) ForeignKeyDefinition { + fd.references = []string{columns} + return fd } -func (fk *foreignKeyDefinition) RestrictOnDelete() ForeignKeyDefinition { - fk.onDelete = "RESTRICT" - return fk +func (fd *foreignKeyDefinition) RestrictOnDelete() ForeignKeyDefinition { + return fd.OnDelete("RESTRICT") } -func (fk *foreignKeyDefinition) RestrictOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "RESTRICT" - return fk +func (fd *foreignKeyDefinition) RestrictOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("RESTRICT") } diff --git a/grammar.go b/grammar.go index ccdce5e..12b70b1 100644 --- a/grammar.go +++ b/grammar.go @@ -10,50 +10,53 @@ type grammar interface { compileCreate(bp *Blueprint) (string, error) compileCreateIfNotExists(bp *Blueprint) (string, error) compileAdd(bp *Blueprint) (string, error) - compileChange(bp *Blueprint) ([]string, error) + compileChange(bp *Blueprint, command *command) (string, error) compileDrop(bp *Blueprint) (string, error) compileDropIfExists(bp *Blueprint) (string, error) - compileRename(bp *Blueprint) (string, error) - compileDropColumn(blueprint *Blueprint) (string, error) - compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) - compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) - compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) - compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) - compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) - compileDropIndex(blueprint *Blueprint, indexName string) (string, error) - compileDropUnique(blueprint *Blueprint, indexName string) (string, error) - compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) - compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) - compileRenameIndex(blueprint *Blueprint, oldName, newName string) (string, error) - compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) - compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) + compileRename(bp *Blueprint, command *command) (string, error) + compileDropColumn(blueprint *Blueprint, command *command) (string, error) + compileRenameColumn(blueprint *Blueprint, command *command) (string, error) + compileIndex(blueprint *Blueprint, command *command) (string, error) + compileUnique(blueprint *Blueprint, command *command) (string, error) + compilePrimary(blueprint *Blueprint, command *command) (string, error) + compileFullText(blueprint *Blueprint, command *command) (string, error) + compileDropIndex(blueprint *Blueprint, command *command) (string, error) + compileDropUnique(blueprint *Blueprint, command *command) (string, error) + compileDropFulltext(blueprint *Blueprint, command *command) (string, error) + compileDropPrimary(blueprint *Blueprint, command *command) (string, error) + compileRenameIndex(blueprint *Blueprint, command *command) (string, error) + compileForeign(blueprint *Blueprint, command *command) (string, error) + compileDropForeign(blueprint *Blueprint, command *command) (string, error) + getFluentCommands() []func(blueprint *Blueprint, command *command) string + createIndexName(blueprint *Blueprint, idxType string, columns ...string) string } type baseGrammar struct{} -func (g *baseGrammar) compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) { - if foreignKey.column == "" || foreignKey.on == "" || foreignKey.references == "" { +func (g *baseGrammar) compileForeign(blueprint *Blueprint, command *command) (string, error) { + if len(command.columns) == 0 || slices.Contains(command.columns, "") || command.on == "" || + len(command.references) == 0 || slices.Contains(command.references, "") { return "", fmt.Errorf("foreign key definition is incomplete: column, on, and references must be set") } onDelete := "" - if foreignKey.onDelete != "" { - onDelete = fmt.Sprintf(" ON DELETE %s", foreignKey.onDelete) + if command.onDelete != "" { + onDelete = fmt.Sprintf(" ON DELETE %s", command.onDelete) } onUpdate := "" - if foreignKey.onUpdate != "" { - onUpdate = fmt.Sprintf(" ON UPDATE %s", foreignKey.onUpdate) + if command.onUpdate != "" { + onUpdate = fmt.Sprintf(" ON UPDATE %s", command.onUpdate) } - containtName := foreignKey.constaintName - if containtName == "" { - containtName = g.createForeignKeyName(blueprint, foreignKey) + index := command.index + if index == "" { + index = g.createForeignKeyName(blueprint, command) } return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)%s%s", blueprint.name, - containtName, - foreignKey.column, - foreignKey.on, - foreignKey.references, + index, + command.columns[0], + command.on, + command.references[0], onDelete, onUpdate, ), nil @@ -78,58 +81,57 @@ func (g *baseGrammar) columnize(columns []string) string { return strings.Join(columns, ", ") } -func (g *baseGrammar) getDefaultValue(col *columnDefinition) string { - if col.defaultValue == nil { - return "NULL" +func (g *baseGrammar) getValue(value any) string { + switch v := value.(type) { + case Expression: + return v.String() + default: + return fmt.Sprintf("'%v'", v) } - useQuote := slices.Contains([]columnType{columnTypeString, columnTypeChar, columnTypeEnum}, col.columnType) +} - switch v := col.defaultValue.(type) { - case string: - if useQuote { - return g.quoteString(v) - } - return v - case int, int64, float64: - return fmt.Sprintf("%v", v) +func (g *baseGrammar) getDefaultValue(value any) string { + if value == nil { + return "NULL" + } + switch v := value.(type) { + case Expression: + return v.String() case bool: - if v { - return "true" - } - return "false" + return ternary(v, "'1'", "'0'") default: - return fmt.Sprintf("'%v'", v) // Fallback for other types + return fmt.Sprintf("'%v'", v) } } -func (g *baseGrammar) createIndexName(blueprint *Blueprint, index *indexDefinition) string { +func (g *baseGrammar) createIndexName(blueprint *Blueprint, idxType string, columns ...string) string { tableName := blueprint.name if strings.Contains(tableName, ".") { parts := strings.Split(tableName, ".") tableName = parts[len(parts)-1] // Use the last part as the table name } - switch index.indexType { - case indexTypePrimary: + switch idxType { + case commandPrimary: return fmt.Sprintf("pk_%s", tableName) - case indexTypeUnique: - return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(index.columns, "_")) - case indexTypeIndex: - return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(index.columns, "_")) - case indexTypeFullText: - return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(index.columns, "_")) + case commandUnique: + return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(columns, "_")) + case commandIndex: + return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_")) + case commandFullText: + return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(columns, "_")) default: return "" } } -func (g *baseGrammar) createForeignKeyName(blueprint *Blueprint, foreignKey *foreignKeyDefinition) string { +func (g *baseGrammar) createForeignKeyName(blueprint *Blueprint, command *command) string { tableName := blueprint.name if strings.Contains(tableName, ".") { parts := strings.Split(tableName, ".") tableName = parts[len(parts)-1] // Use the last part as the table name } - on := foreignKey.on + on := command.on if strings.Contains(on, ".") { parts := strings.Split(on, ".") on = parts[len(parts)-1] // Use the last part as the column name diff --git a/index_definition.go b/index_definition.go index 2798584..6611920 100644 --- a/index_definition.go +++ b/index_definition.go @@ -15,16 +15,8 @@ type IndexDefinition interface { Name(name string) IndexDefinition } -var _ IndexDefinition = &indexDefinition{} - type indexDefinition struct { - name string - indexType indexType - algorithm string - columns []string - language string - deferrable *bool - initiallyImmediate *bool + *command } func (id *indexDefinition) Algorithm(algorithm string) IndexDefinition { @@ -50,6 +42,6 @@ func (id *indexDefinition) Language(language string) IndexDefinition { } func (id *indexDefinition) Name(name string) IndexDefinition { - id.name = name + id.index = name return id } diff --git a/mysql_builder.go b/mysql_builder.go index 653b58c..bb2dbfb 100644 --- a/mysql_builder.go +++ b/mysql_builder.go @@ -12,11 +12,12 @@ type mysqlBuilder struct { grammar *mysqlGrammar } -func newMysqlBuilder() Builder { +func newMysqlBuilder(options ...Option) Builder { grammar := newMysqlGrammar() + cfg := applyOptions(options...) return &mysqlBuilder{ - baseBuilder: baseBuilder{grammar: grammar}, + baseBuilder: baseBuilder{grammar: grammar, debug: cfg.debug}, grammar: grammar, } } @@ -27,7 +28,7 @@ func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (stri } query := b.grammar.compileCurrentDatabase() - row := queryRowContext(ctx, tx, query) + row := tx.QueryRowContext(ctx, query) var dbName string if err := row.Scan(&dbName); err != nil { return "", err @@ -48,16 +49,15 @@ func (b *mysqlBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name s return nil // Table already exists, no need to create it } - bp := &Blueprint{name: name} + bp := &Blueprint{name: name, grammar: b.grammar} bp.createIfNotExists() blueprint(bp) - statements, err := bp.toSql(b.grammar) - if err != nil { + if err := bp.build(ctx, tx); err != nil { return err } - return execContext(ctx, tx, statements...) + return nil } func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { @@ -75,7 +75,7 @@ func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName str return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := tx.QueryContext(ctx, query) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName str return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := tx.QueryContext(ctx, query) if err != nil { return nil, err } @@ -145,7 +145,7 @@ func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := tx.QueryContext(ctx, query) if err != nil { return nil, err } @@ -252,7 +252,7 @@ func (b *mysqlBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (b return false, err } - row := queryRowContext(ctx, tx, query) + row := tx.QueryRowContext(ctx, query) var exists bool if err := row.Scan(&exists); err != nil { if errors.Is(err, sql.ErrNoRows) { diff --git a/mysql_builder_test.go b/mysql_builder_test.go index 39f50c6..f5c1c99 100644 --- a/mysql_builder_test.go +++ b/mysql_builder_test.go @@ -17,8 +17,9 @@ func TestMysqlBuilderSuite(t *testing.T) { type mysqlBuilderSuite struct { suite.Suite - ctx context.Context - db *sql.DB + ctx context.Context + db *sql.DB + builder schema.Builder } func (s *mysqlBuilderSuite) SetupSuite() { @@ -42,7 +43,8 @@ func (s *mysqlBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - schema.SetDebug(false) + s.builder, err = schema.NewBuilder("mysql", schema.WithDebug()) + s.Require().NoError(err) } func (s *mysqlBuilderSuite) TearDownSuite() { @@ -50,7 +52,7 @@ func (s *mysqlBuilderSuite) TearDownSuite() { } func (s *mysqlBuilderSuite) AfterTest(_, _ string) { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) tables, err := builder.GetTables(s.ctx, tx) @@ -66,7 +68,7 @@ func (s *mysqlBuilderSuite) AfterTest(_, _ string) { } func (s *mysqlBuilderSuite) TestCreate() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -112,7 +114,7 @@ func (s *mysqlBuilderSuite) TestCreate() { table.UnsignedBigInteger("user_id") table.String("order_id", 255).Unique() table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") }) @@ -123,7 +125,7 @@ func (s *mysqlBuilderSuite) TestCreate() { table.ID() table.String("order_id", 255).Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Index("created_at").Name("idx_orders_created_at").Algorithm("BTREE") }) @@ -140,7 +142,7 @@ func (s *mysqlBuilderSuite) TestCreate() { } func (s *mysqlBuilderSuite) TestCreateIfNotExists() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -161,7 +163,7 @@ func (s *mysqlBuilderSuite) TestCreateIfNotExists() { err := builder.CreateIfNotExists(s.ctx, tx, "test_table", nil) s.Error(err) }) - s.Run("when all parameters are valid, should create table successfully", func() { + s.Run("when all parameters are valid, should error", func() { err = builder.CreateIfNotExists(context.Background(), tx, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) @@ -170,20 +172,12 @@ func (s *mysqlBuilderSuite) TestCreateIfNotExists() { table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") }) - s.NoError(err, "expected no error when creating table with valid parameters") - }) - s.Run("when table already exists, should not return error", func() { - err = builder.CreateIfNotExists(context.Background(), tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - }) - s.NoError(err, "expected no error when creating table that already exists with CreateIfNotExists") + s.Error(err) }) } func (s *mysqlBuilderSuite) TestDrop() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -202,8 +196,7 @@ func (s *mysqlBuilderSuite) TestDrop() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") err = builder.Drop(context.Background(), tx, "users") @@ -216,7 +209,7 @@ func (s *mysqlBuilderSuite) TestDrop() { } func (s *mysqlBuilderSuite) TestDropIfExists() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -235,8 +228,7 @@ func (s *mysqlBuilderSuite) TestDropIfExists() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") err = builder.DropIfExists(context.Background(), tx, "users") @@ -249,7 +241,7 @@ func (s *mysqlBuilderSuite) TestDropIfExists() { } func (s *mysqlBuilderSuite) TestRename() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -278,7 +270,7 @@ func (s *mysqlBuilderSuite) TestRename() { } func (s *mysqlBuilderSuite) TestTable() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -306,8 +298,7 @@ func (s *mysqlBuilderSuite) TestTable() { table.String("email", 255).Unique("uk_users_email") table.String("password", 255).Nullable() table.Text("bio").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.FullText("bio") }) @@ -392,7 +383,7 @@ func (s *mysqlBuilderSuite) TestTable() { } func (s *mysqlBuilderSuite) TestGetColumns() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -418,8 +409,7 @@ func (s *mysqlBuilderSuite) TestGetColumns() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting columns") @@ -431,7 +421,7 @@ func (s *mysqlBuilderSuite) TestGetColumns() { } func (s *mysqlBuilderSuite) TestGetIndexes() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -450,8 +440,7 @@ func (s *mysqlBuilderSuite) TestGetIndexes() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.Index("name").Name("idx_users_name") }) @@ -470,7 +459,7 @@ func (s *mysqlBuilderSuite) TestGetIndexes() { } func (s *mysqlBuilderSuite) TestGetTables() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -486,8 +475,7 @@ func (s *mysqlBuilderSuite) TestGetTables() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting tables") @@ -506,7 +494,7 @@ func (s *mysqlBuilderSuite) TestGetTables() { } func (s *mysqlBuilderSuite) TestHasColumn() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -532,8 +520,7 @@ func (s *mysqlBuilderSuite) TestHasColumn() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking for column existence") @@ -548,7 +535,7 @@ func (s *mysqlBuilderSuite) TestHasColumn() { } func (s *mysqlBuilderSuite) TestHasColumns() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -574,8 +561,7 @@ func (s *mysqlBuilderSuite) TestHasColumns() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking for columns existence") @@ -590,7 +576,7 @@ func (s *mysqlBuilderSuite) TestHasColumns() { } func (s *mysqlBuilderSuite) TestHasIndex() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -612,7 +598,7 @@ func (s *mysqlBuilderSuite) TestHasIndex() { table.Integer("user_id") table.String("order_id", 255) table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.Index("company_id", "user_id") table.Unique("order_id").Name("uk_orders3_order_id").Algorithm("BTREE") @@ -634,7 +620,7 @@ func (s *mysqlBuilderSuite) TestHasIndex() { } func (s *mysqlBuilderSuite) TestHasTable() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -655,8 +641,7 @@ func (s *mysqlBuilderSuite) TestHasTable() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking if it exists") diff --git a/mysql_grammar.go b/mysql_grammar.go index 7437bca..65c7c92 100644 --- a/mysql_grammar.go +++ b/mysql_grammar.go @@ -1,6 +1,7 @@ package schema import ( + "errors" "fmt" "slices" "strings" @@ -8,12 +9,19 @@ import ( type mysqlGrammar struct { baseGrammar + + serials []string } var _ grammar = (*mysqlGrammar)(nil) func newMysqlGrammar() *mysqlGrammar { - return &mysqlGrammar{} + return &mysqlGrammar{ + serials: []string{ + columnTypeBigInteger, columnTypeInteger, columnTypeMediumInteger, columnTypeSmallInteger, + columnTypeTinyInteger, + }, + } } func (g *mysqlGrammar) compileCurrentDatabase() string { @@ -102,7 +110,7 @@ func (g *mysqlGrammar) compileCreateEngine(sql string, blueprint *Blueprint) str } func (g *mysqlGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, error) { - return g.compileCreate(blueprint) + return "", errors.New("MySQL does not support CREATE TABLE IF NOT EXISTS with custom options") } func (g *mysqlGrammar) compileAdd(blueprint *Blueprint) (string, error) { @@ -125,51 +133,22 @@ func (g *mysqlGrammar) compileAdd(blueprint *Blueprint) (string, error) { ), nil } -func (g *mysqlGrammar) compileChange(bp *Blueprint) ([]string, error) { - if len(bp.getChangedColumns()) == 0 { - return nil, nil +func (g *mysqlGrammar) compileChange(bp *Blueprint, command *command) (string, error) { + column := command.column + if column.name == "" { + return "", fmt.Errorf("column name cannot be empty for change operation") } - var sqls []string - for _, col := range bp.getChangedColumns() { - if col.name == "" { - return nil, fmt.Errorf("column name cannot be empty for change operation") - } - sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s %s", bp.name, col.name, g.getType(col)) - - if col.hasCommand("nullable") { - if col.nullable { - sql += " NULL" - } else { - sql += " NOT NULL" - } - } - if col.hasCommand("default") { - if col.defaultValue != nil { - sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col)) - } else { - sql += " DEFAULT NULL" - } - } - if col.hasCommand("comment") { - if col.comment != "" { - sql += fmt.Sprintf(" COMMENT '%s'", col.comment) - } else { - sql += " COMMENT ''" - } - } - - sqls = append(sqls, sql) + sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s %s", bp.name, column.name, g.getType(column)) + for _, modifier := range g.modifiers() { + sql += modifier(column) } - return sqls, nil + return sql, nil } -func (g *mysqlGrammar) compileRename(blueprint *Blueprint) (string, error) { - if blueprint.newName == "" { - return "", fmt.Errorf("new table name cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, blueprint.newName), nil +func (g *mysqlGrammar) compileRename(blueprint *Blueprint, command *command) (string, error) { + return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil } func (g *mysqlGrammar) compileDrop(blueprint *Blueprint) (string, error) { @@ -186,12 +165,12 @@ func (g *mysqlGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil } -func (g *mysqlGrammar) compileDropColumn(blueprint *Blueprint) (string, error) { - if len(blueprint.dropColumns) == 0 { +func (g *mysqlGrammar) compileDropColumn(blueprint *Blueprint, command *command) (string, error) { + if len(command.columns) == 0 { return "", fmt.Errorf("no columns to drop") } - columns := make([]string, len(blueprint.dropColumns)) - for i, col := range blueprint.dropColumns { + columns := make([]string, len(command.columns)) + for i, col := range command.columns { if col == "" { return "", fmt.Errorf("column name cannot be empty") } @@ -201,111 +180,112 @@ func (g *mysqlGrammar) compileDropColumn(blueprint *Blueprint) (string, error) { return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil } -func (g *mysqlGrammar) compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { +func (g *mysqlGrammar) compileRenameColumn(blueprint *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { return "", fmt.Errorf("old and new column names cannot be empty") } - return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, oldName, newName), nil + return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil } -func (g *mysqlGrammar) compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { +func (g *mysqlGrammar) compileIndex(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { return "", fmt.Errorf("index column cannot be empty") } - indexName := index.name + indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, index) + indexName = g.createIndexName(blueprint, commandIndex, command.columns...) } - sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns)) - if index.algorithm != "" { - sql += fmt.Sprintf(" USING %s", index.algorithm) + sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(command.columns)) + if command.algorithm != "" { + sql += fmt.Sprintf(" USING %s", command.algorithm) } return sql, nil } -func (g *mysqlGrammar) compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { +func (g *mysqlGrammar) compileUnique(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { return "", fmt.Errorf("unique column cannot be empty") } - indexName := index.name + indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, index) + indexName = g.createIndexName(blueprint, commandUnique, command.columns...) } - sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns)) - if index.algorithm != "" { - sql += fmt.Sprintf(" USING %s", index.algorithm) + sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(command.columns)) + if command.algorithm != "" { + sql += fmt.Sprintf(" USING %s", command.algorithm) } return sql, nil } -func (g *mysqlGrammar) compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { +func (g *mysqlGrammar) compileFullText(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { return "", fmt.Errorf("fulltext index column cannot be empty") } - indexName := index.name + indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, index) + indexName = g.createIndexName(blueprint, commandFullText, command.columns...) } - return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns)), nil + return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(command.columns)), nil } -func (g *mysqlGrammar) compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { +func (g *mysqlGrammar) compilePrimary(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { return "", fmt.Errorf("primary key column cannot be empty") } - indexName := index.name + indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, index) + indexName = g.createIndexName(blueprint, commandPrimary) } - return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(index.columns)), nil + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(command.columns)), nil } -func (g *mysqlGrammar) compileDropIndex(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { +func (g *mysqlGrammar) compileDropIndex(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { return "", fmt.Errorf("index name cannot be empty") } - return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, indexName), nil + return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil } -func (g *mysqlGrammar) compileDropUnique(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { +func (g *mysqlGrammar) compileDropUnique(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { return "", fmt.Errorf("unique index name cannot be empty") } - return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, indexName), nil + return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil } -func (g *mysqlGrammar) compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) { - return g.compileDropIndex(blueprint, indexName) +func (g *mysqlGrammar) compileDropFulltext(blueprint *Blueprint, command *command) (string, error) { + return g.compileDropIndex(blueprint, command) } -func (g *mysqlGrammar) compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { - return "", fmt.Errorf("primary key index name cannot be empty") - } +func (g *mysqlGrammar) compileDropPrimary(blueprint *Blueprint, _ *command) (string, error) { return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", blueprint.name), nil } -func (g *mysqlGrammar) compileRenameIndex(blueprint *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { +func (g *mysqlGrammar) compileRenameIndex(blueprint *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { return "", fmt.Errorf("old and new index names cannot be empty") } - return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s", blueprint.name, oldName, newName), nil + return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s", blueprint.name, command.from, command.to), nil } -func (g *mysqlGrammar) compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) { - if foreignKeyName == "" { +func (g *mysqlGrammar) compileDropForeign(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { return "", fmt.Errorf("foreign key name cannot be empty") } - return fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s", blueprint.name, foreignKeyName), nil + return fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s", blueprint.name, command.index), nil +} + +func (g *mysqlGrammar) getFluentCommands() []func(*Blueprint, *command) string { + return []func(*Blueprint, *command) string{} } func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) { @@ -315,24 +295,20 @@ func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) { return nil, fmt.Errorf("column name cannot be empty") } sql := col.name + " " + g.getType(col) - if col.hasCommand("default") { - if col.defaultValue != nil { - sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col)) - } else { - sql += " DEFAULT NULL" - } - } - if col.onUpdateValue != "" { - sql += fmt.Sprintf(" ON UPDATE %s", col.onUpdateValue) - } - if col.nullable { - sql += " NULL" - } else { - sql += " NOT NULL" + sql += g.modifyUnsigned(col) + sql += g.modifyIncrement(col) + + if col.defaultValue != nil { + sql += g.modifyDefault(col) } - if col.comment != "" { - sql += fmt.Sprintf(" COMMENT '%s'", col.comment) + if col.onUpdateValue != nil { + sql += g.modifyOnUpdate(col) } + sql += g.modifyCharset(col) + sql += g.modifyCollate(col) + sql += g.modifyNullable(col) + sql += g.modifyComment(col) + columns = append(columns, sql) } @@ -342,100 +318,301 @@ func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) { func (g *mysqlGrammar) getConstraints(blueprint *Blueprint) []string { var constrains []string for _, col := range blueprint.getAddedColumns() { - if col.primary { - pkConstraintName := g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary}) + if col.primary != nil && *col.primary { + pkConstraintName := g.createIndexName(blueprint, commandPrimary) sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" constrains = append(constrains, sql) continue } - if col.unique { - uniqueName := col.uniqueName - if uniqueName == "" { - uniqueName = g.createIndexName(blueprint, &indexDefinition{ - indexType: indexTypeUnique, - columns: []string{col.name}, - }) - } - sql := fmt.Sprintf("CONSTRAINT %s UNIQUE (%s)", uniqueName, col.name) - constrains = append(constrains, sql) - } } return constrains } func (g *mysqlGrammar) getType(col *columnDefinition) string { - switch col.columnType { - case columnTypeCustom: - return col.customColumnType - case columnTypeBoolean: - return "TINYINT(1)" - case columnTypeChar: - return fmt.Sprintf("CHAR(%d)", col.length) - case columnTypeString: - return fmt.Sprintf("VARCHAR(%d)", col.length) - case columnTypeDecimal: - return fmt.Sprintf("DECIMAL(%d, %d)", col.total, col.places) - case columnTypeDouble, columnTypeFloat: - if col.total > 0 && col.places > 0 { - return fmt.Sprintf("DOUBLE(%d, %d)", col.total, col.places) - } - return "DOUBLE" - case columnTypeBigInteger: - return g.modifyUnsignedAndAutoIncrement("BIGINT", col) - case columnTypeInteger: - return g.modifyUnsignedAndAutoIncrement("INT", col) - case columnTypeSmallInteger: - return g.modifyUnsignedAndAutoIncrement("SMALLINT", col) - case columnTypeMediumInteger: - return g.modifyUnsignedAndAutoIncrement("MEDIUMINT", col) - case columnTypeTinyInteger: - return g.modifyUnsignedAndAutoIncrement("TINYINT", col) - case columnTypeTime: - return fmt.Sprintf("TIME(%d)", col.precision) - case columnTypeDateTime, columnTypeDateTimeTz: - if col.precision > 0 { - return fmt.Sprintf("DATETIME(%d)", col.precision) - } - return "DATETIME" - case columnTypeTimestamp, columnTypeTimestampTz: - if col.precision > 0 { - return fmt.Sprintf("TIMESTAMP(%d)", col.precision) - } - return "TIMESTAMP" - case columnTypeGeography: - return fmt.Sprintf("GEOGRAPHY(%s, %d)", col.subType, col.srid) - case columnTypeEnum: - return fmt.Sprintf("ENUM(%s)", g.quoteString(strings.Join(col.allowedEnums, "','"))) - default: - colType, ok := map[columnType]string{ - columnTypeBoolean: "BOOLEAN", - columnTypeLongText: "LONGTEXT", - columnTypeText: "TEXT", - columnTypeMediumText: "MEDIUMTEXT", - columnTypeTinyText: "TINYTEXT", - columnTypeDate: "DATE", - columnTypeYear: "YEAR", - columnTypeJSON: "JSON", - columnTypeJSONB: "JSON", - columnTypeUUID: "UUID", - columnTypeBinary: "BLOB", - columnTypeGeometry: "GEOMETRY", - columnTypePoint: "POINT", - }[col.columnType] - if !ok { - return "ERROR: Unknown column type " + string(col.columnType) + typeFuncMap := map[string]func(*columnDefinition) string{ + columnTypeChar: g.typeChar, + columnTypeString: g.typeString, + columnTypeTinyText: g.typeTinyText, + columnTypeText: g.typeText, + columnTypeMediumText: g.typeMediumText, + columnTypeLongText: g.typeLongText, + columnTypeInteger: g.typeInteger, + columnTypeBigInteger: g.typeBigInteger, + columnTypeMediumInteger: g.typeMediumInteger, + columnTypeSmallInteger: g.typeSmallInteger, + columnTypeTinyInteger: g.typeTinyInteger, + columnTypeFloat: g.typeFloat, + columnTypeDouble: g.typeDouble, + columnTypeDecimal: g.typeDecimal, + columnTypeBoolean: g.typeBoolean, + columnTypeEnum: g.typeEnum, + columnTypeJSON: g.typeJson, + columnTypeJSONB: g.typeJsonb, + columnTypeDate: g.typeDate, + columnTypeDateTime: g.typeDateTime, + columnTypeDateTimeTz: g.typeDateTimeTz, + columnTypeTime: g.typeTime, + columnTypeTimeTz: g.typeTimeTz, + columnTypeTimestamp: g.typeTimestamp, + columnTypeTimestampTz: g.typeTimestampTz, + columnTypeYear: g.typeYear, + columnTypeBinary: g.typeBinary, + columnTypeUUID: g.typeUUID, + columnTypeGeography: g.typeGeography, + columnTypeGeometry: g.typeGeometry, + columnTypePoint: g.typePoint, + } + if fn, ok := typeFuncMap[col.columnType]; ok { + return fn(col) + } + return col.columnType +} + +func (g *mysqlGrammar) typeChar(col *columnDefinition) string { + return fmt.Sprintf("CHAR(%d)", *col.length) +} + +func (g *mysqlGrammar) typeString(col *columnDefinition) string { + return fmt.Sprintf("VARCHAR(%d)", *col.length) +} + +func (g *mysqlGrammar) typeTinyText(col *columnDefinition) string { + return "TINYTEXT" +} + +func (g *mysqlGrammar) typeText(col *columnDefinition) string { + return "TEXT" +} + +func (g *mysqlGrammar) typeMediumText(col *columnDefinition) string { + return "MEDIUMTEXT" +} + +func (g *mysqlGrammar) typeLongText(col *columnDefinition) string { + return "LONGTEXT" +} + +func (g *mysqlGrammar) typeBigInteger(col *columnDefinition) string { + return "BIGINT" +} + +func (g *mysqlGrammar) typeInteger(col *columnDefinition) string { + return "INT" +} + +func (g *mysqlGrammar) typeMediumInteger(col *columnDefinition) string { + return "MEDIUMINT" +} + +func (g *mysqlGrammar) typeSmallInteger(col *columnDefinition) string { + return "SMALLINT" +} + +func (g *mysqlGrammar) typeTinyInteger(col *columnDefinition) string { + return "TINYINT" +} + +func (g *mysqlGrammar) typeFloat(col *columnDefinition) string { + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("FLOAT(%d)", *col.precision) + } + return "FLOAT" +} + +func (g *mysqlGrammar) typeDouble(col *columnDefinition) string { + return "DOUBLE" +} + +func (g *mysqlGrammar) typeDecimal(col *columnDefinition) string { + return fmt.Sprintf("DECIMAL(%d, %d)", *col.total, *col.places) +} + +func (g *mysqlGrammar) typeBoolean(col *columnDefinition) string { + return "TINYINT(1)" +} + +func (g *mysqlGrammar) typeEnum(col *columnDefinition) string { + allowedValues := make([]string, len(col.allowed)) + for i, e := range col.allowed { + allowedValues[i] = g.quoteString(e) + } + return fmt.Sprintf("ENUM(%s)", strings.Join(allowedValues, ", ")) +} + +func (g *mysqlGrammar) typeJson(col *columnDefinition) string { + return "JSON" +} + +func (g *mysqlGrammar) typeJsonb(col *columnDefinition) string { + return "JSON" +} + +func (g *mysqlGrammar) typeDate(col *columnDefinition) string { + return "DATE" +} + +func (g *mysqlGrammar) typeDateTime(col *columnDefinition) string { + current := "CURRENT_TIMESTAMP" + if col.precision != nil && *col.precision > 0 { + current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision) + } + if col.useCurrent != nil && *col.useCurrent { + col.Default(Expression(current)) + } + if col.useCurrentOnUpdate != nil && *col.useCurrentOnUpdate { + col.OnUpdate(Expression(current)) + } + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("DATETIME(%d)", *col.precision) + } + return "DATETIME" +} + +func (g *mysqlGrammar) typeDateTimeTz(col *columnDefinition) string { + return g.typeDateTime(col) +} + +func (g *mysqlGrammar) typeTime(col *columnDefinition) string { + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("TIME(%d)", *col.precision) + } + return "TIME" +} + +func (g *mysqlGrammar) typeTimeTz(col *columnDefinition) string { + return g.typeTime(col) +} + +func (g *mysqlGrammar) typeTimestamp(col *columnDefinition) string { + current := "CURRENT_TIMESTAMP" + if col.precision != nil && *col.precision > 0 { + current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision) + } + if col.useCurrent != nil && *col.useCurrent { + col.Default(Expression(current)) + } + if col.useCurrentOnUpdate != nil && *col.useCurrentOnUpdate { + col.OnUpdate(Expression(current)) + } + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("TIMESTAMP(%d)", *col.precision) + } + return "TIMESTAMP" +} + +func (g *mysqlGrammar) typeTimestampTz(col *columnDefinition) string { + return g.typeTimestamp(col) +} + +func (g *mysqlGrammar) typeYear(col *columnDefinition) string { + return "YEAR" +} + +func (g *mysqlGrammar) typeBinary(col *columnDefinition) string { + if col.length != nil && *col.length > 0 { + return fmt.Sprintf("BINARY(%d)", *col.length) + } + return "BLOB" +} + +func (g *mysqlGrammar) typeUUID(col *columnDefinition) string { + return "CHAR(36)" // Default UUID length +} + +func (g *mysqlGrammar) typeGeometry(col *columnDefinition) string { + subtype := ternary(col.subtype != nil, ptrOf(strings.ToUpper(*col.subtype)), nil) + if subtype != nil { + if !slices.Contains([]string{"POINT", "LINESTRING", "POLYGON", "GEOMETRYCOLLECTION", "MULTIPOINT", "MULTILINESTRING"}, *subtype) { + subtype = nil } - return colType + } + + if subtype == nil { + subtype = ptrOf("GEOMETRY") + } + if col.srid != nil && *col.srid > 0 { + return fmt.Sprintf("%s SRID %d", *subtype, *col.srid) + } + return *subtype +} + +func (g *mysqlGrammar) typeGeography(col *columnDefinition) string { + return g.typeGeometry(col) +} + +func (g *mysqlGrammar) typePoint(col *columnDefinition) string { + col.subtype = ptrOf("POINT") + return g.typeGeometry(col) +} + +func (g *mysqlGrammar) modifiers() []func(*columnDefinition) string { + return []func(*columnDefinition) string{ + g.modifyUnsigned, + g.modifyCharset, + g.modifyCollate, + g.modifyNullable, + g.modifyDefault, + g.modifyOnUpdate, + g.modifyIncrement, + g.modifyComment, } } -func (g *mysqlGrammar) modifyUnsignedAndAutoIncrement(sql string, col *columnDefinition) string { - if col.unsigned { - sql += " UNSIGNED" +func (g *mysqlGrammar) modifyCharset(col *columnDefinition) string { + if col.charset != nil && *col.charset != "" { + return fmt.Sprintf(" CHARACTER SET %s", *col.charset) } - if col.autoIncrement { - sql += " AUTO_INCREMENT" + return "" +} + +func (g *mysqlGrammar) modifyCollate(col *columnDefinition) string { + if col.collation != nil && *col.collation != "" { + return fmt.Sprintf(" COLLATE %s", *col.collation) } - return sql + return "" +} + +func (g *mysqlGrammar) modifyComment(col *columnDefinition) string { + if col.comment != nil && *col.comment != "" { + return fmt.Sprintf(" COMMENT '%s'", *col.comment) + } + return "" +} + +func (g *mysqlGrammar) modifyDefault(col *columnDefinition) string { + if col.defaultValue != nil { + return fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col.defaultValue)) + } + return "" +} + +func (g *mysqlGrammar) modifyIncrement(col *columnDefinition) string { + if slices.Contains(g.serials, col.columnType) && + col.autoIncrement != nil && *col.autoIncrement && + col.primary != nil && *col.primary { + return " AUTO_INCREMENT" + } + return "" +} + +func (g *mysqlGrammar) modifyNullable(col *columnDefinition) string { + if col.nullable != nil && *col.nullable { + return " NULL" + } + return " NOT NULL" +} + +func (g *mysqlGrammar) modifyOnUpdate(col *columnDefinition) string { + if col.onUpdateValue != nil { + return fmt.Sprintf(" ON UPDATE %s", g.getValue(col.onUpdateValue)) + } + return "" +} + +func (g *mysqlGrammar) modifyUnsigned(col *columnDefinition) string { + if col.unsigned != nil && *col.unsigned { + return " UNSIGNED" + } + return "" } diff --git a/mysql_grammar_test.go b/mysql_grammar_test.go index 583c18e..57f9b21 100644 --- a/mysql_grammar_test.go +++ b/mysql_grammar_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMysqlGrammar_CompileCreate(t *testing.T) { @@ -206,7 +207,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.Integer("age").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN age INT"}, + want: []string{"ALTER TABLE users MODIFY COLUMN age INT NOT NULL"}, wantErr: false, }, { @@ -233,14 +234,14 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.String("status", 50).Default("active").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN status VARCHAR(50) DEFAULT 'active'"}, + want: []string{"ALTER TABLE users MODIFY COLUMN status VARCHAR(50) NOT NULL DEFAULT 'active'"}, wantErr: false, }, { name: "change column with null default", table: "users", blueprint: func(table *Blueprint) { - table.Text("description").Default(nil).Change() + table.Text("description").Nullable().Default(nil).Change() }, want: []string{"ALTER TABLE users MODIFY COLUMN description TEXT DEFAULT NULL"}, wantErr: false, @@ -301,15 +302,18 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} + bp := &Blueprint{name: tt.table, grammar: g, dialect: dialectMySQL} tt.blueprint(bp) - got, err := g.compileChange(bp) + statements, err := bp.toSql() if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return } assert.NoError(t, err, "Did not expect error for test case: %s", tt.name) - assert.Equal(t, tt.want, got, "Expected SQL to match for test case: %s", tt.name) + require.Len(t, statements, len(tt.want), "Expected number of SQL statements to match for test case: %s", tt.name) + for i, stmt := range statements { + assert.Equal(t, tt.want[i], stmt, "Expected SQL to match for test case: %s", tt.name) + } }) } } @@ -345,21 +349,15 @@ func TestMysqlGrammar_CompileRename(t *testing.T) { want: "ALTER TABLE table1 RENAME TO table2", wantErr: false, }, - { - name: "empty new name should return error", - table: "users", - newName: "", - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{ - name: tt.table, - newName: tt.newName, + name: tt.table, } - got, err := g.compileRename(bp) + bp.rename(tt.newName) + got, err := g.compileRename(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -506,12 +504,6 @@ func TestMysqlGrammar_CompileDropColumn(t *testing.T) { want: "ALTER TABLE users DROP COLUMN email, DROP COLUMN phone, DROP COLUMN address", wantErr: false, }, - { - name: "no columns to drop should return error", - table: "users", - blueprint: func(table *Blueprint) {}, - wantErr: true, - }, { name: "empty column name should return error", table: "users", @@ -536,7 +528,7 @@ func TestMysqlGrammar_CompileDropColumn(t *testing.T) { name: tt.table, } tt.blueprint(bp) - got, err := g.compileDropColumn(bp) + got, err := g.compileDropColumn(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -592,7 +584,8 @@ func TestMysqlGrammar_CompileRenameColumn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileRenameColumn(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := g.compileRenameColumn(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -714,7 +707,7 @@ func TestMysqlGrammar_CompileForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileForeign(bp, bp.foreignKeys[0]) + got, err := g.compileForeign(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -753,7 +746,8 @@ func TestMysqlGrammar_CompileDropForeign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropForeign(bp, tt.fkName) + command := &command{index: tt.fkName} + got, err := g.compileDropForeign(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -833,7 +827,7 @@ func TestMysqlGrammar_CompileIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileIndex(bp, bp.indexes[0]) + got, err := g.compileIndex(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -921,7 +915,7 @@ func TestMysqlGrammar_CompileUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileUnique(bp, bp.indexes[0]) + got, err := g.compileUnique(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1017,7 +1011,7 @@ func TestMysqlGrammar_CompilePrimary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compilePrimary(bp, bp.indexes[0]) + got, err := g.compilePrimary(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1113,7 +1107,7 @@ func TestMysqlGrammar_CompileFullText(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileFullText(bp, bp.indexes[0]) + got, err := g.compileFullText(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1152,7 +1146,8 @@ func TestMysqlGrammar_CompileDropIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropIndex(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.compileDropIndex(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1190,7 +1185,8 @@ func TestMysqlGrammar_CompileDropUnique(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropUnique(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.compileDropUnique(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1228,7 +1224,8 @@ func TestMysqlGrammar_CompileDropFulltext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropFulltext(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.compileDropFulltext(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1284,18 +1281,13 @@ func TestMysqlGrammar_CompileDropPrimary(t *testing.T) { want: "ALTER TABLE orders DROP PRIMARY KEY", wantErr: false, }, - { - name: "empty primary key index name should return error", - table: "users", - indexName: "", - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropPrimary(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.compileDropPrimary(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1383,7 +1375,8 @@ func TestMysqlGrammar_CompileRenameIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileRenameIndex(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := g.compileRenameIndex(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1438,32 +1431,25 @@ func TestMysqlGrammar_GetType(t *testing.T) { want: "DECIMAL(10, 2)", }, { - name: "double column type with precision", + name: "double column type", blueprint: func(table *Blueprint) { - table.Double("value", 8, 2) - }, - want: "DOUBLE(8, 2)", - }, - { - name: "double column type without precision", - blueprint: func(table *Blueprint) { - table.Double("value", 0, 0) + table.Double("value") }, want: "DOUBLE", }, { name: "float column type with precision", blueprint: func(table *Blueprint) { - table.Float("value", 6, 2) + table.Float("value", 6) }, - want: "DOUBLE(6, 2)", + want: "FLOAT(6)", }, { name: "float column type without precision", blueprint: func(table *Blueprint) { - table.Float("value", 0, 0) + table.Float("value") }, - want: "DOUBLE", + want: "FLOAT(53)", }, { name: "big integer column type", @@ -1472,27 +1458,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "BIGINT", }, - { - name: "big integer unsigned", - blueprint: func(table *Blueprint) { - table.BigInteger("id").Unsigned() - }, - want: "BIGINT UNSIGNED", - }, - { - name: "big integer auto increment", - blueprint: func(table *Blueprint) { - table.BigInteger("id").AutoIncrement() - }, - want: "BIGINT AUTO_INCREMENT", - }, - { - name: "big integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.BigInteger("id").Unsigned().AutoIncrement() - }, - want: "BIGINT UNSIGNED AUTO_INCREMENT", - }, { name: "integer column type", blueprint: func(table *Blueprint) { @@ -1500,27 +1465,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "INT", }, - { - name: "integer unsigned", - blueprint: func(table *Blueprint) { - table.Integer("count").Unsigned() - }, - want: "INT UNSIGNED", - }, - { - name: "integer auto increment", - blueprint: func(table *Blueprint) { - table.Integer("id").AutoIncrement() - }, - want: "INT AUTO_INCREMENT", - }, - { - name: "integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.Increments("id") - }, - want: "INT UNSIGNED AUTO_INCREMENT", - }, { name: "small integer column type", blueprint: func(table *Blueprint) { @@ -1528,13 +1472,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "SMALLINT", }, - { - name: "small integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.SmallInteger("id").Unsigned().AutoIncrement() - }, - want: "SMALLINT UNSIGNED AUTO_INCREMENT", - }, { name: "medium integer column type", blueprint: func(table *Blueprint) { @@ -1542,13 +1479,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "MEDIUMINT", }, - { - name: "medium integer unsigned", - blueprint: func(table *Blueprint) { - table.UnsignedMediumInteger("value") - }, - want: "MEDIUMINT UNSIGNED", - }, { name: "small integer column type", blueprint: func(table *Blueprint) { @@ -1556,13 +1486,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "SMALLINT", }, - { - name: "small integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.SmallIncrements("id") - }, - want: "SMALLINT UNSIGNED AUTO_INCREMENT", - }, { name: "tiny integer column type", blueprint: func(table *Blueprint) { @@ -1570,26 +1493,12 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "TINYINT", }, - { - name: "tiny integer auto increment", - blueprint: func(table *Blueprint) { - table.TinyInteger("id").AutoIncrement() - }, - want: "TINYINT AUTO_INCREMENT", - }, - { - name: "tiny integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.TinyIncrements("id") - }, - want: "TINYINT UNSIGNED AUTO_INCREMENT", - }, { name: "time column type", blueprint: func(table *Blueprint) { table.Time("created_at") }, - want: "TIME(0)", + want: "TIME", }, { name: "datetime column type with precision", @@ -1647,19 +1556,12 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "TIMESTAMP", }, - { - name: "geography column type", - blueprint: func(table *Blueprint) { - table.Geography("location", "POINT", 4326) - }, - want: "GEOGRAPHY(POINT, 4326)", - }, { name: "enum column type", blueprint: func(table *Blueprint) { table.Enum("status", []string{"active", "inactive", "pending"}) }, - want: "ENUM('active','inactive','pending')", + want: "ENUM('active', 'inactive', 'pending')", }, { name: "long text column type", @@ -1722,7 +1624,7 @@ func TestMysqlGrammar_GetType(t *testing.T) { blueprint: func(table *Blueprint) { table.UUID("uuid") }, - want: "UUID", + want: "CHAR(36)", }, { name: "binary column type", @@ -1731,19 +1633,26 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "BLOB", }, + { + name: "geography column type", + blueprint: func(table *Blueprint) { + table.Geography("location", "LINESTRING", 4326) + }, + want: "LINESTRING SRID 4326", + }, { name: "geometry column type", blueprint: func(table *Blueprint) { - table.Geometry("shape", "GEOMETRY", 4326) + table.Geometry("shape", "", 4326) }, - want: "GEOMETRY", + want: "GEOMETRY SRID 4326", }, { name: "point column type", blueprint: func(table *Blueprint) { table.Point("location") }, - want: "POINT", + want: "POINT SRID 4326", }, } diff --git a/options.go b/options.go new file mode 100644 index 0000000..e14fc19 --- /dev/null +++ b/options.go @@ -0,0 +1,26 @@ +package schema + +type config struct { + debug bool +} + +type Option func(*config) + +// WithDebug sets the debug mode for the schema package. +func WithDebug(debug ...bool) Option { + return func(c *config) { + c.debug = optional(true, debug...) + } +} + +func applyOptions(opts ...Option) *config { + cfg := &config{ + debug: false, // default value + } + + for _, opt := range opts { + opt(cfg) + } + + return cfg +} diff --git a/postgres_builder.go b/postgres_builder.go index f46b9fd..1aacc64 100644 --- a/postgres_builder.go +++ b/postgres_builder.go @@ -12,11 +12,12 @@ type postgresBuilder struct { grammar *pgGrammar } -func newPostgresBuilder() Builder { +func newPostgresBuilder(options ...Option) Builder { grammar := newPgGrammar() + cfg := applyOptions(options...) return &postgresBuilder{ - baseBuilder: baseBuilder{grammar: grammar}, + baseBuilder: baseBuilder{grammar: grammar, debug: cfg.debug}, grammar: grammar, } } @@ -43,7 +44,7 @@ func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := tx.QueryContext(ctx, query) if err != nil { return nil, err } @@ -76,7 +77,7 @@ func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := tx.QueryContext(ctx, query) if err != nil { return nil, err } @@ -106,7 +107,7 @@ func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableIn return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := tx.QueryContext(ctx, query) if err != nil { return nil, err } @@ -216,7 +217,7 @@ func (b *postgresBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) } var exists bool - if err := queryRowContext(ctx, tx, query).Scan(&exists); err != nil { + if err := tx.QueryRowContext(ctx, query).Scan(&exists); err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil // Table does not exist } diff --git a/postgres_builder_test.go b/postgres_builder_test.go index 334aea7..b110b25 100644 --- a/postgres_builder_test.go +++ b/postgres_builder_test.go @@ -16,12 +16,6 @@ func TestPostgresBuilderSuite(t *testing.T) { suite.Run(t, new(postgresBuilderSuite)) } -type postgresBuilderSuite struct { - suite.Suite - ctx context.Context - db *sql.DB -} - type dbConfig struct { Database string Username string @@ -43,6 +37,13 @@ func parseTestConfig() dbConfig { } } +type postgresBuilderSuite struct { + suite.Suite + ctx context.Context + db *sql.DB + builder schema.Builder +} + func (s *postgresBuilderSuite) SetupSuite() { s.ctx = context.Background() @@ -57,7 +58,8 @@ func (s *postgresBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - schema.SetDebug(true) + s.builder, err = schema.NewBuilder("postgres", schema.WithDebug()) + s.Require().NoError(err) } func (s *postgresBuilderSuite) TearDownSuite() { @@ -65,8 +67,7 @@ func (s *postgresBuilderSuite) TearDownSuite() { } func (s *postgresBuilderSuite) TestCreate() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -86,21 +87,21 @@ func (s *postgresBuilderSuite) TestCreate() { s.Run("when all parameters are valid, should create table successfully", func() { err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() table.Timestamps() }) s.NoError(err, "expected no error when creating table with valid parameters") }) s.Run("when use custom schema should create it successfully", func() { - _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_publics") + _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_public") s.NoError(err, "expected no error when creating custom schema") - err = builder.Create(context.Background(), tx, "custom_publics.users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "custom_public.users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() table.TimestampsTz() }) s.NoError(err, "expected no error when creating table with custom schema") @@ -118,9 +119,9 @@ func (s *postgresBuilderSuite) TestCreate() { err = builder.Create(context.Background(), tx, "orders", func(table *schema.Blueprint) { table.ID() table.BigInteger("user_id") - table.String("order_id", 255).Unique() + table.String("order_id").Unique() table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") }) @@ -129,9 +130,9 @@ func (s *postgresBuilderSuite) TestCreate() { s.Run("when have custom index should create it successfully", func() { err = builder.Create(context.Background(), tx, "orders_2", func(table *schema.Blueprint) { table.ID() - table.String("order_id", 255).Unique("uk_orders_2_order_id") + table.String("order_id").Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Index("created_at").Name("idx_orders_created_at").Algorithm("BTREE") }) @@ -140,16 +141,15 @@ func (s *postgresBuilderSuite) TestCreate() { s.Run("when table already exists, should return error", func() { err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() + table.String("name") + table.String("email").Unique() }) s.Error(err, "expected error when creating table that already exists") }) } func (s *postgresBuilderSuite) TestCreateIfNotExists() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -169,27 +169,25 @@ func (s *postgresBuilderSuite) TestCreateIfNotExists() { s.Run("when all parameters are valid, should create table successfully", func() { err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() + table.TimestampsTz() }) s.NoError(err, "expected no error when creating table with valid parameters") }) s.Run("when table already exists, should not return error", func() { err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255) + table.String("name") + table.String("email") }) s.NoError(err, "expected no error when creating table that already exists") }) } func (s *postgresBuilderSuite) TestDrop() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -201,11 +199,10 @@ func (s *postgresBuilderSuite) TestDrop() { s.Run("when all parameters are valid, should drop table successfully", func() { err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") err = builder.Drop(s.ctx, tx, "users") @@ -218,8 +215,7 @@ func (s *postgresBuilderSuite) TestDrop() { } func (s *postgresBuilderSuite) TestDropIfExists() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -235,11 +231,10 @@ func (s *postgresBuilderSuite) TestDropIfExists() { s.Run("when all parameters are valid, should drop table successfully", func() { err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") err = builder.DropIfExists(s.ctx, tx, "users") @@ -252,8 +247,7 @@ func (s *postgresBuilderSuite) TestDropIfExists() { } func (s *postgresBuilderSuite) TestRename() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -287,8 +281,7 @@ func (s *postgresBuilderSuite) TestRename() { } func (s *postgresBuilderSuite) TestTable() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -312,8 +305,7 @@ func (s *postgresBuilderSuite) TestTable() { table.String("email", 255).Unique("uk_users_email") table.String("password", 255).Nullable() table.Text("bio").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.FullText("bio") }) @@ -359,7 +351,7 @@ func (s *postgresBuilderSuite) TestTable() { }) s.Run("should drop unique constraint", func() { err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.DropUnique("uk_users_email") + table.DropUnique([]string{"email"}) }) s.NoError(err, "expected no error when dropping unique constraint from table") }) @@ -397,8 +389,7 @@ func (s *postgresBuilderSuite) TestTable() { } func (s *postgresBuilderSuite) TestGetColumns() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -434,8 +425,7 @@ func (s *postgresBuilderSuite) TestGetColumns() { } func (s *postgresBuilderSuite) TestGetIndexes() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -474,8 +464,7 @@ func (s *postgresBuilderSuite) TestGetIndexes() { } func (s *postgresBuilderSuite) TestGetTables() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -506,8 +495,7 @@ func (s *postgresBuilderSuite) TestGetTables() { } func (s *postgresBuilderSuite) TestHasColumn() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -544,8 +532,7 @@ func (s *postgresBuilderSuite) TestHasColumn() { } func (s *postgresBuilderSuite) TestHasColumns() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -582,8 +569,7 @@ func (s *postgresBuilderSuite) TestHasColumns() { } func (s *postgresBuilderSuite) TestHasIndex() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck @@ -627,8 +613,7 @@ func (s *postgresBuilderSuite) TestHasIndex() { } func (s *postgresBuilderSuite) TestHasTable() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck diff --git a/postgres_grammar.go b/postgres_grammar.go index 4e02e7b..dac364f 100644 --- a/postgres_grammar.go +++ b/postgres_grammar.go @@ -69,8 +69,7 @@ func (g *pgGrammar) compileCreate(blueprint *Blueprint) (string, error) { if err != nil { return "", err } - constraints := g.getConstraints(blueprint) - columns = append(columns, constraints...) + columns = append(columns, g.getConstraints(blueprint)...) return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil } @@ -79,8 +78,7 @@ func (g *pgGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, erro if err != nil { return "", err } - constraints := g.getConstraints(blueprint) - columns = append(columns, constraints...) + columns = append(columns, g.getConstraints(blueprint)...) return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil } @@ -94,11 +92,11 @@ func (g *pgGrammar) compileAdd(blueprint *Blueprint) (string, error) { return "", err } columns = g.prefixArray("ADD COLUMN ", columns) - constraints := g.getConstraints(blueprint) - constraints = g.prefixArray("ADD ", constraints) - - columns = append(columns, constraints...) + if len(constraints) > 0 { + constraints = g.prefixArray("ADD ", constraints) + columns = append(columns, constraints...) + } return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, @@ -106,46 +104,25 @@ func (g *pgGrammar) compileAdd(blueprint *Blueprint) (string, error) { ), nil } -func (g *pgGrammar) compileChange(bp *Blueprint) ([]string, error) { - if len(bp.getChangedColumns()) == 0 { - return nil, nil +func (g *pgGrammar) compileChange(bp *Blueprint, command *command) (string, error) { + column := command.column + if column.name == "" { + return "", fmt.Errorf("column name cannot be empty for change operation") } - var queries []string - for _, col := range bp.getChangedColumns() { - if col.name == "" { - return nil, fmt.Errorf("column name cannot be empty for change operation") - } - var statements []string - - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s TYPE %s", col.name, g.getType(col))) - if col.hasCommand("default") { - if col.defaultValue != nil { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", col.name, g.getDefaultValue(col))) - } else { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s DROP DEFAULT", col.name)) - } - } - if col.hasCommand("nullable") { - if col.nullable { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s DROP NOT NULL", col.name)) - } else { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s SET NOT NULL", col.name)) - } - } - sql := "ALTER TABLE " + bp.name + " " + strings.Join(statements, ", ") - queries = append(queries, sql) - - if col.hasCommand("comment") { - if col.comment != "" { - queries = append(queries, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", bp.name, col.name, col.comment)) - } else { - queries = append(queries, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS NULL", bp.name, col.name)) - } + var changes []string + changes = append(changes, fmt.Sprintf("TYPE %s", g.getType(command.column))) + for _, modifier := range g.modifiers() { + change := modifier(command.column) + if change != "" { + changes = append(changes, strings.TrimSpace(change)) } } - return queries, nil + return fmt.Sprintf("ALTER TABLE %s %s", + bp.name, + strings.Join(g.prefixArray(fmt.Sprintf("ALTER COLUMN %s ", column.name), changes), ", "), + ), nil } func (g *pgGrammar) compileDrop(blueprint *Blueprint) (string, error) { @@ -156,65 +133,85 @@ func (g *pgGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) { return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil } -func (g *pgGrammar) compileRename(blueprint *Blueprint) (string, error) { - return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, blueprint.newName), nil +func (g *pgGrammar) compileRename(blueprint *Blueprint, command *command) (string, error) { + return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil } -func (g *pgGrammar) compileDropColumn(blueprint *Blueprint) (string, error) { - if len(blueprint.dropColumns) == 0 { +func (g *pgGrammar) compileDropColumn(blueprint *Blueprint, command *command) (string, error) { + if len(blueprint.columns) == 0 { return "", nil } - columns := g.prefixArray("DROP COLUMN ", blueprint.dropColumns) + columns := g.prefixArray("DROP COLUMN ", command.columns) return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil } -func (g *pgGrammar) compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { +func (g *pgGrammar) compileRenameColumn(blueprint *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { return "", fmt.Errorf("table name, old column name, and new column name cannot be empty for rename operation") } - return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, oldName, newName), nil + return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil +} + +func (g *pgGrammar) compileFullText(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("fulltext index column cannot be empty") + } + indexName := command.index + if indexName == "" { + indexName = g.createIndexName(blueprint, commandFullText, command.columns...) + } + language := command.language + if language == "" { + language = "english" // Default language for full-text search + } + var columns []string + for _, col := range command.columns { + columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.quoteString(language), col)) + } + + return fmt.Sprintf("CREATE INDEX %s ON %s USING GIN (%s)", indexName, blueprint.name, strings.Join(columns, " || ")), nil } -func (g *pgGrammar) compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { +func (g *pgGrammar) compileIndex(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { return "", fmt.Errorf("index column cannot be empty") } - indexName := index.name + indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, index) + indexName = g.createIndexName(blueprint, commandIndex, command.columns...) } sql := fmt.Sprintf("CREATE INDEX %s ON %s", indexName, blueprint.name) - if index.algorithm != "" { - sql += fmt.Sprintf(" USING %s", index.algorithm) + if command.algorithm != "" { + sql += fmt.Sprintf(" USING %s", command.algorithm) } - return fmt.Sprintf("%s (%s)", sql, g.columnize(index.columns)), nil + return fmt.Sprintf("%s (%s)", sql, g.columnize(command.columns)), nil } -func (g *pgGrammar) compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { +func (g *pgGrammar) compileUnique(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { return "", fmt.Errorf("unique index column cannot be empty") } - indexName := index.name + indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, index) + indexName = g.createIndexName(blueprint, commandUnique, command.columns...) } sql := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)", blueprint.name, indexName, - g.columnize(index.columns), + g.columnize(command.columns), ) - if index.deferrable != nil { - if *index.deferrable { + if command.deferrable != nil { + if *command.deferrable { sql += " DEFERRABLE" } else { sql += " NOT DEFERRABLE" } } - if index.deferrable != nil && *index.deferrable && index.initiallyImmediate != nil { - if *index.initiallyImmediate { + if command.deferrable != nil && *command.deferrable && command.initiallyImmediate != nil { + if *command.initiallyImmediate { sql += " INITIALLY IMMEDIATE" } else { sql += " INITIALLY DEFERRED" @@ -224,84 +221,64 @@ func (g *pgGrammar) compileUnique(blueprint *Blueprint, index *indexDefinition) return sql, nil } -func (g *pgGrammar) compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { +func (g *pgGrammar) compilePrimary(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { return "", fmt.Errorf("primary key index column cannot be empty") } - indexName := index.name + indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, index) + indexName = g.createIndexName(blueprint, commandPrimary, command.columns...) } - return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(index.columns)), nil + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(command.columns)), nil } -func (g *pgGrammar) compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("fulltext index column cannot be empty") - } - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - language := index.language - if language == "" { - language = "english" // Default language for full-text search - } - var columns []string - for _, col := range index.columns { - columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.quoteString(language), col)) +func (g *pgGrammar) compileDropIndex(_ *Blueprint, command *command) (string, error) { + if command.index == "" { + return "", fmt.Errorf("index name cannot be empty for drop operation") } - - return fmt.Sprintf("CREATE INDEX %s ON %s USING GIN (%s)", indexName, blueprint.name, strings.Join(columns, " || ")), nil + return fmt.Sprintf("DROP INDEX %s", command.index), nil } -func (g *pgGrammar) compileDropIndex(_ *Blueprint, indexName string) (string, error) { - if indexName == "" { - return "", fmt.Errorf("index name cannot be empty for drop operation") - } - return fmt.Sprintf("DROP INDEX %s", indexName), nil +func (g *pgGrammar) compileDropFulltext(blueprint *Blueprint, command *command) (string, error) { + return g.compileDropIndex(blueprint, command) } -func (g *pgGrammar) compileDropUnique(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { +func (g *pgGrammar) compileDropUnique(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { return "", fmt.Errorf("index name cannot be empty for drop operation") } - return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, indexName), nil -} - -func (g *pgGrammar) compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) { - return g.compileDropIndex(blueprint, indexName) + return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil } -func (g *pgGrammar) compileRenameIndex(_ *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { - return "", fmt.Errorf("index names for rename operation cannot be empty: oldName=%s, newName=%s", oldName, newName) +func (g *pgGrammar) compileDropPrimary(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { + command.index = g.createIndexName(blueprint, commandPrimary) } - return fmt.Sprintf("ALTER INDEX %s RENAME TO %s", oldName, newName), nil + return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil } -func (g *pgGrammar) compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { - indexName = g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary}) +func (g *pgGrammar) compileRenameIndex(_ *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { + return "", fmt.Errorf("index names for rename operation cannot be empty: oldName=%s, newName=%s", command.from, command.to) } - return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, indexName), nil + return fmt.Sprintf("ALTER INDEX %s RENAME TO %s", command.from, command.to), nil } -func (g *pgGrammar) compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) { - sql, err := g.baseGrammar.compileForeign(blueprint, foreignKey) +func (g *pgGrammar) compileForeign(blueprint *Blueprint, command *command) (string, error) { + sql, err := g.baseGrammar.compileForeign(blueprint, command) if err != nil { return "", err } - if foreignKey.deferrable != nil { - if *foreignKey.deferrable { + if command.deferrable != nil { + if *command.deferrable { sql += " DEFERRABLE" } else { sql += " NOT DEFERRABLE" } } - if foreignKey.deferrable != nil && *foreignKey.deferrable && foreignKey.initiallyImmediate != nil { - if *foreignKey.initiallyImmediate { + if command.deferrable != nil && *command.deferrable && command.initiallyImmediate != nil { + if *command.initiallyImmediate { sql += " INITIALLY IMMEDIATE" } else { sql += " INITIALLY DEFERRED" @@ -311,11 +288,29 @@ func (g *pgGrammar) compileForeign(blueprint *Blueprint, foreignKey *foreignKeyD return sql, nil } -func (g *pgGrammar) compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) { - if foreignKeyName == "" { +func (g *pgGrammar) compileDropForeign(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { return "", fmt.Errorf("foreign key name cannot be empty for drop operation") } - return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, foreignKeyName), nil + return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil +} + +func (g *pgGrammar) getFluentCommands() []func(blueprint *Blueprint, command *command) string { + return []func(blueprint *Blueprint, command *command) string{ + g.compileComment, + } +} + +func (g *pgGrammar) compileComment(blueprint *Blueprint, command *command) string { + if command.column.comment != nil || command.column.change { + sql := fmt.Sprintf("COMMENT ON COLUMN %s.%s IS ", blueprint.name, command.column.name) + if command.column.comment == nil { + return sql + "NULL" + } else { + return sql + fmt.Sprintf("'%s'", *command.column.comment) + } + } + return "" } func (g *pgGrammar) getColumns(blueprint *Blueprint) ([]string, error) { @@ -325,20 +320,8 @@ func (g *pgGrammar) getColumns(blueprint *Blueprint) ([]string, error) { return nil, fmt.Errorf("column name cannot be empty") } sql := col.name + " " + g.getType(col) - if col.hasCommand("default") { - if col.defaultValue != nil { - sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col)) - } else { - sql += " DEFAULT NULL" - } - } - if col.nullable { - sql += " NULL" - } else { - sql += " NOT NULL" - } - if col.comment != "" { - sql += fmt.Sprintf(" COMMENT '%s'", col.comment) + for _, modifier := range g.modifiers() { + sql += modifier(col) } columns = append(columns, sql) } @@ -349,102 +332,254 @@ func (g *pgGrammar) getColumns(blueprint *Blueprint) ([]string, error) { func (g *pgGrammar) getConstraints(blueprint *Blueprint) []string { var constrains []string for _, col := range blueprint.getAddedColumns() { - if col.primary { - pkConstraintName := g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary}) + if col.primary != nil && *col.primary { + pkConstraintName := g.createIndexName(blueprint, commandPrimary) sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" constrains = append(constrains, sql) continue } - if col.unique { - uniqueName := col.uniqueName - if uniqueName == "" { - uniqueName = g.createIndexName(blueprint, &indexDefinition{ - indexType: indexTypeUnique, - columns: []string{col.name}, - }) - } - sql := fmt.Sprintf("CONSTRAINT %s UNIQUE (%s)", uniqueName, col.name) - constrains = append(constrains, sql) - } } return constrains } func (g *pgGrammar) getType(col *columnDefinition) string { - switch col.columnType { - case columnTypeCustom: - return col.customColumnType - case columnTypeChar: - if col.length > 0 { - return fmt.Sprintf("CHAR(%d)", col.length) - } - return "CHAR" - case columnTypeString: - if col.length > 0 { - return fmt.Sprintf("VARCHAR(%d)", col.length) - } - return "VARCHAR" - case columnTypeDecimal: - return fmt.Sprintf("DECIMAL(%d, %d)", col.total, col.places) - case columnTypeBigInteger: - if col.autoIncrement { - return "BIGSERIAL" - } - return "BIGINT" - case columnTypeInteger, columnTypeMediumInteger: - if col.autoIncrement { - return "SERIAL" - } - return "INTEGER" - case columnTypeSmallInteger, columnTypeTinyInteger: - if col.autoIncrement { - return "SMALLSERIAL" - } - return "SMALLINT" - case columnTypeTime: - return fmt.Sprintf("TIME(%d)", col.precision) - case columnTypeTimestamp, columnTypeDateTime: - return fmt.Sprintf("TIMESTAMP(%d)", col.precision) - case columnTypeTimestampTz, columnTypeDateTimeTz: - return fmt.Sprintf("TIMESTAMPTZ(%d)", col.precision) - case columnTypeGeography: - return fmt.Sprintf("GEOGRAPHY(%s, %d)", col.subType, col.srid) - case columnTypeGeometry: - if col.srid > 0 { - return fmt.Sprintf("GEOMETRY(%s, %d)", col.subType, col.srid) - } - return fmt.Sprintf("GEOMETRY(%s)", col.subType) - case columnTypePoint: - return fmt.Sprintf("POINT(%d)", col.srid) - case columnTypeEnum: - if len(col.allowedEnums) == 0 { - return "VARCHAR(255)" // Fallback to VARCHAR if no enums are defined - } - enumValues := make([]string, len(col.allowedEnums)) - for i, v := range col.allowedEnums { - enumValues[i] = fmt.Sprintf("'%s'", v) - } - return "VARCHAR(255) CHECK (" + col.name + " IN (" + strings.Join(enumValues, ", ") + "))" - default: - colType, ok := map[columnType]string{ - columnTypeBoolean: "BOOLEAN", - columnTypeLongText: "TEXT", - columnTypeText: "TEXT", - columnTypeMediumText: "TEXT", - columnTypeTinyText: "TEXT", - columnTypeDouble: "DOUBLE PRECISION", - columnTypeFloat: "REAL", - columnTypeDate: "DATE", - columnTypeYear: "INTEGER", // PostgreSQL does not have a YEAR type, using INTEGER instead - columnTypeJSON: "JSON", - columnTypeJSONB: "JSONB", - columnTypeUUID: "UUID", - columnTypeBinary: "BYTEA", - }[col.columnType] - if !ok { - return "ERROR: Unknown column type " + string(col.columnType) + typeMapFunc := map[string]func(*columnDefinition) string{ + columnTypeChar: g.typeChar, + columnTypeString: g.typeString, + columnTypeTinyText: g.typeTinyText, + columnTypeText: g.typeText, + columnTypeMediumText: g.typeMediumText, + columnTypeLongText: g.typeLongText, + columnTypeInteger: g.typeInteger, + columnTypeBigInteger: g.typeBigInteger, + columnTypeMediumInteger: g.typeMediumInteger, + columnTypeSmallInteger: g.typeSmallInteger, + columnTypeTinyInteger: g.typeTinyInteger, + columnTypeFloat: g.typeFloat, + columnTypeDouble: g.typeDouble, + columnTypeDecimal: g.typeDecimal, + columnTypeBoolean: g.typeBoolean, + columnTypeEnum: g.typeEnum, + columnTypeJSON: g.typeJson, + columnTypeJSONB: g.typeJsonb, + columnTypeDate: g.typeDate, + columnTypeDateTime: g.typeDateTime, + columnTypeDateTimeTz: g.typeDateTimeTz, + columnTypeTime: g.typeTime, + columnTypeTimeTz: g.typeTimeTz, + columnTypeTimestamp: g.typeTimestamp, + columnTypeTimestampTz: g.typeTimestampTz, + columnTypeYear: g.typeYear, + columnTypeBinary: g.typeBinary, + columnTypeUUID: g.typeUUID, + columnTypeGeography: g.typeGeography, + columnTypeGeometry: g.typeGeometry, + columnTypePoint: g.typePoint, + } + if fn, ok := typeMapFunc[col.columnType]; ok { + return fn(col) + } + return col.columnType +} + +func (g *pgGrammar) typeChar(col *columnDefinition) string { + if col.length != nil && *col.length > 0 { + return fmt.Sprintf("CHAR(%d)", *col.length) + } + return "CHAR" +} + +func (g *pgGrammar) typeString(col *columnDefinition) string { + if col.length != nil && *col.length > 0 { + return fmt.Sprintf("VARCHAR(%d)", *col.length) + } + return "VARCHAR" +} + +func (g *pgGrammar) typeTinyText(_ *columnDefinition) string { + return "VARCHAR(255)" +} + +func (g *pgGrammar) typeText(_ *columnDefinition) string { + return "TEXT" +} + +func (g *pgGrammar) typeMediumText(_ *columnDefinition) string { + return "TEXT" +} + +func (g *pgGrammar) typeLongText(_ *columnDefinition) string { + return "TEXT" +} + +func (g *pgGrammar) typeBigInteger(col *columnDefinition) string { + if col.autoIncrement != nil && *col.autoIncrement { + return "BIGSERIAL" + } + return "BIGINT" +} + +func (g *pgGrammar) typeInteger(col *columnDefinition) string { + if col.autoIncrement != nil && *col.autoIncrement { + return "SERIAL" + } + return "INTEGER" +} + +func (g *pgGrammar) typeMediumInteger(col *columnDefinition) string { + return g.typeInteger(col) +} + +func (g *pgGrammar) typeSmallInteger(col *columnDefinition) string { + if col.autoIncrement != nil && *col.autoIncrement { + return "SMALLSERIAL" + } + return "SMALLINT" +} + +func (g *pgGrammar) typeTinyInteger(col *columnDefinition) string { + return g.typeSmallInteger(col) +} + +func (g *pgGrammar) typeFloat(_ *columnDefinition) string { + return "REAL" +} + +func (g *pgGrammar) typeDouble(_ *columnDefinition) string { + return "DOUBLE PRECISION" +} + +func (g *pgGrammar) typeDecimal(col *columnDefinition) string { + return fmt.Sprintf("DECIMAL(%d, %d)", *col.total, *col.places) +} + +func (g *pgGrammar) typeBoolean(_ *columnDefinition) string { + return "BOOLEAN" +} + +func (g *pgGrammar) typeEnum(col *columnDefinition) string { + enumValues := make([]string, len(col.allowed)) + for i, v := range col.allowed { + enumValues[i] = g.quoteString(v) + } + return "VARCHAR(255) CHECK (" + col.name + " IN (" + strings.Join(enumValues, ", ") + "))" +} + +func (g *pgGrammar) typeJson(_ *columnDefinition) string { + return "JSON" +} + +func (g *pgGrammar) typeJsonb(_ *columnDefinition) string { + return "JSONB" +} + +func (g *pgGrammar) typeDate(_ *columnDefinition) string { + return "DATE" +} + +func (g *pgGrammar) typeDateTime(col *columnDefinition) string { + return g.typeTimestamp(col) +} + +func (g *pgGrammar) typeDateTimeTz(col *columnDefinition) string { + return g.typeTimestampTz(col) +} + +func (g *pgGrammar) typeTime(col *columnDefinition) string { + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("TIME(%d)", *col.precision) + } + return "TIME" +} + +func (g *pgGrammar) typeTimeTz(col *columnDefinition) string { + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("TIMETZ(%d)", *col.precision) + } + return "TIMETZ" +} + +func (g *pgGrammar) typeTimestamp(col *columnDefinition) string { + if col.useCurrent != nil && *col.useCurrent { + col.Default(Expression("CURRENT_TIMESTAMP")) + } + if col.precision != nil { + return fmt.Sprintf("TIMESTAMP(%d)", *col.precision) + } + return "TIMESTAMP" +} + +func (g *pgGrammar) typeTimestampTz(col *columnDefinition) string { + if col.useCurrent != nil && *col.useCurrent { + col.Default(Expression("CURRENT_TIMESTAMP")) + } + if col.precision != nil { + return fmt.Sprintf("TIMESTAMPTZ(%d)", *col.precision) + } + return "TIMESTAMPTZ" +} + +func (g *pgGrammar) typeYear(_ *columnDefinition) string { + return "INTEGER" +} + +func (g *pgGrammar) typeBinary(_ *columnDefinition) string { + return "BYTEA" +} + +func (g *pgGrammar) typeUUID(_ *columnDefinition) string { + return "UUID" +} + +func (g *pgGrammar) typeGeography(col *columnDefinition) string { + if col.subtype != nil && col.srid != nil { + return fmt.Sprintf("GEOGRAPHY(%s, %d)", *col.subtype, *col.srid) + } else if col.subtype != nil { + return fmt.Sprintf("GEOGRAPHY(%s)", *col.subtype) + } + return "GEOGRAPHY" +} + +func (g *pgGrammar) typeGeometry(col *columnDefinition) string { + if col.subtype != nil && col.srid != nil { + return fmt.Sprintf("GEOMETRY(%s, %d)", *col.subtype, *col.srid) + } else if col.subtype != nil { + return fmt.Sprintf("GEOMETRY(%s)", *col.subtype) + } + return "GEOMETRY" +} + +func (g *pgGrammar) typePoint(col *columnDefinition) string { + if col.srid != nil { + return fmt.Sprintf("POINT(%d)", *col.srid) + } + return "POINT" +} + +func (g *pgGrammar) modifiers() []func(*columnDefinition) string { + return []func(*columnDefinition) string{ + g.modifyNullable, + g.modifyDefault, + } +} + +func (g *pgGrammar) modifyNullable(col *columnDefinition) string { + if col.change { + if col.nullable != nil && *col.nullable { + return " DROP NOT NULL" } - return colType + return " SET NOT NULL" + } + if col.nullable != nil && *col.nullable { + return " NULL" + } + return " NOT NULL" +} + +func (g *pgGrammar) modifyDefault(col *columnDefinition) string { + if col.defaultValue != nil { + return fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col.defaultValue)) } + return "" } diff --git a/postgres_grammar_test.go b/postgres_grammar_test.go index 843f791..716e5f2 100644 --- a/postgres_grammar_test.go +++ b/postgres_grammar_test.go @@ -21,13 +21,13 @@ func TestPgGrammar_CompileCreate(t *testing.T) { table: "users", blueprint: func(table *Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255) + table.String("name") + table.String("email") table.String("password").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() + table.Timestamp("updated_at").UseCurrent() }, - want: "CREATE TABLE users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", + want: "CREATE TABLE users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR(255) NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", }, { name: "Create table with foreign key", @@ -35,7 +35,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) { blueprint: func(table *Blueprint) { table.ID() table.Integer("user_id") - table.String("title", 255) + table.String("title") table.Text("content").Nullable() table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") }, @@ -45,7 +45,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) { name: "Create table with column name is empty", table: "empty_column_table", blueprint: func(table *Blueprint) { - table.String("", 255) // Intentionally empty column name + table.String("") // Intentionally empty column name }, wantErr: true, }, @@ -82,20 +82,20 @@ func TestPgGrammar_CompileCreateIfNotExists(t *testing.T) { table: "users", blueprint: func(table *Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255) + table.String("name") + table.String("email") table.String("password").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() + table.Timestamp("updated_at").UseCurrent() }, - want: "CREATE TABLE IF NOT EXISTS users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", + want: "CREATE TABLE IF NOT EXISTS users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR(255) NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", wantErr: false, }, { name: "Create table with column name is empty", table: "empty_column_table", blueprint: func(table *Blueprint) { - table.String("", 255) // Intentionally empty column name + table.String("") // Intentionally empty column name }, wantErr: true, }, @@ -153,7 +153,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { blueprint: func(table *Blueprint) { table.Boolean("active").Default(true) }, - want: "ALTER TABLE users ADD COLUMN active BOOLEAN DEFAULT true NOT NULL", + want: "ALTER TABLE users ADD COLUMN active BOOLEAN DEFAULT '1' NOT NULL", wantErr: false, }, { @@ -189,15 +189,15 @@ func TestPgGrammar_CompileAdd(t *testing.T) { blueprint: func(table *Blueprint) { table.Decimal("price", 10, 2).Default(0).Comment("Product price") }, - want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT 0 NOT NULL COMMENT 'Product price'", + want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT '0' NOT NULL COMMENT 'Product price'", wantErr: false, }, { name: "Add timestamp columns", table: "orders", blueprint: func(table *Blueprint) { - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable() + table.Timestamp("created_at").UseCurrent() + table.Timestamp("updated_at").UseCurrent().Nullable() }, want: "ALTER TABLE orders ADD COLUMN created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, ADD COLUMN updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NULL", wantErr: false, @@ -358,7 +358,7 @@ func TestPgGrammar_CompileChange(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileChange(bp) + got, err := grammar.compileChange(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -455,8 +455,9 @@ func TestPgGrammar_CompileRename(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.oldName, newName: tt.newName} - got, err := grammar.compileRename(bp) + bp := &Blueprint{name: tt.oldName} + bp.rename(tt.newName) + got, err := grammar.compileRename(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -583,7 +584,7 @@ func TestPgGrammar_CompileDropColumn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileDropColumn(bp) + got, err := grammar.compileDropColumn(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -640,7 +641,8 @@ func TestPgGrammar_CompileRenameColumn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileRenameColumn(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := grammar.compileRenameColumn(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -684,7 +686,8 @@ func TestPgGrammar_CompileDropIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} - got, err := grammar.compileDropIndex(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.compileDropIndex(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -729,7 +732,8 @@ func TestPgGrammar_CompileDropPrimary(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := grammar.compileDropPrimary(tt.blueprint, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.compileDropPrimary(tt.blueprint, command) if tt.wantErr { assert.Error(t, err) return @@ -794,7 +798,8 @@ func TestPgGrammar_CompileRenameIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileRenameIndex(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := grammar.compileRenameIndex(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -972,7 +977,7 @@ func TestPgGrammar_CompileForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileForeign(bp, bp.foreignKeys[0]) + got, err := grammar.compileForeign(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1020,7 +1025,8 @@ func TestPgGrammar_CompileDropForeign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileDropForeign(bp, tt.foreignKeyName) + command := &command{index: tt.foreignKeyName} + got, err := grammar.compileDropForeign(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -1100,7 +1106,7 @@ func TestPgGrammar_CompileIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileIndex(bp, bp.indexes[0]) + got, err := grammar.compileIndex(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1208,7 +1214,7 @@ func TestPgGrammar_CompileUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileUnique(bp, bp.indexes[0]) + got, err := grammar.compileUnique(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1306,7 +1312,7 @@ func TestPgGrammar_CompileFullText(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileFullText(bp, bp.indexes[0]) + got, err := grammar.compileFullText(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1356,7 +1362,8 @@ func TestPgGrammar_CompileDropUnique(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} - got, err := grammar.compileDropUnique(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.compileDropUnique(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -1413,7 +1420,8 @@ func TestPgGrammar_CompileDropFulltext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} - got, err := grammar.compileDropFulltext(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.compileDropFulltext(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -1494,7 +1502,7 @@ func TestPgGrammar_CompilePrimary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compilePrimary(bp, bp.indexes[0]) + got, err := grammar.compilePrimary(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1540,7 +1548,7 @@ func TestPgGrammar_GetType(t *testing.T) { blueprint: func(table *Blueprint) { table.Char("code") }, - want: "CHAR", + want: "CHAR(255)", }, { name: "string column type", @@ -1557,16 +1565,9 @@ func TestPgGrammar_GetType(t *testing.T) { want: "DECIMAL(10, 2)", }, { - name: "double column type with precision", - blueprint: func(table *Blueprint) { - table.Double("value", 8, 2) - }, - want: "DOUBLE PRECISION", - }, - { - name: "double column type without precision", + name: "double column type", blueprint: func(table *Blueprint) { - table.Double("value", 0, 0) + table.Double("value") }, want: "DOUBLE PRECISION", }, @@ -1743,7 +1744,7 @@ func TestPgGrammar_GetType(t *testing.T) { blueprint: func(table *Blueprint) { table.TinyText("notes") }, - want: "TEXT", + want: "VARCHAR(255)", }, { name: "date column type", @@ -1822,13 +1823,6 @@ func TestPgGrammar_GetType(t *testing.T) { }, want: "VARCHAR(255) CHECK (status IN ('active', 'inactive'))", }, - { - name: "Enum with empty values", - blueprint: func(table *Blueprint) { - table.Enum("status", []string{}) - }, - want: "VARCHAR(255)", - }, } for _, tt := range tests { diff --git a/schema.go b/schema.go index ea067f5..a37f761 100644 --- a/schema.go +++ b/schema.go @@ -43,7 +43,9 @@ func newBuilder() (Builder, error) { return nil, errors.New("schema dialect is not set, please call schema.SetDialect() before using schema functions") } - builder, err := NewBuilder(dialectValue.String()) + builder, err := NewBuilder(dialectValue.String(), + WithDebug(cfg.debug), + ) if err != nil { return nil, err } diff --git a/schema_test.go b/schema_test.go index 5a7bfcb..1e1540d 100644 --- a/schema_test.go +++ b/schema_test.go @@ -35,10 +35,9 @@ func (s *schemaTestSuite) SetupSuite() { s.Require().NoError(err) s.db = db - schema.SetDebug(false) s.Run("when dialect is not set should return error", func() { - err := schema.SetDialect("") + err := schema.Init("") s.Error(err) s.ErrorContains(err, "unknown dialect") @@ -62,7 +61,7 @@ func (s *schemaTestSuite) SetupSuite() { } }) s.Run("when dialect is set to postgres should not return error", func() { - err = schema.SetDialect("postgres") + err = schema.Init("postgres") s.Require().NoError(err) }) } diff --git a/util.go b/util.go index 994156b..e9e838e 100644 --- a/util.go +++ b/util.go @@ -1,44 +1,33 @@ package schema -import ( - "context" - "database/sql" - "log" -) - func optional[T any](defaultValue T, values ...T) T { - return optionalAtIndex(0, defaultValue, values...) + if len(values) > 0 { + return values[0] + } + return defaultValue } -func optionalAtIndex[T any](index int, defaultValue T, values ...T) T { - if index < len(values) { - return values[index] +func optionalPtr[T any](defaultValue T, values ...T) *T { + if len(values) > 0 { + return &values[0] } - return defaultValue + return &defaultValue } -func execContext(ctx context.Context, tx *sql.Tx, queries ...string) error { - for _, query := range queries { - if debug { - log.Printf("Executing SQL: %s\n", query) - } - if _, err := tx.ExecContext(ctx, query); err != nil { - return err - } +func optionalNil[T any](values ...T) *T { + if len(values) > 0 { + return &values[0] } return nil } -func queryRowContext(ctx context.Context, tx *sql.Tx, query string, args ...any) *sql.Row { - if debug { - log.Printf("Executing Query: %s with args: %v\n", query, args) - } - return tx.QueryRowContext(ctx, query, args...) +func ptrOf[T any](value T) *T { + return &value } -func queryContext(ctx context.Context, tx *sql.Tx, query string, args ...any) (*sql.Rows, error) { - if debug { - log.Printf("Executing Query: %s with args: %v\n", query, args) +func ternary[T any](condition bool, trueValue, falseValue T) T { + if condition { + return trueValue } - return tx.QueryContext(ctx, query, args...) + return falseValue } From 253ee139f92b17e70453b85bdedcbade1a0f5155 Mon Sep 17 00:00:00 2001 From: Ahmad Faiz Kamaludin Date: Sat, 30 Aug 2025 21:50:41 +0700 Subject: [PATCH 2/7] feat: refactor schema builder --- Makefile | 53 ++-- blueprint.go | 123 +++++----- builder.go | 85 ++----- column_definition.go | 75 ++++-- command.go | 41 ++-- config.go | 27 ++ dialect.go | 51 ---- examples/basic/cmd/migrate/cmd.go | 62 +++++ examples/basic/cmd/migrate/migrate.go | 67 ++--- examples/basic/cmd/root.go | 20 ++ examples/basic/config/config.go | 20 ++ examples/basic/config/database.go | 14 +- examples/basic/config/util.go | 23 ++ examples/basic/go.mod | 9 +- examples/basic/go.sum | 10 +- examples/basic/main.go | 50 +--- ...o => 20250830103612_create_users_table.go} | 3 +- ...o => 20250830103653_create_roles_table.go} | 0 ...20250830103714_create_user_roles_table.go} | 0 foreign_key_definition.go | 6 +- go.mod | 7 +- go.sum | 19 +- grammar.go | 131 +++++----- index_definition.go | 6 +- internal/config/config.go | 29 +++ internal/dialect/dialect.go | 25 ++ internal/parser/migration.go | 21 ++ internal/parser/migration_test.go | 132 ++++++++++ util.go => internal/util/util.go | 12 +- mysql_builder.go | 76 +++--- mysql_builder_test.go | 37 +-- mysql_grammar.go | 152 ++++++------ mysql_grammar_test.go | 71 ++---- options.go | 26 -- postgres_builder.go | 39 +-- postgres_builder_test.go | 63 +---- postgres_grammar.go | 230 +++++++++--------- postgres_grammar_test.go | 166 +++++-------- schema.go | 33 +-- schema_test.go | 72 +----- template.go | 101 ++++++++ 41 files changed, 1118 insertions(+), 1069 deletions(-) create mode 100644 config.go delete mode 100644 dialect.go create mode 100644 examples/basic/cmd/migrate/cmd.go create mode 100644 examples/basic/cmd/root.go create mode 100644 examples/basic/config/config.go create mode 100644 examples/basic/config/util.go rename examples/basic/migrations/{20250626235117_create_users_table.go => 20250830103612_create_users_table.go} (83%) rename examples/basic/migrations/{20250628092119_create_roles_table.go => 20250830103653_create_roles_table.go} (100%) rename examples/basic/migrations/{20250628092223_create_user_roles_table.go => 20250830103714_create_user_roles_table.go} (100%) create mode 100644 internal/config/config.go create mode 100644 internal/dialect/dialect.go create mode 100644 internal/parser/migration.go create mode 100644 internal/parser/migration_test.go rename util.go => internal/util/util.go (51%) delete mode 100644 options.go create mode 100644 template.go diff --git a/Makefile b/Makefile index 3370ac2..2dc2199 100644 --- a/Makefile +++ b/Makefile @@ -1,49 +1,54 @@ COVERAGE_FILE := coverage.out COVERAGE_HTML := coverage.html -MIN_COVERAGE := 80 FORMAT ?= dots -# Format code .PHONY: fmt -fmt: +fmt: # Format code @echo "Formatting code..." @go fmt ./... -# Lint code .PHONY: lint -lint: +lint: # Lint code @echo "Linting code..." @golangci-lint run --timeout 5m -# Install dependencies and tools .PHONY: install -install: +install: # Install dependencies @echo "Installing dependencies..." @go mod download @go mod tidy -# Run tests +install-tools: # Install tools + @echo "Installing tools..." + @go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest + @go install gotest.tools/gotestsum@latest + .PHONY: test -test: +test: # Run tests @echo "Running tests..." - go test -v -cover -race ./... -coverprofile=$(COVERAGE_FILE) -coverpkg=./... + @gotestsum --format=$(FORMAT) -- ./... + +.PHONY: testcov +testcov: # Run tests with coverage + @echo "Running tests with coverage..." + @gotestsum --format=$(FORMAT) -- -coverprofile=$(COVERAGE_FILE) ./... + @echo "Total coverage is: " $(shell go tool cover -func=$(COVERAGE_FILE) | grep total | awk '{print $$3}') -.PHONY: coverage -coverage: - @echo "Generating test coverage report..." +.PHONY: testcov-html +testcov-html: testcov # Generate coverage HTML report @go tool cover -html=$(COVERAGE_FILE) -o $(COVERAGE_HTML) - @go tool cover -func=$(COVERAGE_FILE) | tee coverage.txt @echo "Coverage HTML report generated: $(COVERAGE_HTML)" @open $(COVERAGE_HTML) -.PHONY: coverage-check -coverage-check: - @COVERAGE=$$(go tool cover -func=$(COVERAGE_FILE) | grep total | awk '{print $$3}' | sed 's/%//'); \ - RESULT=$$(echo "$$COVERAGE < $(MIN_COVERAGE)" | bc); \ - if [ "$$RESULT" -eq 1 ]; then \ - echo "Coverage is below $(MIN_COVERAGE)%: $$COVERAGE%"; \ - exit 1; \ - else \ - echo "Coverage is sufficient: $$COVERAGE%"; \ - fi \ No newline at end of file +.PHONY: help +help: + @echo "Available commands:" + @echo " fmt - Format code" + @echo " lint - Lint code" + @echo " install - Install dependencies" + @echo " install-tools- Install development tools" + @echo " test - Run tests" + @echo " testcov - Run tests with coverage" + @echo " testcov-html - Generate coverage HTML report" + @echo " help - Show this help message" \ No newline at end of file diff --git a/blueprint.go b/blueprint.go index a036772..f2b3149 100644 --- a/blueprint.go +++ b/blueprint.go @@ -5,6 +5,9 @@ import ( "database/sql" "fmt" "log" + + "github.com/afkdevs/go-schema/internal/dialect" + "github.com/afkdevs/go-schema/internal/util" ) const ( @@ -32,19 +35,23 @@ const ( columnTypeTimestampTz string = "timestampTz" columnTypeYear string = "year" columnTypeBinary string = "binary" - columnTypeJSON string = "json" - columnTypeJSONB string = "jsonb" + columnTypeJson string = "json" + columnTypeJsonb string = "jsonb" columnTypeGeography string = "geography" columnTypeGeometry string = "geometry" columnTypePoint string = "point" - columnTypeUUID string = "uuid" + columnTypeUuid string = "uuid" columnTypeEnum string = "enum" ) +const ( + defaultStringLength int = 255 + defaultTimePrecision int = 0 +) + // Blueprint represents a schema blueprint for creating or altering a database table. type Blueprint struct { - debug bool - dialect dialect + dialect dialect.Dialect columns []*columnDefinition commands []*command grammar grammar @@ -52,6 +59,7 @@ type Blueprint struct { charset string collation string engine string + verbose bool } // Charset sets the character set for the table in the blueprint. @@ -82,14 +90,14 @@ func (b *Blueprint) Boolean(name string) ColumnDefinition { // Char creates a new char column definition in the blueprint. func (b *Blueprint) Char(name string, length ...int) ColumnDefinition { return b.addColumn(columnTypeChar, name, &columnDefinition{ - length: optionalPtr(1, length...), + length: util.OptionalPtr(defaultStringLength, length...), }) } // String creates a new string column definition in the blueprint. func (b *Blueprint) String(name string, length ...int) ColumnDefinition { return b.addColumn(columnTypeString, name, &columnDefinition{ - length: optionalPtr(255, length...), + length: util.OptionalPtr(defaultStringLength, length...), }) } @@ -138,8 +146,8 @@ func (b *Blueprint) Decimal(name string, params ...int) ColumnDefinition { defaultPlaces = params[1] } return b.addColumn(columnTypeDecimal, name, &columnDefinition{ - total: optionalPtr(8, params...), - places: ptrOf(defaultPlaces), + total: util.OptionalPtr(8, params...), + places: util.PtrOf(defaultPlaces), }) } @@ -151,7 +159,7 @@ func (b *Blueprint) Double(name string) ColumnDefinition { // Float creates a new float column definition in the blueprint. func (b *Blueprint) Float(name string, precision ...int) ColumnDefinition { return b.addColumn(columnTypeFloat, name, &columnDefinition{ - precision: optionalPtr(53, precision...), + precision: util.OptionalPtr(53, precision...), }) } @@ -159,7 +167,7 @@ func (b *Blueprint) Float(name string, precision ...int) ColumnDefinition { // // If a name is provided, it will be used as the column name; otherwise, "id" will be used. func (b *Blueprint) ID(name ...string) ColumnDefinition { - return b.BigIncrements(optional("id", name...)).Primary() + return b.BigIncrements(util.Optional("id", name...)).Primary() } // Increments create a new increment column definition in the blueprint. @@ -229,14 +237,14 @@ func (b *Blueprint) UnsignedTinyInteger(name string) ColumnDefinition { // DateTime creates a new date time column definition in the blueprint. func (b *Blueprint) DateTime(name string, precision ...int) ColumnDefinition { return b.addColumn(columnTypeDateTime, name, &columnDefinition{ - precision: optionalPtr(0, precision...), + precision: util.OptionalPtr(defaultTimePrecision, precision...), }) } // DateTimeTz creates a new date time with a time zone column definition in the blueprint. func (b *Blueprint) DateTimeTz(name string, precision ...int) ColumnDefinition { return b.addColumn(columnTypeDateTimeTz, name, &columnDefinition{ - precision: optionalPtr(0, precision...), + precision: util.OptionalPtr(defaultTimePrecision, precision...), }) } @@ -248,14 +256,14 @@ func (b *Blueprint) Date(name string) ColumnDefinition { // Time creates a new time column definition in the blueprint. func (b *Blueprint) Time(name string, precision ...int) ColumnDefinition { return b.addColumn(columnTypeTime, name, &columnDefinition{ - precision: optionalPtr(0, precision...), + precision: util.OptionalPtr(defaultTimePrecision, precision...), }) } // TimeTz creates a new time with time zone column definition in the blueprint. func (b *Blueprint) TimeTz(name string, precision ...int) ColumnDefinition { return b.addColumn(columnTypeTimeTz, name, &columnDefinition{ - precision: optionalPtr(0, precision...), + precision: util.OptionalPtr(defaultTimePrecision, precision...), }) } @@ -263,7 +271,7 @@ func (b *Blueprint) TimeTz(name string, precision ...int) ColumnDefinition { // The precision parameter is optional and defaults to 0 if not provided. func (b *Blueprint) Timestamp(name string, precision ...int) ColumnDefinition { return b.addColumn(columnTypeTimestamp, name, &columnDefinition{ - precision: optionalPtr(0, precision...), + precision: util.OptionalPtr(defaultTimePrecision, precision...), }) } @@ -271,7 +279,7 @@ func (b *Blueprint) Timestamp(name string, precision ...int) ColumnDefinition { // The precision parameter is optional and defaults to 0 if not provided. func (b *Blueprint) TimestampTz(name string, precision ...int) ColumnDefinition { return b.addColumn(columnTypeTimestampTz, name, &columnDefinition{ - precision: optionalPtr(0, precision...), + precision: util.OptionalPtr(defaultTimePrecision, precision...), }) } @@ -295,23 +303,23 @@ func (b *Blueprint) Year(name string) ColumnDefinition { // Binary creates a new binary column definition in the blueprint. func (b *Blueprint) Binary(name string, length ...int) ColumnDefinition { return b.addColumn(columnTypeBinary, name, &columnDefinition{ - length: optionalNil(length...), + length: util.OptionalNil(length...), }) } // JSON creates a new JSON column definition in the blueprint. func (b *Blueprint) JSON(name string) ColumnDefinition { - return b.addColumn(columnTypeJSON, name) + return b.addColumn(columnTypeJson, name) } // JSONB creates a new JSONB column definition in the blueprint. func (b *Blueprint) JSONB(name string) ColumnDefinition { - return b.addColumn(columnTypeJSONB, name) + return b.addColumn(columnTypeJsonb, name) } // UUID creates a new UUID column definition in the blueprint. func (b *Blueprint) UUID(name string) ColumnDefinition { - return b.addColumn(columnTypeUUID, name) + return b.addColumn(columnTypeUuid, name) } // Geography creates a new geography column definition in the blueprint. @@ -319,8 +327,8 @@ func (b *Blueprint) UUID(name string) ColumnDefinition { // The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geography type. func (b *Blueprint) Geography(name string, subtype string, srid ...int) ColumnDefinition { return b.addColumn(columnTypeGeography, name, &columnDefinition{ - subtype: optionalPtr("", subtype), - srid: optionalPtr(4326, srid...), + subtype: util.OptionalPtr("", subtype), + srid: util.OptionalPtr(4326, srid...), }) } @@ -329,15 +337,15 @@ func (b *Blueprint) Geography(name string, subtype string, srid ...int) ColumnDe // The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geometry type. func (b *Blueprint) Geometry(name string, subtype string, srid ...int) ColumnDefinition { return b.addColumn(columnTypeGeometry, name, &columnDefinition{ - subtype: optionalPtr("", subtype), - srid: optionalNil(srid...), + subtype: util.OptionalPtr("", subtype), + srid: util.OptionalNil(srid...), }) } // Point creates a new point column definition in the blueprint. func (b *Blueprint) Point(name string, srid ...int) ColumnDefinition { return b.addColumn(columnTypePoint, name, &columnDefinition{ - srid: optionalPtr(4326, srid...), + srid: util.OptionalPtr(4326, srid...), }) } @@ -497,17 +505,13 @@ func (b *Blueprint) create() { func (b *Blueprint) creating() bool { for _, command := range b.commands { - if command.name == commandCreate || command.name == commandCreateIfNotExists { + if command.name == commandCreate { return true } } return false } -func (b *Blueprint) createIfNotExists() { - b.addCommand(commandCreateIfNotExists) -} - func (b *Blueprint) drop() { b.addCommand(commandDrop) } @@ -530,9 +534,11 @@ func (b *Blueprint) addImpliedCommands() { b.commands = append([]*command{{name: commandAdd}}, b.commands...) } if len(b.getChangedColumns()) > 0 { + changedCommands := make([]*command, 0, len(b.getChangedColumns())) for _, col := range b.getChangedColumns() { - b.commands = append([]*command{{name: commandChange, column: col}}, b.commands...) + changedCommands = append(changedCommands, &command{name: commandChange, column: col}) } + b.commands = append(changedCommands, b.commands...) } } } @@ -540,7 +546,7 @@ func (b *Blueprint) addImpliedCommands() { func (b *Blueprint) addFluentIndexes() { for _, col := range b.columns { if col.primary != nil { - if b.dialect == dialectMySQL { + if b.dialect == dialect.MySQL { continue } if !*col.primary && col.change { @@ -573,7 +579,7 @@ func (b *Blueprint) addFluentIndexes() { func (b *Blueprint) getFluentStatements() []string { var statements []string for _, column := range b.columns { - for _, fluentCommand := range b.grammar.getFluentCommands() { + for _, fluentCommand := range b.grammar.GetFluentCommands() { if statement := fluentCommand(b, &command{column: column}); statement != "" { statements = append(statements, statement) } @@ -588,8 +594,8 @@ func (b *Blueprint) build(ctx context.Context, tx *sql.Tx) error { return err } for _, statement := range statements { - if b.debug { - log.Print(statement) + if b.verbose { + log.Println(statement) } if _, err := tx.ExecContext(ctx, statement); err != nil { return err @@ -603,29 +609,28 @@ func (b *Blueprint) toSql() ([]string, error) { var statements []string - mainCommandMap := map[string]func(*Blueprint) (string, error){ - commandCreate: b.grammar.compileCreate, - commandCreateIfNotExists: b.grammar.compileCreateIfNotExists, - commandAdd: b.grammar.compileAdd, - commandDrop: b.grammar.compileDrop, - commandDropIfExists: b.grammar.compileDropIfExists, + mainCommandMap := map[string]func(blueprint *Blueprint) (string, error){ + commandCreate: b.grammar.CompileCreate, + commandAdd: b.grammar.CompileAdd, + commandDrop: b.grammar.CompileDrop, + commandDropIfExists: b.grammar.CompileDropIfExists, } - secondaryCommandMap := map[string]func(*Blueprint, *command) (string, error){ - commandChange: b.grammar.compileChange, - commandDropColumn: b.grammar.compileDropColumn, - commandDropIndex: b.grammar.compileDropIndex, - commandDropForeign: b.grammar.compileDropForeign, - commandDropFullText: b.grammar.compileDropFulltext, - commandDropPrimary: b.grammar.compileDropPrimary, - commandDropUnique: b.grammar.compileDropUnique, - commandForeign: b.grammar.compileForeign, - commandFullText: b.grammar.compileFullText, - commandIndex: b.grammar.compileIndex, - commandPrimary: b.grammar.compilePrimary, - commandRename: b.grammar.compileRename, - commandRenameColumn: b.grammar.compileRenameColumn, - commandRenameIndex: b.grammar.compileRenameIndex, - commandUnique: b.grammar.compileUnique, + secondaryCommandMap := map[string]func(blueprint *Blueprint, command *command) (string, error){ + commandChange: b.grammar.CompileChange, + commandDropColumn: b.grammar.CompileDropColumn, + commandDropIndex: b.grammar.CompileDropIndex, + commandDropForeign: b.grammar.CompileDropForeign, + commandDropFullText: b.grammar.CompileDropFulltext, + commandDropPrimary: b.grammar.CompileDropPrimary, + commandDropUnique: b.grammar.CompileDropUnique, + commandForeign: b.grammar.CompileForeign, + commandFullText: b.grammar.CompileFullText, + commandIndex: b.grammar.CompileIndex, + commandPrimary: b.grammar.CompilePrimary, + commandRename: b.grammar.CompileRename, + commandRenameColumn: b.grammar.CompileRenameColumn, + commandRenameIndex: b.grammar.CompileRenameIndex, + commandUnique: b.grammar.CompileUnique, } for _, cmd := range b.commands { if compileFunc, exists := mainCommandMap[cmd.name]; exists { @@ -688,7 +693,7 @@ func (b *Blueprint) dropIndexCommand(name string, indexType string, index any) { index: index, }) case []string: - indexName := b.grammar.createIndexName(b, indexType, index...) + indexName := b.grammar.CreateIndexName(b, indexType, index...) b.addCommand(name, &command{ index: indexName, }) diff --git a/builder.go b/builder.go index aff8d9a..a300ca2 100644 --- a/builder.go +++ b/builder.go @@ -4,14 +4,14 @@ import ( "context" "database/sql" "errors" + + "github.com/afkdevs/go-schema/internal/dialect" ) // Builder is an interface that defines methods for creating, dropping, and managing database tables. type Builder interface { // Create creates a new table with the given name and applies the provided blueprint. Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error - // CreateIfNotExists creates a new table with the given name and applies the provided blueprint if it does not already exist. - CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error // Drop removes the table with the given name. Drop(ctx context.Context, tx *sql.Tx, name string) error // DropIfExists removes the table with the given name if it exists. @@ -40,53 +40,30 @@ type Builder interface { // It returns an error if the dialect is not supported. // // Supported dialects are "postgres", "pgx", "mysql", and "mariadb". -func NewBuilder(dialect string, options ...Option) (Builder, error) { - dialectVal := dialectFromString(dialect) +func NewBuilder(dialectValue string) (Builder, error) { + dialectVal := dialect.FromString(dialectValue) switch dialectVal { - case dialectMySQL: - return newMysqlBuilder(options...), nil - case dialectPostgres: - return newPostgresBuilder(options...), nil + case dialect.MySQL: + return newMysqlBuilder(), nil + case dialect.Postgres: + return newPostgresBuilder(), nil default: - return nil, errors.New("unsupported dialect: " + dialect) + return nil, errors.New("unsupported dialect: " + dialectValue) } } type baseBuilder struct { - debug bool grammar grammar + verbose bool } func (b *baseBuilder) newBlueprint(name string) *Blueprint { - return &Blueprint{name: name, grammar: b.grammar, debug: b.debug} -} - -func (b *baseBuilder) validateTxAndName(tx *sql.Tx, name string) error { - if name == "" { - return errors.New("table name is empty") - } - if tx == nil { - return errors.New("transaction is nil") - } - return nil -} - -func (b *baseBuilder) validateCreateAndAlter(tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if name == "" { - return errors.New("table name is empty") - } - if blueprint == nil { - return errors.New("blueprint function is nil") - } - if tx == nil { - return errors.New("transaction is nil") - } - return nil + return &Blueprint{name: name, grammar: b.grammar, verbose: b.verbose} } func (b *baseBuilder) Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err + if tx == nil || name == "" || blueprint == nil { + return errors.New("invalid arguments: transaction, name, or blueprint is nil/empty") } bp := b.newBlueprint(name) @@ -100,25 +77,9 @@ func (b *baseBuilder) Create(ctx context.Context, tx *sql.Tx, name string, bluep return nil } -func (b *baseBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err - } - - bp := b.newBlueprint(name) - bp.createIfNotExists() - blueprint(bp) - - if err := bp.build(ctx, tx); err != nil { - return err - } - - return nil -} - func (b *baseBuilder) Drop(ctx context.Context, tx *sql.Tx, name string) error { - if err := b.validateTxAndName(tx, name); err != nil { - return err + if tx == nil || name == "" { + return errors.New("invalid arguments: transaction is nil or name is empty") } bp := b.newBlueprint(name) @@ -132,8 +93,8 @@ func (b *baseBuilder) Drop(ctx context.Context, tx *sql.Tx, name string) error { } func (b *baseBuilder) DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { - if err := b.validateTxAndName(tx, name); err != nil { - return err + if tx == nil || name == "" { + return errors.New("invalid arguments: transaction is nil or name is empty") } bp := b.newBlueprint(name) @@ -147,12 +108,10 @@ func (b *baseBuilder) DropIfExists(ctx context.Context, tx *sql.Tx, name string) } func (b *baseBuilder) Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error { - if oldName == "" || newName == "" { - return errors.New("old or new table name is empty") - } - if tx == nil { - return errors.New("transaction is nil") + if tx == nil || oldName == "" || newName == "" { + return errors.New("invalid arguments: transaction is nil or old/new table name is empty") } + bp := b.newBlueprint(oldName) bp.rename(newName) @@ -164,8 +123,8 @@ func (b *baseBuilder) Rename(ctx context.Context, tx *sql.Tx, oldName string, ne } func (b *baseBuilder) Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err + if tx == nil || name == "" || blueprint == nil { + return errors.New("invalid arguments: transaction is nil or name/blueprint is empty") } bp := b.newBlueprint(name) diff --git a/column_definition.go b/column_definition.go index f6d48aa..b50ac39 100644 --- a/column_definition.go +++ b/column_definition.go @@ -1,5 +1,11 @@ package schema +import ( + "slices" + + "github.com/afkdevs/go-schema/internal/util" +) + // ColumnDefinition defines the interface for defining a column in a database table. type ColumnDefinition interface { // AutoIncrement sets the column to auto-increment. @@ -33,20 +39,8 @@ type ColumnDefinition interface { UseCurrentOnUpdate() ColumnDefinition } -// Expression is a type for expressions that can be used as default values for columns. -// -// Example: -// -// schema.Timestamp("created_at").Default(schema.Expression("CURRENT_TIMESTAMP")) -type Expression string - -func (e Expression) String() string { - return string(e) -} - -var _ ColumnDefinition = &columnDefinition{} - type columnDefinition struct { + commands []string name string columnType string charset *string @@ -54,8 +48,8 @@ type columnDefinition struct { comment *string defaultValue any onUpdateValue any - useCurrent *bool - useCurrentOnUpdate *bool + useCurrent bool + useCurrentOnUpdate bool nullable *bool autoIncrement *bool unsigned *bool @@ -74,8 +68,43 @@ type columnDefinition struct { srid *int // for geography and geometry types } +// Expression is a type for expressions that can be used as default values for columns. +// +// Example: +// +// schema.Timestamp("created_at").Default(schema.Expression("CURRENT_TIMESTAMP")) +type Expression string + +func (e Expression) String() string { + return string(e) +} + +var _ ColumnDefinition = &columnDefinition{} + +func (c *columnDefinition) addCommand(command string) { + c.commands = append(c.commands, command) +} + +func (c *columnDefinition) hasCommand(command string) bool { + return slices.Contains(c.commands, command) +} + +func (c *columnDefinition) SetDefault(value any) { + c.addCommand("default") + c.defaultValue = value +} + +func (c *columnDefinition) SetOnUpdate(value any) { + c.addCommand("onUpdate") + c.onUpdateValue = value +} + +func (c *columnDefinition) SetSubtype(value *string) { + c.subtype = value +} + func (c *columnDefinition) AutoIncrement() ColumnDefinition { - c.autoIncrement = ptrOf(true) + c.autoIncrement = util.PtrOf(true) return c } @@ -95,11 +124,13 @@ func (c *columnDefinition) Collation(collation string) ColumnDefinition { } func (c *columnDefinition) Comment(comment string) ColumnDefinition { + c.addCommand("comment") c.comment = &comment return c } func (c *columnDefinition) Default(value any) ColumnDefinition { + c.addCommand("default") c.defaultValue = value return c @@ -120,17 +151,19 @@ func (c *columnDefinition) Index(params ...any) ColumnDefinition { } func (c *columnDefinition) Nullable(value ...bool) ColumnDefinition { - c.nullable = optionalPtr(true, value...) + c.addCommand("nullable") + c.nullable = util.OptionalPtr(true, value...) return c } func (c *columnDefinition) OnUpdate(value any) ColumnDefinition { + c.addCommand("onUpdate") c.onUpdateValue = value return c } func (c *columnDefinition) Primary(value ...bool) ColumnDefinition { - val := optional(true, value...) + val := util.Optional(true, value...) c.primary = &val return c } @@ -150,16 +183,16 @@ func (c *columnDefinition) Unique(params ...any) ColumnDefinition { } func (c *columnDefinition) Unsigned() ColumnDefinition { - c.unsigned = ptrOf(true) + c.unsigned = util.PtrOf(true) return c } func (c *columnDefinition) UseCurrent() ColumnDefinition { - c.useCurrent = ptrOf(true) + c.useCurrent = true return c } func (c *columnDefinition) UseCurrentOnUpdate() ColumnDefinition { - c.useCurrentOnUpdate = ptrOf(true) + c.useCurrentOnUpdate = true return c } diff --git a/command.go b/command.go index 8d59f5c..907798a 100644 --- a/command.go +++ b/command.go @@ -1,28 +1,25 @@ package schema const ( - commandAdd string = "add" - commandChange string = "change" - commandCreate string = "create" - commandCreateIfNotExists string = "createIfNotExists" - commandComment string = "comment" - commandDefault string = "default" - commandDrop string = "drop" - commandDropIfExists string = "dropIfExists" - commandDropColumn string = "dropColumn" - commandDropForeign string = "dropForeign" - commandDropFullText string = "dropFullText" - commandDropIndex string = "dropIndex" - commandDropPrimary string = "dropPrimary" - commandDropUnique string = "dropUnique" - commandForeign string = "foreign" - commandFullText string = "fullText" - commandIndex string = "index" - commandPrimary string = "primary" - commandRename string = "rename" - commandRenameColumn string = "renameColumn" - commandRenameIndex string = "renameIndex" - commandUnique string = "unique" + commandAdd string = "add" + commandChange string = "change" + commandCreate string = "create" + commandDrop string = "drop" + commandDropIfExists string = "dropIfExists" + commandDropColumn string = "dropColumn" + commandDropForeign string = "dropForeign" + commandDropFullText string = "dropFullText" + commandDropIndex string = "dropIndex" + commandDropPrimary string = "dropPrimary" + commandDropUnique string = "dropUnique" + commandForeign string = "foreign" + commandFullText string = "fullText" + commandIndex string = "index" + commandPrimary string = "primary" + commandRename string = "rename" + commandRenameColumn string = "renameColumn" + commandRenameIndex string = "renameIndex" + commandUnique string = "unique" ) type command struct { diff --git a/config.go b/config.go new file mode 100644 index 0000000..341e709 --- /dev/null +++ b/config.go @@ -0,0 +1,27 @@ +package schema + +import ( + "errors" + + "github.com/afkdevs/go-schema/internal/config" + "github.com/afkdevs/go-schema/internal/dialect" +) + +// SetDialect sets the migrator dialect +func SetDialect(d string) error { + dialectValue := dialect.FromString(d) + if dialectValue == dialect.Unknown { + return errors.New("unsupported dialect: " + d) + } + cfg := config.Get() + cfg.Dialect = dialectValue + config.Set(cfg) + return nil +} + +// SetVerbose enables or disables verbose mode +func SetVerbose(enabled bool) { + cfg := config.Get() + cfg.Verbose = enabled + config.Set(cfg) +} diff --git a/dialect.go b/dialect.go deleted file mode 100644 index 5ce78f6..0000000 --- a/dialect.go +++ /dev/null @@ -1,51 +0,0 @@ -package schema - -import ( - "fmt" -) - -type dialect uint8 - -const ( - dialectUnknown dialect = iota - dialectMySQL - dialectPostgres -) - -func (d dialect) String() string { - switch d { - case dialectMySQL: - return "mysql" - case dialectPostgres: - return "postgres" - default: - return "" - } -} - -var dialectValue = dialectUnknown -var cfg = &config{ - debug: false, // default value -} - -// Init initializes the schema package with the given dialect and options. -func Init(dialect string, options ...Option) error { - dialectValue = dialectFromString(dialect) - if dialectValue == dialectUnknown { - return fmt.Errorf("unknown dialect: %s", dialect) - } - cfg = applyOptions(options...) - - return nil -} - -func dialectFromString(d string) dialect { - switch d { - case "mysql", "mariadb": - return dialectMySQL - case "postgres", "pgx": - return dialectPostgres - default: - return dialectUnknown - } -} diff --git a/examples/basic/cmd/migrate/cmd.go b/examples/basic/cmd/migrate/cmd.go new file mode 100644 index 0000000..fbac4f0 --- /dev/null +++ b/examples/basic/cmd/migrate/cmd.go @@ -0,0 +1,62 @@ +package migrate + +import ( + "context" + + "github.com/urfave/cli/v3" +) + +func Command() *cli.Command { + cmd := &cli.Command{ + Name: "migrate", + Usage: "Migration tool", + Commands: []*cli.Command{ + { + Name: "create", + Usage: "Create a new migration file", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "name", + Aliases: []string{"n"}, + Usage: "Name of the migration", + Required: true, + }, + }, + Action: func(ctx context.Context, c *cli.Command) error { + return Create(c.String("name")) + }, + }, + { + Name: "up", + Usage: "Run all pending migrations", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "dry-run", + Aliases: []string{"d"}, + Usage: "Preview the migration without applying changes", + Required: false, + }, + }, + Action: func(ctx context.Context, c *cli.Command) error { + return Up(c.Bool("dry-run")) + }, + }, + { + Name: "reset", + Usage: "Reset all migrations", + Flags: []cli.Flag{ + &cli.BoolFlag{ + Name: "dry-run", + Aliases: []string{"d"}, + Usage: "Preview the migration without applying changes", + Required: false, + }, + }, + Action: func(ctx context.Context, c *cli.Command) error { + return Reset(c.Bool("dry-run")) + }, + }, + }, + } + return cmd +} diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index 8fb54ae..795dffa 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -11,55 +11,49 @@ import ( "github.com/pressly/goose/v3" ) -func Up() error { - m, err := newMigrator() +const ( + directory = "migrations" + tableName = "schema_migrations" +) + +func Up(dryRun bool) error { + db, err := initMigrator() if err != nil { return err } - return goose.Up(m.db, m.dir) + return goose.Up(db, directory) } func Create(name string) error { - m, err := newMigrator() - if err != nil { - return err - } - return goose.Create(m.db, m.dir, name, m.migrationType) + tmpl := schema.GooseMigrationTemplate(name) + return goose.CreateWithTemplate(nil, directory, tmpl, name, "go") } -func Down() error { - m, err := newMigrator() +func Reset(dryRun bool) error { + db, err := initMigrator() if err != nil { return err } - return goose.Reset(m.db, m.dir) + return goose.Reset(db, directory) } -type migrator struct { - dir string - dialect string - tableName string - migrationType string - db *sql.DB -} - -func newMigrator() (*migrator, error) { - db, err := newDatabase(config.GetDatabase()) +func initMigrator() (*sql.DB, error) { + if err := goose.SetDialect("postgres"); err != nil { + return nil, fmt.Errorf("failed to set dialect: %w", err) + } + goose.SetTableName(tableName) + if err := schema.SetDialect("postgres"); err != nil { + return nil, fmt.Errorf("failed to set schema dialect: %w", err) + } + cfg, err := config.Load() if err != nil { return nil, err } - m := &migrator{ - dir: "migrations", - dialect: "postgres", - tableName: "schema_migrations", - migrationType: "go", - db: db, - } - if err := m.init(); err != nil { + db, err := newDatabase(cfg.Database) + if err != nil { return nil, err } - - return m, nil + return db, nil } func newDatabase(cfg config.Database) (*sql.DB, error) { @@ -75,14 +69,3 @@ func newDatabase(cfg config.Database) (*sql.DB, error) { return db, nil } - -func (m *migrator) init() error { - goose.SetTableName(m.tableName) - if err := goose.SetDialect(m.dialect); err != nil { - return err - } - if err := schema.Init(m.dialect); err != nil { - return err - } - return nil -} diff --git a/examples/basic/cmd/root.go b/examples/basic/cmd/root.go new file mode 100644 index 0000000..e7a80d0 --- /dev/null +++ b/examples/basic/cmd/root.go @@ -0,0 +1,20 @@ +package cmd + +import ( + "context" + + "github.com/afkdevs/go-schema/examples/basic/cmd/migrate" + "github.com/urfave/cli/v3" +) + +var cmd = &cli.Command{ + Name: "schema-example", + Usage: "A simple schema example", + Commands: []*cli.Command{ + migrate.Command(), + }, +} + +func Execute(args []string) error { + return cmd.Run(context.Background(), args) +} diff --git a/examples/basic/config/config.go b/examples/basic/config/config.go new file mode 100644 index 0000000..1b028be --- /dev/null +++ b/examples/basic/config/config.go @@ -0,0 +1,20 @@ +package config + +import "github.com/joho/godotenv" + +type Config struct { + Database Database +} + +func Load() (Config, error) { + var config Config + err := godotenv.Load() + if err != nil { + return config, err + } + config = Config{ + Database: getDatabaseConfig(), + } + + return config, nil +} diff --git a/examples/basic/config/database.go b/examples/basic/config/database.go index 722d0ec..8065182 100644 --- a/examples/basic/config/database.go +++ b/examples/basic/config/database.go @@ -16,13 +16,13 @@ func (db Database) DSN() string { db.User, db.Password, db.Host, db.Port, db.Database, db.SSLMode) } -func GetDatabase() Database { +func getDatabaseConfig() Database { return Database{ - Host: "localhost", - Port: 5432, - User: "root", - Password: "password", - Database: "schema_example", - SSLMode: "disable", + Host: getEnv("DB_HOST"), + Port: getEnvInt("DB_PORT"), + User: getEnv("DB_USER"), + Password: getEnv("DB_PASSWORD"), + Database: getEnv("DB_NAME"), + SSLMode: getEnv("DB_SSLMODE"), } } diff --git a/examples/basic/config/util.go b/examples/basic/config/util.go new file mode 100644 index 0000000..d2b5c43 --- /dev/null +++ b/examples/basic/config/util.go @@ -0,0 +1,23 @@ +package config + +import ( + "os" + "strconv" +) + +func getEnv(key string) string { + value := os.Getenv(key) + return value +} + +func getEnvInt(key string) int { + value := os.Getenv(key) + if value == "" { + return 0 + } + intValue, err := strconv.Atoi(value) + if err != nil { + return 0 + } + return intValue +} diff --git a/examples/basic/go.mod b/examples/basic/go.mod index a4c82dd..3443552 100644 --- a/examples/basic/go.mod +++ b/examples/basic/go.mod @@ -3,17 +3,18 @@ module github.com/afkdevs/go-schema/examples/basic go 1.23.0 require ( - github.com/afkdevs/go-schema v0.0.0-000000000000-00010101000000 + github.com/afkdevs/go-schema v0.0.0 + github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 github.com/pressly/goose/v3 v3.24.3 github.com/urfave/cli/v3 v3.3.8 ) -replace github.com/afkdevs/go-schema => ../.. - require ( github.com/mfridman/interpolate v0.0.2 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/sync v0.14.0 // indirect + golang.org/x/sync v0.16.0 // indirect ) + +replace github.com/afkdevs/go-schema => ../.. diff --git a/examples/basic/go.sum b/examples/basic/go.sum index 69cac7b..1b5a029 100644 --- a/examples/basic/go.sum +++ b/examples/basic/go.sum @@ -8,6 +8,8 @@ github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1 github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -24,16 +26,16 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/urfave/cli/v3 v3.3.8 h1:BzolUExliMdet9NlJ/u4m5vHSotJ3PzEqSAZ1oPMa/E= github.com/urfave/cli/v3 v3.3.8/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/basic/main.go b/examples/basic/main.go index 6a5590b..4c15da0 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -1,60 +1,14 @@ package main import ( - "context" "log" "os" - "github.com/afkdevs/go-schema/examples/basic/cmd/migrate" - "github.com/urfave/cli/v3" + "github.com/afkdevs/go-schema/examples/basic/cmd" ) func main() { - app := &cli.Command{ - Name: "schema-example", - Usage: "A simple schema example", - Commands: []*cli.Command{ - { - Name: "migrate", - Aliases: []string{"m"}, - Usage: "Run database migrations", - Commands: []*cli.Command{ - { - Name: "create", - Usage: "Create a new migration file", - Flags: []cli.Flag{ - &cli.StringFlag{ - Name: "name", - Aliases: []string{"n"}, - Usage: "Name of the migration", - Required: true, - }, - }, - Action: func(ctx context.Context, c *cli.Command) error { - return migrate.Create(c.String("name")) - }, - }, - { - Name: "up", - Aliases: []string{"u"}, - Usage: "Run all pending migrations", - Action: func(ctx context.Context, c *cli.Command) error { - return migrate.Up() - }, - }, - { - Name: "down", - Aliases: []string{"d"}, - Usage: "Down all migrations", - Action: func(ctx context.Context, c *cli.Command) error { - return migrate.Down() - }, - }, - }, - }, - }, - } - if err := app.Run(context.Background(), os.Args); err != nil { + if err := cmd.Execute(os.Args); err != nil { log.Printf("Error running app: %v\n", err) os.Exit(1) } diff --git a/examples/basic/migrations/20250626235117_create_users_table.go b/examples/basic/migrations/20250830103612_create_users_table.go similarity index 83% rename from examples/basic/migrations/20250626235117_create_users_table.go rename to examples/basic/migrations/20250830103612_create_users_table.go index 23d318e..53776a3 100644 --- a/examples/basic/migrations/20250626235117_create_users_table.go +++ b/examples/basic/migrations/20250830103612_create_users_table.go @@ -19,8 +19,7 @@ func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { table.String("email") table.Timestamp("email_verified_at").Nullable() table.String("password") - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) } diff --git a/examples/basic/migrations/20250628092119_create_roles_table.go b/examples/basic/migrations/20250830103653_create_roles_table.go similarity index 100% rename from examples/basic/migrations/20250628092119_create_roles_table.go rename to examples/basic/migrations/20250830103653_create_roles_table.go diff --git a/examples/basic/migrations/20250628092223_create_user_roles_table.go b/examples/basic/migrations/20250830103714_create_user_roles_table.go similarity index 100% rename from examples/basic/migrations/20250628092223_create_user_roles_table.go rename to examples/basic/migrations/20250830103714_create_user_roles_table.go diff --git a/foreign_key_definition.go b/foreign_key_definition.go index d92596e..476e05b 100644 --- a/foreign_key_definition.go +++ b/foreign_key_definition.go @@ -1,5 +1,7 @@ package schema +import "github.com/afkdevs/go-schema/internal/util" + // ForeignKeyDefinition defines the interface for defining a foreign key constraint in a database table. type ForeignKeyDefinition interface { // CascadeOnDelete sets the foreign key to cascade on delete. @@ -48,13 +50,13 @@ func (fd *foreignKeyDefinition) CascadeOnUpdate() ForeignKeyDefinition { } func (fd *foreignKeyDefinition) Deferrable(value ...bool) ForeignKeyDefinition { - val := optional(true, value...) + val := util.Optional(true, value...) fd.deferrable = &val return fd } func (fd *foreignKeyDefinition) InitiallyImmediate(value ...bool) ForeignKeyDefinition { - val := optional(true, value...) + val := util.Optional(true, value...) fd.initiallyImmediate = &val return fd } diff --git a/go.mod b/go.mod index 1f0d73d..9c4b745 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,16 @@ module github.com/afkdevs/go-schema -go 1.23 +go 1.23.0 require ( github.com/go-sql-driver/mysql v1.9.3 github.com/lib/pq v1.10.9 - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 ) require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d87df1b..f813003 100644 --- a/go.sum +++ b/go.sum @@ -1,29 +1,16 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/grammar.go b/grammar.go index 12b70b1..aed1955 100644 --- a/grammar.go +++ b/grammar.go @@ -4,36 +4,37 @@ import ( "fmt" "slices" "strings" + + "github.com/afkdevs/go-schema/internal/util" ) type grammar interface { - compileCreate(bp *Blueprint) (string, error) - compileCreateIfNotExists(bp *Blueprint) (string, error) - compileAdd(bp *Blueprint) (string, error) - compileChange(bp *Blueprint, command *command) (string, error) - compileDrop(bp *Blueprint) (string, error) - compileDropIfExists(bp *Blueprint) (string, error) - compileRename(bp *Blueprint, command *command) (string, error) - compileDropColumn(blueprint *Blueprint, command *command) (string, error) - compileRenameColumn(blueprint *Blueprint, command *command) (string, error) - compileIndex(blueprint *Blueprint, command *command) (string, error) - compileUnique(blueprint *Blueprint, command *command) (string, error) - compilePrimary(blueprint *Blueprint, command *command) (string, error) - compileFullText(blueprint *Blueprint, command *command) (string, error) - compileDropIndex(blueprint *Blueprint, command *command) (string, error) - compileDropUnique(blueprint *Blueprint, command *command) (string, error) - compileDropFulltext(blueprint *Blueprint, command *command) (string, error) - compileDropPrimary(blueprint *Blueprint, command *command) (string, error) - compileRenameIndex(blueprint *Blueprint, command *command) (string, error) - compileForeign(blueprint *Blueprint, command *command) (string, error) - compileDropForeign(blueprint *Blueprint, command *command) (string, error) - getFluentCommands() []func(blueprint *Blueprint, command *command) string - createIndexName(blueprint *Blueprint, idxType string, columns ...string) string + CompileCreate(bp *Blueprint) (string, error) + CompileAdd(bp *Blueprint) (string, error) + CompileChange(bp *Blueprint, command *command) (string, error) + CompileDrop(bp *Blueprint) (string, error) + CompileDropIfExists(bp *Blueprint) (string, error) + CompileRename(bp *Blueprint, command *command) (string, error) + CompileDropColumn(blueprint *Blueprint, command *command) (string, error) + CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) + CompileIndex(blueprint *Blueprint, command *command) (string, error) + CompileUnique(blueprint *Blueprint, command *command) (string, error) + CompilePrimary(blueprint *Blueprint, command *command) (string, error) + CompileFullText(blueprint *Blueprint, command *command) (string, error) + CompileDropIndex(blueprint *Blueprint, command *command) (string, error) + CompileDropUnique(blueprint *Blueprint, command *command) (string, error) + CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) + CompileDropPrimary(blueprint *Blueprint, command *command) (string, error) + CompileRenameIndex(blueprint *Blueprint, command *command) (string, error) + CompileForeign(blueprint *Blueprint, command *command) (string, error) + CompileDropForeign(blueprint *Blueprint, command *command) (string, error) + GetFluentCommands() []func(blueprint *Blueprint, command *command) string + CreateIndexName(blueprint *Blueprint, idxType string, columns ...string) string } type baseGrammar struct{} -func (g *baseGrammar) compileForeign(blueprint *Blueprint, command *command) (string, error) { +func (g *baseGrammar) CompileForeign(blueprint *Blueprint, command *command) (string, error) { if len(command.columns) == 0 || slices.Contains(command.columns, "") || command.on == "" || len(command.references) == 0 || slices.Contains(command.references, "") { return "", fmt.Errorf("foreign key definition is incomplete: column, on, and references must be set") @@ -48,7 +49,7 @@ func (g *baseGrammar) compileForeign(blueprint *Blueprint, command *command) (st } index := command.index if index == "" { - index = g.createForeignKeyName(blueprint, command) + index = g.CreateForeignKeyName(blueprint, command) } return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)%s%s", @@ -62,11 +63,46 @@ func (g *baseGrammar) compileForeign(blueprint *Blueprint, command *command) (st ), nil } -func (g *baseGrammar) quoteString(s string) string { +func (g *baseGrammar) CreateIndexName(blueprint *Blueprint, idxType string, columns ...string) string { + tableName := blueprint.name + if strings.Contains(tableName, ".") { + parts := strings.Split(tableName, ".") + tableName = parts[len(parts)-1] // Use the last part as the table name + } + + switch idxType { + case "primary": + return fmt.Sprintf("pk_%s", tableName) + case "unique": + return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(columns, "_")) + case "index": + return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_")) + case "fulltext": + return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(columns, "_")) + default: + return "" + } +} + +func (g *baseGrammar) CreateForeignKeyName(blueprint *Blueprint, command *command) string { + tableName := blueprint.name + if strings.Contains(tableName, ".") { + parts := strings.Split(tableName, ".") + tableName = parts[len(parts)-1] // Use the last part as the table name + } + on := command.on + if strings.Contains(on, ".") { + parts := strings.Split(on, ".") + on = parts[len(parts)-1] // Use the last part as the column name + } + return fmt.Sprintf("fk_%s_%s", tableName, on) +} + +func (g *baseGrammar) QuoteString(s string) string { return "'" + s + "'" } -func (g *baseGrammar) prefixArray(prefix string, items []string) []string { +func (g *baseGrammar) PrefixArray(prefix string, items []string) []string { prefixed := make([]string, len(items)) for i, item := range items { prefixed[i] = fmt.Sprintf("%s%s", prefix, item) @@ -74,14 +110,14 @@ func (g *baseGrammar) prefixArray(prefix string, items []string) []string { return prefixed } -func (g *baseGrammar) columnize(columns []string) string { +func (g *baseGrammar) Columnize(columns []string) string { if len(columns) == 0 { return "" } return strings.Join(columns, ", ") } -func (g *baseGrammar) getValue(value any) string { +func (g *baseGrammar) GetValue(value any) string { switch v := value.(type) { case Expression: return v.String() @@ -90,7 +126,7 @@ func (g *baseGrammar) getValue(value any) string { } } -func (g *baseGrammar) getDefaultValue(value any) string { +func (g *baseGrammar) GetDefaultValue(value any) string { if value == nil { return "NULL" } @@ -98,43 +134,8 @@ func (g *baseGrammar) getDefaultValue(value any) string { case Expression: return v.String() case bool: - return ternary(v, "'1'", "'0'") + return util.Ternary(v, "'1'", "'0'") default: return fmt.Sprintf("'%v'", v) } } - -func (g *baseGrammar) createIndexName(blueprint *Blueprint, idxType string, columns ...string) string { - tableName := blueprint.name - if strings.Contains(tableName, ".") { - parts := strings.Split(tableName, ".") - tableName = parts[len(parts)-1] // Use the last part as the table name - } - - switch idxType { - case commandPrimary: - return fmt.Sprintf("pk_%s", tableName) - case commandUnique: - return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(columns, "_")) - case commandIndex: - return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_")) - case commandFullText: - return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(columns, "_")) - default: - return "" - } -} - -func (g *baseGrammar) createForeignKeyName(blueprint *Blueprint, command *command) string { - tableName := blueprint.name - if strings.Contains(tableName, ".") { - parts := strings.Split(tableName, ".") - tableName = parts[len(parts)-1] // Use the last part as the table name - } - on := command.on - if strings.Contains(on, ".") { - parts := strings.Split(on, ".") - on = parts[len(parts)-1] // Use the last part as the column name - } - return fmt.Sprintf("fk_%s_%s", tableName, on) -} diff --git a/index_definition.go b/index_definition.go index 6611920..d13e5f8 100644 --- a/index_definition.go +++ b/index_definition.go @@ -1,5 +1,7 @@ package schema +import "github.com/afkdevs/go-schema/internal/util" + // IndexDefinition defines the interface for defining an index in a database table. type IndexDefinition interface { // Algorithm sets the algorithm for the index. @@ -25,13 +27,13 @@ func (id *indexDefinition) Algorithm(algorithm string) IndexDefinition { } func (id *indexDefinition) Deferrable(value ...bool) IndexDefinition { - val := optional(true, value...) + val := util.Optional(true, value...) id.deferrable = &val return id } func (id *indexDefinition) InitiallyImmediate(value ...bool) IndexDefinition { - val := optional(true, value...) + val := util.Optional(true, value...) id.initiallyImmediate = &val return id } diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..da85f61 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,29 @@ +package config + +import ( + "sync/atomic" + + "github.com/afkdevs/go-schema/internal/dialect" +) + +type Config struct { + Dialect dialect.Dialect + Verbose bool +} + +var config = atomic.Pointer[Config]{} + +func init() { + config.Store(&Config{ + Dialect: dialect.Unknown, + Verbose: true, + }) +} + +func Set(newConfig *Config) { + config.Store(newConfig) +} + +func Get() *Config { + return config.Load() +} diff --git a/internal/dialect/dialect.go b/internal/dialect/dialect.go new file mode 100644 index 0000000..3291cba --- /dev/null +++ b/internal/dialect/dialect.go @@ -0,0 +1,25 @@ +package dialect + +// Dialect is the type of database dialect. +type Dialect string + +const ( + MySQL Dialect = "mysql" + Postgres Dialect = "postgres" + Unknown Dialect = "" +) + +func (d Dialect) String() string { + return string(d) +} + +func FromString(dialect string) Dialect { + switch dialect { + case "mysql", "mariadb": + return MySQL + case "postgres", "pgx": + return Postgres + default: + return Unknown + } +} diff --git a/internal/parser/migration.go b/internal/parser/migration.go new file mode 100644 index 0000000..5e89a25 --- /dev/null +++ b/internal/parser/migration.go @@ -0,0 +1,21 @@ +package parser + +import "regexp" + +func ParseMigrationName(filename string) (tableName string, create bool) { + // Regex patterns for common migration styles + createPattern := regexp.MustCompile(`^create_(?P[a-z0-9_]+?)(?:_table)?$`) + addColPattern := regexp.MustCompile(`^add_(?P[a-z0-9_]+?)_to_(?P
[a-z0-9_]+?)(?:_table)?$`) + removeColPattern := regexp.MustCompile(`^remove_(?P[a-z0-9_]+?)_from_(?P
[a-z0-9_]+?)(?:_table)?$`) + + switch { + case createPattern.MatchString(filename): + return createPattern.ReplaceAllString(filename, "${table}"), true + case addColPattern.MatchString(filename): + return addColPattern.ReplaceAllString(filename, "${table}"), false + case removeColPattern.MatchString(filename): + return removeColPattern.ReplaceAllString(filename, "${table}"), false + default: + return "", false // Unknown migration style + } +} diff --git a/internal/parser/migration_test.go b/internal/parser/migration_test.go new file mode 100644 index 0000000..c890078 --- /dev/null +++ b/internal/parser/migration_test.go @@ -0,0 +1,132 @@ +package parser_test + +import ( + "testing" + + "github.com/afkdevs/go-schema/internal/parser" +) + +func TestParseMigrationName(t *testing.T) { + tests := []struct { + name string + filename string + wantTable string + wantCreate bool + }{ + // Create table patterns + { + name: "create table with table suffix", + filename: "create_users_table", + wantTable: "users", + wantCreate: true, + }, + { + name: "create table without table suffix", + filename: "create_posts", + wantTable: "posts", + wantCreate: true, + }, + { + name: "create table with underscores", + filename: "create_user_profiles_table", + wantTable: "user_profiles", + wantCreate: true, + }, + { + name: "create table with numbers", + filename: "create_table_v2", + wantTable: "table_v2", + wantCreate: true, + }, + + // Add column patterns + { + name: "add column with table suffix", + filename: "add_email_to_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "add column without table suffix", + filename: "add_name_to_posts", + wantTable: "posts", + wantCreate: false, + }, + { + name: "add multiple columns", + filename: "add_first_name_last_name_to_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "add column with underscores", + filename: "add_created_at_to_user_profiles", + wantTable: "user_profiles", + wantCreate: false, + }, + + // Remove column patterns + { + name: "remove column with table suffix", + filename: "remove_email_from_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "remove column without table suffix", + filename: "remove_name_from_posts", + wantTable: "posts", + wantCreate: false, + }, + { + name: "remove multiple columns", + filename: "remove_old_field_from_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "remove column with underscores", + filename: "remove_updated_at_from_user_profiles", + wantTable: "user_profiles", + wantCreate: false, + }, + + // Unknown patterns + { + name: "unknown migration pattern", + filename: "update_users_table", + wantTable: "", + wantCreate: false, + }, + { + name: "empty filename", + filename: "", + wantTable: "", + wantCreate: false, + }, + { + name: "invalid format", + filename: "invalid_migration_name", + wantTable: "", + wantCreate: false, + }, + { + name: "uppercase letters", + filename: "create_Users_table", + wantTable: "", + wantCreate: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTable, gotCreate := parser.ParseMigrationName(tt.filename) + if gotTable != tt.wantTable { + t.Errorf("ParseMigrationName() gotTable = %v, want %v", gotTable, tt.wantTable) + } + if gotCreate != tt.wantCreate { + t.Errorf("ParseMigrationName() gotCreate = %v, want %v", gotCreate, tt.wantCreate) + } + }) + } +} diff --git a/util.go b/internal/util/util.go similarity index 51% rename from util.go rename to internal/util/util.go index e9e838e..61a4ca3 100644 --- a/util.go +++ b/internal/util/util.go @@ -1,31 +1,31 @@ -package schema +package util -func optional[T any](defaultValue T, values ...T) T { +func Optional[T any](defaultValue T, values ...T) T { if len(values) > 0 { return values[0] } return defaultValue } -func optionalPtr[T any](defaultValue T, values ...T) *T { +func OptionalPtr[T any](defaultValue T, values ...T) *T { if len(values) > 0 { return &values[0] } return &defaultValue } -func optionalNil[T any](values ...T) *T { +func OptionalNil[T any](values ...T) *T { if len(values) > 0 { return &values[0] } return nil } -func ptrOf[T any](value T) *T { +func PtrOf[T any](value T) *T { return &value } -func ternary[T any](condition bool, trueValue, falseValue T) T { +func Ternary[T any](condition bool, trueValue, falseValue T) T { if condition { return trueValue } diff --git a/mysql_builder.go b/mysql_builder.go index bb2dbfb..284052a 100644 --- a/mysql_builder.go +++ b/mysql_builder.go @@ -5,6 +5,8 @@ import ( "database/sql" "errors" "strings" + + "github.com/afkdevs/go-schema/internal/config" ) type mysqlBuilder struct { @@ -12,22 +14,20 @@ type mysqlBuilder struct { grammar *mysqlGrammar } -func newMysqlBuilder(options ...Option) Builder { +var _ Builder = (*mysqlBuilder)(nil) + +func newMysqlBuilder() Builder { grammar := newMysqlGrammar() - cfg := applyOptions(options...) + cfg := config.Get() return &mysqlBuilder{ - baseBuilder: baseBuilder{grammar: grammar, debug: cfg.debug}, + baseBuilder: baseBuilder{grammar: grammar, verbose: cfg.Verbose}, grammar: grammar, } } func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (string, error) { - if tx == nil { - return "", errors.New("transaction is nil") - } - - query := b.grammar.compileCurrentDatabase() + query := b.grammar.CompileCurrentDatabase() row := tx.QueryRowContext(ctx, query) var dbName string if err := row.Scan(&dbName); err != nil { @@ -36,33 +36,9 @@ func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (stri return dbName, nil } -func (b *mysqlBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err - } - - exist, err := b.HasTable(ctx, tx, name) - if err != nil { - return err - } - if exist { - return nil // Table already exists, no need to create it - } - - bp := &Blueprint{name: name, grammar: b.grammar} - bp.createIfNotExists() - blueprint(bp) - - if err := bp.build(ctx, tx); err != nil { - return err - } - - return nil -} - func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err + if tx == nil || tableName == "" { + return nil, errors.New("invalid arguments: transaction is nil or table name is empty") } database, err := b.getCurrentDatabase(ctx, tx) @@ -70,7 +46,7 @@ func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName str return nil, err } - query, err := b.grammar.compileColumns(database, tableName) + query, err := b.grammar.CompileColumns(database, tableName) if err != nil { return nil, err } @@ -102,8 +78,8 @@ func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName str } func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err + if tx == nil || tableName == "" { + return nil, errors.New("invalid arguments: transaction is nil or table name is empty") } database, err := b.getCurrentDatabase(ctx, tx) @@ -111,7 +87,7 @@ func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName str return nil, err } - query, err := b.grammar.compileIndexes(database, tableName) + query, err := b.grammar.CompileIndexes(database, tableName) if err != nil { return nil, err } @@ -136,12 +112,16 @@ func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName str } func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { + if tx == nil { + return nil, errors.New("invalid arguments: transaction is nil") + } + database, err := b.getCurrentDatabase(ctx, tx) if err != nil { return nil, err } - query, err := b.grammar.compileTables(database) + query, err := b.grammar.CompileTables(database) if err != nil { return nil, err } @@ -163,15 +143,15 @@ func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, } func (b *mysqlBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) { - if columnName == "" { - return false, errors.New("column name cannot be empty") + if tx == nil || columnName == "" { + return false, errors.New("invalid arguments: transaction is nil or column name is empty") } return b.HasColumns(ctx, tx, tableName, []string{columnName}) } func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return false, err + if tx == nil || tableName == "" { + return false, errors.New("invalid arguments: transaction is nil or table name is empty") } if len(columnNames) == 0 { return false, errors.New("no column names provided") @@ -194,8 +174,8 @@ func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName str } func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return false, err + if tx == nil || tableName == "" { + return false, errors.New("invalid arguments: transaction is nil or table name is empty") } existingIndexes, err := b.GetIndexes(ctx, tx, tableName) @@ -238,8 +218,8 @@ func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName strin } func (b *mysqlBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { - if err := b.validateTxAndName(tx, name); err != nil { - return false, err + if tx == nil || name == "" { + return false, errors.New("invalid arguments: transaction is nil or table name is empty") } database, err := b.getCurrentDatabase(ctx, tx) @@ -247,7 +227,7 @@ func (b *mysqlBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (b return false, err } - query, err := b.grammar.compileTableExists(database, name) + query, err := b.grammar.CompileTableExists(database, name) if err != nil { return false, err } diff --git a/mysql_builder_test.go b/mysql_builder_test.go index f5c1c99..b274864 100644 --- a/mysql_builder_test.go +++ b/mysql_builder_test.go @@ -43,7 +43,7 @@ func (s *mysqlBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - s.builder, err = schema.NewBuilder("mysql", schema.WithDebug()) + s.builder, err = schema.NewBuilder("mysql") s.Require().NoError(err) } @@ -141,41 +141,6 @@ func (s *mysqlBuilderSuite) TestCreate() { }) } -func (s *mysqlBuilderSuite) TestCreateIfNotExists() { - builder := s.builder - tx, err := s.db.BeginTx(s.ctx, nil) - s.Require().NoError(err) - defer tx.Rollback() //nolint:errcheck - - s.Run("when tx is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, nil, "test_table", func(table *schema.Blueprint) { - table.String("name") - }) - s.Error(err) - }) - s.Run("when table name is empty, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) { - table.String("name") - }) - s.Error(err) - }) - s.Run("when blueprint is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "test_table", nil) - s.Error(err) - }) - s.Run("when all parameters are valid, should error", func() { - err = builder.CreateIfNotExists(context.Background(), tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") - }) - s.Error(err) - }) -} - func (s *mysqlBuilderSuite) TestDrop() { builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) diff --git a/mysql_grammar.go b/mysql_grammar.go index 65c7c92..3c07d85 100644 --- a/mysql_grammar.go +++ b/mysql_grammar.go @@ -1,10 +1,11 @@ package schema import ( - "errors" "fmt" "slices" "strings" + + "github.com/afkdevs/go-schema/internal/util" ) type mysqlGrammar struct { @@ -13,63 +14,61 @@ type mysqlGrammar struct { serials []string } -var _ grammar = (*mysqlGrammar)(nil) - func newMysqlGrammar() *mysqlGrammar { return &mysqlGrammar{ serials: []string{ - columnTypeBigInteger, columnTypeInteger, columnTypeMediumInteger, columnTypeSmallInteger, - columnTypeTinyInteger, + "bigInteger", "integer", "mediumInteger", "smallInteger", + "tinyInteger", }, } } -func (g *mysqlGrammar) compileCurrentDatabase() string { +func (g *mysqlGrammar) CompileCurrentDatabase() string { return "SELECT DATABASE()" } -func (g *mysqlGrammar) compileTableExists(database string, table string) (string, error) { +func (g *mysqlGrammar) CompileTableExists(database string, table string) (string, error) { return fmt.Sprintf( "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'", - g.quoteString(database), - g.quoteString(table), + g.QuoteString(database), + g.QuoteString(table), ), nil } -func (g *mysqlGrammar) compileTables(database string) (string, error) { +func (g *mysqlGrammar) CompileTables(database string) (string, error) { return fmt.Sprintf( "select table_name as `name`, (data_length + index_length) as `size`, "+ "table_comment as `comment`, engine as `engine`, table_collation as `collation` "+ "from information_schema.tables where table_schema = %s and table_type in ('BASE TABLE', 'SYSTEM VERSIONED') "+ "order by table_name", - g.quoteString(database), + g.QuoteString(database), ), nil } -func (g *mysqlGrammar) compileColumns(database, table string) (string, error) { +func (g *mysqlGrammar) CompileColumns(database, table string) (string, error) { return fmt.Sprintf( "select column_name as `name`, data_type as `type_name`, column_type as `type`, "+ "collation_name as `collation`, is_nullable as `nullable`, "+ "column_default as `default`, column_comment as `comment`, extra as `extra` "+ "from information_schema.columns where table_schema = %s and table_name = %s "+ "order by ordinal_position asc", - g.quoteString(database), - g.quoteString(table), + g.QuoteString(database), + g.QuoteString(table), ), nil } -func (g *mysqlGrammar) compileIndexes(database, table string) (string, error) { +func (g *mysqlGrammar) CompileIndexes(database, table string) (string, error) { return fmt.Sprintf( "select index_name as `name`, group_concat(column_name order by seq_in_index) as `columns`, "+ "index_type as `type`, not non_unique as `unique` "+ "from information_schema.statistics where table_schema = %s and table_name = %s "+ "group by index_name, index_type, non_unique", - g.quoteString(database), - g.quoteString(table), + g.QuoteString(database), + g.QuoteString(table), ), nil } -func (g *mysqlGrammar) compileCreate(blueprint *Blueprint) (string, error) { +func (g *mysqlGrammar) CompileCreate(blueprint *Blueprint) (string, error) { sql, err := g.compileCreateTable(blueprint) if err != nil { return "", err @@ -109,11 +108,7 @@ func (g *mysqlGrammar) compileCreateEngine(sql string, blueprint *Blueprint) str return sql } -func (g *mysqlGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, error) { - return "", errors.New("MySQL does not support CREATE TABLE IF NOT EXISTS with custom options") -} - -func (g *mysqlGrammar) compileAdd(blueprint *Blueprint) (string, error) { +func (g *mysqlGrammar) CompileAdd(blueprint *Blueprint) (string, error) { if len(blueprint.getAddedColumns()) == 0 { return "", nil } @@ -122,9 +117,9 @@ func (g *mysqlGrammar) compileAdd(blueprint *Blueprint) (string, error) { if err != nil { return "", err } - columns = g.prefixArray("ADD COLUMN ", columns) + columns = g.PrefixArray("ADD COLUMN ", columns) constraints := g.getConstraints(blueprint) - constraints = g.prefixArray("ADD ", constraints) + constraints = g.PrefixArray("ADD ", constraints) columns = append(columns, constraints...) return fmt.Sprintf("ALTER TABLE %s %s", @@ -133,7 +128,7 @@ func (g *mysqlGrammar) compileAdd(blueprint *Blueprint) (string, error) { ), nil } -func (g *mysqlGrammar) compileChange(bp *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileChange(bp *Blueprint, command *command) (string, error) { column := command.column if column.name == "" { return "", fmt.Errorf("column name cannot be empty for change operation") @@ -147,25 +142,25 @@ func (g *mysqlGrammar) compileChange(bp *Blueprint, command *command) (string, e return sql, nil } -func (g *mysqlGrammar) compileRename(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileRename(blueprint *Blueprint, command *command) (string, error) { return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil } -func (g *mysqlGrammar) compileDrop(blueprint *Blueprint) (string, error) { +func (g *mysqlGrammar) CompileDrop(blueprint *Blueprint) (string, error) { if blueprint.name == "" { return "", fmt.Errorf("table name cannot be empty") } return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil } -func (g *mysqlGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) { +func (g *mysqlGrammar) CompileDropIfExists(blueprint *Blueprint) (string, error) { if blueprint.name == "" { return "", fmt.Errorf("table name cannot be empty") } return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil } -func (g *mysqlGrammar) compileDropColumn(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileDropColumn(blueprint *Blueprint, command *command) (string, error) { if len(command.columns) == 0 { return "", fmt.Errorf("no columns to drop") } @@ -176,28 +171,28 @@ func (g *mysqlGrammar) compileDropColumn(blueprint *Blueprint, command *command) } columns[i] = col } - columns = g.prefixArray("DROP COLUMN ", columns) + columns = g.PrefixArray("DROP COLUMN ", columns) return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil } -func (g *mysqlGrammar) compileRenameColumn(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) { if command.from == "" || command.to == "" { return "", fmt.Errorf("old and new column names cannot be empty") } return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil } -func (g *mysqlGrammar) compileIndex(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileIndex(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("index column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandIndex, command.columns...) + indexName = g.CreateIndexName(blueprint, "index", command.columns...) } - sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(command.columns)) + sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns)) if command.algorithm != "" { sql += fmt.Sprintf(" USING %s", command.algorithm) } @@ -205,16 +200,16 @@ func (g *mysqlGrammar) compileIndex(blueprint *Blueprint, command *command) (str return sql, nil } -func (g *mysqlGrammar) compileUnique(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileUnique(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("unique column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandUnique, command.columns...) + indexName = g.CreateIndexName(blueprint, "unique", command.columns...) } - sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(command.columns)) + sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns)) if command.algorithm != "" { sql += fmt.Sprintf(" USING %s", command.algorithm) } @@ -222,69 +217,69 @@ func (g *mysqlGrammar) compileUnique(blueprint *Blueprint, command *command) (st return sql, nil } -func (g *mysqlGrammar) compileFullText(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileFullText(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("fulltext index column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandFullText, command.columns...) + indexName = g.CreateIndexName(blueprint, "fulltext", command.columns...) } - return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(command.columns)), nil + return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns)), nil } -func (g *mysqlGrammar) compilePrimary(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompilePrimary(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("primary key column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandPrimary) + indexName = g.CreateIndexName(blueprint, "primary", command.columns...) } - return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(command.columns)), nil + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.Columnize(command.columns)), nil } -func (g *mysqlGrammar) compileDropIndex(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileDropIndex(blueprint *Blueprint, command *command) (string, error) { if command.index == "" { return "", fmt.Errorf("index name cannot be empty") } return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil } -func (g *mysqlGrammar) compileDropUnique(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileDropUnique(blueprint *Blueprint, command *command) (string, error) { if command.index == "" { return "", fmt.Errorf("unique index name cannot be empty") } return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil } -func (g *mysqlGrammar) compileDropFulltext(blueprint *Blueprint, command *command) (string, error) { - return g.compileDropIndex(blueprint, command) +func (g *mysqlGrammar) CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) { + return g.CompileDropIndex(blueprint, command) } -func (g *mysqlGrammar) compileDropPrimary(blueprint *Blueprint, _ *command) (string, error) { +func (g *mysqlGrammar) CompileDropPrimary(blueprint *Blueprint, _ *command) (string, error) { return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", blueprint.name), nil } -func (g *mysqlGrammar) compileRenameIndex(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileRenameIndex(blueprint *Blueprint, command *command) (string, error) { if command.from == "" || command.to == "" { return "", fmt.Errorf("old and new index names cannot be empty") } return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s", blueprint.name, command.from, command.to), nil } -func (g *mysqlGrammar) compileDropForeign(blueprint *Blueprint, command *command) (string, error) { +func (g *mysqlGrammar) CompileDropForeign(blueprint *Blueprint, command *command) (string, error) { if command.index == "" { return "", fmt.Errorf("foreign key name cannot be empty") } return fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s", blueprint.name, command.index), nil } -func (g *mysqlGrammar) getFluentCommands() []func(*Blueprint, *command) string { +func (g *mysqlGrammar) GetFluentCommands() []func(*Blueprint, *command) string { return []func(*Blueprint, *command) string{} } @@ -297,13 +292,8 @@ func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) { sql := col.name + " " + g.getType(col) sql += g.modifyUnsigned(col) sql += g.modifyIncrement(col) - - if col.defaultValue != nil { - sql += g.modifyDefault(col) - } - if col.onUpdateValue != nil { - sql += g.modifyOnUpdate(col) - } + sql += g.modifyDefault(col) + sql += g.modifyOnUpdate(col) sql += g.modifyCharset(col) sql += g.modifyCollate(col) sql += g.modifyNullable(col) @@ -319,7 +309,7 @@ func (g *mysqlGrammar) getConstraints(blueprint *Blueprint) []string { var constrains []string for _, col := range blueprint.getAddedColumns() { if col.primary != nil && *col.primary { - pkConstraintName := g.createIndexName(blueprint, commandPrimary) + pkConstraintName := g.CreateIndexName(blueprint, "primary") sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" constrains = append(constrains, sql) continue @@ -347,8 +337,8 @@ func (g *mysqlGrammar) getType(col *columnDefinition) string { columnTypeDecimal: g.typeDecimal, columnTypeBoolean: g.typeBoolean, columnTypeEnum: g.typeEnum, - columnTypeJSON: g.typeJson, - columnTypeJSONB: g.typeJsonb, + columnTypeJson: g.typeJson, + columnTypeJsonb: g.typeJsonb, columnTypeDate: g.typeDate, columnTypeDateTime: g.typeDateTime, columnTypeDateTimeTz: g.typeDateTimeTz, @@ -358,7 +348,7 @@ func (g *mysqlGrammar) getType(col *columnDefinition) string { columnTypeTimestampTz: g.typeTimestampTz, columnTypeYear: g.typeYear, columnTypeBinary: g.typeBinary, - columnTypeUUID: g.typeUUID, + columnTypeUuid: g.typeUuid, columnTypeGeography: g.typeGeography, columnTypeGeometry: g.typeGeometry, columnTypePoint: g.typePoint, @@ -435,7 +425,7 @@ func (g *mysqlGrammar) typeBoolean(col *columnDefinition) string { func (g *mysqlGrammar) typeEnum(col *columnDefinition) string { allowedValues := make([]string, len(col.allowed)) for i, e := range col.allowed { - allowedValues[i] = g.quoteString(e) + allowedValues[i] = g.QuoteString(e) } return fmt.Sprintf("ENUM(%s)", strings.Join(allowedValues, ", ")) } @@ -457,11 +447,11 @@ func (g *mysqlGrammar) typeDateTime(col *columnDefinition) string { if col.precision != nil && *col.precision > 0 { current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision) } - if col.useCurrent != nil && *col.useCurrent { - col.Default(Expression(current)) + if col.useCurrent { + col.SetDefault(Expression(current)) } - if col.useCurrentOnUpdate != nil && *col.useCurrentOnUpdate { - col.OnUpdate(Expression(current)) + if col.useCurrentOnUpdate { + col.SetOnUpdate(Expression(current)) } if col.precision != nil && *col.precision > 0 { return fmt.Sprintf("DATETIME(%d)", *col.precision) @@ -489,11 +479,11 @@ func (g *mysqlGrammar) typeTimestamp(col *columnDefinition) string { if col.precision != nil && *col.precision > 0 { current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision) } - if col.useCurrent != nil && *col.useCurrent { - col.Default(Expression(current)) + if col.useCurrent { + col.SetDefault(Expression(current)) } - if col.useCurrentOnUpdate != nil && *col.useCurrentOnUpdate { - col.OnUpdate(Expression(current)) + if col.useCurrentOnUpdate { + col.SetOnUpdate(Expression(current)) } if col.precision != nil && *col.precision > 0 { return fmt.Sprintf("TIMESTAMP(%d)", *col.precision) @@ -516,12 +506,12 @@ func (g *mysqlGrammar) typeBinary(col *columnDefinition) string { return "BLOB" } -func (g *mysqlGrammar) typeUUID(col *columnDefinition) string { +func (g *mysqlGrammar) typeUuid(col *columnDefinition) string { return "CHAR(36)" // Default UUID length } func (g *mysqlGrammar) typeGeometry(col *columnDefinition) string { - subtype := ternary(col.subtype != nil, ptrOf(strings.ToUpper(*col.subtype)), nil) + subtype := util.Ternary(col.subtype != nil, util.PtrOf(strings.ToUpper(*col.subtype)), nil) if subtype != nil { if !slices.Contains([]string{"POINT", "LINESTRING", "POLYGON", "GEOMETRYCOLLECTION", "MULTIPOINT", "MULTILINESTRING"}, *subtype) { subtype = nil @@ -529,7 +519,7 @@ func (g *mysqlGrammar) typeGeometry(col *columnDefinition) string { } if subtype == nil { - subtype = ptrOf("GEOMETRY") + subtype = util.PtrOf("GEOMETRY") } if col.srid != nil && *col.srid > 0 { return fmt.Sprintf("%s SRID %d", *subtype, *col.srid) @@ -542,7 +532,7 @@ func (g *mysqlGrammar) typeGeography(col *columnDefinition) string { } func (g *mysqlGrammar) typePoint(col *columnDefinition) string { - col.subtype = ptrOf("POINT") + col.SetSubtype(util.PtrOf("POINT")) return g.typeGeometry(col) } @@ -574,15 +564,15 @@ func (g *mysqlGrammar) modifyCollate(col *columnDefinition) string { } func (g *mysqlGrammar) modifyComment(col *columnDefinition) string { - if col.comment != nil && *col.comment != "" { + if col.comment != nil { return fmt.Sprintf(" COMMENT '%s'", *col.comment) } return "" } func (g *mysqlGrammar) modifyDefault(col *columnDefinition) string { - if col.defaultValue != nil { - return fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col.defaultValue)) + if col.hasCommand("default") { + return fmt.Sprintf(" DEFAULT %s", g.GetDefaultValue(col.defaultValue)) } return "" } @@ -604,8 +594,8 @@ func (g *mysqlGrammar) modifyNullable(col *columnDefinition) string { } func (g *mysqlGrammar) modifyOnUpdate(col *columnDefinition) string { - if col.onUpdateValue != nil { - return fmt.Sprintf(" ON UPDATE %s", g.getValue(col.onUpdateValue)) + if col.hasCommand("onUpdate") { + return fmt.Sprintf(" ON UPDATE %s", g.GetValue(col.onUpdateValue)) } return "" } diff --git a/mysql_grammar_test.go b/mysql_grammar_test.go index 57f9b21..1dcb2af 100644 --- a/mysql_grammar_test.go +++ b/mysql_grammar_test.go @@ -3,8 +3,8 @@ package schema import ( "testing" + "github.com/afkdevs/go-schema/internal/dialect" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestMysqlGrammar_CompileCreate(t *testing.T) { @@ -83,7 +83,7 @@ func TestMysqlGrammar_CompileCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileCreate(bp) + got, err := g.CompileCreate(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -173,7 +173,7 @@ func TestMysqlGrammar_CompileAdd(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileAdd(bp) + got, err := g.CompileAdd(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -243,7 +243,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.Text("description").Nullable().Default(nil).Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN description TEXT DEFAULT NULL"}, + want: []string{"ALTER TABLE users MODIFY COLUMN description TEXT NULL DEFAULT NULL"}, wantErr: false, }, { @@ -252,7 +252,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.Integer("age").Comment("User age in years").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN age INT COMMENT 'User age in years'"}, + want: []string{"ALTER TABLE users MODIFY COLUMN age INT NOT NULL COMMENT 'User age in years'"}, wantErr: false, }, { @@ -261,7 +261,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.Text("notes").Comment("").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN notes TEXT COMMENT ''"}, + want: []string{"ALTER TABLE users MODIFY COLUMN notes TEXT NOT NULL COMMENT ''"}, wantErr: false, }, { @@ -285,7 +285,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { table.SmallInteger("age").Nullable().Change() }, want: []string{ - "ALTER TABLE users MODIFY COLUMN name VARCHAR(200)", + "ALTER TABLE users MODIFY COLUMN name VARCHAR(200) NOT NULL", "ALTER TABLE users MODIFY COLUMN age SMALLINT NULL", }, wantErr: false, @@ -302,7 +302,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table, grammar: g, dialect: dialectMySQL} + bp := &Blueprint{name: tt.table, grammar: g, dialect: dialect.MySQL} tt.blueprint(bp) statements, err := bp.toSql() if tt.wantErr { @@ -310,10 +310,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { return } assert.NoError(t, err, "Did not expect error for test case: %s", tt.name) - require.Len(t, statements, len(tt.want), "Expected number of SQL statements to match for test case: %s", tt.name) - for i, stmt := range statements { - assert.Equal(t, tt.want[i], stmt, "Expected SQL to match for test case: %s", tt.name) - } + assert.Equal(t, tt.want, statements, "Expected SQL to match for test case: %s", tt.name) }) } } @@ -357,7 +354,7 @@ func TestMysqlGrammar_CompileRename(t *testing.T) { name: tt.table, } bp.rename(tt.newName) - got, err := g.compileRename(bp, bp.commands[0]) + got, err := g.CompileRename(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -411,7 +408,7 @@ func TestMysqlGrammar_CompileDrop(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDrop(bp) + got, err := g.CompileDrop(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -465,7 +462,7 @@ func TestMysqlGrammar_CompileDropIfExists(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropIfExists(bp) + got, err := g.CompileDropIfExists(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -528,7 +525,7 @@ func TestMysqlGrammar_CompileDropColumn(t *testing.T) { name: tt.table, } tt.blueprint(bp) - got, err := g.compileDropColumn(bp, bp.commands[0]) + got, err := g.CompileDropColumn(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -585,7 +582,7 @@ func TestMysqlGrammar_CompileRenameColumn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{from: tt.oldName, to: tt.newName} - got, err := g.compileRenameColumn(bp, command) + got, err := g.CompileRenameColumn(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -707,7 +704,7 @@ func TestMysqlGrammar_CompileForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileForeign(bp, bp.commands[0]) + got, err := g.CompileForeign(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -747,7 +744,7 @@ func TestMysqlGrammar_CompileDropForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{index: tt.fkName} - got, err := g.compileDropForeign(bp, command) + got, err := g.CompileDropForeign(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -827,7 +824,7 @@ func TestMysqlGrammar_CompileIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileIndex(bp, bp.commands[0]) + got, err := g.CompileIndex(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -915,7 +912,7 @@ func TestMysqlGrammar_CompileUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileUnique(bp, bp.commands[0]) + got, err := g.CompileUnique(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1011,7 +1008,7 @@ func TestMysqlGrammar_CompilePrimary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compilePrimary(bp, bp.commands[0]) + got, err := g.CompilePrimary(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1107,7 +1104,7 @@ func TestMysqlGrammar_CompileFullText(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileFullText(bp, bp.commands[0]) + got, err := g.CompileFullText(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1147,7 +1144,7 @@ func TestMysqlGrammar_CompileDropIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{index: tt.indexName} - got, err := g.compileDropIndex(bp, command) + got, err := g.CompileDropIndex(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1186,7 +1183,7 @@ func TestMysqlGrammar_CompileDropUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{index: tt.indexName} - got, err := g.compileDropUnique(bp, command) + got, err := g.CompileDropUnique(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1225,7 +1222,7 @@ func TestMysqlGrammar_CompileDropFulltext(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{index: tt.indexName} - got, err := g.compileDropFulltext(bp, command) + got, err := g.CompileDropFulltext(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1287,7 +1284,7 @@ func TestMysqlGrammar_CompileDropPrimary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{index: tt.indexName} - got, err := g.compileDropPrimary(bp, command) + got, err := g.CompileDropPrimary(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1376,7 +1373,7 @@ func TestMysqlGrammar_CompileRenameIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{from: tt.oldName, to: tt.newName} - got, err := g.compileRenameIndex(bp, command) + got, err := g.CompileRenameIndex(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1700,14 +1697,6 @@ func TestMysqlGrammar_GetColumns(t *testing.T) { want: []string{"status VARCHAR(50) DEFAULT 'active' NOT NULL"}, wantErr: false, }, - { - name: "column with on update value", - blueprint: func(table *Blueprint) { - table.Timestamp("updated_at", 0).UseCurrentOnUpdate() - }, - want: []string{"updated_at TIMESTAMP ON UPDATE CURRENT_TIMESTAMP NOT NULL"}, - wantErr: false, - }, { name: "nullable column", blueprint: func(table *Blueprint) { @@ -1759,21 +1748,13 @@ func TestMysqlGrammar_GetColumns(t *testing.T) { want: []string{"id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL"}, wantErr: false, }, - { - name: "timestamp with default", - blueprint: func(table *Blueprint) { - table.Timestamp("created_at", 0).UseCurrent() - }, - want: []string{"created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL"}, - wantErr: false, - }, { name: "multiple columns with different attributes", blueprint: func(table *Blueprint) { table.BigInteger("id").Unsigned().AutoIncrement().Primary() table.String("name", 255).Comment("User name") table.String("email", 255).Nullable() - table.Timestamp("created_at", 0).Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at", 0).UseCurrent() }, want: []string{ "id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL", diff --git a/options.go b/options.go deleted file mode 100644 index e14fc19..0000000 --- a/options.go +++ /dev/null @@ -1,26 +0,0 @@ -package schema - -type config struct { - debug bool -} - -type Option func(*config) - -// WithDebug sets the debug mode for the schema package. -func WithDebug(debug ...bool) Option { - return func(c *config) { - c.debug = optional(true, debug...) - } -} - -func applyOptions(opts ...Option) *config { - cfg := &config{ - debug: false, // default value - } - - for _, opt := range opts { - opt(cfg) - } - - return cfg -} diff --git a/postgres_builder.go b/postgres_builder.go index 1aacc64..c49f6f1 100644 --- a/postgres_builder.go +++ b/postgres_builder.go @@ -5,19 +5,21 @@ import ( "database/sql" "errors" "strings" + + "github.com/afkdevs/go-schema/internal/config" ) type postgresBuilder struct { baseBuilder - grammar *pgGrammar + grammar *postgresGrammar } -func newPostgresBuilder(options ...Option) Builder { - grammar := newPgGrammar() - cfg := applyOptions(options...) +func newPostgresBuilder() Builder { + grammar := newPostgresGrammar() + cfg := config.Get() return &postgresBuilder{ - baseBuilder: baseBuilder{grammar: grammar, debug: cfg.debug}, + baseBuilder: baseBuilder{grammar: grammar, verbose: cfg.Verbose}, grammar: grammar, } } @@ -31,15 +33,15 @@ func (b *postgresBuilder) parseSchemaAndTable(name string) (string, string) { } func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err + if tx == nil || tableName == "" { + return nil, errors.New("invalid arguments: transaction is nil or table name is empty") } schema, name := b.parseSchemaAndTable(tableName) if schema == "" { schema = "public" // Default schema for PostgreSQL } - query, err := b.grammar.compileColumns(schema, name) + query, err := b.grammar.CompileColumns(schema, name) if err != nil { return nil, err } @@ -66,14 +68,14 @@ func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName } func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err + if tx == nil || tableName == "" { + return nil, errors.New("invalid arguments: transaction is nil or table name is empty") } schema, name := b.parseSchemaAndTable(tableName) if schema == "" { schema = "public" // Default schema for PostgreSQL } - query, err := b.grammar.compileIndexes(schema, name) + query, err := b.grammar.CompileIndexes(schema, name) if err != nil { return nil, err } @@ -102,7 +104,7 @@ func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableIn return nil, errors.New("transaction is nil") } - query, err := b.grammar.compileTables() + query, err := b.grammar.CompileTables() if err != nil { return nil, err } @@ -130,6 +132,9 @@ func (b *postgresBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName s } func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { + if tx == nil || tableName == "" { + return false, errors.New("invalid arguments: transaction is nil or table name is empty") + } if len(columnNames) == 0 { return false, errors.New("no column names provided") } @@ -159,8 +164,8 @@ func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName } func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return false, err + if tx == nil || tableName == "" { + return false, errors.New("invalid arguments: transaction is nil or table name is empty") } existingIndexes, err := b.GetIndexes(ctx, tx, tableName) @@ -203,15 +208,15 @@ func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName st } func (b *postgresBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { - if err := b.validateTxAndName(tx, name); err != nil { - return false, err + if tx == nil || name == "" { + return false, errors.New("invalid arguments: transaction is nil or table name is empty") } schema, name := b.parseSchemaAndTable(name) if schema == "" { schema = "public" // Default schema for PostgreSQL } - query, err := b.grammar.compileTableExists(schema, name) + query, err := b.grammar.CompileTableExists(schema, name) if err != nil { return false, err } diff --git a/postgres_builder_test.go b/postgres_builder_test.go index b110b25..33f2645 100644 --- a/postgres_builder_test.go +++ b/postgres_builder_test.go @@ -58,7 +58,7 @@ func (s *postgresBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - s.builder, err = schema.NewBuilder("postgres", schema.WithDebug()) + s.builder, err = schema.NewBuilder("postgres") s.Require().NoError(err) } @@ -148,44 +148,6 @@ func (s *postgresBuilderSuite) TestCreate() { }) } -func (s *postgresBuilderSuite) TestCreateIfNotExists() { - builder := s.builder - tx, err := s.db.BeginTx(s.ctx, nil) - s.Require().NoError(err) - defer tx.Rollback() //nolint:errcheck - - s.Run("when tx is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, nil, "test_table", func(table *schema.Blueprint) {}) - s.Error(err, "expected error when transaction is nil") - }) - s.Run("when table name is empty, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) {}) - s.Error(err, "expected error when table name is empty") - }) - s.Run("when blueprint is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "test_table", nil) - s.Error(err, "expected error when blueprint is nil") - }) - s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email").Unique() - table.String("password").Nullable() - table.TimestampsTz() - }) - s.NoError(err, "expected no error when creating table with valid parameters") - }) - s.Run("when table already exists, should not return error", func() { - err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email") - }) - s.NoError(err, "expected no error when creating table that already exists") - }) -} - func (s *postgresBuilderSuite) TestDrop() { builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) @@ -408,8 +370,7 @@ func (s *postgresBuilderSuite) TestGetColumns() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting columns") @@ -444,8 +405,7 @@ func (s *postgresBuilderSuite) TestGetIndexes() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.Index("name").Name("idx_users_name") }) @@ -479,8 +439,7 @@ func (s *postgresBuilderSuite) TestGetTables() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting tables") @@ -516,8 +475,7 @@ func (s *postgresBuilderSuite) TestHasColumn() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking column existence") @@ -553,8 +511,7 @@ func (s *postgresBuilderSuite) TestHasColumns() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking column existence") @@ -591,7 +548,7 @@ func (s *postgresBuilderSuite) TestHasIndex() { table.Integer("user_id") table.String("order_id", 255) table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Index("company_id", "user_id") table.Unique("order_id").Name("uk_orders_order_id") @@ -634,8 +591,7 @@ func (s *postgresBuilderSuite) TestHasTable() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking existence") @@ -656,8 +612,7 @@ func (s *postgresBuilderSuite) TestHasTable() { table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table with custom schema") diff --git a/postgres_grammar.go b/postgres_grammar.go index dac364f..a5fdb25 100644 --- a/postgres_grammar.go +++ b/postgres_grammar.go @@ -6,32 +6,30 @@ import ( "strings" ) -type pgGrammar struct { +type postgresGrammar struct { baseGrammar } -var _ grammar = (*pgGrammar)(nil) - -func newPgGrammar() *pgGrammar { - return &pgGrammar{} +func newPostgresGrammar() *postgresGrammar { + return &postgresGrammar{} } -func (g *pgGrammar) compileTableExists(schema string, table string) (string, error) { +func (g *postgresGrammar) CompileTableExists(schema string, table string) (string, error) { return fmt.Sprintf( "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'", - g.quoteString(schema), - g.quoteString(table), + g.QuoteString(schema), + g.QuoteString(table), ), nil } -func (g *pgGrammar) compileTables() (string, error) { +func (g *postgresGrammar) CompileTables() (string, error) { return "select c.relname as name, n.nspname as schema, pg_total_relation_size(c.oid) as size, " + "obj_description(c.oid, 'pg_class') as comment from pg_class c, pg_namespace n " + "where c.relkind in ('r', 'p') and n.oid = c.relnamespace and n.nspname not in ('pg_catalog', 'information_schema') " + "order by c.relname", nil } -func (g *pgGrammar) compileColumns(schema, table string) (string, error) { +func (g *postgresGrammar) CompileColumns(schema, table string) (string, error) { return fmt.Sprintf( "select a.attname as name, t.typname as type_name, format_type(a.atttypid, a.atttypmod) as type, "+ "(select tc.collcollate from pg_catalog.pg_collation tc where tc.oid = a.attcollation) as collation, "+ @@ -41,12 +39,12 @@ func (g *pgGrammar) compileColumns(schema, table string) (string, error) { "from pg_attribute a, pg_class c, pg_type t, pg_namespace n "+ "where c.relname = %s and n.nspname = %s and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid and n.oid = c.relnamespace "+ "order by a.attnum", - g.quoteString(table), - g.quoteString(schema), + g.QuoteString(table), + g.QuoteString(schema), ), nil } -func (g *pgGrammar) compileIndexes(schema, table string) (string, error) { +func (g *postgresGrammar) CompileIndexes(schema, table string) (string, error) { return fmt.Sprintf( "select ic.relname as name, string_agg(a.attname, ',' order by indseq.ord) as columns, "+ "am.amname as \"type\", i.indisunique as \"unique\", i.indisprimary as \"primary\" "+ @@ -59,12 +57,12 @@ func (g *pgGrammar) compileIndexes(schema, table string) (string, error) { "left join pg_attribute a on a.attrelid = i.indrelid and a.attnum = indseq.num "+ "where tc.relname = %s and tn.nspname = %s "+ "group by ic.relname, am.amname, i.indisunique, i.indisprimary", - g.quoteString(table), - g.quoteString(schema), + g.QuoteString(table), + g.QuoteString(schema), ), nil } -func (g *pgGrammar) compileCreate(blueprint *Blueprint) (string, error) { +func (g *postgresGrammar) CompileCreate(blueprint *Blueprint) (string, error) { columns, err := g.getColumns(blueprint) if err != nil { return "", err @@ -73,16 +71,7 @@ func (g *pgGrammar) compileCreate(blueprint *Blueprint) (string, error) { return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil } -func (g *pgGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, error) { - columns, err := g.getColumns(blueprint) - if err != nil { - return "", err - } - columns = append(columns, g.getConstraints(blueprint)...) - return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil -} - -func (g *pgGrammar) compileAdd(blueprint *Blueprint) (string, error) { +func (g *postgresGrammar) CompileAdd(blueprint *Blueprint) (string, error) { if len(blueprint.getAddedColumns()) == 0 { return "", nil } @@ -91,10 +80,10 @@ func (g *pgGrammar) compileAdd(blueprint *Blueprint) (string, error) { if err != nil { return "", err } - columns = g.prefixArray("ADD COLUMN ", columns) + columns = g.PrefixArray("ADD COLUMN ", columns) constraints := g.getConstraints(blueprint) if len(constraints) > 0 { - constraints = g.prefixArray("ADD ", constraints) + constraints = g.PrefixArray("ADD ", constraints) columns = append(columns, constraints...) } @@ -104,7 +93,7 @@ func (g *pgGrammar) compileAdd(blueprint *Blueprint) (string, error) { ), nil } -func (g *pgGrammar) compileChange(bp *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileChange(bp *Blueprint, command *command) (string, error) { column := command.column if column.name == "" { return "", fmt.Errorf("column name cannot be empty for change operation") @@ -121,45 +110,45 @@ func (g *pgGrammar) compileChange(bp *Blueprint, command *command) (string, erro return fmt.Sprintf("ALTER TABLE %s %s", bp.name, - strings.Join(g.prefixArray(fmt.Sprintf("ALTER COLUMN %s ", column.name), changes), ", "), + strings.Join(g.PrefixArray(fmt.Sprintf("ALTER COLUMN %s ", column.name), changes), ", "), ), nil } -func (g *pgGrammar) compileDrop(blueprint *Blueprint) (string, error) { +func (g *postgresGrammar) CompileDrop(blueprint *Blueprint) (string, error) { return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil } -func (g *pgGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) { +func (g *postgresGrammar) CompileDropIfExists(blueprint *Blueprint) (string, error) { return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil } -func (g *pgGrammar) compileRename(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileRename(blueprint *Blueprint, command *command) (string, error) { return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil } -func (g *pgGrammar) compileDropColumn(blueprint *Blueprint, command *command) (string, error) { - if len(blueprint.columns) == 0 { +func (g *postgresGrammar) CompileDropColumn(blueprint *Blueprint, command *command) (string, error) { + if len(command.columns) == 0 { return "", nil } - columns := g.prefixArray("DROP COLUMN ", command.columns) + columns := g.PrefixArray("DROP COLUMN ", command.columns) return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil } -func (g *pgGrammar) compileRenameColumn(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) { if command.from == "" || command.to == "" { return "", fmt.Errorf("table name, old column name, and new column name cannot be empty for rename operation") } return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil } -func (g *pgGrammar) compileFullText(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileFullText(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("fulltext index column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandFullText, command.columns...) + indexName = g.CreateIndexName(blueprint, "fulltext", command.columns...) } language := command.language if language == "" { @@ -167,40 +156,40 @@ func (g *pgGrammar) compileFullText(blueprint *Blueprint, command *command) (str } var columns []string for _, col := range command.columns { - columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.quoteString(language), col)) + columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.QuoteString(language), col)) } return fmt.Sprintf("CREATE INDEX %s ON %s USING GIN (%s)", indexName, blueprint.name, strings.Join(columns, " || ")), nil } -func (g *pgGrammar) compileIndex(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileIndex(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("index column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandIndex, command.columns...) + indexName = g.CreateIndexName(blueprint, "index", command.columns...) } sql := fmt.Sprintf("CREATE INDEX %s ON %s", indexName, blueprint.name) if command.algorithm != "" { sql += fmt.Sprintf(" USING %s", command.algorithm) } - return fmt.Sprintf("%s (%s)", sql, g.columnize(command.columns)), nil + return fmt.Sprintf("%s (%s)", sql, g.Columnize(command.columns)), nil } -func (g *pgGrammar) compileUnique(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileUnique(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("unique index column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandUnique, command.columns...) + indexName = g.CreateIndexName(blueprint, "unique", command.columns...) } sql := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)", blueprint.name, indexName, - g.columnize(command.columns), + g.Columnize(command.columns), ) if command.deferrable != nil { @@ -221,51 +210,52 @@ func (g *pgGrammar) compileUnique(blueprint *Blueprint, command *command) (strin return sql, nil } -func (g *pgGrammar) compilePrimary(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompilePrimary(blueprint *Blueprint, command *command) (string, error) { if slices.Contains(command.columns, "") { return "", fmt.Errorf("primary key index column cannot be empty") } indexName := command.index if indexName == "" { - indexName = g.createIndexName(blueprint, commandPrimary, command.columns...) + indexName = g.CreateIndexName(blueprint, "primary", command.columns...) } - return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(command.columns)), nil + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.Columnize(command.columns)), nil } -func (g *pgGrammar) compileDropIndex(_ *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileDropIndex(_ *Blueprint, command *command) (string, error) { if command.index == "" { return "", fmt.Errorf("index name cannot be empty for drop operation") } return fmt.Sprintf("DROP INDEX %s", command.index), nil } -func (g *pgGrammar) compileDropFulltext(blueprint *Blueprint, command *command) (string, error) { - return g.compileDropIndex(blueprint, command) +func (g *postgresGrammar) CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) { + return g.CompileDropIndex(blueprint, command) } -func (g *pgGrammar) compileDropUnique(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileDropUnique(blueprint *Blueprint, command *command) (string, error) { if command.index == "" { return "", fmt.Errorf("index name cannot be empty for drop operation") } return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil } -func (g *pgGrammar) compileDropPrimary(blueprint *Blueprint, command *command) (string, error) { - if command.index == "" { - command.index = g.createIndexName(blueprint, commandPrimary) +func (g *postgresGrammar) CompileDropPrimary(blueprint *Blueprint, command *command) (string, error) { + index := command.index + if index == "" { + index = g.CreateIndexName(blueprint, "primary", command.columns...) } - return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil + return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, index), nil } -func (g *pgGrammar) compileRenameIndex(_ *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileRenameIndex(_ *Blueprint, command *command) (string, error) { if command.from == "" || command.to == "" { return "", fmt.Errorf("index names for rename operation cannot be empty: oldName=%s, newName=%s", command.from, command.to) } return fmt.Sprintf("ALTER INDEX %s RENAME TO %s", command.from, command.to), nil } -func (g *pgGrammar) compileForeign(blueprint *Blueprint, command *command) (string, error) { - sql, err := g.baseGrammar.compileForeign(blueprint, command) +func (g *postgresGrammar) CompileForeign(blueprint *Blueprint, command *command) (string, error) { + sql, err := g.baseGrammar.CompileForeign(blueprint, command) if err != nil { return "", err } @@ -288,21 +278,21 @@ func (g *pgGrammar) compileForeign(blueprint *Blueprint, command *command) (stri return sql, nil } -func (g *pgGrammar) compileDropForeign(blueprint *Blueprint, command *command) (string, error) { +func (g *postgresGrammar) CompileDropForeign(blueprint *Blueprint, command *command) (string, error) { if command.index == "" { return "", fmt.Errorf("foreign key name cannot be empty for drop operation") } return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil } -func (g *pgGrammar) getFluentCommands() []func(blueprint *Blueprint, command *command) string { +func (g *postgresGrammar) GetFluentCommands() []func(blueprint *Blueprint, command *command) string { return []func(blueprint *Blueprint, command *command) string{ - g.compileComment, + g.CompileComment, } } -func (g *pgGrammar) compileComment(blueprint *Blueprint, command *command) string { - if command.column.comment != nil || command.column.change { +func (g *postgresGrammar) CompileComment(blueprint *Blueprint, command *command) string { + if command.column.comment != nil { sql := fmt.Sprintf("COMMENT ON COLUMN %s.%s IS ", blueprint.name, command.column.name) if command.column.comment == nil { return sql + "NULL" @@ -313,7 +303,7 @@ func (g *pgGrammar) compileComment(blueprint *Blueprint, command *command) strin return "" } -func (g *pgGrammar) getColumns(blueprint *Blueprint) ([]string, error) { +func (g *postgresGrammar) getColumns(blueprint *Blueprint) ([]string, error) { var columns []string for _, col := range blueprint.getAddedColumns() { if col.name == "" { @@ -329,11 +319,11 @@ func (g *pgGrammar) getColumns(blueprint *Blueprint) ([]string, error) { return columns, nil } -func (g *pgGrammar) getConstraints(blueprint *Blueprint) []string { +func (g *postgresGrammar) getConstraints(blueprint *Blueprint) []string { var constrains []string for _, col := range blueprint.getAddedColumns() { if col.primary != nil && *col.primary { - pkConstraintName := g.createIndexName(blueprint, commandPrimary) + pkConstraintName := g.CreateIndexName(blueprint, "primary") sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" constrains = append(constrains, sql) continue @@ -343,7 +333,7 @@ func (g *pgGrammar) getConstraints(blueprint *Blueprint) []string { return constrains } -func (g *pgGrammar) getType(col *columnDefinition) string { +func (g *postgresGrammar) getType(col *columnDefinition) string { typeMapFunc := map[string]func(*columnDefinition) string{ columnTypeChar: g.typeChar, columnTypeString: g.typeString, @@ -361,8 +351,8 @@ func (g *pgGrammar) getType(col *columnDefinition) string { columnTypeDecimal: g.typeDecimal, columnTypeBoolean: g.typeBoolean, columnTypeEnum: g.typeEnum, - columnTypeJSON: g.typeJson, - columnTypeJSONB: g.typeJsonb, + columnTypeJson: g.typeJson, + columnTypeJsonb: g.typeJsonb, columnTypeDate: g.typeDate, columnTypeDateTime: g.typeDateTime, columnTypeDateTimeTz: g.typeDateTimeTz, @@ -372,7 +362,7 @@ func (g *pgGrammar) getType(col *columnDefinition) string { columnTypeTimestampTz: g.typeTimestampTz, columnTypeYear: g.typeYear, columnTypeBinary: g.typeBinary, - columnTypeUUID: g.typeUUID, + columnTypeUuid: g.typeUuid, columnTypeGeography: g.typeGeography, columnTypeGeometry: g.typeGeometry, columnTypePoint: g.typePoint, @@ -383,126 +373,126 @@ func (g *pgGrammar) getType(col *columnDefinition) string { return col.columnType } -func (g *pgGrammar) typeChar(col *columnDefinition) string { +func (g *postgresGrammar) typeChar(col *columnDefinition) string { if col.length != nil && *col.length > 0 { return fmt.Sprintf("CHAR(%d)", *col.length) } return "CHAR" } -func (g *pgGrammar) typeString(col *columnDefinition) string { +func (g *postgresGrammar) typeString(col *columnDefinition) string { if col.length != nil && *col.length > 0 { return fmt.Sprintf("VARCHAR(%d)", *col.length) } return "VARCHAR" } -func (g *pgGrammar) typeTinyText(_ *columnDefinition) string { +func (g *postgresGrammar) typeTinyText(_ *columnDefinition) string { return "VARCHAR(255)" } -func (g *pgGrammar) typeText(_ *columnDefinition) string { +func (g *postgresGrammar) typeText(_ *columnDefinition) string { return "TEXT" } -func (g *pgGrammar) typeMediumText(_ *columnDefinition) string { +func (g *postgresGrammar) typeMediumText(_ *columnDefinition) string { return "TEXT" } -func (g *pgGrammar) typeLongText(_ *columnDefinition) string { +func (g *postgresGrammar) typeLongText(_ *columnDefinition) string { return "TEXT" } -func (g *pgGrammar) typeBigInteger(col *columnDefinition) string { +func (g *postgresGrammar) typeBigInteger(col *columnDefinition) string { if col.autoIncrement != nil && *col.autoIncrement { return "BIGSERIAL" } return "BIGINT" } -func (g *pgGrammar) typeInteger(col *columnDefinition) string { +func (g *postgresGrammar) typeInteger(col *columnDefinition) string { if col.autoIncrement != nil && *col.autoIncrement { return "SERIAL" } return "INTEGER" } -func (g *pgGrammar) typeMediumInteger(col *columnDefinition) string { +func (g *postgresGrammar) typeMediumInteger(col *columnDefinition) string { return g.typeInteger(col) } -func (g *pgGrammar) typeSmallInteger(col *columnDefinition) string { +func (g *postgresGrammar) typeSmallInteger(col *columnDefinition) string { if col.autoIncrement != nil && *col.autoIncrement { return "SMALLSERIAL" } return "SMALLINT" } -func (g *pgGrammar) typeTinyInteger(col *columnDefinition) string { +func (g *postgresGrammar) typeTinyInteger(col *columnDefinition) string { return g.typeSmallInteger(col) } -func (g *pgGrammar) typeFloat(_ *columnDefinition) string { +func (g *postgresGrammar) typeFloat(_ *columnDefinition) string { return "REAL" } -func (g *pgGrammar) typeDouble(_ *columnDefinition) string { +func (g *postgresGrammar) typeDouble(_ *columnDefinition) string { return "DOUBLE PRECISION" } -func (g *pgGrammar) typeDecimal(col *columnDefinition) string { +func (g *postgresGrammar) typeDecimal(col *columnDefinition) string { return fmt.Sprintf("DECIMAL(%d, %d)", *col.total, *col.places) } -func (g *pgGrammar) typeBoolean(_ *columnDefinition) string { +func (g *postgresGrammar) typeBoolean(_ *columnDefinition) string { return "BOOLEAN" } -func (g *pgGrammar) typeEnum(col *columnDefinition) string { +func (g *postgresGrammar) typeEnum(col *columnDefinition) string { enumValues := make([]string, len(col.allowed)) for i, v := range col.allowed { - enumValues[i] = g.quoteString(v) + enumValues[i] = g.QuoteString(v) } return "VARCHAR(255) CHECK (" + col.name + " IN (" + strings.Join(enumValues, ", ") + "))" } -func (g *pgGrammar) typeJson(_ *columnDefinition) string { +func (g *postgresGrammar) typeJson(_ *columnDefinition) string { return "JSON" } -func (g *pgGrammar) typeJsonb(_ *columnDefinition) string { +func (g *postgresGrammar) typeJsonb(_ *columnDefinition) string { return "JSONB" } -func (g *pgGrammar) typeDate(_ *columnDefinition) string { +func (g *postgresGrammar) typeDate(_ *columnDefinition) string { return "DATE" } -func (g *pgGrammar) typeDateTime(col *columnDefinition) string { +func (g *postgresGrammar) typeDateTime(col *columnDefinition) string { return g.typeTimestamp(col) } -func (g *pgGrammar) typeDateTimeTz(col *columnDefinition) string { +func (g *postgresGrammar) typeDateTimeTz(col *columnDefinition) string { return g.typeTimestampTz(col) } -func (g *pgGrammar) typeTime(col *columnDefinition) string { - if col.precision != nil && *col.precision > 0 { +func (g *postgresGrammar) typeTime(col *columnDefinition) string { + if col.precision != nil { return fmt.Sprintf("TIME(%d)", *col.precision) } return "TIME" } -func (g *pgGrammar) typeTimeTz(col *columnDefinition) string { - if col.precision != nil && *col.precision > 0 { +func (g *postgresGrammar) typeTimeTz(col *columnDefinition) string { + if col.precision != nil { return fmt.Sprintf("TIMETZ(%d)", *col.precision) } return "TIMETZ" } -func (g *pgGrammar) typeTimestamp(col *columnDefinition) string { - if col.useCurrent != nil && *col.useCurrent { - col.Default(Expression("CURRENT_TIMESTAMP")) +func (g *postgresGrammar) typeTimestamp(col *columnDefinition) string { + if col.useCurrent { + col.SetDefault(Expression("CURRENT_TIMESTAMP")) } if col.precision != nil { return fmt.Sprintf("TIMESTAMP(%d)", *col.precision) @@ -510,9 +500,9 @@ func (g *pgGrammar) typeTimestamp(col *columnDefinition) string { return "TIMESTAMP" } -func (g *pgGrammar) typeTimestampTz(col *columnDefinition) string { - if col.useCurrent != nil && *col.useCurrent { - col.Default(Expression("CURRENT_TIMESTAMP")) +func (g *postgresGrammar) typeTimestampTz(col *columnDefinition) string { + if col.useCurrent { + col.SetDefault(Expression("CURRENT_TIMESTAMP")) } if col.precision != nil { return fmt.Sprintf("TIMESTAMPTZ(%d)", *col.precision) @@ -520,19 +510,19 @@ func (g *pgGrammar) typeTimestampTz(col *columnDefinition) string { return "TIMESTAMPTZ" } -func (g *pgGrammar) typeYear(_ *columnDefinition) string { +func (g *postgresGrammar) typeYear(_ *columnDefinition) string { return "INTEGER" } -func (g *pgGrammar) typeBinary(_ *columnDefinition) string { +func (g *postgresGrammar) typeBinary(_ *columnDefinition) string { return "BYTEA" } -func (g *pgGrammar) typeUUID(_ *columnDefinition) string { +func (g *postgresGrammar) typeUuid(_ *columnDefinition) string { return "UUID" } -func (g *pgGrammar) typeGeography(col *columnDefinition) string { +func (g *postgresGrammar) typeGeography(col *columnDefinition) string { if col.subtype != nil && col.srid != nil { return fmt.Sprintf("GEOGRAPHY(%s, %d)", *col.subtype, *col.srid) } else if col.subtype != nil { @@ -541,7 +531,7 @@ func (g *pgGrammar) typeGeography(col *columnDefinition) string { return "GEOGRAPHY" } -func (g *pgGrammar) typeGeometry(col *columnDefinition) string { +func (g *postgresGrammar) typeGeometry(col *columnDefinition) string { if col.subtype != nil && col.srid != nil { return fmt.Sprintf("GEOMETRY(%s, %d)", *col.subtype, *col.srid) } else if col.subtype != nil { @@ -550,23 +540,26 @@ func (g *pgGrammar) typeGeometry(col *columnDefinition) string { return "GEOMETRY" } -func (g *pgGrammar) typePoint(col *columnDefinition) string { +func (g *postgresGrammar) typePoint(col *columnDefinition) string { if col.srid != nil { return fmt.Sprintf("POINT(%d)", *col.srid) } return "POINT" } -func (g *pgGrammar) modifiers() []func(*columnDefinition) string { +func (g *postgresGrammar) modifiers() []func(*columnDefinition) string { return []func(*columnDefinition) string{ - g.modifyNullable, g.modifyDefault, + g.modifyNullable, } } -func (g *pgGrammar) modifyNullable(col *columnDefinition) string { +func (g *postgresGrammar) modifyNullable(col *columnDefinition) string { if col.change { - if col.nullable != nil && *col.nullable { + if col.nullable == nil { + return "" + } + if *col.nullable { return " DROP NOT NULL" } return " SET NOT NULL" @@ -577,9 +570,12 @@ func (g *pgGrammar) modifyNullable(col *columnDefinition) string { return " NOT NULL" } -func (g *pgGrammar) modifyDefault(col *columnDefinition) string { - if col.defaultValue != nil { - return fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col.defaultValue)) +func (g *postgresGrammar) modifyDefault(col *columnDefinition) string { + if col.hasCommand("default") { + if col.change { + return fmt.Sprintf(" SET DEFAULT %s", g.GetDefaultValue(col.defaultValue)) + } + return fmt.Sprintf(" DEFAULT %s", g.GetDefaultValue(col.defaultValue)) } return "" } diff --git a/postgres_grammar_test.go b/postgres_grammar_test.go index 716e5f2..2a805f0 100644 --- a/postgres_grammar_test.go +++ b/postgres_grammar_test.go @@ -7,7 +7,7 @@ import ( ) func TestPgGrammar_CompileCreate(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -55,57 +55,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileCreate(bp) - if tt.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, tt.want, got, "SQL statement mismatch for %s", tt.name) - }) - } -} - -func TestPgGrammar_CompileCreateIfNotExists(t *testing.T) { - grammar := newPgGrammar() - - tests := []struct { - name string - table string - blueprint func(table *Blueprint) - want string - wantErr bool - }{ - { - name: "Create simple table if not exists", - table: "users", - blueprint: func(table *Blueprint) { - table.ID() - table.String("name") - table.String("email") - table.String("password").Nullable() - table.Timestamp("created_at").UseCurrent() - table.Timestamp("updated_at").UseCurrent() - }, - want: "CREATE TABLE IF NOT EXISTS users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR(255) NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", - wantErr: false, - }, - { - name: "Create table with column name is empty", - table: "empty_column_table", - blueprint: func(table *Blueprint) { - table.String("") // Intentionally empty column name - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} - tt.blueprint(bp) - got, err := grammar.compileCreateIfNotExists(bp) + got, err := grammar.CompileCreate(bp) if tt.wantErr { assert.Error(t, err) return @@ -118,7 +68,7 @@ func TestPgGrammar_CompileCreateIfNotExists(t *testing.T) { } func TestPgGrammar_CompileAdd(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -162,7 +112,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { blueprint: func(table *Blueprint) { table.String("notes", 500).Comment("User notes") }, - want: "ALTER TABLE users ADD COLUMN notes VARCHAR(500) NOT NULL COMMENT 'User notes'", + want: "ALTER TABLE users ADD COLUMN notes VARCHAR(500) NOT NULL", wantErr: false, }, { @@ -187,9 +137,9 @@ func TestPgGrammar_CompileAdd(t *testing.T) { name: "Add complex column with all attributes", table: "products", blueprint: func(table *Blueprint) { - table.Decimal("price", 10, 2).Default(0).Comment("Product price") + table.Decimal("price", 10, 2).Default(0) }, - want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT '0' NOT NULL COMMENT 'Product price'", + want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT '0' NOT NULL", wantErr: false, }, { @@ -253,7 +203,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileAdd(bp) + got, err := grammar.CompileAdd(bp) if tt.wantErr { assert.Error(t, err) return @@ -266,7 +216,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { } func TestPgGrammar_CompileChange(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -309,7 +259,7 @@ func TestPgGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.String("email", 500).Default(nil).Change() }, - want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500), ALTER COLUMN email DROP DEFAULT"}, + want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500), ALTER COLUMN email SET DEFAULT NULL"}, }, { name: "Add comment to column", @@ -328,7 +278,7 @@ func TestPgGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.String("email", 500).Comment("").Change() }, - want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500)", "COMMENT ON COLUMN users.email IS NULL"}, + want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500)", "COMMENT ON COLUMN users.email IS ''"}, }, { name: "Set column to not nullable", @@ -356,9 +306,9 @@ func TestPgGrammar_CompileChange(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} + bp := &Blueprint{name: tt.table, grammar: grammar} tt.blueprint(bp) - got, err := grammar.compileChange(bp, bp.commands[0]) + got, err := bp.toSql() if tt.wantErr { assert.Error(t, err) return @@ -371,7 +321,7 @@ func TestPgGrammar_CompileChange(t *testing.T) { } func TestPgGrammar_CompileDrop(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -390,7 +340,7 @@ func TestPgGrammar_CompileDrop(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileDrop(bp) + got, err := grammar.CompileDrop(bp) if tt.wantErr { assert.Error(t, err) return @@ -403,7 +353,7 @@ func TestPgGrammar_CompileDrop(t *testing.T) { } func TestPgGrammar_CompileDropIfExists(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -422,7 +372,7 @@ func TestPgGrammar_CompileDropIfExists(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileDropIfExists(bp) + got, err := grammar.CompileDropIfExists(bp) if tt.wantErr { assert.Error(t, err) return @@ -435,7 +385,7 @@ func TestPgGrammar_CompileDropIfExists(t *testing.T) { } func TestPgGrammar_CompileRename(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -457,7 +407,7 @@ func TestPgGrammar_CompileRename(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.oldName} bp.rename(tt.newName) - got, err := grammar.compileRename(bp, bp.commands[0]) + got, err := grammar.CompileRename(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -470,7 +420,7 @@ func TestPgGrammar_CompileRename(t *testing.T) { } func TestPgGrammar_GetColumns(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -504,14 +454,7 @@ func TestPgGrammar_GetColumns(t *testing.T) { blueprint: func(table *Blueprint) { table.Boolean("active").Default(true) }, - want: []string{"active BOOLEAN DEFAULT true NOT NULL"}, - }, - { - name: "Column with comment", - blueprint: func(table *Blueprint) { - table.String("description", 500).Comment("User description") - }, - want: []string{"description VARCHAR(500) NOT NULL COMMENT 'User description'"}, + want: []string{"active BOOLEAN DEFAULT '1' NOT NULL"}, }, { name: "Primary key column", @@ -544,13 +487,13 @@ func TestPgGrammar_GetColumns(t *testing.T) { } func TestPgGrammar_CompileDropColumn(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string table string blueprint func(table *Blueprint) - want string + wants []string wantErr bool }{ { @@ -559,45 +502,46 @@ func TestPgGrammar_CompileDropColumn(t *testing.T) { blueprint: func(table *Blueprint) { table.DropColumn("email") }, - want: "ALTER TABLE users DROP COLUMN email", + wants: []string{"ALTER TABLE users DROP COLUMN email"}, wantErr: false, }, { name: "Drop multiple columns", table: "users", blueprint: func(table *Blueprint) { - table.DropColumn("email", "phone", "address") + table.DropColumn("email", "phone") + table.DropColumn("address") }, - want: "ALTER TABLE users DROP COLUMN email, DROP COLUMN phone, DROP COLUMN address", + wants: []string{"ALTER TABLE users DROP COLUMN email, DROP COLUMN phone", "ALTER TABLE users DROP COLUMN address"}, wantErr: false, }, { name: "No columns to drop", table: "users", blueprint: func(table *Blueprint) {}, - want: "", + wants: nil, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} + bp := &Blueprint{name: tt.table, grammar: grammar} tt.blueprint(bp) - got, err := grammar.compileDropColumn(bp, bp.commands[0]) + got, err := bp.toSql() if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) - assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wants, got) }) } } func TestPgGrammar_CompileRenameColumn(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -642,7 +586,7 @@ func TestPgGrammar_CompileRenameColumn(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{from: tt.oldName, to: tt.newName} - got, err := grammar.compileRenameColumn(bp, command) + got, err := grammar.CompileRenameColumn(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -655,7 +599,7 @@ func TestPgGrammar_CompileRenameColumn(t *testing.T) { } func TestPgGrammar_CompileDropIndex(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -687,7 +631,7 @@ func TestPgGrammar_CompileDropIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} command := &command{index: tt.indexName} - got, err := grammar.compileDropIndex(bp, command) + got, err := grammar.CompileDropIndex(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -701,7 +645,7 @@ func TestPgGrammar_CompileDropIndex(t *testing.T) { } func TestPgGrammar_CompileDropPrimary(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -733,7 +677,7 @@ func TestPgGrammar_CompileDropPrimary(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { command := &command{index: tt.indexName} - got, err := grammar.compileDropPrimary(tt.blueprint, command) + got, err := grammar.CompileDropPrimary(tt.blueprint, command) if tt.wantErr { assert.Error(t, err) return @@ -746,7 +690,7 @@ func TestPgGrammar_CompileDropPrimary(t *testing.T) { } func TestPgGrammar_CompileRenameIndex(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -799,7 +743,7 @@ func TestPgGrammar_CompileRenameIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{from: tt.oldName, to: tt.newName} - got, err := grammar.compileRenameIndex(bp, command) + got, err := grammar.CompileRenameIndex(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -812,7 +756,7 @@ func TestPgGrammar_CompileRenameIndex(t *testing.T) { } func TestPgGrammar_CompileForeign(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -977,7 +921,7 @@ func TestPgGrammar_CompileForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileForeign(bp, bp.commands[0]) + got, err := grammar.CompileForeign(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -990,7 +934,7 @@ func TestPgGrammar_CompileForeign(t *testing.T) { } func TestPgGrammar_CompileDropForeign(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1026,7 +970,7 @@ func TestPgGrammar_CompileDropForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} command := &command{index: tt.foreignKeyName} - got, err := grammar.compileDropForeign(bp, command) + got, err := grammar.CompileDropForeign(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -1039,7 +983,7 @@ func TestPgGrammar_CompileDropForeign(t *testing.T) { } func TestPgGrammar_CompileIndex(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1106,7 +1050,7 @@ func TestPgGrammar_CompileIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileIndex(bp, bp.commands[0]) + got, err := grammar.CompileIndex(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1119,7 +1063,7 @@ func TestPgGrammar_CompileIndex(t *testing.T) { } func TestPgGrammar_CompileUnique(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1214,7 +1158,7 @@ func TestPgGrammar_CompileUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileUnique(bp, bp.commands[0]) + got, err := grammar.CompileUnique(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1227,7 +1171,7 @@ func TestPgGrammar_CompileUnique(t *testing.T) { } func TestPgGrammar_CompileFullText(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1312,7 +1256,7 @@ func TestPgGrammar_CompileFullText(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileFullText(bp, bp.commands[0]) + got, err := grammar.CompileFullText(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1325,7 +1269,7 @@ func TestPgGrammar_CompileFullText(t *testing.T) { } func TestPgGrammar_CompileDropUnique(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1363,7 +1307,7 @@ func TestPgGrammar_CompileDropUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} command := &command{index: tt.indexName} - got, err := grammar.compileDropUnique(bp, command) + got, err := grammar.CompileDropUnique(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -1377,7 +1321,7 @@ func TestPgGrammar_CompileDropUnique(t *testing.T) { } func TestPgGrammar_CompileDropFulltext(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1421,7 +1365,7 @@ func TestPgGrammar_CompileDropFulltext(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} command := &command{index: tt.indexName} - got, err := grammar.compileDropFulltext(bp, command) + got, err := grammar.CompileDropFulltext(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -1435,7 +1379,7 @@ func TestPgGrammar_CompileDropFulltext(t *testing.T) { } func TestPgGrammar_CompilePrimary(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1502,7 +1446,7 @@ func TestPgGrammar_CompilePrimary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compilePrimary(bp, bp.commands[0]) + got, err := grammar.CompilePrimary(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1515,7 +1459,7 @@ func TestPgGrammar_CompilePrimary(t *testing.T) { } func TestPgGrammar_GetType(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string diff --git a/schema.go b/schema.go index a37f761..f5531ff 100644 --- a/schema.go +++ b/schema.go @@ -4,6 +4,9 @@ import ( "context" "database/sql" "errors" + + "github.com/afkdevs/go-schema/internal/config" + "github.com/afkdevs/go-schema/internal/dialect" ) // Column represents a database column with its properties. @@ -39,13 +42,12 @@ type TableInfo struct { } func newBuilder() (Builder, error) { - if dialectValue == dialectUnknown { + cfg := config.Get() + if cfg.Dialect == dialect.Unknown { return nil, errors.New("schema dialect is not set, please call schema.SetDialect() before using schema functions") } - builder, err := NewBuilder(dialectValue.String(), - WithDebug(cfg.debug), - ) + builder, err := NewBuilder(cfg.Dialect.String()) if err != nil { return nil, err } @@ -76,29 +78,6 @@ func Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table * return builder.Create(ctx, tx, name, blueprint) } -// CreateIfNotExists creates a new table with the given name and blueprint if it does not already exist. -// The blueprint function is used to define the structure of the table. -// It returns an error if the table creation fails. -// -// Example: -// -// err := schema.CreateIfNotExists(ctx, tx, "users", func(table *schema.Blueprint) { -// table.ID() -// table.String("name").Nullable(false) -// table.String("email").Unique().Nullable(false) -// table.String("password").Nullable() -// table.Timestamp("created_at").Default("CURRENT_TIMESTAMP").Nullable(false) -// table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable(false) -// }) -func CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - builder, err := newBuilder() - if err != nil { - return err - } - - return builder.CreateIfNotExists(ctx, tx, name, blueprint) -} - // Drop removes the table with the given name. // It returns an error if the table removal fails. // diff --git a/schema_test.go b/schema_test.go index 1e1540d..8462997 100644 --- a/schema_test.go +++ b/schema_test.go @@ -37,13 +37,8 @@ func (s *schemaTestSuite) SetupSuite() { s.db = db s.Run("when dialect is not set should return error", func() { - err := schema.Init("") - s.Error(err) - s.ErrorContains(err, "unknown dialect") - builderFuncs := []func() error{ func() error { return schema.Create(ctx, nil, "", nil) }, - func() error { return schema.CreateIfNotExists(ctx, nil, "", nil) }, func() error { return schema.Drop(ctx, nil, "") }, func() error { return schema.DropIfExists(ctx, nil, "") }, func() error { _, err := schema.GetColumns(ctx, nil, ""); return err }, @@ -60,10 +55,7 @@ func (s *schemaTestSuite) SetupSuite() { s.Error(fn(), "Expected error when dialect is not set") } }) - s.Run("when dialect is set to postgres should not return error", func() { - err = schema.Init("postgres") - s.Require().NoError(err) - }) + schema.SetDialect("postgres") } func (s *schemaTestSuite) TearDownSuite() { @@ -98,67 +90,19 @@ func (s *schemaTestSuite) TestCreate() { table.ID() }) s.Error(err) - s.ErrorContains(err, "table name is empty") + s.ErrorContains(err, "invalid arguments") }) s.Run("when blueprint function is nil should return error", func() { err := schema.Create(s.ctx, tx, "test", nil) s.Error(err) - s.ErrorContains(err, "blueprint function is nil") + s.ErrorContains(err, "invalid arguments") }) s.Run("when transaction is nil should return error", func() { err := schema.Create(s.ctx, nil, "test", func(table *schema.Blueprint) { table.ID() }) s.Error(err) - s.ErrorContains(err, "transaction is nil") - }) -} - -func (s *schemaTestSuite) TestCreateIfNotExists() { - tx, err := s.db.BeginTx(s.ctx, nil) - s.Require().NoError(err) - defer tx.Rollback() //nolint:errcheck - - s.Run("when parameters are valid should create table", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email").Unique() - table.String("password") - table.Timestamp("created_at").UseCurrent() - table.Timestamp("updated_at").UseCurrent() - }) - s.NoError(err) - }) - s.Run("when table already exists should return no error", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email") - table.String("password") - table.Timestamp("created_at").UseCurrent() - table.Timestamp("updated_at").UseCurrent() - }) - s.NoError(err) - }) - s.Run("when table name is empty should return error", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) { - table.ID() - }) - s.Error(err) - s.ErrorContains(err, "table name is empty") - }) - s.Run("when blueprint function is nil should return error", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "test", nil) - s.Error(err) - s.ErrorContains(err, "blueprint function is nil") - }) - s.Run("when transaction is nil should return error", func() { - err := schema.CreateIfNotExists(s.ctx, nil, "test", func(table *schema.Blueprint) { - table.ID() - }) - s.Error(err) - s.ErrorContains(err, "transaction is nil") + s.ErrorContains(err, "invalid arguments") }) } @@ -187,12 +131,12 @@ func (s *schemaTestSuite) TestDrop() { s.Run("when table name is empty should return error", func() { err := schema.Drop(s.ctx, tx, "") s.Error(err) - s.ErrorContains(err, "table name is empty") + s.ErrorContains(err, "invalid arguments") }) s.Run("when transaction is nil should return error", func() { err := schema.Drop(s.ctx, nil, "test") s.Error(err) - s.ErrorContains(err, "transaction is nil") + s.ErrorContains(err, "invalid arguments") }) } @@ -221,12 +165,12 @@ func (s *schemaTestSuite) TestDropIfExists() { s.Run("when table name is empty should return error", func() { err := schema.DropIfExists(s.ctx, tx, "") s.Error(err) - s.ErrorContains(err, "table name is empty") + s.ErrorContains(err, "invalid arguments") }) s.Run("when transaction is nil should return error", func() { err := schema.DropIfExists(s.ctx, nil, "test") s.Error(err) - s.ErrorContains(err, "transaction is nil") + s.ErrorContains(err, "invalid arguments") }) } diff --git a/template.go b/template.go new file mode 100644 index 0000000..deb5928 --- /dev/null +++ b/template.go @@ -0,0 +1,101 @@ +package schema + +import ( + "text/template" + + "github.com/afkdevs/go-schema/internal/parser" +) + +func GooseMigrationTemplate(name string) *template.Template { + tableName, create := parser.ParseMigrationName(name) + if create { + return migrationCreateTemplate(tableName) + } + if tableName != "" { + return migrationUpdateTemplate(tableName) + } + return migrationTemplate +} + +var migrationTemplate = template.Must(template.New("migrator.go-migration").Parse(`package migrations + +import ( + "context" + "database/sql" + + "github.com/afkdevs/go-schema" + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) +} + +func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { + // This code is executed when the migration is applied. + return nil +} + +func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { + // This code is executed when the migration is rolled back. + return nil +} +`)) + +func migrationCreateTemplate(table string) *template.Template { + tmpl := `package migrations + +import ( + "context" + "database/sql" + + "github.com/afkdevs/go-schema" + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) +} + +func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { + return schema.Create(ctx, tx, "` + table + `", func(table *schema.Blueprint) { + // Define your table schema here + }) +} + +func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { + return schema.DropIfExists(ctx, tx, "` + table + `") +} +` + return template.Must(template.New("migration-create").Parse(tmpl)) +} + +func migrationUpdateTemplate(table string) *template.Template { + tmpl := `package migrations + +import ( + "context" + "database/sql" + + "github.com/afkdevs/go-schema" + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) +} + +func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { + return schema.Table(ctx, tx, "` + table + `", func(table *schema.Blueprint) { + // Define your table schema changes here + }) +} + +func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { + return schema.Table(ctx, tx, "` + table + `", func(table *schema.Blueprint) { + // Define your table schema changes here + }) +} +` + return template.Must(template.New("migration-update").Parse(tmpl)) +} From b88574b03ad8a7ceb08b0b9c0850e09f9fe54af2 Mon Sep 17 00:00:00 2001 From: Ahmad Faiz Kamaludin Date: Sat, 30 Aug 2025 23:03:17 +0700 Subject: [PATCH 3/7] feat: refactor pkg name to migris --- config.go | 13 +++- template.go => create.go | 31 +++++--- down.go | 23 ++++++ examples/basic/cmd/migrate/cmd.go | 4 +- examples/basic/cmd/migrate/migrate.go | 25 +++--- examples/basic/cmd/root.go | 2 +- examples/basic/go.mod | 8 +- examples/basic/go.sum | 24 +++--- examples/basic/main.go | 2 +- .../20250830103612_create_users_table.go | 6 +- .../20250830103653_create_roles_table.go | 6 +- .../20250830103714_create_user_roles_table.go | 6 +- go.mod | 7 +- go.sum | 32 ++++++++ internal/config/config.go | 16 ++-- internal/dialect/dialect.go | 13 ++++ internal/parser/migration_test.go | 2 +- provider.go | 34 ++++++++ register.go | 78 +++++++++++++++++++ reset.go | 23 ++++++ blueprint.go => schema/blueprint.go | 4 +- builder.go => schema/builder.go | 2 +- .../column_definition.go | 2 +- command.go => schema/command.go | 0 .../foreign_key_definition.go | 2 +- grammar.go => schema/grammar.go | 2 +- .../index_definition.go | 2 +- mysql_builder.go => schema/mysql_builder.go | 2 +- .../mysql_builder_test.go | 70 ++++++++--------- mysql_grammar.go => schema/mysql_grammar.go | 2 +- .../mysql_grammar_test.go | 2 +- .../postgres_builder.go | 2 +- .../postgres_builder_test.go | 74 +++++++++--------- .../postgres_grammar.go | 0 .../postgres_grammar_test.go | 0 schema.go => schema/schema.go | 4 +- schema_test.go => schema/schema_test.go | 5 +- up.go | 25 ++++++ 38 files changed, 399 insertions(+), 156 deletions(-) rename template.go => create.go (73%) create mode 100644 down.go create mode 100644 provider.go create mode 100644 register.go create mode 100644 reset.go rename blueprint.go => schema/blueprint.go (99%) rename builder.go => schema/builder.go (98%) rename column_definition.go => schema/column_definition.go (99%) rename command.go => schema/command.go (100%) rename foreign_key_definition.go => schema/foreign_key_definition.go (98%) rename grammar.go => schema/grammar.go (98%) rename index_definition.go => schema/index_definition.go (96%) rename mysql_builder.go => schema/mysql_builder.go (99%) rename mysql_builder_test.go => schema/mysql_builder_test.go (92%) rename mysql_grammar.go => schema/mysql_grammar.go (99%) rename mysql_grammar_test.go => schema/mysql_grammar_test.go (99%) rename postgres_builder.go => schema/postgres_builder.go (99%) rename postgres_builder_test.go => schema/postgres_builder_test.go (91%) rename postgres_grammar.go => schema/postgres_grammar.go (100%) rename postgres_grammar_test.go => schema/postgres_grammar_test.go (100%) rename schema.go => schema/schema.go (98%) rename schema_test.go => schema/schema_test.go (99%) create mode 100644 up.go diff --git a/config.go b/config.go index 341e709..1a80e34 100644 --- a/config.go +++ b/config.go @@ -1,10 +1,10 @@ -package schema +package migris import ( "errors" - "github.com/afkdevs/go-schema/internal/config" - "github.com/afkdevs/go-schema/internal/dialect" + "github.com/afkdevs/migris/internal/config" + "github.com/afkdevs/migris/internal/dialect" ) // SetDialect sets the migrator dialect @@ -25,3 +25,10 @@ func SetVerbose(enabled bool) { cfg.Verbose = enabled config.Set(cfg) } + +// SetTableName sets the table name for the migrator +func SetTableName(name string) { + cfg := config.Get() + cfg.TableName = name + config.Set(cfg) +} diff --git a/template.go b/create.go similarity index 73% rename from template.go rename to create.go index deb5928..1b139d7 100644 --- a/template.go +++ b/create.go @@ -1,12 +1,19 @@ -package schema +package migris import ( "text/template" - "github.com/afkdevs/go-schema/internal/parser" + "github.com/afkdevs/migris/internal/parser" + "github.com/pressly/goose/v3" ) -func GooseMigrationTemplate(name string) *template.Template { +// Create a new migration file +func Create(dir, name string) error { + tmpl := getMigrationTemplate(name) + return goose.CreateWithTemplate(nil, dir, tmpl, name, "go") +} + +func getMigrationTemplate(name string) *template.Template { tableName, create := parser.ParseMigrationName(name) if create { return migrationCreateTemplate(tableName) @@ -23,12 +30,12 @@ import ( "context" "database/sql" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" ) func init() { - goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) + migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { @@ -49,12 +56,12 @@ import ( "context" "database/sql" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" ) func init() { - goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) + migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { @@ -77,12 +84,12 @@ import ( "context" "database/sql" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" ) func init() { - goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) + migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { diff --git a/down.go b/down.go new file mode 100644 index 0000000..dd364c1 --- /dev/null +++ b/down.go @@ -0,0 +1,23 @@ +package migris + +import ( + "context" + "database/sql" +) + +func Down(db *sql.DB, dir string) error { + ctx := context.Background() + return DownContext(ctx, db, dir) +} + +func DownContext(ctx context.Context, db *sql.DB, dir string) error { + provider, err := newProvider(db, dir) + if err != nil { + return err + } + _, err = provider.Down(ctx) + if err != nil { + return err + } + return nil +} diff --git a/examples/basic/cmd/migrate/cmd.go b/examples/basic/cmd/migrate/cmd.go index fbac4f0..a5b1ddc 100644 --- a/examples/basic/cmd/migrate/cmd.go +++ b/examples/basic/cmd/migrate/cmd.go @@ -38,7 +38,7 @@ func Command() *cli.Command { }, }, Action: func(ctx context.Context, c *cli.Command) error { - return Up(c.Bool("dry-run")) + return Up() }, }, { @@ -53,7 +53,7 @@ func Command() *cli.Command { }, }, Action: func(ctx context.Context, c *cli.Command) error { - return Reset(c.Bool("dry-run")) + return Reset() }, }, }, diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index 795dffa..ae081c6 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -4,45 +4,38 @@ import ( "database/sql" "fmt" - "github.com/afkdevs/go-schema" - "github.com/afkdevs/go-schema/examples/basic/config" - _ "github.com/afkdevs/go-schema/examples/basic/migrations" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/examples/basic/config" + _ "github.com/afkdevs/migris/examples/basic/migrations" _ "github.com/lib/pq" // PostgreSQL driver - "github.com/pressly/goose/v3" ) const ( directory = "migrations" - tableName = "schema_migrations" ) -func Up(dryRun bool) error { +func Up() error { db, err := initMigrator() if err != nil { return err } - return goose.Up(db, directory) + return migris.Up(db, directory) } func Create(name string) error { - tmpl := schema.GooseMigrationTemplate(name) - return goose.CreateWithTemplate(nil, directory, tmpl, name, "go") + return migris.Create(directory, name) } -func Reset(dryRun bool) error { +func Reset() error { db, err := initMigrator() if err != nil { return err } - return goose.Reset(db, directory) + return migris.Reset(db, directory) } func initMigrator() (*sql.DB, error) { - if err := goose.SetDialect("postgres"); err != nil { - return nil, fmt.Errorf("failed to set dialect: %w", err) - } - goose.SetTableName(tableName) - if err := schema.SetDialect("postgres"); err != nil { + if err := migris.SetDialect("postgres"); err != nil { return nil, fmt.Errorf("failed to set schema dialect: %w", err) } cfg, err := config.Load() diff --git a/examples/basic/cmd/root.go b/examples/basic/cmd/root.go index e7a80d0..7f6846a 100644 --- a/examples/basic/cmd/root.go +++ b/examples/basic/cmd/root.go @@ -3,7 +3,7 @@ package cmd import ( "context" - "github.com/afkdevs/go-schema/examples/basic/cmd/migrate" + "github.com/afkdevs/migris/examples/basic/cmd/migrate" "github.com/urfave/cli/v3" ) diff --git a/examples/basic/go.mod b/examples/basic/go.mod index 3443552..6f3febf 100644 --- a/examples/basic/go.mod +++ b/examples/basic/go.mod @@ -1,20 +1,20 @@ -module github.com/afkdevs/go-schema/examples/basic +module github.com/afkdevs/migris/examples/basic go 1.23.0 require ( - github.com/afkdevs/go-schema v0.0.0 + github.com/afkdevs/migris v0.0.0 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 - github.com/pressly/goose/v3 v3.24.3 github.com/urfave/cli/v3 v3.3.8 ) require ( github.com/mfridman/interpolate v0.0.2 // indirect + github.com/pressly/goose/v3 v3.25.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/sync v0.16.0 // indirect ) -replace github.com/afkdevs/go-schema => ../.. +replace github.com/afkdevs/migris => ../.. diff --git a/examples/basic/go.sum b/examples/basic/go.sum index 1b5a029..99b7bc7 100644 --- a/examples/basic/go.sum +++ b/examples/basic/go.sum @@ -20,8 +20,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pressly/goose/v3 v3.24.3 h1:DSWWNwwggVUsYZ0X2VitiAa9sKuqtBfe+Jr9zFGwWlM= -github.com/pressly/goose/v3 v3.24.3/go.mod h1:v9zYL4xdViLHCUUJh/mhjnm6JrK7Eul8AS93IxiZM4E= +github.com/pressly/goose/v3 v3.25.0 h1:6WeYhMWGRCzpyd89SpODFnCBCKz41KrVbRT58nVjGng= +github.com/pressly/goose/v3 v3.25.0/go.mod h1:4hC1KrritdCxtuFsqgs1R4AU5bWtTAf+cnWvfhf2DNY= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= @@ -32,19 +32,19 @@ github.com/urfave/cli/v3 v3.3.8 h1:BzolUExliMdet9NlJ/u4m5vHSotJ3PzEqSAZ1oPMa/E= github.com/urfave/cli/v3 v3.3.8/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= -golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -modernc.org/libc v1.65.0 h1:e183gLDnAp9VJh6gWKdTy0CThL9Pt7MfcR/0bgb7Y1Y= -modernc.org/libc v1.65.0/go.mod h1:7m9VzGq7APssBTydds2zBcxGREwvIGpuUBaKTXdm2Qs= +modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ= +modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= -modernc.org/memory v1.10.0 h1:fzumd51yQ1DxcOxSO+S6X7+QTuVU+n8/Aj7swYjFfC4= -modernc.org/memory v1.10.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/sqlite v1.37.0 h1:s1TMe7T3Q3ovQiK2Ouz4Jwh7dw4ZDqbebSDTlSJdfjI= -modernc.org/sqlite v1.37.0/go.mod h1:5YiWv+YviqGMuGw4V+PNplcyaJ5v+vQd7TQOgkACoJM= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek= +modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E= diff --git a/examples/basic/main.go b/examples/basic/main.go index 4c15da0..ac85745 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -4,7 +4,7 @@ import ( "log" "os" - "github.com/afkdevs/go-schema/examples/basic/cmd" + "github.com/afkdevs/migris/examples/basic/cmd" ) func main() { diff --git a/examples/basic/migrations/20250830103612_create_users_table.go b/examples/basic/migrations/20250830103612_create_users_table.go index 53776a3..393b615 100644 --- a/examples/basic/migrations/20250830103612_create_users_table.go +++ b/examples/basic/migrations/20250830103612_create_users_table.go @@ -4,12 +4,12 @@ import ( "context" "database/sql" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" ) func init() { - goose.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) + migris.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) } func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { diff --git a/examples/basic/migrations/20250830103653_create_roles_table.go b/examples/basic/migrations/20250830103653_create_roles_table.go index d857cc5..7354d31 100644 --- a/examples/basic/migrations/20250830103653_create_roles_table.go +++ b/examples/basic/migrations/20250830103653_create_roles_table.go @@ -4,12 +4,12 @@ import ( "context" "database/sql" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" ) func init() { - goose.AddMigrationContext(upCreateRolesTable, downCreateRolesTable) + migris.AddMigrationContext(upCreateRolesTable, downCreateRolesTable) } func upCreateRolesTable(ctx context.Context, tx *sql.Tx) error { diff --git a/examples/basic/migrations/20250830103714_create_user_roles_table.go b/examples/basic/migrations/20250830103714_create_user_roles_table.go index 713023b..581b201 100644 --- a/examples/basic/migrations/20250830103714_create_user_roles_table.go +++ b/examples/basic/migrations/20250830103714_create_user_roles_table.go @@ -4,12 +4,12 @@ import ( "context" "database/sql" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" ) func init() { - goose.AddMigrationContext(upCreateUserRolesTable, downCreateUserRolesTable) + migris.AddMigrationContext(upCreateUserRolesTable, downCreateUserRolesTable) } func upCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error { diff --git a/go.mod b/go.mod index 9c4b745..718df7b 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,21 @@ -module github.com/afkdevs/go-schema +module github.com/afkdevs/migris go 1.23.0 require ( github.com/go-sql-driver/mysql v1.9.3 github.com/lib/pq v1.10.9 + github.com/pressly/goose/v3 v3.25.0 github.com/stretchr/testify v1.11.1 ) require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mfridman/interpolate v0.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/sethvargo/go-retry v0.3.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/sync v0.16.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f813003..9628ba3 100644 --- a/go.sum +++ b/go.sum @@ -2,15 +2,47 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= +github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pressly/goose/v3 v3.25.0 h1:6WeYhMWGRCzpyd89SpODFnCBCKz41KrVbRT58nVjGng= +github.com/pressly/goose/v3 v3.25.0/go.mod h1:4hC1KrritdCxtuFsqgs1R4AU5bWtTAf+cnWvfhf2DNY= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= +github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ= +modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek= +modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E= diff --git a/internal/config/config.go b/internal/config/config.go index da85f61..293570b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,25 +3,27 @@ package config import ( "sync/atomic" - "github.com/afkdevs/go-schema/internal/dialect" + "github.com/afkdevs/migris/internal/dialect" ) type Config struct { - Dialect dialect.Dialect - Verbose bool + Dialect dialect.Dialect + TableName string + Verbose bool } var config = atomic.Pointer[Config]{} func init() { config.Store(&Config{ - Dialect: dialect.Unknown, - Verbose: true, + Dialect: dialect.Unknown, + TableName: "migris_db_version", + Verbose: true, }) } -func Set(newConfig *Config) { - config.Store(newConfig) +func Set(cfg *Config) { + config.Store(cfg) } func Get() *Config { diff --git a/internal/dialect/dialect.go b/internal/dialect/dialect.go index 3291cba..c9f2fb4 100644 --- a/internal/dialect/dialect.go +++ b/internal/dialect/dialect.go @@ -1,5 +1,7 @@ package dialect +import "github.com/pressly/goose/v3/database" + // Dialect is the type of database dialect. type Dialect string @@ -13,6 +15,17 @@ func (d Dialect) String() string { return string(d) } +func (d Dialect) GooseDialect() database.Dialect { + switch d { + case MySQL: + return database.DialectMySQL + case Postgres: + return database.DialectPostgres + default: + return database.DialectCustom + } +} + func FromString(dialect string) Dialect { switch dialect { case "mysql", "mariadb": diff --git a/internal/parser/migration_test.go b/internal/parser/migration_test.go index c890078..f7c7b99 100644 --- a/internal/parser/migration_test.go +++ b/internal/parser/migration_test.go @@ -3,7 +3,7 @@ package parser_test import ( "testing" - "github.com/afkdevs/go-schema/internal/parser" + "github.com/afkdevs/migris/internal/parser" ) func TestParseMigrationName(t *testing.T) { diff --git a/provider.go b/provider.go new file mode 100644 index 0000000..42122b9 --- /dev/null +++ b/provider.go @@ -0,0 +1,34 @@ +package migris + +import ( + "database/sql" + "errors" + "os" + + "github.com/afkdevs/migris/internal/config" + "github.com/afkdevs/migris/internal/dialect" + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/database" +) + +func newProvider(db *sql.DB, dir string) (*goose.Provider, error) { + cfg := config.Get() + if cfg.Dialect == dialect.Unknown { + return nil, errors.New("unknown database dialect") + } + dialect := cfg.Dialect.GooseDialect() + store, err := database.NewStore(dialect, cfg.TableName) + if err != nil { + return nil, err + } + provider, err := goose.NewProvider(database.DialectCustom, db, os.DirFS(dir), + goose.WithStore(store), + goose.WithDisableGlobalRegistry(true), + goose.WithGoMigrations(getMigrations()...), + goose.WithVerbose(cfg.Verbose), + ) + if err != nil { + return nil, err + } + return provider, nil +} diff --git a/register.go b/register.go new file mode 100644 index 0000000..b0bd30f --- /dev/null +++ b/register.go @@ -0,0 +1,78 @@ +package migris + +import ( + "context" + "database/sql" + "fmt" + "runtime" + "slices" + + "github.com/pressly/goose/v3" +) + +var ( + registeredGoMigrations = make(map[int64]*goose.Migration) +) + +// GoMigrationContext is a Go migration func that is run within a transaction and receives a +// context. +type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error + +// AddMigrationContext adds Go migrations. +func AddMigrationContext(up, down GoMigrationContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, up, down) +} + +// AddNamedMigrationContext adds named Go migrations. +func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { + if err := register( + filename, + true, + &goose.GoFunc{RunTx: up, Mode: goose.TransactionEnabled}, + &goose.GoFunc{RunTx: down, Mode: goose.TransactionEnabled}, + ); err != nil { + panic(err) + } +} + +func register(filename string, useTx bool, up, down *goose.GoFunc) error { + v, _ := goose.NumericComponent(filename) + if existing, ok := registeredGoMigrations[v]; ok { + return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", + filename, + v, + existing.Source, + ) + } + // Add to global as a registered migration. + m := goose.NewGoMigration(v, up, down) + m.Source = filename + // We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but + // we know based on the register function what the user is requesting. + m.UseTx = useTx + registeredGoMigrations[v] = m + return nil +} + +func getMigrations() []*goose.Migration { + type migrationWithVersion struct { + version int64 + migration *goose.Migration + } + var migrations []*migrationWithVersion + for _, m := range registeredGoMigrations { + migrations = append(migrations, &migrationWithVersion{ + version: m.Version, + migration: m, + }) + } + slices.SortFunc(migrations, func(a, b *migrationWithVersion) int { + return int(a.version - b.version) + }) + var results []*goose.Migration + for _, m := range migrations { + results = append(results, m.migration) + } + return results +} diff --git a/reset.go b/reset.go new file mode 100644 index 0000000..43ee030 --- /dev/null +++ b/reset.go @@ -0,0 +1,23 @@ +package migris + +import ( + "context" + "database/sql" +) + +func Reset(db *sql.DB, dir string) error { + ctx := context.Background() + return ResetContext(ctx, db, dir) +} + +func ResetContext(ctx context.Context, db *sql.DB, dir string) error { + provider, err := newProvider(db, dir) + if err != nil { + return err + } + _, err = provider.DownTo(ctx, 0) + if err != nil { + return err + } + return nil +} diff --git a/blueprint.go b/schema/blueprint.go similarity index 99% rename from blueprint.go rename to schema/blueprint.go index f2b3149..c28b533 100644 --- a/blueprint.go +++ b/schema/blueprint.go @@ -6,8 +6,8 @@ import ( "fmt" "log" - "github.com/afkdevs/go-schema/internal/dialect" - "github.com/afkdevs/go-schema/internal/util" + "github.com/afkdevs/migris/internal/dialect" + "github.com/afkdevs/migris/internal/util" ) const ( diff --git a/builder.go b/schema/builder.go similarity index 98% rename from builder.go rename to schema/builder.go index a300ca2..823d014 100644 --- a/builder.go +++ b/schema/builder.go @@ -5,7 +5,7 @@ import ( "database/sql" "errors" - "github.com/afkdevs/go-schema/internal/dialect" + "github.com/afkdevs/migris/internal/dialect" ) // Builder is an interface that defines methods for creating, dropping, and managing database tables. diff --git a/column_definition.go b/schema/column_definition.go similarity index 99% rename from column_definition.go rename to schema/column_definition.go index b50ac39..4c4f006 100644 --- a/column_definition.go +++ b/schema/column_definition.go @@ -3,7 +3,7 @@ package schema import ( "slices" - "github.com/afkdevs/go-schema/internal/util" + "github.com/afkdevs/migris/internal/util" ) // ColumnDefinition defines the interface for defining a column in a database table. diff --git a/command.go b/schema/command.go similarity index 100% rename from command.go rename to schema/command.go diff --git a/foreign_key_definition.go b/schema/foreign_key_definition.go similarity index 98% rename from foreign_key_definition.go rename to schema/foreign_key_definition.go index 476e05b..cd86613 100644 --- a/foreign_key_definition.go +++ b/schema/foreign_key_definition.go @@ -1,6 +1,6 @@ package schema -import "github.com/afkdevs/go-schema/internal/util" +import "github.com/afkdevs/migris/internal/util" // ForeignKeyDefinition defines the interface for defining a foreign key constraint in a database table. type ForeignKeyDefinition interface { diff --git a/grammar.go b/schema/grammar.go similarity index 98% rename from grammar.go rename to schema/grammar.go index aed1955..45ab2e3 100644 --- a/grammar.go +++ b/schema/grammar.go @@ -5,7 +5,7 @@ import ( "slices" "strings" - "github.com/afkdevs/go-schema/internal/util" + "github.com/afkdevs/migris/internal/util" ) type grammar interface { diff --git a/index_definition.go b/schema/index_definition.go similarity index 96% rename from index_definition.go rename to schema/index_definition.go index d13e5f8..6aa5b67 100644 --- a/index_definition.go +++ b/schema/index_definition.go @@ -1,6 +1,6 @@ package schema -import "github.com/afkdevs/go-schema/internal/util" +import "github.com/afkdevs/migris/internal/util" // IndexDefinition defines the interface for defining an index in a database table. type IndexDefinition interface { diff --git a/mysql_builder.go b/schema/mysql_builder.go similarity index 99% rename from mysql_builder.go rename to schema/mysql_builder.go index 284052a..6bc4399 100644 --- a/mysql_builder.go +++ b/schema/mysql_builder.go @@ -6,7 +6,7 @@ import ( "errors" "strings" - "github.com/afkdevs/go-schema/internal/config" + "github.com/afkdevs/migris/internal/config" ) type mysqlBuilder struct { diff --git a/mysql_builder_test.go b/schema/mysql_builder_test.go similarity index 92% rename from mysql_builder_test.go rename to schema/mysql_builder_test.go index b274864..bf3d06e 100644 --- a/mysql_builder_test.go +++ b/schema/mysql_builder_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - "github.com/afkdevs/go-schema" + schema2 "github.com/afkdevs/migris/schema" _ "github.com/go-sql-driver/mysql" // MySQL driver "github.com/stretchr/testify/suite" ) @@ -19,7 +19,7 @@ type mysqlBuilderSuite struct { suite.Suite ctx context.Context db *sql.DB - builder schema.Builder + builder schema2.Builder } func (s *mysqlBuilderSuite) SetupSuite() { @@ -43,7 +43,7 @@ func (s *mysqlBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - s.builder, err = schema.NewBuilder("mysql") + s.builder, err = schema2.NewBuilder("mysql") s.Require().NoError(err) } @@ -74,13 +74,13 @@ func (s *mysqlBuilderSuite) TestCreate() { defer tx.Rollback() //nolint:errcheck s.Run("when tx is nil, should return error", func() { - err := builder.Create(s.ctx, nil, "test_table", func(table *schema.Blueprint) { + err := builder.Create(s.ctx, nil, "test_table", func(table *schema2.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Create(s.ctx, tx, "", func(table *schema.Blueprint) { + err := builder.Create(s.ctx, tx, "", func(table *schema2.Blueprint) { table.String("name") }) s.Error(err) @@ -90,7 +90,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.Error(err) }) s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -100,7 +100,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with valid parameters") }) s.Run("when have composite primary key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "user_roles", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "user_roles", func(table *schema2.Blueprint) { table.Integer("user_id") table.Integer("role_id") @@ -109,7 +109,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with composite primary key") }) s.Run("when have foreign key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "orders", func(table *schema2.Blueprint) { table.ID() table.UnsignedBigInteger("user_id") table.String("order_id", 255).Unique() @@ -121,7 +121,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with foreign key") }) s.Run("when have custom index should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders_2", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "orders_2", func(table *schema2.Blueprint) { table.ID() table.String("order_id", 255).Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) @@ -132,7 +132,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with custom index") }) s.Run("when table already exists, should return error", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -156,7 +156,7 @@ func (s *mysqlBuilderSuite) TestDrop() { s.Error(err) }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -188,7 +188,7 @@ func (s *mysqlBuilderSuite) TestDropIfExists() { s.Error(err) }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -224,7 +224,7 @@ func (s *mysqlBuilderSuite) TestRename() { s.Error(err) }) s.Run("when all parameters are valid, should rename table successfully", func() { - err = builder.Create(context.Background(), tx, "old_table", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "old_table", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) }) @@ -241,13 +241,13 @@ func (s *mysqlBuilderSuite) TestTable() { defer tx.Rollback() //nolint:errcheck s.Run("when tx is nil, should return error", func() { - err := builder.Table(s.ctx, nil, "test_table", func(table *schema.Blueprint) { + err := builder.Table(s.ctx, nil, "test_table", func(table *schema2.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Table(s.ctx, tx, "", func(table *schema.Blueprint) { + err := builder.Table(s.ctx, tx, "", func(table *schema2.Blueprint) { table.String("name") }) s.Error(err) @@ -257,7 +257,7 @@ func (s *mysqlBuilderSuite) TestTable() { s.Error(err) }) s.Run("when all parameters are valid, should modify table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique("uk_users_email") @@ -270,75 +270,75 @@ func (s *mysqlBuilderSuite) TestTable() { s.NoError(err, "expected no error when creating table before modifying it") s.Run("should add new columns and modify existing ones", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.String("address", 255).Nullable() table.String("phone", 20).Nullable().Unique("uk_users_phone") }) s.NoError(err, "expected no error when modifying table with valid parameters") }) s.Run("should modify existing column", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.String("email", 255).Nullable().Change() }) s.NoError(err, "expected no error when modifying existing column") }) s.Run("should drop column and rename existing one", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropColumn("password") table.RenameColumn("name", "full_name") }) s.NoError(err, "expected no error when dropping column and renaming existing one") }) s.Run("should add index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.Index("phone").Name("idx_users_phone").Algorithm("BTREE") }) s.NoError(err, "expected no error when adding index to table") }) s.Run("should rename index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.RenameIndex("idx_users_phone", "idx_users_contact") }) s.NoError(err, "expected no error when renaming index in table") }) s.Run("should drop index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropIndex("idx_users_contact") }) s.NoError(err, "expected no error when dropping index from table") }) s.Run("should drop unique constraint", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropUnique("uk_users_email") }) s.NoError(err, "expected no error when dropping unique constraint from table") }) s.Run("should drop fulltext index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropFulltext("ft_users_bio") }) s.NoError(err, "expected no error when dropping fulltext index from table") }) s.Run("should add foreign key", func() { - err = builder.Create(s.ctx, tx, "roles", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "roles", func(table *schema2.Blueprint) { table.UnsignedInteger("id").Primary() table.String("role_name", 255).Unique("uk_roles_role_name") }) s.NoError(err, "expected no error when creating roles table before adding foreign key") - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.UnsignedInteger("role_id").Nullable() table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when adding foreign key to users table") }) s.Run("should drop foreign key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropForeign("fk_users_roles") }) s.NoError(err, "expected no error when dropping foreign key from users table") }) s.Run("should drop primary key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.UnsignedBigInteger("id").Change() table.DropPrimary("users_pkey") }) @@ -369,7 +369,7 @@ func (s *mysqlBuilderSuite) TestGetColumns() { s.Empty(columns) }) s.Run("when table exists, should return columns successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -400,7 +400,7 @@ func (s *mysqlBuilderSuite) TestGetIndexes() { s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -435,7 +435,7 @@ func (s *mysqlBuilderSuite) TestGetTables() { s.Nil(tables) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -480,7 +480,7 @@ func (s *mysqlBuilderSuite) TestHasColumn() { s.False(exists) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -521,7 +521,7 @@ func (s *mysqlBuilderSuite) TestHasColumns() { s.False(exists) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -557,7 +557,7 @@ func (s *mysqlBuilderSuite) TestHasIndex() { s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "orders", func(table *schema2.Blueprint) { table.ID() table.Integer("company_id") table.Integer("user_id") @@ -601,7 +601,7 @@ func (s *mysqlBuilderSuite) TestHasTable() { s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() diff --git a/mysql_grammar.go b/schema/mysql_grammar.go similarity index 99% rename from mysql_grammar.go rename to schema/mysql_grammar.go index 3c07d85..4c9d49d 100644 --- a/mysql_grammar.go +++ b/schema/mysql_grammar.go @@ -5,7 +5,7 @@ import ( "slices" "strings" - "github.com/afkdevs/go-schema/internal/util" + "github.com/afkdevs/migris/internal/util" ) type mysqlGrammar struct { diff --git a/mysql_grammar_test.go b/schema/mysql_grammar_test.go similarity index 99% rename from mysql_grammar_test.go rename to schema/mysql_grammar_test.go index 1dcb2af..a50e33d 100644 --- a/mysql_grammar_test.go +++ b/schema/mysql_grammar_test.go @@ -3,7 +3,7 @@ package schema import ( "testing" - "github.com/afkdevs/go-schema/internal/dialect" + "github.com/afkdevs/migris/internal/dialect" "github.com/stretchr/testify/assert" ) diff --git a/postgres_builder.go b/schema/postgres_builder.go similarity index 99% rename from postgres_builder.go rename to schema/postgres_builder.go index c49f6f1..c5e4fca 100644 --- a/postgres_builder.go +++ b/schema/postgres_builder.go @@ -6,7 +6,7 @@ import ( "errors" "strings" - "github.com/afkdevs/go-schema/internal/config" + "github.com/afkdevs/migris/internal/config" ) type postgresBuilder struct { diff --git a/postgres_builder_test.go b/schema/postgres_builder_test.go similarity index 91% rename from postgres_builder_test.go rename to schema/postgres_builder_test.go index 33f2645..6fa5270 100644 --- a/postgres_builder_test.go +++ b/schema/postgres_builder_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/afkdevs/go-schema" + schema2 "github.com/afkdevs/migris/schema" _ "github.com/lib/pq" "github.com/stretchr/testify/suite" ) @@ -41,7 +41,7 @@ type postgresBuilderSuite struct { suite.Suite ctx context.Context db *sql.DB - builder schema.Builder + builder schema2.Builder } func (s *postgresBuilderSuite) SetupSuite() { @@ -58,7 +58,7 @@ func (s *postgresBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - s.builder, err = schema.NewBuilder("postgres") + s.builder, err = schema2.NewBuilder("postgres") s.Require().NoError(err) } @@ -73,11 +73,11 @@ func (s *postgresBuilderSuite) TestCreate() { defer tx.Rollback() //nolint:errcheck s.Run("when tx is nil, should return error", func() { - err := builder.Create(s.ctx, nil, "test_table", func(table *schema.Blueprint) {}) + err := builder.Create(s.ctx, nil, "test_table", func(table *schema2.Blueprint) {}) s.Error(err, "expected error when transaction is nil") }) s.Run("when table name is empty, should return error", func() { - err := builder.Create(s.ctx, tx, "", func(table *schema.Blueprint) {}) + err := builder.Create(s.ctx, tx, "", func(table *schema2.Blueprint) {}) s.Error(err, "expected error when table name is empty") }) s.Run("when blueprint is nil, should return error", func() { @@ -85,7 +85,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.Error(err, "expected error when blueprint is nil") }) s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -97,7 +97,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.Run("when use custom schema should create it successfully", func() { _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_public") s.NoError(err, "expected no error when creating custom schema") - err = builder.Create(context.Background(), tx, "custom_public.users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "custom_public.users", func(table *schema2.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -107,7 +107,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with custom schema") }) s.Run("when have composite primary key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "user_roles", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "user_roles", func(table *schema2.Blueprint) { table.Integer("user_id") table.Integer("role_id") @@ -116,7 +116,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with composite primary key") }) s.Run("when have foreign key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "orders", func(table *schema2.Blueprint) { table.ID() table.BigInteger("user_id") table.String("order_id").Unique() @@ -128,7 +128,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with foreign key") }) s.Run("when have custom index should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders_2", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "orders_2", func(table *schema2.Blueprint) { table.ID() table.String("order_id").Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) @@ -139,7 +139,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with custom index") }) s.Run("when table already exists, should return error", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -159,7 +159,7 @@ func (s *postgresBuilderSuite) TestDrop() { s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -191,7 +191,7 @@ func (s *postgresBuilderSuite) TestDropIfExists() { s.Error(err, "expected error when transaction is nil") }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -227,7 +227,7 @@ func (s *postgresBuilderSuite) TestRename() { s.Error(err, "expected error when new table name is empty") }) s.Run("when all parameters are valid, should rename table successfully", func() { - err = builder.Create(s.ctx, tx, "old_table", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "old_table", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) }) @@ -249,11 +249,11 @@ func (s *postgresBuilderSuite) TestTable() { defer tx.Rollback() //nolint:errcheck s.Run("when tx is nil, should return error", func() { - err := builder.Table(s.ctx, nil, "test_table", func(table *schema.Blueprint) {}) + err := builder.Table(s.ctx, nil, "test_table", func(table *schema2.Blueprint) {}) s.Error(err, "expected error when transaction is nil") }) s.Run("when table name is empty, should return error", func() { - err := builder.Table(s.ctx, tx, "", func(table *schema.Blueprint) {}) + err := builder.Table(s.ctx, tx, "", func(table *schema2.Blueprint) {}) s.Error(err, "expected error when table name is empty") }) s.Run("when blueprint is nil, should return error", func() { @@ -261,7 +261,7 @@ func (s *postgresBuilderSuite) TestTable() { s.Error(err, "expected error when blueprint is nil") }) s.Run("when all parameters are valid, should modify table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique("uk_users_email") @@ -274,75 +274,75 @@ func (s *postgresBuilderSuite) TestTable() { s.NoError(err, "expected no error when creating table before modifying it") s.Run("should add new columns and modify existing ones", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.String("address", 255).Nullable() table.String("phone", 20).Nullable().Unique("uk_users_phone") }) s.NoError(err, "expected no error when modifying table with valid parameters") }) s.Run("should modify existing column", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.String("email", 255).Nullable().Change() }) s.NoError(err, "expected no error when modifying existing column") }) s.Run("should drop column and rename existing one", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropColumn("password") table.RenameColumn("name", "full_name") }) s.NoError(err, "expected no error when dropping column and renaming existing one") }) s.Run("should add index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.Index("phone").Name("idx_users_phone").Algorithm("BTREE") }) s.NoError(err, "expected no error when adding index to table") }) s.Run("should rename index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.RenameIndex("idx_users_phone", "idx_users_contact") }) s.NoError(err, "expected no error when renaming index in table") }) s.Run("should drop index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropIndex("idx_users_contact") }) s.NoError(err, "expected no error when dropping index from table") }) s.Run("should drop unique constraint", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropUnique([]string{"email"}) }) s.NoError(err, "expected no error when dropping unique constraint from table") }) s.Run("should drop fulltext index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropFulltext("ft_users_bio") }) s.NoError(err, "expected no error when dropping fulltext index from table") }) s.Run("should add foreign key", func() { - err = builder.Create(s.ctx, tx, "roles", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "roles", func(table *schema2.Blueprint) { table.ID() table.String("role_name", 255).Unique("uk_roles_role_name") }) s.NoError(err, "expected no error when creating roles table before adding foreign key") - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.Integer("role_id").Nullable() table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when adding foreign key to users table") }) s.Run("should drop foreign key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropForeign("fk_users_roles") }) s.NoError(err, "expected no error when dropping foreign key from users table") }) s.Run("should drop primary key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.DropPrimary("pk_users") }) s.NoError(err, "expected no error when dropping primary key from users table") @@ -365,7 +365,7 @@ func (s *postgresBuilderSuite) TestGetColumns() { s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -400,7 +400,7 @@ func (s *postgresBuilderSuite) TestGetIndexes() { s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -434,7 +434,7 @@ func (s *postgresBuilderSuite) TestGetTables() { s.Error(err, "expected error when transaction is nil") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -470,7 +470,7 @@ func (s *postgresBuilderSuite) TestHasColumn() { s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -506,7 +506,7 @@ func (s *postgresBuilderSuite) TestHasColumns() { s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -542,7 +542,7 @@ func (s *postgresBuilderSuite) TestHasIndex() { s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "orders", func(table *schema2.Blueprint) { table.ID() table.Integer("company_id") table.Integer("user_id") @@ -586,7 +586,7 @@ func (s *postgresBuilderSuite) TestHasTable() { s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -607,7 +607,7 @@ func (s *postgresBuilderSuite) TestHasTable() { _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_publics") s.NoError(err, "expected no error when creating custom schema") - err = builder.Create(s.ctx, tx, "custom_publics.users", func(table *schema.Blueprint) { + err = builder.Create(s.ctx, tx, "custom_publics.users", func(table *schema2.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() diff --git a/postgres_grammar.go b/schema/postgres_grammar.go similarity index 100% rename from postgres_grammar.go rename to schema/postgres_grammar.go diff --git a/postgres_grammar_test.go b/schema/postgres_grammar_test.go similarity index 100% rename from postgres_grammar_test.go rename to schema/postgres_grammar_test.go diff --git a/schema.go b/schema/schema.go similarity index 98% rename from schema.go rename to schema/schema.go index f5531ff..b9ff1f2 100644 --- a/schema.go +++ b/schema/schema.go @@ -5,8 +5,8 @@ import ( "database/sql" "errors" - "github.com/afkdevs/go-schema/internal/config" - "github.com/afkdevs/go-schema/internal/dialect" + "github.com/afkdevs/migris/internal/config" + "github.com/afkdevs/migris/internal/dialect" ) // Column represents a database column with its properties. diff --git a/schema_test.go b/schema/schema_test.go similarity index 99% rename from schema_test.go rename to schema/schema_test.go index 8462997..f28f54e 100644 --- a/schema_test.go +++ b/schema/schema_test.go @@ -6,7 +6,8 @@ import ( "fmt" "testing" - "github.com/afkdevs/go-schema" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" "github.com/stretchr/testify/suite" ) @@ -55,7 +56,7 @@ func (s *schemaTestSuite) SetupSuite() { s.Error(fn(), "Expected error when dialect is not set") } }) - schema.SetDialect("postgres") + migris.SetDialect("postgres") } func (s *schemaTestSuite) TearDownSuite() { diff --git a/up.go b/up.go new file mode 100644 index 0000000..4df1ea7 --- /dev/null +++ b/up.go @@ -0,0 +1,25 @@ +package migris + +import ( + "context" + "database/sql" +) + +// Up applies the migrations in the specified directory. +func Up(db *sql.DB, dir string) error { + ctx := context.Background() + return UpContext(ctx, db, dir) +} + +// UpContext applies the migrations in the specified directory. +func UpContext(ctx context.Context, db *sql.DB, dir string) error { + provider, err := newProvider(db, dir) + if err != nil { + return err + } + _, err = provider.Up(ctx) + if err != nil { + return err + } + return nil +} From 5c1683e7691c9cee193e1e2006064ac51baeac6b Mon Sep 17 00:00:00 2001 From: Ahmad Faiz Kamaludin Date: Sun, 31 Aug 2025 01:36:52 +0700 Subject: [PATCH 4/7] feat: refactor migration func and log --- config.go | 7 - create.go | 29 +-- down.go | 22 +- examples/basic/cmd/migrate/cmd.go | 30 ++- examples/basic/cmd/migrate/migrate.go | 16 ++ examples/basic/go.mod | 5 + examples/basic/go.sum | 13 +- .../20250830103612_create_users_table.go | 11 +- .../20250830103653_create_roles_table.go | 11 +- .../20250830103714_create_user_roles_table.go | 11 +- go.mod | 5 + go.sum | 13 +- internal/config/config.go | 2 - internal/logger/logger.go | 101 ++++++++ provider.go | 3 +- register.go | 49 ++-- reset.go | 22 +- schema/blueprint.go | 14 +- schema/builder.go | 69 +++--- schema/context.go | 43 ++++ schema/mysql_builder.go | 74 +++--- schema/mysql_builder_test.go | 221 +++++++++-------- schema/postgres_builder.go | 58 ++--- schema/postgres_builder_test.go | 232 ++++++++++-------- schema/schema.go | 49 ++-- schema/schema_test.go | 183 +++++++------- status.go | 26 ++ up.go | 23 +- 28 files changed, 807 insertions(+), 535 deletions(-) create mode 100644 internal/logger/logger.go create mode 100644 schema/context.go create mode 100644 status.go diff --git a/config.go b/config.go index 1a80e34..a448753 100644 --- a/config.go +++ b/config.go @@ -19,13 +19,6 @@ func SetDialect(d string) error { return nil } -// SetVerbose enables or disables verbose mode -func SetVerbose(enabled bool) { - cfg := config.Get() - cfg.Verbose = enabled - config.Set(cfg) -} - // SetTableName sets the table name for the migrator func SetTableName(name string) { cfg := config.Get() diff --git a/create.go b/create.go index 1b139d7..5ea1922 100644 --- a/create.go +++ b/create.go @@ -27,9 +27,6 @@ func getMigrationTemplate(name string) *template.Template { var migrationTemplate = template.Must(template.New("migrator.go-migration").Parse(`package migrations import ( - "context" - "database/sql" - "github.com/afkdevs/migris" "github.com/afkdevs/migris/schema" ) @@ -38,12 +35,12 @@ func init() { migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } -func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { +func up{{.CamelName}}(c *schema.Context) error { // This code is executed when the migration is applied. return nil } -func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { +func down{{.CamelName}}(c *schema.Context) error { // This code is executed when the migration is rolled back. return nil } @@ -53,9 +50,6 @@ func migrationCreateTemplate(table string) *template.Template { tmpl := `package migrations import ( - "context" - "database/sql" - "github.com/afkdevs/migris" "github.com/afkdevs/migris/schema" ) @@ -64,14 +58,14 @@ func init() { migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } -func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "` + table + `", func(table *schema.Blueprint) { +func up{{.CamelName}}(c *schema.Context) error { + return schema.Create(c, "` + table + `", func(table *schema.Blueprint) { // Define your table schema here }) } -func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { - return schema.DropIfExists(ctx, tx, "` + table + `") +func down{{.CamelName}}(c *schema.Context) error { + return schema.DropIfExists(c, "` + table + `") } ` return template.Must(template.New("migration-create").Parse(tmpl)) @@ -81,9 +75,6 @@ func migrationUpdateTemplate(table string) *template.Template { tmpl := `package migrations import ( - "context" - "database/sql" - "github.com/afkdevs/migris" "github.com/afkdevs/migris/schema" ) @@ -92,14 +83,14 @@ func init() { migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } -func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { - return schema.Table(ctx, tx, "` + table + `", func(table *schema.Blueprint) { +func up{{.CamelName}}(c *schema.Context) error { + return schema.Table(c, "` + table + `", func(table *schema.Blueprint) { // Define your table schema changes here }) } -func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { - return schema.Table(ctx, tx, "` + table + `", func(table *schema.Blueprint) { +func down{{.CamelName}}(c *schema.Context) error { + return schema.Table(c, "` + table + `", func(table *schema.Blueprint) { // Define your table schema changes here }) } diff --git a/down.go b/down.go index dd364c1..66bb3c3 100644 --- a/down.go +++ b/down.go @@ -3,6 +3,10 @@ package migris import ( "context" "database/sql" + "errors" + + "github.com/afkdevs/migris/internal/logger" + "github.com/pressly/goose/v3" ) func Down(db *sql.DB, dir string) error { @@ -15,9 +19,25 @@ func DownContext(ctx context.Context, db *sql.DB, dir string) error { if err != nil { return err } - _, err = provider.Down(ctx) + currentVersion, err := provider.GetDBVersion(ctx) + if err != nil { + return err + } + if currentVersion == 0 { + logger.Info("Nothing to rollback.") + return nil + } + logger.Info("Rolling back migrations.\n") + result, err := provider.Down(ctx) if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResult(partialErr.Failed) + } return err } + if result != nil { + logger.PrintResult(result) + } return nil } diff --git a/examples/basic/cmd/migrate/cmd.go b/examples/basic/cmd/migrate/cmd.go index a5b1ddc..9b39d89 100644 --- a/examples/basic/cmd/migrate/cmd.go +++ b/examples/basic/cmd/migrate/cmd.go @@ -29,14 +29,6 @@ func Command() *cli.Command { { Name: "up", Usage: "Run all pending migrations", - Flags: []cli.Flag{ - &cli.BoolFlag{ - Name: "dry-run", - Aliases: []string{"d"}, - Usage: "Preview the migration without applying changes", - Required: false, - }, - }, Action: func(ctx context.Context, c *cli.Command) error { return Up() }, @@ -44,18 +36,24 @@ func Command() *cli.Command { { Name: "reset", Usage: "Reset all migrations", - Flags: []cli.Flag{ - &cli.BoolFlag{ - Name: "dry-run", - Aliases: []string{"d"}, - Usage: "Preview the migration without applying changes", - Required: false, - }, - }, Action: func(ctx context.Context, c *cli.Command) error { return Reset() }, }, + { + Name: "down", + Usage: "Revert the last migration", + Action: func(ctx context.Context, c *cli.Command) error { + return Down() + }, + }, + { + Name: "status", + Usage: "Show the status of migrations", + Action: func(ctx context.Context, c *cli.Command) error { + return Status() + }, + }, }, } return cmd diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index ae081c6..3695d07 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -34,6 +34,22 @@ func Reset() error { return migris.Reset(db, directory) } +func Down() error { + db, err := initMigrator() + if err != nil { + return err + } + return migris.Down(db, directory) +} + +func Status() error { + db, err := initMigrator() + if err != nil { + return err + } + return migris.Status(db, directory) +} + func initMigrator() (*sql.DB, error) { if err := migris.SetDialect("postgres"); err != nil { return nil, fmt.Errorf("failed to set schema dialect: %w", err) diff --git a/examples/basic/go.mod b/examples/basic/go.mod index 6f3febf..ded60c3 100644 --- a/examples/basic/go.mod +++ b/examples/basic/go.mod @@ -10,11 +10,16 @@ require ( ) require ( + github.com/fatih/color v1.18.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mfridman/interpolate v0.0.2 // indirect github.com/pressly/goose/v3 v3.25.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/term v0.34.0 // indirect ) replace github.com/afkdevs/migris => ../.. diff --git a/examples/basic/go.sum b/examples/basic/go.sum index 99b7bc7..1b92f73 100644 --- a/examples/basic/go.sum +++ b/examples/basic/go.sum @@ -4,6 +4,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -12,6 +14,9 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= @@ -36,8 +41,12 @@ golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/y golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ= diff --git a/examples/basic/migrations/20250830103612_create_users_table.go b/examples/basic/migrations/20250830103612_create_users_table.go index 393b615..640a37f 100644 --- a/examples/basic/migrations/20250830103612_create_users_table.go +++ b/examples/basic/migrations/20250830103612_create_users_table.go @@ -1,9 +1,6 @@ package migrations import ( - "context" - "database/sql" - "github.com/afkdevs/migris" "github.com/afkdevs/migris/schema" ) @@ -12,8 +9,8 @@ func init() { migris.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) } -func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "users", func(table *schema.Blueprint) { +func upCreateUsersTable(c *schema.Context) error { + return schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email") @@ -23,6 +20,6 @@ func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { }) } -func downCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.DropIfExists(ctx, tx, "users") +func downCreateUsersTable(c *schema.Context) error { + return schema.DropIfExists(c, "users") } diff --git a/examples/basic/migrations/20250830103653_create_roles_table.go b/examples/basic/migrations/20250830103653_create_roles_table.go index 7354d31..99d5b79 100644 --- a/examples/basic/migrations/20250830103653_create_roles_table.go +++ b/examples/basic/migrations/20250830103653_create_roles_table.go @@ -1,9 +1,6 @@ package migrations import ( - "context" - "database/sql" - "github.com/afkdevs/migris" "github.com/afkdevs/migris/schema" ) @@ -12,13 +9,13 @@ func init() { migris.AddMigrationContext(upCreateRolesTable, downCreateRolesTable) } -func upCreateRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "roles", func(table *schema.Blueprint) { +func upCreateRolesTable(c *schema.Context) error { + return schema.Create(c, "roles", func(table *schema.Blueprint) { table.Increments("id").Primary() table.String("name").Unique().Nullable(false) }) } -func downCreateRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.DropIfExists(ctx, tx, "roles") +func downCreateRolesTable(c *schema.Context) error { + return schema.DropIfExists(c, "roles") } diff --git a/examples/basic/migrations/20250830103714_create_user_roles_table.go b/examples/basic/migrations/20250830103714_create_user_roles_table.go index 581b201..7f5fbb0 100644 --- a/examples/basic/migrations/20250830103714_create_user_roles_table.go +++ b/examples/basic/migrations/20250830103714_create_user_roles_table.go @@ -1,9 +1,6 @@ package migrations import ( - "context" - "database/sql" - "github.com/afkdevs/migris" "github.com/afkdevs/migris/schema" ) @@ -12,8 +9,8 @@ func init() { migris.AddMigrationContext(upCreateUserRolesTable, downCreateUserRolesTable) } -func upCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "user_roles", func(table *schema.Blueprint) { +func upCreateUserRolesTable(c *schema.Context) error { + return schema.Create(c, "user_roles", func(table *schema.Blueprint) { table.BigInteger("user_id") table.Integer("role_id") table.Primary("user_id", "role_id") @@ -22,6 +19,6 @@ func upCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error { }) } -func downCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.DropIfExists(ctx, tx, "user_roles") +func downCreateUserRolesTable(c *schema.Context) error { + return schema.DropIfExists(c, "user_roles") } diff --git a/go.mod b/go.mod index 718df7b..6bab72b 100644 --- a/go.mod +++ b/go.mod @@ -3,19 +3,24 @@ module github.com/afkdevs/migris go 1.23.0 require ( + github.com/fatih/color v1.18.0 github.com/go-sql-driver/mysql v1.9.3 github.com/lib/pq v1.10.9 github.com/pressly/goose/v3 v3.25.0 github.com/stretchr/testify v1.11.1 + golang.org/x/term v0.34.0 ) require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mfridman/interpolate v0.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.35.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9628ba3..e53afdf 100644 --- a/go.sum +++ b/go.sum @@ -4,12 +4,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= @@ -32,8 +37,12 @@ golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/y golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/config/config.go b/internal/config/config.go index 293570b..9c02a82 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,7 +9,6 @@ import ( type Config struct { Dialect dialect.Dialect TableName string - Verbose bool } var config = atomic.Pointer[Config]{} @@ -18,7 +17,6 @@ func init() { config.Store(&Config{ Dialect: dialect.Unknown, TableName: "migris_db_version", - Verbose: true, }) } diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..232e71e --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,101 @@ +package logger + +import ( + "fmt" + "os" + "strings" + + "github.com/fatih/color" + "github.com/pressly/goose/v3" + "golang.org/x/term" +) + +var ( + grey = color.New(color.FgHiBlack).SprintFunc() + greenBold = color.New(color.FgGreen, color.Bold).SprintFunc() + yellowBold = color.New(color.FgYellow, color.Bold).SprintFunc() + redBold = color.New(color.FgRed, color.Bold).SprintFunc() +) + +func Info(msg string) { + Infof("%s", msg) +} + +func Infof(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + + whiteBgBlue := color.New(color.FgWhite, color.BgBlue).SprintFunc() + fmt.Printf("%s %s\n", whiteBgBlue(" INFO "), msg) +} + +func PrintResults(results []*goose.MigrationResult) { + for _, result := range results { + PrintResult(result) + } +} + +func PrintResult(result *goose.MigrationResult) { + // Get terminal width + width, _, err := term.GetSize(int(os.Stdout.Fd())) + if err != nil || width <= 0 { + width = 80 // fallback + } + + durText := fmt.Sprintf(" %.2fms", result.Duration.Seconds()*1000) + statusText := " DONE" + if result.Error != nil { + statusText = " FAIL" + } + + name := result.Source.Path + + // Calculate how many dots to add + fillLen := width - len(name) - len(durText) - len(statusText) - 1 + if fillLen < 0 { + fillLen = 0 + } + dots := strings.Repeat(".", fillLen) + + // Print + fmt.Printf("%s %s%s", name, grey(dots), grey(durText)) + if result.Error != nil { + fmt.Printf("%s\n", redBold(statusText)) + } else { + fmt.Printf("%s\n", greenBold(statusText)) + } +} + +func PrintStatuses(statuses []*goose.MigrationStatus) { + for _, status := range statuses { + PrintStatus(status) + } +} + +func PrintStatus(status *goose.MigrationStatus) { + width, _, err := term.GetSize(int(os.Stdout.Fd())) + if err != nil || width <= 0 { + width = 80 // fallback + } + + var statusText string + if status.State == goose.StateApplied { + statusText = " Applied" + } else { + statusText = " Pending" + } + + name := status.Source.Path + + fillLen := width - len(name) - len(statusText) - 1 + if fillLen < 0 { + fillLen = 0 + } + dots := strings.Repeat(".", fillLen) + + fmt.Printf("%s %s", name, grey(dots)) + if status.State == goose.StateApplied { + fmt.Printf("%s\n", greenBold(statusText)) + } else { + fmt.Printf("%s\n", yellowBold(statusText)) + } +} diff --git a/provider.go b/provider.go index 42122b9..09af4fe 100644 --- a/provider.go +++ b/provider.go @@ -24,8 +24,7 @@ func newProvider(db *sql.DB, dir string) (*goose.Provider, error) { provider, err := goose.NewProvider(database.DialectCustom, db, os.DirFS(dir), goose.WithStore(store), goose.WithDisableGlobalRegistry(true), - goose.WithGoMigrations(getMigrations()...), - goose.WithVerbose(cfg.Verbose), + goose.WithGoMigrations(registeredMigrations...), ) if err != nil { return nil, err diff --git a/register.go b/register.go index b0bd30f..d4b4add 100644 --- a/register.go +++ b/register.go @@ -4,19 +4,29 @@ import ( "context" "database/sql" "fmt" + "path" "runtime" - "slices" + "github.com/afkdevs/migris/schema" "github.com/pressly/goose/v3" ) var ( - registeredGoMigrations = make(map[int64]*goose.Migration) + registeredVersions = make(map[int64]string) + registeredMigrations = make([]*goose.Migration, 0) ) // GoMigrationContext is a Go migration func that is run within a transaction and receives a // context. -type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error +type GoMigrationContext func(ctx *schema.Context) error + +func (m GoMigrationContext) RunTxFunc(filename string) func(ctx context.Context, tx *sql.Tx) error { + return func(ctx context.Context, tx *sql.Tx) error { + filename = path.Base(filename) + c := schema.NewContext(ctx, tx, schema.WithFilename(filename)) + return m(c) + } +} // AddMigrationContext adds Go migrations. func AddMigrationContext(up, down GoMigrationContext) { @@ -29,8 +39,8 @@ func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { if err := register( filename, true, - &goose.GoFunc{RunTx: up, Mode: goose.TransactionEnabled}, - &goose.GoFunc{RunTx: down, Mode: goose.TransactionEnabled}, + &goose.GoFunc{RunTx: up.RunTxFunc(filename), Mode: goose.TransactionEnabled}, + &goose.GoFunc{RunTx: down.RunTxFunc(filename), Mode: goose.TransactionEnabled}, ); err != nil { panic(err) } @@ -38,11 +48,11 @@ func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { func register(filename string, useTx bool, up, down *goose.GoFunc) error { v, _ := goose.NumericComponent(filename) - if existing, ok := registeredGoMigrations[v]; ok { + if existing, ok := registeredVersions[v]; ok { return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", filename, v, - existing.Source, + existing, ) } // Add to global as a registered migration. @@ -51,28 +61,7 @@ func register(filename string, useTx bool, up, down *goose.GoFunc) error { // We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but // we know based on the register function what the user is requesting. m.UseTx = useTx - registeredGoMigrations[v] = m + registeredVersions[v] = filename + registeredMigrations = append(registeredMigrations, m) return nil } - -func getMigrations() []*goose.Migration { - type migrationWithVersion struct { - version int64 - migration *goose.Migration - } - var migrations []*migrationWithVersion - for _, m := range registeredGoMigrations { - migrations = append(migrations, &migrationWithVersion{ - version: m.Version, - migration: m, - }) - } - slices.SortFunc(migrations, func(a, b *migrationWithVersion) int { - return int(a.version - b.version) - }) - var results []*goose.Migration - for _, m := range migrations { - results = append(results, m.migration) - } - return results -} diff --git a/reset.go b/reset.go index 43ee030..aa1b55f 100644 --- a/reset.go +++ b/reset.go @@ -3,6 +3,10 @@ package migris import ( "context" "database/sql" + "errors" + + "github.com/afkdevs/migris/internal/logger" + "github.com/pressly/goose/v3" ) func Reset(db *sql.DB, dir string) error { @@ -15,9 +19,25 @@ func ResetContext(ctx context.Context, db *sql.DB, dir string) error { if err != nil { return err } - _, err = provider.DownTo(ctx, 0) + currentVersion, err := provider.GetDBVersion(ctx) + if err != nil { + return err + } + if currentVersion == 0 { + logger.Info("Nothing to rollback.") + return nil + } + logger.Info("Rolling back migrations.\n") + results, err := provider.DownTo(ctx, 0) if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResults(partialErr.Applied) + logger.PrintResult(partialErr.Failed) + } + return err } + logger.PrintResults(results) return nil } diff --git a/schema/blueprint.go b/schema/blueprint.go index c28b533..069bc4a 100644 --- a/schema/blueprint.go +++ b/schema/blueprint.go @@ -1,10 +1,7 @@ package schema import ( - "context" - "database/sql" "fmt" - "log" "github.com/afkdevs/migris/internal/dialect" "github.com/afkdevs/migris/internal/util" @@ -59,7 +56,6 @@ type Blueprint struct { charset string collation string engine string - verbose bool } // Charset sets the character set for the table in the blueprint. @@ -588,16 +584,16 @@ func (b *Blueprint) getFluentStatements() []string { return statements } -func (b *Blueprint) build(ctx context.Context, tx *sql.Tx) error { +func (b *Blueprint) build(ctx *Context) error { statements, err := b.toSql() if err != nil { return err } for _, statement := range statements { - if b.verbose { - log.Println(statement) - } - if _, err := tx.ExecContext(ctx, statement); err != nil { + // if b.verbose { + // log.Println(statement) + // } + if _, err := ctx.Exec(statement); err != nil { return err } } diff --git a/schema/builder.go b/schema/builder.go index 823d014..6e3cc2c 100644 --- a/schema/builder.go +++ b/schema/builder.go @@ -1,8 +1,6 @@ package schema import ( - "context" - "database/sql" "errors" "github.com/afkdevs/migris/internal/dialect" @@ -11,29 +9,29 @@ import ( // Builder is an interface that defines methods for creating, dropping, and managing database tables. type Builder interface { // Create creates a new table with the given name and applies the provided blueprint. - Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error + Create(c *Context, name string, blueprint func(table *Blueprint)) error // Drop removes the table with the given name. - Drop(ctx context.Context, tx *sql.Tx, name string) error + Drop(c *Context, name string) error // DropIfExists removes the table with the given name if it exists. - DropIfExists(ctx context.Context, tx *sql.Tx, name string) error + DropIfExists(c *Context, name string) error // GetColumns retrieves the columns of the specified table. - GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) + GetColumns(c *Context, tableName string) ([]*Column, error) // GetIndexes retrieves the indexes of the specified table. - GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) + GetIndexes(c *Context, tableName string) ([]*Index, error) // GetTables retrieves all tables in the database. - GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) + GetTables(c *Context) ([]*TableInfo, error) // HasColumn checks if the specified table has the given column. - HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) + HasColumn(c *Context, tableName string, columnName string) (bool, error) // HasColumns checks if the specified table has all the given columns. - HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) + HasColumns(c *Context, tableName string, columnNames []string) (bool, error) // HasIndex checks if the specified table has the given index. - HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) + HasIndex(c *Context, tableName string, indexes []string) (bool, error) // HasTable checks if a table with the given name exists. - HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) + HasTable(c *Context, name string) (bool, error) // Rename renames a table from oldName to newName. - Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error + Rename(c *Context, oldName string, newName string) error // Table applies the provided blueprint to the specified table. - Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error + Table(c *Context, name string, blueprint func(table *Blueprint)) error } // NewBuilder creates a new Builder instance based on the specified dialect. @@ -54,83 +52,82 @@ func NewBuilder(dialectValue string) (Builder, error) { type baseBuilder struct { grammar grammar - verbose bool } func (b *baseBuilder) newBlueprint(name string) *Blueprint { - return &Blueprint{name: name, grammar: b.grammar, verbose: b.verbose} + return &Blueprint{name: name, grammar: b.grammar} } -func (b *baseBuilder) Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if tx == nil || name == "" || blueprint == nil { - return errors.New("invalid arguments: transaction, name, or blueprint is nil/empty") +func (b *baseBuilder) Create(c *Context, name string, blueprint func(table *Blueprint)) error { + if c == nil || name == "" || blueprint == nil { + return errors.New("invalid arguments: context, name, or blueprint is nil/empty") } bp := b.newBlueprint(name) bp.create() blueprint(bp) - if err := bp.build(ctx, tx); err != nil { + if err := bp.build(c); err != nil { return err } return nil } -func (b *baseBuilder) Drop(ctx context.Context, tx *sql.Tx, name string) error { - if tx == nil || name == "" { - return errors.New("invalid arguments: transaction is nil or name is empty") +func (b *baseBuilder) Drop(c *Context, name string) error { + if c == nil || name == "" { + return errors.New("invalid arguments: context is nil or name is empty") } bp := b.newBlueprint(name) bp.drop() - if err := bp.build(ctx, tx); err != nil { + if err := bp.build(c); err != nil { return err } return nil } -func (b *baseBuilder) DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { - if tx == nil || name == "" { - return errors.New("invalid arguments: transaction is nil or name is empty") +func (b *baseBuilder) DropIfExists(c *Context, name string) error { + if c == nil || name == "" { + return errors.New("invalid arguments: context is nil or name is empty") } bp := b.newBlueprint(name) bp.dropIfExists() - if err := bp.build(ctx, tx); err != nil { + if err := bp.build(c); err != nil { return err } return nil } -func (b *baseBuilder) Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error { - if tx == nil || oldName == "" || newName == "" { - return errors.New("invalid arguments: transaction is nil or old/new table name is empty") +func (b *baseBuilder) Rename(c *Context, oldName string, newName string) error { + if c == nil || oldName == "" || newName == "" { + return errors.New("invalid arguments: context is nil or old/new table name is empty") } bp := b.newBlueprint(oldName) bp.rename(newName) - if err := bp.build(ctx, tx); err != nil { + if err := bp.build(c); err != nil { return err } return nil } -func (b *baseBuilder) Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if tx == nil || name == "" || blueprint == nil { - return errors.New("invalid arguments: transaction is nil or name/blueprint is empty") +func (b *baseBuilder) Table(c *Context, name string, blueprint func(table *Blueprint)) error { + if c == nil || name == "" || blueprint == nil { + return errors.New("invalid arguments: context is nil or name/blueprint is empty") } bp := b.newBlueprint(name) blueprint(bp) - if err := bp.build(ctx, tx); err != nil { + if err := bp.build(c); err != nil { return err } diff --git a/schema/context.go b/schema/context.go new file mode 100644 index 0000000..aa74605 --- /dev/null +++ b/schema/context.go @@ -0,0 +1,43 @@ +package schema + +import ( + "context" + "database/sql" +) + +type Context struct { + ctx context.Context + tx *sql.Tx + filename string +} + +type ContextOptions func(*Context) + +func WithFilename(filename string) ContextOptions { + return func(c *Context) { + c.filename = filename + } +} + +func NewContext(ctx context.Context, tx *sql.Tx, opts ...ContextOptions) *Context { + c := &Context{ + ctx: ctx, + tx: tx, + } + for _, opt := range opts { + opt(c) + } + return c +} + +func (c *Context) Exec(query string, args ...any) (sql.Result, error) { + return c.tx.ExecContext(c.ctx, query, args...) +} + +func (c *Context) Query(query string, args ...any) (*sql.Rows, error) { + return c.tx.QueryContext(c.ctx, query, args...) +} + +func (c *Context) QueryRow(query string, args ...any) *sql.Row { + return c.tx.QueryRowContext(c.ctx, query, args...) +} diff --git a/schema/mysql_builder.go b/schema/mysql_builder.go index 6bc4399..a02b984 100644 --- a/schema/mysql_builder.go +++ b/schema/mysql_builder.go @@ -1,12 +1,9 @@ package schema import ( - "context" "database/sql" "errors" "strings" - - "github.com/afkdevs/migris/internal/config" ) type mysqlBuilder struct { @@ -18,17 +15,16 @@ var _ Builder = (*mysqlBuilder)(nil) func newMysqlBuilder() Builder { grammar := newMysqlGrammar() - cfg := config.Get() return &mysqlBuilder{ - baseBuilder: baseBuilder{grammar: grammar, verbose: cfg.Verbose}, + baseBuilder: baseBuilder{grammar: grammar}, grammar: grammar, } } -func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (string, error) { +func (b *mysqlBuilder) getCurrentDatabase(c *Context) (string, error) { query := b.grammar.CompileCurrentDatabase() - row := tx.QueryRowContext(ctx, query) + row := c.QueryRow(query) var dbName string if err := row.Scan(&dbName); err != nil { return "", err @@ -36,12 +32,12 @@ func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (stri return dbName, nil } -func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { - if tx == nil || tableName == "" { - return nil, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *mysqlBuilder) GetColumns(c *Context, tableName string) ([]*Column, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } - database, err := b.getCurrentDatabase(ctx, tx) + database, err := b.getCurrentDatabase(c) if err != nil { return nil, err } @@ -51,7 +47,7 @@ func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName str return nil, err } - rows, err := tx.QueryContext(ctx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -77,12 +73,12 @@ func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName str return columns, nil } -func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { - if tx == nil || tableName == "" { - return nil, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *mysqlBuilder) GetIndexes(c *Context, tableName string) ([]*Index, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } - database, err := b.getCurrentDatabase(ctx, tx) + database, err := b.getCurrentDatabase(c) if err != nil { return nil, err } @@ -92,7 +88,7 @@ func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName str return nil, err } - rows, err := tx.QueryContext(ctx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -111,12 +107,12 @@ func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName str return indexes, nil } -func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { - if tx == nil { - return nil, errors.New("invalid arguments: transaction is nil") +func (b *mysqlBuilder) GetTables(c *Context) ([]*TableInfo, error) { + if c == nil { + return nil, errors.New("invalid arguments: context is nil") } - database, err := b.getCurrentDatabase(ctx, tx) + database, err := b.getCurrentDatabase(c) if err != nil { return nil, err } @@ -125,7 +121,7 @@ func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, if err != nil { return nil, err } - rows, err := tx.QueryContext(ctx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -142,22 +138,22 @@ func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, return tables, nil } -func (b *mysqlBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) { - if tx == nil || columnName == "" { - return false, errors.New("invalid arguments: transaction is nil or column name is empty") +func (b *mysqlBuilder) HasColumn(c *Context, tableName string, columnName string) (bool, error) { + if c == nil || columnName == "" { + return false, errors.New("invalid arguments: context is nil or column name is empty") } - return b.HasColumns(ctx, tx, tableName, []string{columnName}) + return b.HasColumns(c, tableName, []string{columnName}) } -func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { - if tx == nil || tableName == "" { - return false, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *mysqlBuilder) HasColumns(c *Context, tableName string, columnNames []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } if len(columnNames) == 0 { return false, errors.New("no column names provided") } - columns, err := b.GetColumns(ctx, tx, tableName) + columns, err := b.GetColumns(c, tableName) if err != nil { return false, err } @@ -173,12 +169,12 @@ func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName str return true, nil // All specified columns exist } -func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { - if tx == nil || tableName == "" { - return false, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *mysqlBuilder) HasIndex(c *Context, tableName string, indexes []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } - existingIndexes, err := b.GetIndexes(ctx, tx, tableName) + existingIndexes, err := b.GetIndexes(c, tableName) if err != nil { return false, err } @@ -217,12 +213,12 @@ func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName strin return false, nil // If no specified index exists, return false } -func (b *mysqlBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { - if tx == nil || name == "" { - return false, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *mysqlBuilder) HasTable(c *Context, name string) (bool, error) { + if c == nil || name == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } - database, err := b.getCurrentDatabase(ctx, tx) + database, err := b.getCurrentDatabase(c) if err != nil { return false, err } @@ -232,7 +228,7 @@ func (b *mysqlBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (b return false, err } - row := tx.QueryRowContext(ctx, query) + row := c.QueryRow(query) var exists bool if err := row.Scan(&exists); err != nil { if errors.Is(err, sql.ErrNoRows) { diff --git a/schema/mysql_builder_test.go b/schema/mysql_builder_test.go index bf3d06e..6f6250a 100644 --- a/schema/mysql_builder_test.go +++ b/schema/mysql_builder_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - schema2 "github.com/afkdevs/migris/schema" + "github.com/afkdevs/migris/schema" _ "github.com/go-sql-driver/mysql" // MySQL driver "github.com/stretchr/testify/suite" ) @@ -19,7 +19,7 @@ type mysqlBuilderSuite struct { suite.Suite ctx context.Context db *sql.DB - builder schema2.Builder + builder schema.Builder } func (s *mysqlBuilderSuite) SetupSuite() { @@ -43,7 +43,7 @@ func (s *mysqlBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - s.builder, err = schema2.NewBuilder("mysql") + s.builder, err = schema.NewBuilder("mysql") s.Require().NoError(err) } @@ -55,10 +55,11 @@ func (s *mysqlBuilderSuite) AfterTest(_, _ string) { builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) - tables, err := builder.GetTables(s.ctx, tx) + c := schema.NewContext(s.ctx, tx) + tables, err := builder.GetTables(c) s.Require().NoError(err) for _, table := range tables { - err := builder.DropIfExists(s.ctx, tx, table.Name) + err := builder.DropIfExists(c, table.Name) if err != nil { s.T().Logf("error dropping table %s: %v", table.Name, err) } @@ -73,24 +74,26 @@ func (s *mysqlBuilderSuite) TestCreate() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Create(s.ctx, nil, "test_table", func(table *schema2.Blueprint) { + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Create(nil, "test_table", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Create(s.ctx, tx, "", func(table *schema2.Blueprint) { + err := builder.Create(c, "", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Create(s.ctx, tx, "test_table", nil) + err := builder.Create(c, "test_table", nil) s.Error(err) }) s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -100,7 +103,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with valid parameters") }) s.Run("when have composite primary key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "user_roles", func(table *schema2.Blueprint) { + err = builder.Create(c, "user_roles", func(table *schema.Blueprint) { table.Integer("user_id") table.Integer("role_id") @@ -109,7 +112,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with composite primary key") }) s.Run("when have foreign key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders", func(table *schema2.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.UnsignedBigInteger("user_id") table.String("order_id", 255).Unique() @@ -121,7 +124,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with foreign key") }) s.Run("when have custom index should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders_2", func(table *schema2.Blueprint) { + err = builder.Create(c, "orders_2", func(table *schema.Blueprint) { table.ID() table.String("order_id", 255).Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) @@ -132,7 +135,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with custom index") }) s.Run("when table already exists, should return error", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -147,16 +150,18 @@ func (s *mysqlBuilderSuite) TestDrop() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Drop(s.ctx, nil, "test_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Drop(nil, "test_table") s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Drop(s.ctx, tx, "") + err := builder.Drop(c, "") s.Error(err) }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -164,11 +169,11 @@ func (s *mysqlBuilderSuite) TestDrop() { table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.Drop(context.Background(), tx, "users") + err = builder.Drop(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should return error", func() { - err = builder.Drop(context.Background(), tx, "non_existent_table") + err = builder.Drop(c, "non_existent_table") s.Error(err, "expected error when dropping a table that does not exist") }) } @@ -179,16 +184,18 @@ func (s *mysqlBuilderSuite) TestDropIfExists() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.DropIfExists(s.ctx, nil, "test_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.DropIfExists(nil, "test_table") s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.DropIfExists(s.ctx, tx, "") + err := builder.DropIfExists(c, "") s.Error(err) }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -196,11 +203,11 @@ func (s *mysqlBuilderSuite) TestDropIfExists() { table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.DropIfExists(context.Background(), tx, "users") + err = builder.DropIfExists(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should not return error", func() { - err = builder.DropIfExists(context.Background(), tx, "non_existent_table") + err = builder.DropIfExists(c, "non_existent_table") s.NoError(err, "expected no error when dropping a table that does not exist with DropIfExists") }) } @@ -211,25 +218,27 @@ func (s *mysqlBuilderSuite) TestRename() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Rename(s.ctx, nil, "old_table", "new_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Rename(nil, "old_table", "new_table") s.Error(err) }) s.Run("when old table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "", "new_table") + err := builder.Rename(c, "", "new_table") s.Error(err) }) s.Run("when new table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "old_table", "") + err := builder.Rename(c, "old_table", "") s.Error(err) }) s.Run("when all parameters are valid, should rename table successfully", func() { - err = builder.Create(context.Background(), tx, "old_table", func(table *schema2.Blueprint) { + err = builder.Create(c, "old_table", func(table *schema.Blueprint) { table.ID() table.String("name", 255) }) s.NoError(err, "expected no error when creating old_table before renaming it") - err = builder.Rename(context.Background(), tx, "old_table", "new_table") + err = builder.Rename(c, "old_table", "new_table") s.NoError(err, "expected no error when renaming table with valid parameters") }) } @@ -240,24 +249,26 @@ func (s *mysqlBuilderSuite) TestTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Table(s.ctx, nil, "test_table", func(table *schema2.Blueprint) { + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Table(nil, "test_table", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Table(s.ctx, tx, "", func(table *schema2.Blueprint) { + err := builder.Table(c, "", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Table(s.ctx, tx, "test_table", nil) + err := builder.Table(c, "test_table", nil) s.Error(err) }) s.Run("when all parameters are valid, should modify table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique("uk_users_email") @@ -270,75 +281,75 @@ func (s *mysqlBuilderSuite) TestTable() { s.NoError(err, "expected no error when creating table before modifying it") s.Run("should add new columns and modify existing ones", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("address", 255).Nullable() table.String("phone", 20).Nullable().Unique("uk_users_phone") }) s.NoError(err, "expected no error when modifying table with valid parameters") }) s.Run("should modify existing column", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("email", 255).Nullable().Change() }) s.NoError(err, "expected no error when modifying existing column") }) s.Run("should drop column and rename existing one", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropColumn("password") table.RenameColumn("name", "full_name") }) s.NoError(err, "expected no error when dropping column and renaming existing one") }) s.Run("should add index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.Index("phone").Name("idx_users_phone").Algorithm("BTREE") }) s.NoError(err, "expected no error when adding index to table") }) s.Run("should rename index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.RenameIndex("idx_users_phone", "idx_users_contact") }) s.NoError(err, "expected no error when renaming index in table") }) s.Run("should drop index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropIndex("idx_users_contact") }) s.NoError(err, "expected no error when dropping index from table") }) s.Run("should drop unique constraint", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropUnique("uk_users_email") }) s.NoError(err, "expected no error when dropping unique constraint from table") }) s.Run("should drop fulltext index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropFulltext("ft_users_bio") }) s.NoError(err, "expected no error when dropping fulltext index from table") }) s.Run("should add foreign key", func() { - err = builder.Create(s.ctx, tx, "roles", func(table *schema2.Blueprint) { + err = builder.Create(c, "roles", func(table *schema.Blueprint) { table.UnsignedInteger("id").Primary() table.String("role_name", 255).Unique("uk_roles_role_name") }) s.NoError(err, "expected no error when creating roles table before adding foreign key") - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.UnsignedInteger("role_id").Nullable() table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when adding foreign key to users table") }) s.Run("should drop foreign key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropForeign("fk_users_roles") }) s.NoError(err, "expected no error when dropping foreign key from users table") }) s.Run("should drop primary key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.UnsignedBigInteger("id").Change() table.DropPrimary("users_pkey") }) @@ -353,23 +364,25 @@ func (s *mysqlBuilderSuite) TestGetColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - columns, err := builder.GetColumns(s.ctx, nil, "test_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + columns, err := builder.GetColumns(nil, "test_table") s.Error(err) s.Nil(columns) }) s.Run("when table name is empty, should return error", func() { - columns, err := builder.GetColumns(s.ctx, tx, "") + columns, err := builder.GetColumns(c, "") s.Error(err) s.Nil(columns) }) s.Run("when table does not exist, should return empty slice", func() { - columns, err := builder.GetColumns(s.ctx, tx, "non_existent_table") + columns, err := builder.GetColumns(c, "non_existent_table") s.NoError(err) s.Empty(columns) }) s.Run("when table exists, should return columns successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -378,7 +391,7 @@ func (s *mysqlBuilderSuite) TestGetColumns() { }) s.NoError(err, "expected no error when creating table before getting columns") - columns, err := builder.GetColumns(context.Background(), tx, "users") + columns, err := builder.GetColumns(c, "users") s.NoError(err, "expected no error when getting columns from existing table") s.NotEmpty(columns) s.Len(columns, 6, "expected 6 columns in the users table") @@ -391,16 +404,18 @@ func (s *mysqlBuilderSuite) TestGetIndexes() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetIndexes(s.ctx, nil, "users_indexes") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetIndexes(nil, "users_indexes") + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - _, err := builder.GetIndexes(s.ctx, tx, "") + _, err := builder.GetIndexes(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -411,13 +426,13 @@ func (s *mysqlBuilderSuite) TestGetIndexes() { }) s.NoError(err, "expected no error when creating table before getting indexes") - indexes, err := builder.GetIndexes(s.ctx, tx, "users") + indexes, err := builder.GetIndexes(c, "users") s.NoError(err, "expected no error when getting indexes with valid parameters") s.Len(indexes, 3, "expected 3 index to be returned") }) s.Run("when table does not exist, should return empty indexes", func() { - indexes, err := builder.GetIndexes(s.ctx, tx, "non_existent_table") + indexes, err := builder.GetIndexes(c, "non_existent_table") s.NoError(err, "expected no error when getting indexes of non-existent table") s.Empty(indexes, "expected empty indexes for non-existent table") }) @@ -429,13 +444,15 @@ func (s *mysqlBuilderSuite) TestGetTables() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when tx is nil, should return error", func() { - tables, err := builder.GetTables(s.ctx, nil) + tables, err := builder.GetTables(nil) s.Error(err, "expected error when transaction is nil") s.Nil(tables) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -444,7 +461,7 @@ func (s *mysqlBuilderSuite) TestGetTables() { }) s.NoError(err, "expected no error when creating table before getting tables") - tables, err := builder.GetTables(context.Background(), tx) + tables, err := builder.GetTables(c) s.NoError(err, "expected no error when getting tables after creating one") s.NotEmpty(tables, "expected non-empty tables slice after creating a table") found := false @@ -464,23 +481,25 @@ func (s *mysqlBuilderSuite) TestHasColumn() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumn(s.ctx, nil, "users", "name") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasColumn(nil, "users", "name") + s.Error(err, "expected error when context is nil") s.False(exists) }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumn(s.ctx, tx, "", "name") + exists, err := builder.HasColumn(c, "", "name") s.Error(err, "expected error when table name is empty") s.False(exists) }) s.Run("when column name is empty, should return error", func() { - exists, err := builder.HasColumn(s.ctx, tx, "users", "") + exists, err := builder.HasColumn(c, "users", "") s.Error(err, "expected error when column name is empty") s.False(exists) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -489,11 +508,11 @@ func (s *mysqlBuilderSuite) TestHasColumn() { }) s.NoError(err, "expected no error when creating table before checking for column existence") - exists, err := builder.HasColumn(context.Background(), tx, "users", "name") + exists, err := builder.HasColumn(c, "users", "name") s.NoError(err, "expected no error when checking for existing column") s.True(exists, "expected 'name' column to exist in users table") - exists, err = builder.HasColumn(context.Background(), tx, "users", "non_existent_column") + exists, err = builder.HasColumn(c, "users", "non_existent_column") s.NoError(err, "expected no error when checking for non-existing column") s.False(exists, "expected 'non_existent_column' to not exist in users table") }) @@ -505,23 +524,25 @@ func (s *mysqlBuilderSuite) TestHasColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumns(s.ctx, nil, "users", []string{"name"}) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasColumns(nil, "users", []string{"name"}) + s.Error(err, "expected error when context is nil") s.False(exists) }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumns(s.ctx, tx, "", []string{"name"}) + exists, err := builder.HasColumns(c, "", []string{"name"}) s.Error(err, "expected error when table name is empty") s.False(exists) }) s.Run("when column names are empty, should return error", func() { - exists, err := builder.HasColumns(s.ctx, tx, "users", []string{}) + exists, err := builder.HasColumns(c, "users", []string{}) s.Error(err, "expected error when column names are empty") s.False(exists) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -530,11 +551,11 @@ func (s *mysqlBuilderSuite) TestHasColumns() { }) s.NoError(err, "expected no error when creating table before checking for columns existence") - exists, err := builder.HasColumns(context.Background(), tx, "users", []string{"name", "email"}) + exists, err := builder.HasColumns(c, "users", []string{"name", "email"}) s.NoError(err, "expected no error when checking for existing columns") s.True(exists, "expected 'name' and 'email' columns to exist in users_has_columns table") - exists, err = builder.HasColumns(context.Background(), tx, "users", []string{"name", "non_existent_column"}) + exists, err = builder.HasColumns(c, "users", []string{"name", "non_existent_column"}) s.NoError(err, "expected no error when checking for mixed existing and non-existing columns") s.False(exists, "expected 'non_existent_column' to not exist in users_has_columns table") }) @@ -546,18 +567,20 @@ func (s *mysqlBuilderSuite) TestHasIndex() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasIndex(s.ctx, nil, "orders", []string{"idx_users_name"}) - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasIndex(nil, "orders", []string{"idx_users_name"}) + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasIndex(s.ctx, tx, "", []string{"idx_users_name"}) + exists, err := builder.HasIndex(c, "", []string{"idx_users_name"}) s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "orders", func(table *schema2.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.Integer("company_id") table.Integer("user_id") @@ -570,15 +593,15 @@ func (s *mysqlBuilderSuite) TestHasIndex() { }) s.NoError(err, "expected no error when creating table with index") - exists, err := builder.HasIndex(s.ctx, tx, "orders", []string{"uk_orders3_order_id"}) + exists, err := builder.HasIndex(c, "orders", []string{"uk_orders3_order_id"}) s.NoError(err, "expected no error when checking if index exists with valid parameters") s.True(exists, "expected exists to be true for existing index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"company_id", "user_id"}) + exists, err = builder.HasIndex(c, "orders", []string{"company_id", "user_id"}) s.NoError(err, "expected no error when checking non-existent index") s.True(exists, "expected exists to be true for existing composite index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"non_existent_index"}) + exists, err = builder.HasIndex(c, "orders", []string{"non_existent_index"}) s.NoError(err, "expected no error when checking non-existent index") s.False(exists, "expected exists to be false for non-existent index") }) @@ -590,18 +613,20 @@ func (s *mysqlBuilderSuite) TestHasTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasTable(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasTable(nil, "users") + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasTable(s.ctx, tx, "") + exists, err := builder.HasTable(c, "") s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -610,11 +635,11 @@ func (s *mysqlBuilderSuite) TestHasTable() { }) s.NoError(err, "expected no error when creating table before checking if it exists") - exists, err := builder.HasTable(s.ctx, tx, "users") + exists, err := builder.HasTable(c, "users") s.NoError(err, "expected no error when checking if table exists with valid parameters") s.True(exists, "expected exists to be true for existing table") - exists, err = builder.HasTable(s.ctx, tx, "non_existent_table") + exists, err = builder.HasTable(c, "non_existent_table") s.NoError(err, "expected no error when checking non-existent table") s.False(exists, "expected exists to be false for non-existent table") }) diff --git a/schema/postgres_builder.go b/schema/postgres_builder.go index c5e4fca..be214e3 100644 --- a/schema/postgres_builder.go +++ b/schema/postgres_builder.go @@ -1,12 +1,9 @@ package schema import ( - "context" "database/sql" "errors" "strings" - - "github.com/afkdevs/migris/internal/config" ) type postgresBuilder struct { @@ -16,10 +13,9 @@ type postgresBuilder struct { func newPostgresBuilder() Builder { grammar := newPostgresGrammar() - cfg := config.Get() return &postgresBuilder{ - baseBuilder: baseBuilder{grammar: grammar, verbose: cfg.Verbose}, + baseBuilder: baseBuilder{grammar: grammar}, grammar: grammar, } } @@ -32,9 +28,9 @@ func (b *postgresBuilder) parseSchemaAndTable(name string) (string, string) { return "", names[0] } -func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { - if tx == nil || tableName == "" { - return nil, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *postgresBuilder) GetColumns(c *Context, tableName string) ([]*Column, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } schema, name := b.parseSchemaAndTable(tableName) @@ -46,7 +42,7 @@ func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName return nil, err } - rows, err := tx.QueryContext(ctx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -67,9 +63,9 @@ func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName return columns, nil } -func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { - if tx == nil || tableName == "" { - return nil, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *postgresBuilder) GetIndexes(c *Context, tableName string) ([]*Index, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } schema, name := b.parseSchemaAndTable(tableName) if schema == "" { @@ -79,7 +75,7 @@ func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName if err != nil { return nil, err } - rows, err := tx.QueryContext(ctx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -99,9 +95,9 @@ func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName return indexes, nil } -func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { - if tx == nil { - return nil, errors.New("transaction is nil") +func (b *postgresBuilder) GetTables(c *Context) ([]*TableInfo, error) { + if c == nil { + return nil, errors.New("invalid arguments: context is nil") } query, err := b.grammar.CompileTables() @@ -109,7 +105,7 @@ func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableIn return nil, err } - rows, err := tx.QueryContext(ctx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -127,18 +123,18 @@ func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableIn return tables, nil } -func (b *postgresBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) { - return b.HasColumns(ctx, tx, tableName, []string{columnName}) +func (b *postgresBuilder) HasColumn(c *Context, tableName string, columnName string) (bool, error) { + return b.HasColumns(c, tableName, []string{columnName}) } -func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { - if tx == nil || tableName == "" { - return false, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *postgresBuilder) HasColumns(c *Context, tableName string, columnNames []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } if len(columnNames) == 0 { return false, errors.New("no column names provided") } - existingColumns, err := b.GetColumns(ctx, tx, tableName) + existingColumns, err := b.GetColumns(c, tableName) if err != nil { return false, err } @@ -163,12 +159,12 @@ func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName return true, nil // All specified columns exist } -func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { - if tx == nil || tableName == "" { - return false, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *postgresBuilder) HasIndex(c *Context, tableName string, indexes []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } - existingIndexes, err := b.GetIndexes(ctx, tx, tableName) + existingIndexes, err := b.GetIndexes(c, tableName) if err != nil { return false, err } @@ -207,9 +203,9 @@ func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName st return false, nil // If no specified index exists, return false } -func (b *postgresBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { - if tx == nil || name == "" { - return false, errors.New("invalid arguments: transaction is nil or table name is empty") +func (b *postgresBuilder) HasTable(c *Context, name string) (bool, error) { + if c == nil || name == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } schema, name := b.parseSchemaAndTable(name) @@ -222,7 +218,7 @@ func (b *postgresBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) } var exists bool - if err := tx.QueryRowContext(ctx, query).Scan(&exists); err != nil { + if err := c.QueryRow(query).Scan(&exists); err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil // Table does not exist } diff --git a/schema/postgres_builder_test.go b/schema/postgres_builder_test.go index 6fa5270..9ad049d 100644 --- a/schema/postgres_builder_test.go +++ b/schema/postgres_builder_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - schema2 "github.com/afkdevs/migris/schema" + "github.com/afkdevs/migris/schema" _ "github.com/lib/pq" "github.com/stretchr/testify/suite" ) @@ -41,7 +41,7 @@ type postgresBuilderSuite struct { suite.Suite ctx context.Context db *sql.DB - builder schema2.Builder + builder schema.Builder } func (s *postgresBuilderSuite) SetupSuite() { @@ -58,7 +58,7 @@ func (s *postgresBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - s.builder, err = schema2.NewBuilder("postgres") + s.builder, err = schema.NewBuilder("postgres") s.Require().NoError(err) } @@ -72,20 +72,22 @@ func (s *postgresBuilderSuite) TestCreate() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Create(s.ctx, nil, "test_table", func(table *schema2.Blueprint) {}) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Create(nil, "test_table", func(table *schema.Blueprint) {}) + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - err := builder.Create(s.ctx, tx, "", func(table *schema2.Blueprint) {}) + err := builder.Create(c, "", func(table *schema.Blueprint) {}) s.Error(err, "expected error when table name is empty") }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Create(s.ctx, tx, "test_table", nil) + err := builder.Create(c, "test_table", nil) s.Error(err, "expected error when blueprint is nil") }) s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -97,7 +99,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.Run("when use custom schema should create it successfully", func() { _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_public") s.NoError(err, "expected no error when creating custom schema") - err = builder.Create(context.Background(), tx, "custom_public.users", func(table *schema2.Blueprint) { + err = builder.Create(c, "custom_public.users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -107,7 +109,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with custom schema") }) s.Run("when have composite primary key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "user_roles", func(table *schema2.Blueprint) { + err = builder.Create(c, "user_roles", func(table *schema.Blueprint) { table.Integer("user_id") table.Integer("role_id") @@ -116,7 +118,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with composite primary key") }) s.Run("when have foreign key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders", func(table *schema2.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.BigInteger("user_id") table.String("order_id").Unique() @@ -128,7 +130,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with foreign key") }) s.Run("when have custom index should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders_2", func(table *schema2.Blueprint) { + err = builder.Create(c, "orders_2", func(table *schema.Blueprint) { table.ID() table.String("order_id").Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) @@ -139,7 +141,7 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with custom index") }) s.Run("when table already exists, should return error", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -154,12 +156,18 @@ func (s *postgresBuilderSuite) TestDrop() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Drop(nil, "test_table") + s.Error(err, "expected error when context is nil") + }) s.Run("when table name is empty, should return error", func() { - err := builder.Drop(s.ctx, nil, "") + err := builder.Drop(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -167,11 +175,11 @@ func (s *postgresBuilderSuite) TestDrop() { table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.Drop(s.ctx, tx, "users") + err = builder.Drop(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should return error", func() { - err = builder.Drop(s.ctx, tx, "non_existent_table") + err = builder.Drop(c, "non_existent_table") s.Error(err, "expected error when dropping table that does not exist") }) } @@ -182,16 +190,22 @@ func (s *postgresBuilderSuite) TestDropIfExists() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.DropIfExists(nil, "test_table") + s.Error(err, "expected error when context is nil") + }) s.Run("when table name is empty, should return error", func() { - err := builder.DropIfExists(s.ctx, nil, "") + err := builder.DropIfExists(c, "") s.Error(err, "expected error when table name is empty") }) - s.Run("when tx is nil, should return error", func() { - err = builder.DropIfExists(s.ctx, nil, "test_table") - s.Error(err, "expected error when transaction is nil") + s.Run("when context is nil, should return error", func() { + err = builder.DropIfExists(nil, "test_table") + s.Error(err, "expected error when context is nil") }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -199,11 +213,11 @@ func (s *postgresBuilderSuite) TestDropIfExists() { table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.DropIfExists(s.ctx, tx, "users") + err = builder.DropIfExists(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should not return error", func() { - err = builder.DropIfExists(s.ctx, tx, "non_existent_table") + err = builder.DropIfExists(c, "non_existent_table") s.NoError(err, "expected no error when dropping non-existent table with IF EXISTS clause") }) } @@ -214,29 +228,31 @@ func (s *postgresBuilderSuite) TestRename() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Rename(s.ctx, nil, "old_table", "new_table") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Rename(nil, "old_table", "new_table") + s.Error(err, "expected error when context is nil") }) s.Run("when old table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "", "new_table") + err := builder.Rename(c, "", "new_table") s.Error(err, "expected error when old table name is empty") }) s.Run("when new table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "old_table", "") + err := builder.Rename(c, "old_table", "") s.Error(err, "expected error when new table name is empty") }) s.Run("when all parameters are valid, should rename table successfully", func() { - err = builder.Create(s.ctx, tx, "old_table", func(table *schema2.Blueprint) { + err = builder.Create(c, "old_table", func(table *schema.Blueprint) { table.ID() table.String("name", 255) }) s.NoError(err, "expected no error when creating table before renaming it") - err = builder.Rename(s.ctx, tx, "old_table", "new_table") + err = builder.Rename(c, "old_table", "new_table") s.NoError(err, "expected no error when renaming table with valid parameters") }) s.Run("when renaming non-existent table, should return error", func() { - err = builder.Rename(s.ctx, tx, "non_existent_table", "new_table") + err = builder.Rename(c, "non_existent_table", "new_table") s.Error(err, "expected error when renaming non-existent table") s.ErrorContains(err, "does not exist", "expected error message to contain 'does not exist'") }) @@ -248,20 +264,22 @@ func (s *postgresBuilderSuite) TestTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Table(s.ctx, nil, "test_table", func(table *schema2.Blueprint) {}) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Table(nil, "test_table", func(table *schema.Blueprint) {}) + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - err := builder.Table(s.ctx, tx, "", func(table *schema2.Blueprint) {}) + err := builder.Table(c, "", func(table *schema.Blueprint) {}) s.Error(err, "expected error when table name is empty") }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Table(s.ctx, tx, "test_table", nil) + err := builder.Table(c, "test_table", nil) s.Error(err, "expected error when blueprint is nil") }) s.Run("when all parameters are valid, should modify table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique("uk_users_email") @@ -274,75 +292,75 @@ func (s *postgresBuilderSuite) TestTable() { s.NoError(err, "expected no error when creating table before modifying it") s.Run("should add new columns and modify existing ones", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("address", 255).Nullable() table.String("phone", 20).Nullable().Unique("uk_users_phone") }) s.NoError(err, "expected no error when modifying table with valid parameters") }) s.Run("should modify existing column", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("email", 255).Nullable().Change() }) s.NoError(err, "expected no error when modifying existing column") }) s.Run("should drop column and rename existing one", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropColumn("password") table.RenameColumn("name", "full_name") }) s.NoError(err, "expected no error when dropping column and renaming existing one") }) s.Run("should add index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.Index("phone").Name("idx_users_phone").Algorithm("BTREE") }) s.NoError(err, "expected no error when adding index to table") }) s.Run("should rename index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.RenameIndex("idx_users_phone", "idx_users_contact") }) s.NoError(err, "expected no error when renaming index in table") }) s.Run("should drop index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropIndex("idx_users_contact") }) s.NoError(err, "expected no error when dropping index from table") }) s.Run("should drop unique constraint", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropUnique([]string{"email"}) }) s.NoError(err, "expected no error when dropping unique constraint from table") }) s.Run("should drop fulltext index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropFulltext("ft_users_bio") }) s.NoError(err, "expected no error when dropping fulltext index from table") }) s.Run("should add foreign key", func() { - err = builder.Create(s.ctx, tx, "roles", func(table *schema2.Blueprint) { + err = builder.Create(c, "roles", func(table *schema.Blueprint) { table.ID() table.String("role_name", 255).Unique("uk_roles_role_name") }) s.NoError(err, "expected no error when creating roles table before adding foreign key") - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.Integer("role_id").Nullable() table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when adding foreign key to users table") }) s.Run("should drop foreign key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropForeign("fk_users_roles") }) s.NoError(err, "expected no error when dropping foreign key from users table") }) s.Run("should drop primary key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropPrimary("pk_users") }) s.NoError(err, "expected no error when dropping primary key from users table") @@ -356,16 +374,18 @@ func (s *postgresBuilderSuite) TestGetColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetColumns(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetColumns(nil, "users") + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - _, err := builder.GetColumns(s.ctx, tx, "") + _, err := builder.GetColumns(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -374,12 +394,12 @@ func (s *postgresBuilderSuite) TestGetColumns() { }) s.NoError(err, "expected no error when creating table before getting columns") - columns, err := builder.GetColumns(s.ctx, tx, "users") + columns, err := builder.GetColumns(c, "users") s.NoError(err, "expected no error when getting columns with valid parameters") s.Len(columns, 6, "expected 6 columns to be returned") }) s.Run("when table does not exist, should return empty columns", func() { - columns, err := builder.GetColumns(s.ctx, tx, "non_existent_table") + columns, err := builder.GetColumns(c, "non_existent_table") s.NoError(err, "expected no error when getting columns of non-existent table") s.Empty(columns, "expected empty columns for non-existent table") }) @@ -391,16 +411,18 @@ func (s *postgresBuilderSuite) TestGetIndexes() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetIndexes(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetIndexes(nil, "users") + s.Error(err, "expected error when contexts is nil") }) s.Run("when table name is empty, should return error", func() { - _, err := builder.GetIndexes(s.ctx, tx, "") + _, err := builder.GetIndexes(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -411,13 +433,13 @@ func (s *postgresBuilderSuite) TestGetIndexes() { }) s.NoError(err, "expected no error when creating table before getting indexes") - indexes, err := builder.GetIndexes(s.ctx, tx, "users") + indexes, err := builder.GetIndexes(c, "users") s.NoError(err, "expected no error when getting indexes with valid parameters") s.Len(indexes, 3, "expected 3 index to be returned") }) s.Run("when table does not exist, should return empty indexes", func() { - indexes, err := builder.GetIndexes(s.ctx, tx, "non_existent_table") + indexes, err := builder.GetIndexes(c, "non_existent_table") s.NoError(err, "expected no error when getting indexes of non-existent table") s.Empty(indexes, "expected empty indexes for non-existent table") }) @@ -429,12 +451,14 @@ func (s *postgresBuilderSuite) TestGetTables() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetTables(s.ctx, nil) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetTables(nil) + s.Error(err, "expected error when context is nil") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -443,7 +467,7 @@ func (s *postgresBuilderSuite) TestGetTables() { }) s.NoError(err, "expected no error when creating table before getting tables") - tables, err := builder.GetTables(s.ctx, tx) + tables, err := builder.GetTables(c) s.NoError(err, "expected no error when getting tables with valid parameters") s.Len(tables, 1, "expected 1 table to be returned") userTable := tables[0] @@ -459,18 +483,20 @@ func (s *postgresBuilderSuite) TestHasColumn() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumn(s.ctx, nil, "users", "name") + exists, err := builder.HasColumn(nil, "users", "name") s.Error(err, "expected error when transaction is nil") s.False(exists, "expected exists to be false when transaction is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumn(s.ctx, tx, "", "name") + exists, err := builder.HasColumn(c, "", "name") s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -479,11 +505,11 @@ func (s *postgresBuilderSuite) TestHasColumn() { }) s.NoError(err, "expected no error when creating table before checking column existence") - exists, err := builder.HasColumn(s.ctx, tx, "users", "name") + exists, err := builder.HasColumn(c, "users", "name") s.NoError(err, "expected no error when checking if column exists with valid parameters") s.True(exists, "expected exists to be true for existing column") - exists, err = builder.HasColumn(s.ctx, tx, "users", "non_existent_column") + exists, err = builder.HasColumn(c, "users", "non_existent_column") s.NoError(err, "expected no error when checking non-existent column") s.False(exists, "expected exists to be false for non-existent column") }) @@ -495,18 +521,20 @@ func (s *postgresBuilderSuite) TestHasColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumns(s.ctx, nil, "users", []string{"name"}) + exists, err := builder.HasColumns(nil, "users", []string{"name"}) s.Error(err, "expected error when transaction is nil") s.False(exists, "expected exists to be false when transaction is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumns(s.ctx, tx, "", []string{"name"}) + exists, err := builder.HasColumns(c, "", []string{"name"}) s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -515,11 +543,11 @@ func (s *postgresBuilderSuite) TestHasColumns() { }) s.NoError(err, "expected no error when creating table before checking column existence") - exists, err := builder.HasColumns(s.ctx, tx, "users", []string{"name", "email"}) + exists, err := builder.HasColumns(c, "users", []string{"name", "email"}) s.NoError(err, "expected no error when checking if columns exist with valid parameters") s.True(exists, "expected exists to be true for existing columns") - exists, err = builder.HasColumns(s.ctx, tx, "users", []string{"name", "non_existent_column"}) + exists, err = builder.HasColumns(c, "users", []string{"name", "non_existent_column"}) s.NoError(err, "expected no error when checking mixed existing and non-existent columns") s.False(exists, "expected exists to be false for mixed existing and non-existent columns") }) @@ -531,18 +559,20 @@ func (s *postgresBuilderSuite) TestHasIndex() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasIndex(s.ctx, nil, "users", []string{"idx_users_name"}) - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasIndex(nil, "users", []string{"idx_users_name"}) + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasIndex(s.ctx, tx, "", []string{"idx_users_name"}) + exists, err := builder.HasIndex(c, "", []string{"idx_users_name"}) s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "orders", func(table *schema2.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.Integer("company_id") table.Integer("user_id") @@ -555,15 +585,15 @@ func (s *postgresBuilderSuite) TestHasIndex() { }) s.Require().NoError(err, "expected no error when creating table with index") - exists, err := builder.HasIndex(s.ctx, tx, "orders", []string{"uk_orders_order_id"}) + exists, err := builder.HasIndex(c, "orders", []string{"uk_orders_order_id"}) s.NoError(err, "expected no error when checking if index exists with valid parameters") s.True(exists, "expected exists to be true for existing index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"company_id", "user_id"}) + exists, err = builder.HasIndex(c, "orders", []string{"company_id", "user_id"}) s.NoError(err, "expected no error when checking non-existent index") s.True(exists, "expected exists to be true for existing composite index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"non_existent_index"}) + exists, err = builder.HasIndex(c, "orders", []string{"non_existent_index"}) s.NoError(err, "expected no error when checking non-existent index") s.False(exists, "expected exists to be false for non-existent index") }) @@ -575,18 +605,20 @@ func (s *postgresBuilderSuite) TestHasTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasTable(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasTable(nil, "users") + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasTable(s.ctx, tx, "") + exists, err := builder.HasTable(c, "") s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema2.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -595,11 +627,11 @@ func (s *postgresBuilderSuite) TestHasTable() { }) s.NoError(err, "expected no error when creating table before checking existence") - exists, err := builder.HasTable(s.ctx, tx, "users") + exists, err := builder.HasTable(c, "users") s.NoError(err, "expected no error when checking if table exists with valid parameters") s.True(exists, "expected exists to be true for existing table") - exists, err = builder.HasTable(s.ctx, tx, "non_existent_table") + exists, err = builder.HasTable(c, "non_existent_table") s.NoError(err, "expected no error when checking non-existent table") s.False(exists, "expected exists to be false for non-existent table") }) @@ -607,7 +639,7 @@ func (s *postgresBuilderSuite) TestHasTable() { _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_publics") s.NoError(err, "expected no error when creating custom schema") - err = builder.Create(s.ctx, tx, "custom_publics.users", func(table *schema2.Blueprint) { + err = builder.Create(c, "custom_publics.users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -616,10 +648,10 @@ func (s *postgresBuilderSuite) TestHasTable() { }) s.NoError(err, "expected no error when creating table with custom schema") - exists, err := builder.HasTable(s.ctx, tx, "custom_publics.users") + exists, err := builder.HasTable(c, "custom_publics.users") s.NoError(err, "expected no error when checking if table with custom schema exists") s.True(exists, "expected exists to be true for existing table with custom schema") - exists, err = builder.HasTable(s.ctx, tx, "custom_publics.non_existent_table") + exists, err = builder.HasTable(c, "custom_publics.non_existent_table") s.NoError(err, "expected no error when checking non-existent table with custom schema") s.False(exists, "expected exists to be false for non-existent table with custom schema") }) diff --git a/schema/schema.go b/schema/schema.go index b9ff1f2..c818dc6 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -1,7 +1,6 @@ package schema import ( - "context" "database/sql" "errors" @@ -69,13 +68,13 @@ func newBuilder() (Builder, error) { // table.Timestamp("created_at").Default("CURRENT_TIMESTAMP").Nullable(false) // table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable(false) // }) -func Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { +func Create(c *Context, name string, blueprint func(table *Blueprint)) error { builder, err := newBuilder() if err != nil { return err } - return builder.Create(ctx, tx, name, blueprint) + return builder.Create(c, name, blueprint) } // Drop removes the table with the given name. @@ -84,13 +83,13 @@ func Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table * // Example: // // err := schema.Drop(ctx, tx, "users") -func Drop(ctx context.Context, tx *sql.Tx, name string) error { +func Drop(c *Context, name string) error { builder, err := newBuilder() if err != nil { return err } - return builder.Drop(ctx, tx, name) + return builder.Drop(c, name) } // DropIfExists removes the table with the given name if it exists. @@ -99,13 +98,13 @@ func Drop(ctx context.Context, tx *sql.Tx, name string) error { // Example: // // err := schema.DropIfExists(ctx, tx, "users") -func DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { +func DropIfExists(c *Context, name string) error { builder, err := newBuilder() if err != nil { return err } - return builder.DropIfExists(ctx, tx, name) + return builder.DropIfExists(c, name) } // GetColumns retrieves the columns of the specified table. @@ -114,13 +113,13 @@ func DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { // Example: // // columns, err := schema.GetColumns(ctx, tx, "users") -func GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { +func GetColumns(c *Context, tableName string) ([]*Column, error) { builder, err := newBuilder() if err != nil { return nil, err } - return builder.GetColumns(ctx, tx, tableName) + return builder.GetColumns(c, tableName) } // GetIndexes retrieves the indexes of the specified table. @@ -129,13 +128,13 @@ func GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, e // Example: // // indexes, err := schema.GetIndexes(ctx, tx, "users") -func GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { +func GetIndexes(c *Context, tableName string) ([]*Index, error) { builder, err := newBuilder() if err != nil { return nil, err } - return builder.GetIndexes(ctx, tx, tableName) + return builder.GetIndexes(c, tableName) } // GetTables retrieves all tables in the database. @@ -144,13 +143,13 @@ func GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, er // Example: // // tables, err := schema.GetTables(ctx, tx) -func GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { +func GetTables(c *Context) ([]*TableInfo, error) { builder, err := newBuilder() if err != nil { return nil, err } - return builder.GetTables(ctx, tx) + return builder.GetTables(c) } // HasColumn checks if a column with the given name exists in the specified table. @@ -159,13 +158,13 @@ func GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { // Example: // // exists, err := schema.HasColumn(ctx, tx, "users", "email") -func HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) { +func HasColumn(c *Context, tableName string, columnName string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasColumn(ctx, tx, tableName, columnName) + return builder.HasColumn(c, tableName, columnName) } // HasColumns checks if the specified columns exist in the given table. @@ -176,13 +175,13 @@ func HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName str // exists, err := schema.HasColumns(ctx, tx, "users", []string{"email", "name"}) // // If any of the specified columns do not exist, it returns false. -func HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { +func HasColumns(c *Context, tableName string, columnNames []string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasColumns(ctx, tx, tableName, columnNames) + return builder.HasColumns(c, tableName, columnNames) } // HasIndex checks if an index with the given name exists in the specified table. @@ -193,13 +192,13 @@ func HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames [ // exists, err := schema.HasIndex(ctx, tx, "users", []string{"uk_users_email"}) // Checks if the index with name "uk_users_email" exists in the "users" table. // // exists, err := schema.HasIndex(ctx, tx, "users", []string{"email", "name"}) // Checks if a composite index exists on the "email" and "name" columns in the "users" table. -func HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { +func HasIndex(c *Context, tableName string, indexes []string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasIndex(ctx, tx, tableName, indexes) + return builder.HasIndex(c, tableName, indexes) } // HasTable checks if a table with the given name exists in the database. @@ -209,13 +208,13 @@ func HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []strin // Example: // // exists, err := schema.HasTable(ctx, tx, "users") -func HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { +func HasTable(c *Context, name string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasTable(ctx, tx, name) + return builder.HasTable(c, name) } // Rename changes the name of the table from name to newName. @@ -224,13 +223,13 @@ func HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { // Example: // // err := schema.Rename(ctx, tx, "users", "people") -func Rename(ctx context.Context, tx *sql.Tx, name string, newName string) error { +func Rename(c *Context, name string, newName string) error { builder, err := newBuilder() if err != nil { return err } - return builder.Rename(ctx, tx, name, newName) + return builder.Rename(c, name, newName) } // Table modifies an existing table with the given name and blueprint. @@ -244,11 +243,11 @@ func Rename(ctx context.Context, tx *sql.Tx, name string, newName string) error // table.DropColumn("password") // table.RenameColumn("email", "contact_email") // }) -func Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { +func Table(c *Context, name string, blueprint func(table *Blueprint)) error { builder, err := newBuilder() if err != nil { return err } - return builder.Table(ctx, tx, name, blueprint) + return builder.Table(c, name, blueprint) } diff --git a/schema/schema_test.go b/schema/schema_test.go index f28f54e..c790d58 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -36,26 +36,6 @@ func (s *schemaTestSuite) SetupSuite() { s.Require().NoError(err) s.db = db - - s.Run("when dialect is not set should return error", func() { - builderFuncs := []func() error{ - func() error { return schema.Create(ctx, nil, "", nil) }, - func() error { return schema.Drop(ctx, nil, "") }, - func() error { return schema.DropIfExists(ctx, nil, "") }, - func() error { _, err := schema.GetColumns(ctx, nil, ""); return err }, - func() error { _, err := schema.GetIndexes(ctx, nil, ""); return err }, - func() error { _, err := schema.GetTables(ctx, nil); return err }, - func() error { _, err := schema.HasColumn(ctx, nil, "", ""); return err }, - func() error { _, err := schema.HasColumns(ctx, nil, "", nil); return err }, - func() error { _, err := schema.HasIndex(ctx, nil, "", nil); return err }, - func() error { _, err := schema.HasTable(ctx, nil, ""); return err }, - func() error { return schema.Rename(ctx, nil, "", "") }, - func() error { return schema.Table(ctx, nil, "", nil) }, - } - for _, fn := range builderFuncs { - s.Error(fn(), "Expected error when dialect is not set") - } - }) migris.SetDialect("postgres") } @@ -68,8 +48,10 @@ func (s *schemaTestSuite) TestCreate() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should create table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -80,28 +62,21 @@ func (s *schemaTestSuite) TestCreate() { s.NoError(err) }) s.Run("when table already exists should return error", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() }) s.Error(err) s.ErrorContains(err, "\"users\" already exists") }) s.Run("when table name is empty should return error", func() { - err := schema.Create(s.ctx, tx, "", func(table *schema.Blueprint) { + err := schema.Create(c, "", func(table *schema.Blueprint) { table.ID() }) s.Error(err) s.ErrorContains(err, "invalid arguments") }) s.Run("when blueprint function is nil should return error", func() { - err := schema.Create(s.ctx, tx, "test", nil) - s.Error(err) - s.ErrorContains(err, "invalid arguments") - }) - s.Run("when transaction is nil should return error", func() { - err := schema.Create(s.ctx, nil, "test", func(table *schema.Blueprint) { - table.ID() - }) + err := schema.Create(c, "test", nil) s.Error(err) s.ErrorContains(err, "invalid arguments") }) @@ -112,8 +87,10 @@ func (s *schemaTestSuite) TestDrop() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should drop table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -122,20 +99,20 @@ func (s *schemaTestSuite) TestDrop() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - err = schema.Drop(s.ctx, tx, "users") + err = schema.Drop(c, "users") s.NoError(err) }) s.Run("when table does not exist should return error", func() { - err := schema.Drop(s.ctx, tx, "non_existing_table") + err := schema.Drop(c, "non_existing_table") s.Error(err) }) s.Run("when table name is empty should return error", func() { - err := schema.Drop(s.ctx, tx, "") + err := schema.Drop(c, "") s.Error(err) s.ErrorContains(err, "invalid arguments") }) - s.Run("when transaction is nil should return error", func() { - err := schema.Drop(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + err := schema.Drop(nil, "test") s.Error(err) s.ErrorContains(err, "invalid arguments") }) @@ -146,8 +123,10 @@ func (s *schemaTestSuite) TestDropIfExists() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should drop table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -156,20 +135,20 @@ func (s *schemaTestSuite) TestDropIfExists() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - err = schema.DropIfExists(s.ctx, tx, "users") + err = schema.DropIfExists(c, "users") s.NoError(err) }) s.Run("when table does not exist should return no error", func() { - err := schema.DropIfExists(s.ctx, tx, "non_existing_table") + err := schema.DropIfExists(c, "non_existing_table") s.NoError(err) }) s.Run("when table name is empty should return error", func() { - err := schema.DropIfExists(s.ctx, tx, "") + err := schema.DropIfExists(c, "") s.Error(err) s.ErrorContains(err, "invalid arguments") }) - s.Run("when transaction is nil should return error", func() { - err := schema.DropIfExists(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + err := schema.DropIfExists(nil, "test") s.Error(err) s.ErrorContains(err, "invalid arguments") }) @@ -180,8 +159,10 @@ func (s *schemaTestSuite) TestGetColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should return columns", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -190,23 +171,23 @@ func (s *schemaTestSuite) TestGetColumns() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - columns, err := schema.GetColumns(s.ctx, tx, "users") + columns, err := schema.GetColumns(c, "users") s.NoError(err) s.NotEmpty(columns) s.Len(columns, 6) }) s.Run("when table does not exist should empty columns", func() { - columns, err := schema.GetColumns(s.ctx, tx, "non_existing_table") + columns, err := schema.GetColumns(c, "non_existing_table") s.NoError(err) s.Nil(columns) }) s.Run("when table name is empty should return error", func() { - columns, err := schema.GetColumns(s.ctx, tx, "") + columns, err := schema.GetColumns(c, "") s.Error(err) s.Nil(columns) }) - s.Run("when transaction is nil should return error", func() { - columns, err := schema.GetColumns(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + columns, err := schema.GetColumns(nil, "test") s.Error(err) s.Nil(columns) }) @@ -217,8 +198,10 @@ func (s *schemaTestSuite) TestGetIndexes() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should return indexes", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -227,23 +210,23 @@ func (s *schemaTestSuite) TestGetIndexes() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - indexes, err := schema.GetIndexes(s.ctx, tx, "users") + indexes, err := schema.GetIndexes(c, "users") s.NoError(err) s.NotEmpty(indexes) s.Len(indexes, 2) // Expecting the unique index on email and the primary key index on id }) s.Run("when table does not exist should return empty indexes", func() { - indexes, err := schema.GetIndexes(s.ctx, tx, "non_existing_table") + indexes, err := schema.GetIndexes(c, "non_existing_table") s.NoError(err) s.Nil(indexes) }) s.Run("when table name is empty should return error", func() { - indexes, err := schema.GetIndexes(s.ctx, tx, "") + indexes, err := schema.GetIndexes(c, "") s.Error(err) s.Nil(indexes) }) - s.Run("when transaction is nil should return error", func() { - indexes, err := schema.GetIndexes(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + indexes, err := schema.GetIndexes(nil, "test") s.Error(err) s.Nil(indexes) }) @@ -254,13 +237,15 @@ func (s *schemaTestSuite) TestGetTables() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when no tables exist should return empty", func() { - tables, err := schema.GetTables(s.ctx, tx) + tables, err := schema.GetTables(c) s.NoError(err) s.Empty(tables) }) s.Run("when transaction is valid should return tables", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -269,13 +254,13 @@ func (s *schemaTestSuite) TestGetTables() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - tables, err := schema.GetTables(s.ctx, tx) + tables, err := schema.GetTables(c) s.NoError(err) s.NotEmpty(tables) s.Len(tables, 1) // Expecting at least the 'users' table created in previous tests }) - s.Run("when transaction is nil should return error", func() { - tables, err := schema.GetTables(s.ctx, nil) + s.Run("when context is nil should return error", func() { + tables, err := schema.GetTables(nil) s.Error(err) s.Nil(tables) }) @@ -286,8 +271,10 @@ func (s *schemaTestSuite) TestHasColumn() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when column exists should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -297,19 +284,19 @@ func (s *schemaTestSuite) TestHasColumn() { }) s.Require().NoError(err) - exists, err := schema.HasColumn(s.ctx, tx, "users", "email") + exists, err := schema.HasColumn(c, "users", "email") s.NoError(err) s.True(exists) }) s.Run("when column does not exist should return false", func() { - exists, err := schema.HasColumn(s.ctx, tx, "users", "non_existing_column") + exists, err := schema.HasColumn(c, "users", "non_existing_column") s.NoError(err) s.False(exists) }) - s.Run("when transaction is nil should return error", func() { - exists, err := schema.HasColumn(s.ctx, nil, "users", "email") + s.Run("when context is nil should return error", func() { + exists, err := schema.HasColumn(nil, "users", "email") s.Error(err) s.False(exists) }) @@ -320,8 +307,10 @@ func (s *schemaTestSuite) TestHasColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when all columns exist should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -331,17 +320,17 @@ func (s *schemaTestSuite) TestHasColumns() { }) s.Require().NoError(err) - exists, err := schema.HasColumns(s.ctx, tx, "users", []string{"email", "name"}) + exists, err := schema.HasColumns(c, "users", []string{"email", "name"}) s.NoError(err) s.True(exists) }) s.Run("when some columns do not exist should return false", func() { - exists, err := schema.HasColumns(s.ctx, tx, "users", []string{"email", "non_existing_column"}) + exists, err := schema.HasColumns(c, "users", []string{"email", "non_existing_column"}) s.NoError(err) s.False(exists) }) s.Run("when no columns are provided should return error", func() { - exists, err := schema.HasColumns(s.ctx, tx, "users", []string{}) + exists, err := schema.HasColumns(c, "users", []string{}) s.Error(err) s.False(exists) }) @@ -352,8 +341,10 @@ func (s *schemaTestSuite) TestHasIndex() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when index exists should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.Integer("company_id") table.String("name") @@ -366,20 +357,20 @@ func (s *schemaTestSuite) TestHasIndex() { }) s.Require().NoError(err) - exists, err := schema.HasIndex(s.ctx, tx, "users", []string{"email"}) + exists, err := schema.HasIndex(c, "users", []string{"email"}) s.NoError(err) s.True(exists) - exists, err = schema.HasIndex(s.ctx, tx, "users", []string{"company_id", "id"}) + exists, err = schema.HasIndex(c, "users", []string{"company_id", "id"}) s.NoError(err) s.True(exists) - exists, err = schema.HasIndex(s.ctx, tx, "users", []string{"uk_users_email"}) + exists, err = schema.HasIndex(c, "users", []string{"uk_users_email"}) s.NoError(err) s.True(exists) }) s.Run("when index does not exist should return false", func() { - exists, err := schema.HasIndex(s.ctx, tx, "users", []string{"non_existing_index"}) + exists, err := schema.HasIndex(c, "users", []string{"non_existing_index"}) s.NoError(err) s.False(exists) }) @@ -390,8 +381,10 @@ func (s *schemaTestSuite) TestHasTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when table exists should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -401,22 +394,22 @@ func (s *schemaTestSuite) TestHasTable() { }) s.Require().NoError(err) - exists, err := schema.HasTable(s.ctx, tx, "users") + exists, err := schema.HasTable(c, "users") s.NoError(err) s.True(exists) }) s.Run("when table does not exist should return false", func() { - exists, err := schema.HasTable(s.ctx, tx, "non_existing_table") + exists, err := schema.HasTable(c, "non_existing_table") s.NoError(err) s.False(exists) }) s.Run("when table name is empty should return error", func() { - exists, err := schema.HasTable(s.ctx, tx, "") + exists, err := schema.HasTable(c, "") s.Error(err) s.False(exists) }) - s.Run("when transaction is nil should return error", func() { - exists, err := schema.HasTable(s.ctx, nil, "users") + s.Run("when context is nil should return error", func() { + exists, err := schema.HasTable(nil, "users") s.Error(err) s.False(exists) }) @@ -427,8 +420,10 @@ func (s *schemaTestSuite) TestRename() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should rename table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -438,27 +433,27 @@ func (s *schemaTestSuite) TestRename() { }) s.Require().NoError(err) - err = schema.Rename(s.ctx, tx, "users", "members") + err = schema.Rename(c, "users", "members") s.NoError(err) - columns, err := schema.GetColumns(s.ctx, tx, "members") + columns, err := schema.GetColumns(c, "members") s.NoError(err) s.NotEmpty(columns) s.Len(columns, 6) }) s.Run("when table does not exist should return error", func() { - err := schema.Rename(s.ctx, tx, "non_existing_table", "new_name") + err := schema.Rename(c, "non_existing_table", "new_name") s.Error(err) }) s.Run("when new name is empty should return error", func() { - err := schema.Rename(s.ctx, tx, "users", "") + err := schema.Rename(c, "users", "") s.Error(err) }) - s.Run("when transaction is nil should return error", func() { - err := schema.Rename(s.ctx, nil, "users", "new_name") + s.Run("when context is nil should return error", func() { + err := schema.Rename(nil, "users", "new_name") s.Error(err) }) } @@ -468,8 +463,10 @@ func (s *schemaTestSuite) TestTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should alter table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -479,25 +476,25 @@ func (s *schemaTestSuite) TestTable() { }) s.Require().NoError(err) - err = schema.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = schema.Table(c, "users", func(table *schema.Blueprint) { table.String("phone").Nullable() }) s.NoError(err) - columns, err := schema.GetColumns(s.ctx, tx, "users") + columns, err := schema.GetColumns(c, "users") s.NoError(err) s.Len(columns, 7) // Expecting the new 'phone' column to be added }) s.Run("when table does not exist should return error", func() { - err := schema.Table(s.ctx, tx, "non_existing_table", func(table *schema.Blueprint) { + err := schema.Table(c, "non_existing_table", func(table *schema.Blueprint) { table.String("new_column") }) s.Error(err) }) - s.Run("when transaction is nil should return error", func() { - err := schema.Table(s.ctx, nil, "users", func(table *schema.Blueprint) { + s.Run("when context is nil should return error", func() { + err := schema.Table(nil, "users", func(table *schema.Blueprint) { table.String("new_column") }) s.Error(err) diff --git a/status.go b/status.go new file mode 100644 index 0000000..dcbeb70 --- /dev/null +++ b/status.go @@ -0,0 +1,26 @@ +package migris + +import ( + "context" + "database/sql" + + "github.com/afkdevs/migris/internal/logger" +) + +func Status(db *sql.DB, dir string) error { + ctx := context.Background() + return StatusContext(ctx, db, dir) +} + +func StatusContext(ctx context.Context, db *sql.DB, dir string) error { + provider, err := newProvider(db, dir) + if err != nil { + return err + } + migrations, err := provider.Status(ctx) + if err != nil { + return err + } + logger.PrintStatuses(migrations) + return nil +} diff --git a/up.go b/up.go index 4df1ea7..c5624bd 100644 --- a/up.go +++ b/up.go @@ -3,6 +3,10 @@ package migris import ( "context" "database/sql" + "errors" + + "github.com/afkdevs/migris/internal/logger" + "github.com/pressly/goose/v3" ) // Up applies the migrations in the specified directory. @@ -17,9 +21,26 @@ func UpContext(ctx context.Context, db *sql.DB, dir string) error { if err != nil { return err } - _, err = provider.Up(ctx) + hasPending, err := provider.HasPending(ctx) if err != nil { return err } + if !hasPending { + logger.Info("Nothing to migrate.") + return nil + } + logger.Infof("Running migrations.\n") + + results, err := provider.Up(ctx) + if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResults(partialErr.Applied) + logger.PrintResult(partialErr.Failed) + } + + return err + } + logger.PrintResults(results) return nil } From 0eecc7c13dd59cd449415ba917d52d27fc9e203e Mon Sep 17 00:00:00 2001 From: Ahmad Faiz Kamaludin Date: Sun, 31 Aug 2025 02:01:00 +0700 Subject: [PATCH 5/7] feat: update doc and adjust migrate command --- CONTRIBUTING.md | 2 +- README.md | 181 +++++++++++++++++++------- create.go | 6 +- down.go | 11 +- examples/basic/cmd/migrate/migrate.go | 41 +++--- migrate.go | 17 +++ reset.go | 11 +- schema/schema_test.go | 3 +- status.go | 11 +- up.go | 9 +- 10 files changed, 196 insertions(+), 96 deletions(-) create mode 100644 migrate.go diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b13bcce..1e76341 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# Contributing to \[Schema] +# Contributing to \[Migris] Thank you for considering contributing to this Go library! We welcome contributions of all kindsβ€”bug reports, feature requests, documentation updates, tests, and code improvements. diff --git a/README.md b/README.md index b6c800f..e4bb720 100644 --- a/README.md +++ b/README.md @@ -1,74 +1,159 @@ -# Go-Schema -[![Go](https://github.com/afkdevs/go-schema/actions/workflows/ci.yml/badge.svg)](https://github.com/afkdevs/go-schema/actions/workflows/ci.yml) -[![Go Report Card](https://goreportcard.com/badge/github.com/afkdevs/go-schema)](https://goreportcard.com/report/github.com/afkdevs/go-schema) -[![codecov](https://codecov.io/gh/afkdevs/go-schema/graph/badge.svg?token=7tbSVRaD4b)](https://codecov.io/gh/afkdevs/go-schema) -[![GoDoc](https://pkg.go.dev/badge/github.com/afkdevs/go-schema)](https://pkg.go.dev/github.com/afkdevs/go-schema) -[![Go Version](https://img.shields.io/github/go-mod/go-version/afkdevs/go-schema)](https://golang.org/doc/devel/release.html) -[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) +# Migris -`Go-Schema` is a simple Go library for building and running SQL schema (DDL) code in a clean, readable, and migration-friendly way. Inspired by Laravel's Schema Builder, it helps you easily create or change database tablesβ€”and works well with tools like [`goose`](https://github.com/pressly/goose). +**Migris** is a database migration library for Go, inspired by Laravel's migrations. +It combines the power of [pressly/goose](https://github.com/pressly/goose) with a fluent schema builder, making migrations easy to write, run, and maintain. -## Features +## ✨ Features -- πŸ“Š Programmatic table and column definitions -- πŸ—ƒοΈ Support for common data types and constraints -- βš™οΈ Auto-generates `CREATE TABLE`, `ALTER TABLE`, index and foreign key SQL -- πŸ”€ Designed to work with database transactions -- πŸ§ͺ Built-in types and functions make migration code clear and testable -- πŸ” Provides helper functions to get list tables, columns, and indexes +- πŸ“¦ Migration management (`up`, `down`, `reset`, `status`, `create`) +- πŸ—οΈ Fluent schema builder (similar to Laravel migrations) +- πŸ—„οΈ Supports PostgreSQL, MySQL, and MariaDB +- πŸ”„ Transaction-based migrations +- πŸ› οΈ Integration with Go projects (no external CLI required) -## Supported Databases +## πŸš€ Installation -Currently, `schema` is tested and optimized for: +```bash +go get github.com/afkdevs/migris +``` -* PostgreSQL -* MySQL / MariaDB -* SQLite (TODO) +## πŸ“š Usage -## Installation +### 1. Create a Migration -```bash -go get github.com/afkdevs/go-schema -``` +Migrations are defined in Go files using the schema builder: -## Integration Example (with goose) ```go package migrations import ( - "context" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/schema" +) + +func init() { + migris.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) +} + +func upCreateUsersTable(c *schema.Context) error { + return schema.Create(c, "users", func(table *schema.Blueprint) { + table.ID() + table.String("name") + table.String("email") + table.Timestamp("email_verified_at").Nullable() + table.String("password") + table.Timestamps() + }) +} + +func downCreateUsersTable(c *schema.Context) error { + return schema.DropIfExists(c, "users") +} +``` + +This creates a `users` table with common fields. + +### 2. Run Migrations + +You can manage migrations directly from Go code: + +```go +package migrate + +import ( "database/sql" + "fmt" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/afkdevs/migris" + "github.com/afkdevs/migris/examples/basic/config" + _ "github.com/afkdevs/migris/examples/basic/migrations" + _ "github.com/lib/pq" // PostgreSQL driver ) -func init() { - goose.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) +func Up() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Up() +} + +func Create(name string) error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Create(name) +} + +func Reset() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Reset() +} + +func Down() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Down() } -func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email") - table.Timestamp("email_verified_at").Nullable() - table.String("password") - table.Timestamps() - }) +func Status() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Status() } -func downCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.Drop(ctx, tx, "users") +func newMigrate() (*migris.Migrate, error) { + if err := migris.SetDialect("postgres"); err != nil { + return nil, fmt.Errorf("failed to set schema dialect: %w", err) + } + dsn := "postgres://user:pass@localhost:5432/mydb?sslmode=disable" + db, err := sql.Open("postgres", dsn) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + return migris.New(db, "migrations"), nil } ``` -For more examples, check out the [examples](examples/basic) directory. -## Documentation -For detailed documentation, please refer to the [GoDoc](https://pkg.go.dev/github.com/afkdevs/go-schema) page. +## πŸ”§ Commands + +Here are the available migration commands: + +| Function | Description | +|-------------------|--------------------------------------------| +| `migris.Up` | Apply all pending migrations | +| `migris.Down` | Rollback the last migration | +| `migris.Reset` | Rollback all migrations | +| `migris.Status` | Show migration status | +| `migris.Create` | Create a new migration file with timestamp | + +## πŸ› οΈ Example Schema + +```go +schema.Create(c, "posts", func(table *schema.Blueprint) { + table.ID() + table.String("title") + table.Text("body") + table.ForeignID("user_id").Constrained("users") + table.Timestamps() +}) +``` + +## πŸ“– Roadmap + +- [ ] Add dry-run mode +- [ ] Add SQLite support +- [ ] CLI wrapper for quick usage -## Contributing -Contributions are welcome! Please read the [contributing guidelines](CONTRIBUTING.md) and submit a pull request. +## πŸ“„ License -## License -This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. \ No newline at end of file +MIT License. +See [LICENSE](./LICENSE) for details. \ No newline at end of file diff --git a/create.go b/create.go index 5ea1922..5d777ca 100644 --- a/create.go +++ b/create.go @@ -7,10 +7,10 @@ import ( "github.com/pressly/goose/v3" ) -// Create a new migration file -func Create(dir, name string) error { +// Create creates a new migration file with the given name in the specified directory. +func (m *Migrate) Create(name string) error { tmpl := getMigrationTemplate(name) - return goose.CreateWithTemplate(nil, dir, tmpl, name, "go") + return goose.CreateWithTemplate(nil, m.dir, tmpl, name, "go") } func getMigrationTemplate(name string) *template.Template { diff --git a/down.go b/down.go index 66bb3c3..756dfef 100644 --- a/down.go +++ b/down.go @@ -2,20 +2,21 @@ package migris import ( "context" - "database/sql" "errors" "github.com/afkdevs/migris/internal/logger" "github.com/pressly/goose/v3" ) -func Down(db *sql.DB, dir string) error { +// Down rolls back the last migration. +func (m *Migrate) Down() error { ctx := context.Background() - return DownContext(ctx, db, dir) + return m.DownContext(ctx) } -func DownContext(ctx context.Context, db *sql.DB, dir string) error { - provider, err := newProvider(db, dir) +// DownContext rolls back the last migration. +func (m *Migrate) DownContext(ctx context.Context) error { + provider, err := newProvider(m.db, m.dir) if err != nil { return err } diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index 3695d07..8106ae2 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -10,47 +10,47 @@ import ( _ "github.com/lib/pq" // PostgreSQL driver ) -const ( - directory = "migrations" -) - func Up() error { - db, err := initMigrator() + m, err := newMigrate() if err != nil { return err } - return migris.Up(db, directory) + return m.Up() } func Create(name string) error { - return migris.Create(directory, name) + m, err := newMigrate() + if err != nil { + return err + } + return m.Create(name) } func Reset() error { - db, err := initMigrator() + m, err := newMigrate() if err != nil { return err } - return migris.Reset(db, directory) + return m.Reset() } func Down() error { - db, err := initMigrator() + m, err := newMigrate() if err != nil { return err } - return migris.Down(db, directory) + return m.Down() } func Status() error { - db, err := initMigrator() + m, err := newMigrate() if err != nil { return err } - return migris.Status(db, directory) + return m.Status() } -func initMigrator() (*sql.DB, error) { +func newMigrate() (*migris.Migrate, error) { if err := migris.SetDialect("postgres"); err != nil { return nil, fmt.Errorf("failed to set schema dialect: %w", err) } @@ -58,23 +58,18 @@ func initMigrator() (*sql.DB, error) { if err != nil { return nil, err } - db, err := newDatabase(cfg.Database) + db, err := openDatabase(cfg.Database) if err != nil { return nil, err } - return db, nil + return migris.New(db, "migrations"), nil } -func newDatabase(cfg config.Database) (*sql.DB, error) { +func openDatabase(cfg config.Database) (*sql.DB, error) { dsn := cfg.DSN() db, err := sql.Open("postgres", dsn) if err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) - } - - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("failed to ping database: %w", err) + return nil, fmt.Errorf("failed to open database: %w", err) } - return db, nil } diff --git a/migrate.go b/migrate.go new file mode 100644 index 0000000..de8ce4c --- /dev/null +++ b/migrate.go @@ -0,0 +1,17 @@ +package migris + +import "database/sql" + +// Migrate handles database migrations +type Migrate struct { + dir string + db *sql.DB +} + +// New creates a new Migrate instance +func New(db *sql.DB, dir string) *Migrate { + return &Migrate{ + dir: dir, + db: db, + } +} diff --git a/reset.go b/reset.go index aa1b55f..580b134 100644 --- a/reset.go +++ b/reset.go @@ -2,20 +2,21 @@ package migris import ( "context" - "database/sql" "errors" "github.com/afkdevs/migris/internal/logger" "github.com/pressly/goose/v3" ) -func Reset(db *sql.DB, dir string) error { +// Reset rolls back all migrations. +func (m *Migrate) Reset() error { ctx := context.Background() - return ResetContext(ctx, db, dir) + return m.ResetContext(ctx) } -func ResetContext(ctx context.Context, db *sql.DB, dir string) error { - provider, err := newProvider(db, dir) +// ResetContext rolls back all migrations. +func (m *Migrate) ResetContext(ctx context.Context) error { + provider, err := newProvider(m.db, m.dir) if err != nil { return err } diff --git a/schema/schema_test.go b/schema/schema_test.go index c790d58..4f705ca 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -36,7 +36,8 @@ func (s *schemaTestSuite) SetupSuite() { s.Require().NoError(err) s.db = db - migris.SetDialect("postgres") + err = migris.SetDialect("postgres") + s.Require().NoError(err) } func (s *schemaTestSuite) TearDownSuite() { diff --git a/status.go b/status.go index dcbeb70..ad1ec36 100644 --- a/status.go +++ b/status.go @@ -2,18 +2,19 @@ package migris import ( "context" - "database/sql" "github.com/afkdevs/migris/internal/logger" ) -func Status(db *sql.DB, dir string) error { +// Status returns the status of the migrations. +func (m *Migrate) Status() error { ctx := context.Background() - return StatusContext(ctx, db, dir) + return m.StatusContext(ctx) } -func StatusContext(ctx context.Context, db *sql.DB, dir string) error { - provider, err := newProvider(db, dir) +// StatusContext returns the status of the migrations. +func (m *Migrate) StatusContext(ctx context.Context) error { + provider, err := newProvider(m.db, m.dir) if err != nil { return err } diff --git a/up.go b/up.go index c5624bd..3508309 100644 --- a/up.go +++ b/up.go @@ -2,7 +2,6 @@ package migris import ( "context" - "database/sql" "errors" "github.com/afkdevs/migris/internal/logger" @@ -10,14 +9,14 @@ import ( ) // Up applies the migrations in the specified directory. -func Up(db *sql.DB, dir string) error { +func (m *Migrate) Up() error { ctx := context.Background() - return UpContext(ctx, db, dir) + return m.UpContext(ctx) } // UpContext applies the migrations in the specified directory. -func UpContext(ctx context.Context, db *sql.DB, dir string) error { - provider, err := newProvider(db, dir) +func (m *Migrate) UpContext(ctx context.Context) error { + provider, err := newProvider(m.db, m.dir) if err != nil { return err } From 507f01e055f3ffc24b3c61c4f4ed611bce9d8dc8 Mon Sep 17 00:00:00 2001 From: Ahmad Faiz Kamaludin Date: Sun, 31 Aug 2025 02:55:58 +0700 Subject: [PATCH 6/7] feat: adjust migrate options --- README.md | 26 +---------- config.go | 27 ------------ create.go | 2 +- down.go | 36 ++++++++++++++- examples/basic/cmd/migrate/migrate.go | 9 ++-- internal/config/config.go | 14 +++--- migrate.go | 60 ++++++++++++++++++++++--- options.go | 26 +++++++++++ provider.go | 33 -------------- register.go | 63 ++++++++++++++++++--------- reset.go | 2 +- schema/schema.go | 6 +-- schema/schema_test.go | 7 ++- status.go | 2 +- up.go | 18 ++++++-- 15 files changed, 194 insertions(+), 137 deletions(-) delete mode 100644 config.go create mode 100644 options.go delete mode 100644 provider.go diff --git a/README.md b/README.md index e4bb720..e2a3acc 100644 --- a/README.md +++ b/README.md @@ -119,34 +119,10 @@ func newMigrate() (*migris.Migrate, error) { if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } - return migris.New(db, "migrations"), nil + return migris.New("postgres", migris.WithDB(db), migris.WithMigrationPath("migrations")), nil } ``` -## πŸ”§ Commands - -Here are the available migration commands: - -| Function | Description | -|-------------------|--------------------------------------------| -| `migris.Up` | Apply all pending migrations | -| `migris.Down` | Rollback the last migration | -| `migris.Reset` | Rollback all migrations | -| `migris.Status` | Show migration status | -| `migris.Create` | Create a new migration file with timestamp | - -## πŸ› οΈ Example Schema - -```go -schema.Create(c, "posts", func(table *schema.Blueprint) { - table.ID() - table.String("title") - table.Text("body") - table.ForeignID("user_id").Constrained("users") - table.Timestamps() -}) -``` - ## πŸ“– Roadmap - [ ] Add dry-run mode diff --git a/config.go b/config.go deleted file mode 100644 index a448753..0000000 --- a/config.go +++ /dev/null @@ -1,27 +0,0 @@ -package migris - -import ( - "errors" - - "github.com/afkdevs/migris/internal/config" - "github.com/afkdevs/migris/internal/dialect" -) - -// SetDialect sets the migrator dialect -func SetDialect(d string) error { - dialectValue := dialect.FromString(d) - if dialectValue == dialect.Unknown { - return errors.New("unsupported dialect: " + d) - } - cfg := config.Get() - cfg.Dialect = dialectValue - config.Set(cfg) - return nil -} - -// SetTableName sets the table name for the migrator -func SetTableName(name string) { - cfg := config.Get() - cfg.TableName = name - config.Set(cfg) -} diff --git a/create.go b/create.go index 5d777ca..39a3e44 100644 --- a/create.go +++ b/create.go @@ -10,7 +10,7 @@ import ( // Create creates a new migration file with the given name in the specified directory. func (m *Migrate) Create(name string) error { tmpl := getMigrationTemplate(name) - return goose.CreateWithTemplate(nil, m.dir, tmpl, name, "go") + return goose.CreateWithTemplate(nil, m.migrationPath, tmpl, name, "go") } func getMigrationTemplate(name string) *template.Template { diff --git a/down.go b/down.go index 756dfef..5a4b437 100644 --- a/down.go +++ b/down.go @@ -16,7 +16,7 @@ func (m *Migrate) Down() error { // DownContext rolls back the last migration. func (m *Migrate) DownContext(ctx context.Context) error { - provider, err := newProvider(m.db, m.dir) + provider, err := m.newProvider() if err != nil { return err } @@ -42,3 +42,37 @@ func (m *Migrate) DownContext(ctx context.Context) error { } return nil } + +// DownTo rolls back the migrations to the specified version. +func (m *Migrate) DownTo(version int64) error { + ctx := context.Background() + return m.DownToContext(ctx, version) +} + +// DownToContext rolls back the migrations to the specified version. +func (m *Migrate) DownToContext(ctx context.Context, version int64) error { + provider, err := m.newProvider() + if err != nil { + return err + } + currentVersion, err := provider.GetDBVersion(ctx) + if err != nil { + return err + } + if currentVersion == 0 { + logger.Info("Nothing to rollback.") + return nil + } + logger.Info("Rolling back migrations.\n") + results, err := provider.DownTo(ctx, version) + if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResults(partialErr.Applied) + logger.PrintResult(partialErr.Failed) + } + return err + } + logger.PrintResults(results) + return nil +} diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index 8106ae2..8cc107c 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -51,9 +51,6 @@ func Status() error { } func newMigrate() (*migris.Migrate, error) { - if err := migris.SetDialect("postgres"); err != nil { - return nil, fmt.Errorf("failed to set schema dialect: %w", err) - } cfg, err := config.Load() if err != nil { return nil, err @@ -62,7 +59,11 @@ func newMigrate() (*migris.Migrate, error) { if err != nil { return nil, err } - return migris.New(db, "migrations"), nil + migrate, err := migris.New("postgres", migris.WithDB(db)) + if err != nil { + return nil, fmt.Errorf("failed to create migris instance: %w", err) + } + return migrate, nil } func openDatabase(cfg config.Database) (*sql.DB, error) { diff --git a/internal/config/config.go b/internal/config/config.go index 9c02a82..cf3511d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -7,23 +7,23 @@ import ( ) type Config struct { - Dialect dialect.Dialect - TableName string + Dialect dialect.Dialect } var config = atomic.Pointer[Config]{} func init() { config.Store(&Config{ - Dialect: dialect.Unknown, - TableName: "migris_db_version", + Dialect: dialect.Unknown, }) } -func Set(cfg *Config) { +func SetDialect(dialect dialect.Dialect) { + cfg := config.Load() + cfg.Dialect = dialect config.Store(cfg) } -func Get() *Config { - return config.Load() +func GetDialect() dialect.Dialect { + return config.Load().Dialect } diff --git a/migrate.go b/migrate.go index de8ce4c..5991261 100644 --- a/migrate.go +++ b/migrate.go @@ -1,17 +1,63 @@ package migris -import "database/sql" +import ( + "database/sql" + "errors" + "os" + + "github.com/afkdevs/migris/internal/config" + "github.com/afkdevs/migris/internal/dialect" + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/database" +) // Migrate handles database migrations type Migrate struct { - dir string - db *sql.DB + dialect dialect.Dialect + migrationPath string + db *sql.DB + tableName string } // New creates a new Migrate instance -func New(db *sql.DB, dir string) *Migrate { - return &Migrate{ - dir: dir, - db: db, +func New(dialectValue string, opts ...Option) (*Migrate, error) { + dialectVal := dialect.FromString(dialectValue) + if dialectVal == dialect.Unknown { + return nil, errors.New("unknown database dialect") + } + config.SetDialect(dialectVal) + + m := &Migrate{ + dialect: dialectVal, + migrationPath: "migrations", + tableName: "migris_db_version", + } + for _, opt := range opts { + opt(m) + } + if m.db == nil { + return nil, errors.New("database connection is not set, please call WithDB option") + } + return m, nil +} + +func (m *Migrate) newProvider() (*goose.Provider, error) { + val := config.GetDialect() + if val == dialect.Unknown { + return nil, errors.New("unknown database dialect") + } + gooseDialect := val.GooseDialect() + store, err := database.NewStore(gooseDialect, m.tableName) + if err != nil { + return nil, err + } + provider, err := goose.NewProvider(database.DialectCustom, m.db, os.DirFS(m.migrationPath), + goose.WithStore(store), + goose.WithDisableGlobalRegistry(true), + goose.WithGoMigrations(gooseMigrations()...), + ) + if err != nil { + return nil, err } + return provider, nil } diff --git a/options.go b/options.go new file mode 100644 index 0000000..bfaeefa --- /dev/null +++ b/options.go @@ -0,0 +1,26 @@ +package migris + +import "database/sql" + +type Option func(*Migrate) + +// WithTableName sets the table name for the migration. +func WithTableName(name string) Option { + return func(m *Migrate) { + m.tableName = name + } +} + +// WithMigrationPath sets the directory for the migration files. +func WithMigrationPath(path string) Option { + return func(m *Migrate) { + m.migrationPath = path + } +} + +// WithDB sets the database connection for the migration. +func WithDB(db *sql.DB) Option { + return func(m *Migrate) { + m.db = db + } +} diff --git a/provider.go b/provider.go deleted file mode 100644 index 09af4fe..0000000 --- a/provider.go +++ /dev/null @@ -1,33 +0,0 @@ -package migris - -import ( - "database/sql" - "errors" - "os" - - "github.com/afkdevs/migris/internal/config" - "github.com/afkdevs/migris/internal/dialect" - "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/database" -) - -func newProvider(db *sql.DB, dir string) (*goose.Provider, error) { - cfg := config.Get() - if cfg.Dialect == dialect.Unknown { - return nil, errors.New("unknown database dialect") - } - dialect := cfg.Dialect.GooseDialect() - store, err := database.NewStore(dialect, cfg.TableName) - if err != nil { - return nil, err - } - provider, err := goose.NewProvider(database.DialectCustom, db, os.DirFS(dir), - goose.WithStore(store), - goose.WithDisableGlobalRegistry(true), - goose.WithGoMigrations(registeredMigrations...), - ) - if err != nil { - return nil, err - } - return provider, nil -} diff --git a/register.go b/register.go index d4b4add..514fc6d 100644 --- a/register.go +++ b/register.go @@ -13,55 +13,78 @@ import ( var ( registeredVersions = make(map[int64]string) - registeredMigrations = make([]*goose.Migration, 0) + registeredMigrations = make([]*Migration, 0) ) -// GoMigrationContext is a Go migration func that is run within a transaction and receives a +type Migration struct { + version int64 + source string + upFnContext, downFnContext MigrationContext +} + +// MigrationContext is a Go migration func that is run within a transaction and receives a // context. -type GoMigrationContext func(ctx *schema.Context) error +type MigrationContext func(ctx *schema.Context) error -func (m GoMigrationContext) RunTxFunc(filename string) func(ctx context.Context, tx *sql.Tx) error { +func (m MigrationContext) runTxFunc(source string) func(ctx context.Context, tx *sql.Tx) error { return func(ctx context.Context, tx *sql.Tx) error { - filename = path.Base(filename) + filename := path.Base(source) c := schema.NewContext(ctx, tx, schema.WithFilename(filename)) return m(c) } } // AddMigrationContext adds Go migrations. -func AddMigrationContext(up, down GoMigrationContext) { +func AddMigrationContext(up, down MigrationContext) { _, filename, _, _ := runtime.Caller(1) AddNamedMigrationContext(filename, up, down) } // AddNamedMigrationContext adds named Go migrations. -func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { +func AddNamedMigrationContext(source string, up, down MigrationContext) { if err := register( - filename, - true, - &goose.GoFunc{RunTx: up.RunTxFunc(filename), Mode: goose.TransactionEnabled}, - &goose.GoFunc{RunTx: down.RunTxFunc(filename), Mode: goose.TransactionEnabled}, + source, + up, + down, ); err != nil { panic(err) } } -func register(filename string, useTx bool, up, down *goose.GoFunc) error { - v, _ := goose.NumericComponent(filename) +func register(source string, up, down MigrationContext) error { + v, _ := goose.NumericComponent(source) if existing, ok := registeredVersions[v]; ok { return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", - filename, + source, v, existing, ) } // Add to global as a registered migration. - m := goose.NewGoMigration(v, up, down) - m.Source = filename - // We explicitly set transaction to maintain existing behavior. Both up and down may be nil, but - // we know based on the register function what the user is requesting. - m.UseTx = useTx - registeredVersions[v] = filename + m := &Migration{ + version: v, + source: source, + upFnContext: up, + downFnContext: down, + } + registeredVersions[v] = source registeredMigrations = append(registeredMigrations, m) return nil } + +func gooseMigrations() []*goose.Migration { + migrations := make([]*goose.Migration, 0, len(registeredMigrations)) + for _, m := range registeredMigrations { + upFunc := &goose.GoFunc{ + RunTx: m.upFnContext.runTxFunc(m.source), + Mode: goose.TransactionEnabled, + } + downFunc := &goose.GoFunc{ + RunTx: m.downFnContext.runTxFunc(m.source), + Mode: goose.TransactionEnabled, + } + gm := goose.NewGoMigration(m.version, upFunc, downFunc) + migrations = append(migrations, gm) + } + return migrations +} diff --git a/reset.go b/reset.go index 580b134..65346ab 100644 --- a/reset.go +++ b/reset.go @@ -16,7 +16,7 @@ func (m *Migrate) Reset() error { // ResetContext rolls back all migrations. func (m *Migrate) ResetContext(ctx context.Context) error { - provider, err := newProvider(m.db, m.dir) + provider, err := m.newProvider() if err != nil { return err } diff --git a/schema/schema.go b/schema/schema.go index c818dc6..e5ab707 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -41,12 +41,12 @@ type TableInfo struct { } func newBuilder() (Builder, error) { - cfg := config.Get() - if cfg.Dialect == dialect.Unknown { + dialectVal := config.GetDialect() + if dialectVal == dialect.Unknown { return nil, errors.New("schema dialect is not set, please call schema.SetDialect() before using schema functions") } - builder, err := NewBuilder(cfg.Dialect.String()) + builder, err := NewBuilder(dialectVal.String()) if err != nil { return nil, err } diff --git a/schema/schema_test.go b/schema/schema_test.go index 4f705ca..74d9f88 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -6,7 +6,8 @@ import ( "fmt" "testing" - "github.com/afkdevs/migris" + "github.com/afkdevs/migris/internal/config" + "github.com/afkdevs/migris/internal/dialect" "github.com/afkdevs/migris/schema" "github.com/stretchr/testify/suite" ) @@ -22,6 +23,7 @@ type schemaTestSuite struct { } func (s *schemaTestSuite) SetupSuite() { + config.SetDialect(dialect.Postgres) ctx := context.Background() s.ctx = ctx @@ -34,10 +36,7 @@ func (s *schemaTestSuite) SetupSuite() { err = db.Ping() s.Require().NoError(err) - s.db = db - err = migris.SetDialect("postgres") - s.Require().NoError(err) } func (s *schemaTestSuite) TearDownSuite() { diff --git a/status.go b/status.go index ad1ec36..8f30626 100644 --- a/status.go +++ b/status.go @@ -14,7 +14,7 @@ func (m *Migrate) Status() error { // StatusContext returns the status of the migrations. func (m *Migrate) StatusContext(ctx context.Context) error { - provider, err := newProvider(m.db, m.dir) + provider, err := m.newProvider() if err != nil { return err } diff --git a/up.go b/up.go index 3508309..684dc37 100644 --- a/up.go +++ b/up.go @@ -16,7 +16,18 @@ func (m *Migrate) Up() error { // UpContext applies the migrations in the specified directory. func (m *Migrate) UpContext(ctx context.Context) error { - provider, err := newProvider(m.db, m.dir) + return m.UpToContext(ctx, goose.MaxVersion) +} + +// UpTo applies the migrations up to the specified version. +func (m *Migrate) UpTo(version int64) error { + ctx := context.Background() + return m.UpToContext(ctx, version) +} + +// UpToContext applies the migrations up to the specified version. +func (m *Migrate) UpToContext(ctx context.Context, version int64) error { + provider, err := m.newProvider() if err != nil { return err } @@ -28,9 +39,9 @@ func (m *Migrate) UpContext(ctx context.Context) error { logger.Info("Nothing to migrate.") return nil } - logger.Infof("Running migrations.\n") - results, err := provider.Up(ctx) + logger.Infof("Running migrations.\n") + results, err := provider.UpTo(ctx, version) if err != nil { var partialErr *goose.PartialError if errors.As(err, &partialErr) { @@ -41,5 +52,6 @@ func (m *Migrate) UpContext(ctx context.Context) error { return err } logger.PrintResults(results) + return nil } From 30c8e6f363b39657d3c49e291f795f25db125645 Mon Sep 17 00:00:00 2001 From: Ahmad Faiz Kamaludin Date: Sun, 31 Aug 2025 04:24:30 +0700 Subject: [PATCH 7/7] feat: update package name --- README.md | 17 ++++++---------- create.go | 16 +++++++-------- down.go | 2 +- examples/basic/cmd/migrate/migrate.go | 6 +++--- examples/basic/cmd/root.go | 2 +- examples/basic/go.mod | 6 +++--- examples/basic/main.go | 2 +- .../20250830103612_create_users_table.go | 4 ++-- .../20250830103653_create_roles_table.go | 4 ++-- .../20250830103714_create_user_roles_table.go | 4 ++-- go.mod | 2 +- internal/config/config.go | 2 +- internal/parser/migration_test.go | 2 +- migrate.go | 20 +++++++++---------- options.go | 6 +++--- register.go | 2 +- reset.go | 2 +- schema/blueprint.go | 4 ++-- schema/builder.go | 2 +- schema/column_definition.go | 2 +- schema/foreign_key_definition.go | 2 +- schema/grammar.go | 2 +- schema/index_definition.go | 2 +- schema/mysql_builder_test.go | 2 +- schema/mysql_grammar.go | 2 +- schema/mysql_grammar_test.go | 2 +- schema/postgres_builder_test.go | 2 +- schema/schema.go | 4 ++-- schema/schema_test.go | 6 +++--- status.go | 2 +- up.go | 2 +- 31 files changed, 65 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index e2a3acc..a725956 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ It combines the power of [pressly/goose](https://github.com/pressly/goose) with ## πŸš€ Installation ```bash -go get github.com/afkdevs/migris +go get -u github.com/akfaiz/migris ``` ## πŸ“š Usage @@ -27,8 +27,8 @@ Migrations are defined in Go files using the schema builder: package migrations import ( - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" ) func init() { @@ -64,9 +64,8 @@ import ( "database/sql" "fmt" - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/examples/basic/config" - _ "github.com/afkdevs/migris/examples/basic/migrations" + "github.com/akfaiz/migris" + _ "migrations" // Import migrations _ "github.com/lib/pq" // PostgreSQL driver ) @@ -111,21 +110,17 @@ func Status() error { } func newMigrate() (*migris.Migrate, error) { - if err := migris.SetDialect("postgres"); err != nil { - return nil, fmt.Errorf("failed to set schema dialect: %w", err) - } dsn := "postgres://user:pass@localhost:5432/mydb?sslmode=disable" db, err := sql.Open("postgres", dsn) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } - return migris.New("postgres", migris.WithDB(db), migris.WithMigrationPath("migrations")), nil + return migris.New("postgres", migris.WithDB(db), migris.WithMigrationDir("migrations")), nil } ``` ## πŸ“– Roadmap -- [ ] Add dry-run mode - [ ] Add SQLite support - [ ] CLI wrapper for quick usage diff --git a/create.go b/create.go index 39a3e44..f348cb4 100644 --- a/create.go +++ b/create.go @@ -3,14 +3,14 @@ package migris import ( "text/template" - "github.com/afkdevs/migris/internal/parser" + "github.com/akfaiz/migris/internal/parser" "github.com/pressly/goose/v3" ) // Create creates a new migration file with the given name in the specified directory. func (m *Migrate) Create(name string) error { tmpl := getMigrationTemplate(name) - return goose.CreateWithTemplate(nil, m.migrationPath, tmpl, name, "go") + return goose.CreateWithTemplate(nil, m.migrationDir, tmpl, name, "go") } func getMigrationTemplate(name string) *template.Template { @@ -27,8 +27,8 @@ func getMigrationTemplate(name string) *template.Template { var migrationTemplate = template.Must(template.New("migrator.go-migration").Parse(`package migrations import ( - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" ) func init() { @@ -50,8 +50,8 @@ func migrationCreateTemplate(table string) *template.Template { tmpl := `package migrations import ( - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" ) func init() { @@ -75,8 +75,8 @@ func migrationUpdateTemplate(table string) *template.Template { tmpl := `package migrations import ( - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" ) func init() { diff --git a/down.go b/down.go index 5a4b437..9140a19 100644 --- a/down.go +++ b/down.go @@ -4,7 +4,7 @@ import ( "context" "errors" - "github.com/afkdevs/migris/internal/logger" + "github.com/akfaiz/migris/internal/logger" "github.com/pressly/goose/v3" ) diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index 8cc107c..7981229 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -4,9 +4,9 @@ import ( "database/sql" "fmt" - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/examples/basic/config" - _ "github.com/afkdevs/migris/examples/basic/migrations" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/examples/basic/config" + _ "github.com/akfaiz/migris/examples/basic/migrations" _ "github.com/lib/pq" // PostgreSQL driver ) diff --git a/examples/basic/cmd/root.go b/examples/basic/cmd/root.go index 7f6846a..ee336c3 100644 --- a/examples/basic/cmd/root.go +++ b/examples/basic/cmd/root.go @@ -3,7 +3,7 @@ package cmd import ( "context" - "github.com/afkdevs/migris/examples/basic/cmd/migrate" + "github.com/akfaiz/migris/examples/basic/cmd/migrate" "github.com/urfave/cli/v3" ) diff --git a/examples/basic/go.mod b/examples/basic/go.mod index ded60c3..6d2324f 100644 --- a/examples/basic/go.mod +++ b/examples/basic/go.mod @@ -1,9 +1,9 @@ -module github.com/afkdevs/migris/examples/basic +module github.com/akfaiz/migris/examples/basic go 1.23.0 require ( - github.com/afkdevs/migris v0.0.0 + github.com/akfaiz/migris v0.0.0 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 github.com/urfave/cli/v3 v3.3.8 @@ -22,4 +22,4 @@ require ( golang.org/x/term v0.34.0 // indirect ) -replace github.com/afkdevs/migris => ../.. +replace github.com/akfaiz/migris => ../.. diff --git a/examples/basic/main.go b/examples/basic/main.go index ac85745..050e9ea 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -4,7 +4,7 @@ import ( "log" "os" - "github.com/afkdevs/migris/examples/basic/cmd" + "github.com/akfaiz/migris/examples/basic/cmd" ) func main() { diff --git a/examples/basic/migrations/20250830103612_create_users_table.go b/examples/basic/migrations/20250830103612_create_users_table.go index 640a37f..158e385 100644 --- a/examples/basic/migrations/20250830103612_create_users_table.go +++ b/examples/basic/migrations/20250830103612_create_users_table.go @@ -1,8 +1,8 @@ package migrations import ( - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" ) func init() { diff --git a/examples/basic/migrations/20250830103653_create_roles_table.go b/examples/basic/migrations/20250830103653_create_roles_table.go index 99d5b79..1e961ce 100644 --- a/examples/basic/migrations/20250830103653_create_roles_table.go +++ b/examples/basic/migrations/20250830103653_create_roles_table.go @@ -1,8 +1,8 @@ package migrations import ( - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" ) func init() { diff --git a/examples/basic/migrations/20250830103714_create_user_roles_table.go b/examples/basic/migrations/20250830103714_create_user_roles_table.go index 7f5fbb0..764a76b 100644 --- a/examples/basic/migrations/20250830103714_create_user_roles_table.go +++ b/examples/basic/migrations/20250830103714_create_user_roles_table.go @@ -1,8 +1,8 @@ package migrations import ( - "github.com/afkdevs/migris" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" ) func init() { diff --git a/go.mod b/go.mod index 6bab72b..241fe24 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/afkdevs/migris +module github.com/akfaiz/migris go 1.23.0 diff --git a/internal/config/config.go b/internal/config/config.go index cf3511d..856cfd6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,7 +3,7 @@ package config import ( "sync/atomic" - "github.com/afkdevs/migris/internal/dialect" + "github.com/akfaiz/migris/internal/dialect" ) type Config struct { diff --git a/internal/parser/migration_test.go b/internal/parser/migration_test.go index f7c7b99..a8c2de5 100644 --- a/internal/parser/migration_test.go +++ b/internal/parser/migration_test.go @@ -3,7 +3,7 @@ package parser_test import ( "testing" - "github.com/afkdevs/migris/internal/parser" + "github.com/akfaiz/migris/internal/parser" ) func TestParseMigrationName(t *testing.T) { diff --git a/migrate.go b/migrate.go index 5991261..581564c 100644 --- a/migrate.go +++ b/migrate.go @@ -5,18 +5,18 @@ import ( "errors" "os" - "github.com/afkdevs/migris/internal/config" - "github.com/afkdevs/migris/internal/dialect" + "github.com/akfaiz/migris/internal/config" + "github.com/akfaiz/migris/internal/dialect" "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/database" ) // Migrate handles database migrations type Migrate struct { - dialect dialect.Dialect - migrationPath string - db *sql.DB - tableName string + dialect dialect.Dialect + db *sql.DB + migrationDir string + tableName string } // New creates a new Migrate instance @@ -28,9 +28,9 @@ func New(dialectValue string, opts ...Option) (*Migrate, error) { config.SetDialect(dialectVal) m := &Migrate{ - dialect: dialectVal, - migrationPath: "migrations", - tableName: "migris_db_version", + dialect: dialectVal, + migrationDir: "migrations", + tableName: "migris_db_version", } for _, opt := range opts { opt(m) @@ -51,7 +51,7 @@ func (m *Migrate) newProvider() (*goose.Provider, error) { if err != nil { return nil, err } - provider, err := goose.NewProvider(database.DialectCustom, m.db, os.DirFS(m.migrationPath), + provider, err := goose.NewProvider(database.DialectCustom, m.db, os.DirFS(m.migrationDir), goose.WithStore(store), goose.WithDisableGlobalRegistry(true), goose.WithGoMigrations(gooseMigrations()...), diff --git a/options.go b/options.go index bfaeefa..8da4f71 100644 --- a/options.go +++ b/options.go @@ -11,10 +11,10 @@ func WithTableName(name string) Option { } } -// WithMigrationPath sets the directory for the migration files. -func WithMigrationPath(path string) Option { +// WithMigrationDir sets the directory for the migration files. +func WithMigrationDir(dir string) Option { return func(m *Migrate) { - m.migrationPath = path + m.migrationDir = dir } } diff --git a/register.go b/register.go index 514fc6d..96c148d 100644 --- a/register.go +++ b/register.go @@ -7,7 +7,7 @@ import ( "path" "runtime" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris/schema" "github.com/pressly/goose/v3" ) diff --git a/reset.go b/reset.go index 65346ab..124bcb5 100644 --- a/reset.go +++ b/reset.go @@ -4,7 +4,7 @@ import ( "context" "errors" - "github.com/afkdevs/migris/internal/logger" + "github.com/akfaiz/migris/internal/logger" "github.com/pressly/goose/v3" ) diff --git a/schema/blueprint.go b/schema/blueprint.go index 069bc4a..b0d6e55 100644 --- a/schema/blueprint.go +++ b/schema/blueprint.go @@ -3,8 +3,8 @@ package schema import ( "fmt" - "github.com/afkdevs/migris/internal/dialect" - "github.com/afkdevs/migris/internal/util" + "github.com/akfaiz/migris/internal/dialect" + "github.com/akfaiz/migris/internal/util" ) const ( diff --git a/schema/builder.go b/schema/builder.go index 6e3cc2c..35f56c0 100644 --- a/schema/builder.go +++ b/schema/builder.go @@ -3,7 +3,7 @@ package schema import ( "errors" - "github.com/afkdevs/migris/internal/dialect" + "github.com/akfaiz/migris/internal/dialect" ) // Builder is an interface that defines methods for creating, dropping, and managing database tables. diff --git a/schema/column_definition.go b/schema/column_definition.go index 4c4f006..639ed1a 100644 --- a/schema/column_definition.go +++ b/schema/column_definition.go @@ -3,7 +3,7 @@ package schema import ( "slices" - "github.com/afkdevs/migris/internal/util" + "github.com/akfaiz/migris/internal/util" ) // ColumnDefinition defines the interface for defining a column in a database table. diff --git a/schema/foreign_key_definition.go b/schema/foreign_key_definition.go index cd86613..49ad922 100644 --- a/schema/foreign_key_definition.go +++ b/schema/foreign_key_definition.go @@ -1,6 +1,6 @@ package schema -import "github.com/afkdevs/migris/internal/util" +import "github.com/akfaiz/migris/internal/util" // ForeignKeyDefinition defines the interface for defining a foreign key constraint in a database table. type ForeignKeyDefinition interface { diff --git a/schema/grammar.go b/schema/grammar.go index 45ab2e3..78d7c3a 100644 --- a/schema/grammar.go +++ b/schema/grammar.go @@ -5,7 +5,7 @@ import ( "slices" "strings" - "github.com/afkdevs/migris/internal/util" + "github.com/akfaiz/migris/internal/util" ) type grammar interface { diff --git a/schema/index_definition.go b/schema/index_definition.go index 6aa5b67..cbcb14d 100644 --- a/schema/index_definition.go +++ b/schema/index_definition.go @@ -1,6 +1,6 @@ package schema -import "github.com/afkdevs/migris/internal/util" +import "github.com/akfaiz/migris/internal/util" // IndexDefinition defines the interface for defining an index in a database table. type IndexDefinition interface { diff --git a/schema/mysql_builder_test.go b/schema/mysql_builder_test.go index 6f6250a..0d8a643 100644 --- a/schema/mysql_builder_test.go +++ b/schema/mysql_builder_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris/schema" _ "github.com/go-sql-driver/mysql" // MySQL driver "github.com/stretchr/testify/suite" ) diff --git a/schema/mysql_grammar.go b/schema/mysql_grammar.go index 4c9d49d..eccbf12 100644 --- a/schema/mysql_grammar.go +++ b/schema/mysql_grammar.go @@ -5,7 +5,7 @@ import ( "slices" "strings" - "github.com/afkdevs/migris/internal/util" + "github.com/akfaiz/migris/internal/util" ) type mysqlGrammar struct { diff --git a/schema/mysql_grammar_test.go b/schema/mysql_grammar_test.go index a50e33d..91534f2 100644 --- a/schema/mysql_grammar_test.go +++ b/schema/mysql_grammar_test.go @@ -3,7 +3,7 @@ package schema import ( "testing" - "github.com/afkdevs/migris/internal/dialect" + "github.com/akfaiz/migris/internal/dialect" "github.com/stretchr/testify/assert" ) diff --git a/schema/postgres_builder_test.go b/schema/postgres_builder_test.go index 9ad049d..2b5b0b5 100644 --- a/schema/postgres_builder_test.go +++ b/schema/postgres_builder_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris/schema" _ "github.com/lib/pq" "github.com/stretchr/testify/suite" ) diff --git a/schema/schema.go b/schema/schema.go index e5ab707..edc0350 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -4,8 +4,8 @@ import ( "database/sql" "errors" - "github.com/afkdevs/migris/internal/config" - "github.com/afkdevs/migris/internal/dialect" + "github.com/akfaiz/migris/internal/config" + "github.com/akfaiz/migris/internal/dialect" ) // Column represents a database column with its properties. diff --git a/schema/schema_test.go b/schema/schema_test.go index 74d9f88..399181c 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -6,9 +6,9 @@ import ( "fmt" "testing" - "github.com/afkdevs/migris/internal/config" - "github.com/afkdevs/migris/internal/dialect" - "github.com/afkdevs/migris/schema" + "github.com/akfaiz/migris/internal/config" + "github.com/akfaiz/migris/internal/dialect" + "github.com/akfaiz/migris/schema" "github.com/stretchr/testify/suite" ) diff --git a/status.go b/status.go index 8f30626..7fa50d9 100644 --- a/status.go +++ b/status.go @@ -3,7 +3,7 @@ package migris import ( "context" - "github.com/afkdevs/migris/internal/logger" + "github.com/akfaiz/migris/internal/logger" ) // Status returns the status of the migrations. diff --git a/up.go b/up.go index 684dc37..0c230e7 100644 --- a/up.go +++ b/up.go @@ -4,7 +4,7 @@ import ( "context" "errors" - "github.com/afkdevs/migris/internal/logger" + "github.com/akfaiz/migris/internal/logger" "github.com/pressly/goose/v3" )