diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b13bcce..1e76341 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# Contributing to \[Schema] +# Contributing to \[Migris] Thank you for considering contributing to this Go library! We welcome contributions of all kindsβ€”bug reports, feature requests, documentation updates, tests, and code improvements. diff --git a/Makefile b/Makefile index 3370ac2..2dc2199 100644 --- a/Makefile +++ b/Makefile @@ -1,49 +1,54 @@ COVERAGE_FILE := coverage.out COVERAGE_HTML := coverage.html -MIN_COVERAGE := 80 FORMAT ?= dots -# Format code .PHONY: fmt -fmt: +fmt: # Format code @echo "Formatting code..." @go fmt ./... -# Lint code .PHONY: lint -lint: +lint: # Lint code @echo "Linting code..." @golangci-lint run --timeout 5m -# Install dependencies and tools .PHONY: install -install: +install: # Install dependencies @echo "Installing dependencies..." @go mod download @go mod tidy -# Run tests +install-tools: # Install tools + @echo "Installing tools..." + @go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest + @go install gotest.tools/gotestsum@latest + .PHONY: test -test: +test: # Run tests @echo "Running tests..." - go test -v -cover -race ./... -coverprofile=$(COVERAGE_FILE) -coverpkg=./... + @gotestsum --format=$(FORMAT) -- ./... + +.PHONY: testcov +testcov: # Run tests with coverage + @echo "Running tests with coverage..." + @gotestsum --format=$(FORMAT) -- -coverprofile=$(COVERAGE_FILE) ./... + @echo "Total coverage is: " $(shell go tool cover -func=$(COVERAGE_FILE) | grep total | awk '{print $$3}') -.PHONY: coverage -coverage: - @echo "Generating test coverage report..." +.PHONY: testcov-html +testcov-html: testcov # Generate coverage HTML report @go tool cover -html=$(COVERAGE_FILE) -o $(COVERAGE_HTML) - @go tool cover -func=$(COVERAGE_FILE) | tee coverage.txt @echo "Coverage HTML report generated: $(COVERAGE_HTML)" @open $(COVERAGE_HTML) -.PHONY: coverage-check -coverage-check: - @COVERAGE=$$(go tool cover -func=$(COVERAGE_FILE) | grep total | awk '{print $$3}' | sed 's/%//'); \ - RESULT=$$(echo "$$COVERAGE < $(MIN_COVERAGE)" | bc); \ - if [ "$$RESULT" -eq 1 ]; then \ - echo "Coverage is below $(MIN_COVERAGE)%: $$COVERAGE%"; \ - exit 1; \ - else \ - echo "Coverage is sufficient: $$COVERAGE%"; \ - fi \ No newline at end of file +.PHONY: help +help: + @echo "Available commands:" + @echo " fmt - Format code" + @echo " lint - Lint code" + @echo " install - Install dependencies" + @echo " install-tools- Install development tools" + @echo " test - Run tests" + @echo " testcov - Run tests with coverage" + @echo " testcov-html - Generate coverage HTML report" + @echo " help - Show this help message" \ No newline at end of file diff --git a/README.md b/README.md index b6c800f..a725956 100644 --- a/README.md +++ b/README.md @@ -1,74 +1,130 @@ -# Go-Schema -[![Go](https://github.com/afkdevs/go-schema/actions/workflows/ci.yml/badge.svg)](https://github.com/afkdevs/go-schema/actions/workflows/ci.yml) -[![Go Report Card](https://goreportcard.com/badge/github.com/afkdevs/go-schema)](https://goreportcard.com/report/github.com/afkdevs/go-schema) -[![codecov](https://codecov.io/gh/afkdevs/go-schema/graph/badge.svg?token=7tbSVRaD4b)](https://codecov.io/gh/afkdevs/go-schema) -[![GoDoc](https://pkg.go.dev/badge/github.com/afkdevs/go-schema)](https://pkg.go.dev/github.com/afkdevs/go-schema) -[![Go Version](https://img.shields.io/github/go-mod/go-version/afkdevs/go-schema)](https://golang.org/doc/devel/release.html) -[![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) +# Migris -`Go-Schema` is a simple Go library for building and running SQL schema (DDL) code in a clean, readable, and migration-friendly way. Inspired by Laravel's Schema Builder, it helps you easily create or change database tablesβ€”and works well with tools like [`goose`](https://github.com/pressly/goose). +**Migris** is a database migration library for Go, inspired by Laravel's migrations. +It combines the power of [pressly/goose](https://github.com/pressly/goose) with a fluent schema builder, making migrations easy to write, run, and maintain. -## Features +## ✨ Features -- πŸ“Š Programmatic table and column definitions -- πŸ—ƒοΈ Support for common data types and constraints -- βš™οΈ Auto-generates `CREATE TABLE`, `ALTER TABLE`, index and foreign key SQL -- πŸ”€ Designed to work with database transactions -- πŸ§ͺ Built-in types and functions make migration code clear and testable -- πŸ” Provides helper functions to get list tables, columns, and indexes +- πŸ“¦ Migration management (`up`, `down`, `reset`, `status`, `create`) +- πŸ—οΈ Fluent schema builder (similar to Laravel migrations) +- πŸ—„οΈ Supports PostgreSQL, MySQL, and MariaDB +- πŸ”„ Transaction-based migrations +- πŸ› οΈ Integration with Go projects (no external CLI required) -## Supported Databases +## πŸš€ Installation -Currently, `schema` is tested and optimized for: +```bash +go get -u github.com/akfaiz/migris +``` -* PostgreSQL -* MySQL / MariaDB -* SQLite (TODO) +## πŸ“š Usage -## Installation +### 1. Create a Migration -```bash -go get github.com/afkdevs/go-schema -``` +Migrations are defined in Go files using the schema builder: -## Integration Example (with goose) ```go package migrations import ( - "context" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" +) + +func init() { + migris.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) +} + +func upCreateUsersTable(c *schema.Context) error { + return schema.Create(c, "users", func(table *schema.Blueprint) { + table.ID() + table.String("name") + table.String("email") + table.Timestamp("email_verified_at").Nullable() + table.String("password") + table.Timestamps() + }) +} + +func downCreateUsersTable(c *schema.Context) error { + return schema.DropIfExists(c, "users") +} +``` + +This creates a `users` table with common fields. + +### 2. Run Migrations + +You can manage migrations directly from Go code: + +```go +package migrate + +import ( "database/sql" + "fmt" - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" + "github.com/akfaiz/migris" + _ "migrations" // Import migrations + _ "github.com/lib/pq" // PostgreSQL driver ) -func init() { - goose.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) +func Up() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Up() } -func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email") - table.Timestamp("email_verified_at").Nullable() - table.String("password") - table.Timestamps() - }) +func Create(name string) error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Create(name) } -func downCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.Drop(ctx, tx, "users") +func Reset() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Reset() +} + +func Down() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Down() +} + +func Status() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Status() +} + +func newMigrate() (*migris.Migrate, error) { + dsn := "postgres://user:pass@localhost:5432/mydb?sslmode=disable" + db, err := sql.Open("postgres", dsn) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + return migris.New("postgres", migris.WithDB(db), migris.WithMigrationDir("migrations")), nil } ``` -For more examples, check out the [examples](examples/basic) directory. -## Documentation -For detailed documentation, please refer to the [GoDoc](https://pkg.go.dev/github.com/afkdevs/go-schema) page. +## πŸ“– Roadmap + +- [ ] Add SQLite support +- [ ] CLI wrapper for quick usage -## Contributing -Contributions are welcome! Please read the [contributing guidelines](CONTRIBUTING.md) and submit a pull request. +## πŸ“„ License -## License -This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. \ No newline at end of file +MIT License. +See [LICENSE](./LICENSE) for details. \ No newline at end of file diff --git a/blueprint.go b/blueprint.go deleted file mode 100644 index 0a3f7d6..0000000 --- a/blueprint.go +++ /dev/null @@ -1,887 +0,0 @@ -package schema - -import ( - "fmt" - "slices" -) - -type columnType uint8 - -const ( - columnTypeBoolean columnType = iota - columnTypeChar - columnTypeString - columnTypeLongText - columnTypeMediumText - columnTypeText - columnTypeTinyText - columnTypeBigInteger - columnTypeInteger - columnTypeMediumInteger - columnTypeSmallInteger - columnTypeTinyInteger - columnTypeDecimal - columnTypeDouble - columnTypeFloat - columnTypeDateTime - columnTypeDateTimeTz - columnTypeDate - columnTypeTime - columnTypeTimestamp - columnTypeTimestampTz - columnTypeYear - columnTypeBinary - columnTypeJSON - columnTypeJSONB - columnTypeGeography - columnTypeGeometry - columnTypePoint - columnTypeUUID - columnTypeEnum - columnTypeCustom // Custom type for user-defined types -) - -type indexType int - -const ( - indexTypeIndex indexType = iota - indexTypeUnique - indexTypePrimary - indexTypeFullText -) - -// Blueprint represents a schema blueprint for creating or altering a database table. -type Blueprint struct { - commands []string // commands to be executed - name string - newName string - charset string - collation string - engine string - columns []*columnDefinition - indexes []*indexDefinition - foreignKeys []*foreignKeyDefinition - dropColumns []string - renameColumns map[string]string // old column name to new column name - dropIndexes []string // indexes to be dropped - dropForeignKeys []string // foreign keys to be dropped - dropPrimaryKeys []string // primary keys to be dropped - dropUniqueKeys []string // unique keys to be dropped - dropFullText []string // fulltext indexes to be dropped - renameIndexes map[string]string // old index name to new index name -} - -// Charset sets the character set for the table in the blueprint. -func (b *Blueprint) Charset(charset string) { - b.charset = charset -} - -// Collation sets the collation for the table in the blueprint. -func (b *Blueprint) Collation(collation string) { - b.collation = collation -} - -// Engine sets the storage engine for the table in the blueprint. -func (b *Blueprint) Engine(engine string) { - b.engine = engine -} - -// Column creates a new custom column definition in the blueprint with the specified name and type. -func (b *Blueprint) Column(name string, columnType string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeCustom, - customColumnType: columnType, - }) -} - -// Boolean creates a new boolean column definition in the blueprint. -func (b *Blueprint) Boolean(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBoolean, - }) -} - -// Char creates a new char column definition in the blueprint. -func (b *Blueprint) Char(name string, length ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeChar, - length: optional(0, length...), - }) -} - -// String creates a new string column definition in the blueprint. -func (b *Blueprint) String(name string, length ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeString, - length: optional(0, length...), - }) -} - -// LongText creates a new long text column definition in the blueprint. -func (b *Blueprint) LongText(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeLongText, - }) -} - -// Text creates a new text column definition in the blueprint. -func (b *Blueprint) Text(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeText, - }) -} - -// MediumText creates a new medium text column definition in the blueprint. -func (b *Blueprint) MediumText(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeMediumText, - }) -} - -// TinyText creates a new tiny text column definition in the blueprint. -func (b *Blueprint) TinyText(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTinyText, - }) -} - -// BigIncrements creates a new big increments column definition in the blueprint. -func (b *Blueprint) BigIncrements(name string) ColumnDefinition { - return b.UnsignedBigInteger(name, true) -} - -// BigInteger creates a new big integer column definition in the blueprint. -func (b *Blueprint) BigInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBigInteger, - autoIncrement: optional(false, autoIncrement...), - }) -} - -// Decimal creates a new decimal column definition in the blueprint. -// -// The total and places parameters are optional. -// -// Example: -// -// table.Decimal("price", 10, 2) // creates a decimal column with total 10 and places 2 -// -// table.Decimal("price") // creates a decimal column with default total 8 and places 2 -func (b *Blueprint) Decimal(name string, params ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDecimal, - total: optional(8, params...), - places: optionalAtIndex(1, 2, params...), - }) -} - -// Double creates a new double column definition in the blueprint. -func (b *Blueprint) Double(name string, params ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDouble, - total: optional(0, params...), - places: optionalAtIndex(1, 0, params...), - }) -} - -// Float creates a new float column definition in the blueprint. -// -// The total and places parameters are optional. -// -// Example: -// -// table.Float("price", 10, 2) // creates a float column with total 10 and places 2 -// -// table.Float("price") // creates a float column with default total 8 and places 2 -func (b *Blueprint) Float(name string, params ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeFloat, - total: optional(8, params...), - places: optionalAtIndex(1, 2, params...), - }) -} - -// ID creates a new big increments column definition in the blueprint with the name "id" or a custom name. -// -// If a name is provided, it will be used as the column name; otherwise, "id" will be used. -func (b *Blueprint) ID(name ...string) ColumnDefinition { - return b.BigIncrements(optional("id", name...)).Primary() -} - -// Increments create a new increment column definition in the blueprint. -func (b *Blueprint) Increments(name string) ColumnDefinition { - return b.UnsignedInteger(name, true) -} - -// Integer creates a new integer column definition in the blueprint. -func (b *Blueprint) Integer(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeInteger, - autoIncrement: optional(false, autoIncrement...), - }) -} - -// MediumIncrements creates a new medium increments column definition in the blueprint. -func (b *Blueprint) MediumIncrements(name string) ColumnDefinition { - return b.UnsignedMediumInteger(name, true) -} - -func (b *Blueprint) MediumInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeMediumInteger, - autoIncrement: optional(false, autoIncrement...), - }) -} - -// SmallIncrements creates a new small increments column definition in the blueprint. -func (b *Blueprint) SmallIncrements(name string) ColumnDefinition { - return b.UnsignedSmallInteger(name, true) -} - -// SmallInteger creates a new small integer column definition in the blueprint. -func (b *Blueprint) SmallInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeSmallInteger, - autoIncrement: optional(false, autoIncrement...), - }) -} - -// TinyIncrements creates a new tiny increments column definition in the blueprint. -func (b *Blueprint) TinyIncrements(name string) ColumnDefinition { - return b.UnsignedTinyInteger(name, true) -} - -// TinyInteger creates a new tiny integer column definition in the blueprint. -func (b *Blueprint) TinyInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTinyInteger, - autoIncrement: optional(false, autoIncrement...), - }) -} - -// UnsignedBigInteger creates a new unsigned big integer column definition in the blueprint. -func (b *Blueprint) UnsignedBigInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBigInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) -} - -// UnsignedInteger creates a new unsigned integer column definition in the blueprint. -func (b *Blueprint) UnsignedInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) -} - -// UnsignedMediumInteger creates a new unsigned medium integer column definition in the blueprint. -func (b *Blueprint) UnsignedMediumInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeMediumInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) -} - -// UnsignedSmallInteger creates a new unsigned small integer column definition in the blueprint. -func (b *Blueprint) UnsignedSmallInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeSmallInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) -} - -// UnsignedTinyInteger creates a new unsigned tiny integer column definition in the blueprint. -func (b *Blueprint) UnsignedTinyInteger(name string, autoIncrement ...bool) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTinyInteger, - autoIncrement: optional(false, autoIncrement...), - unsigned: true, - }) -} - -// DateTime creates a new date time column definition in the blueprint. -func (b *Blueprint) DateTime(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDateTime, - precision: optional(0, precision...), - }) -} - -// DateTimeTz creates a new date time with a time zone column definition in the blueprint. -func (b *Blueprint) DateTimeTz(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDateTimeTz, - precision: optional(0, precision...), - }) -} - -// Date creates a new date column definition in the blueprint. -func (b *Blueprint) Date(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeDate, - }) -} - -// Time creates a new time column definition in the blueprint. -func (b *Blueprint) Time(name string, precission ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTime, - precision: optional(0, precission...), - }) -} - -// Timestamp creates a new timestamp column definition in the blueprint. -// The precision parameter is optional and defaults to 0 if not provided. -func (b *Blueprint) Timestamp(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTimestamp, - precision: optional(0, precision...), - }) -} - -// TimestampTz creates a new timestamp with time zone column definition in the blueprint. -// The precision parameter is optional and defaults to 0 if not provided. -func (b *Blueprint) TimestampTz(name string, precision ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeTimestampTz, - precision: optional(0, precision...), - }) -} - -// Timestamps adds created_at and updated_at timestamp columns to the blueprint. -func (b *Blueprint) Timestamps() { - b.Timestamp("created_at").Nullable(false).UseCurrent() - b.Timestamp("updated_at").Nullable(false).UseCurrent().UseCurrentOnUpdate() -} - -// TimestampsTz adds created_at and updated_at timestamp with time zone columns to the blueprint. -func (b *Blueprint) TimestampsTz() { - b.TimestampTz("created_at").Nullable(false).UseCurrent() - b.TimestampTz("updated_at").Nullable(false).UseCurrent().UseCurrentOnUpdate() -} - -// Year creates a new year column definition in the blueprint. -func (b *Blueprint) Year(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeYear, - }) -} - -// Binary creates a new binary column definition in the blueprint. -func (b *Blueprint) Binary(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeBinary, - }) -} - -// JSON creates a new JSON column definition in the blueprint. -func (b *Blueprint) JSON(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeJSON, - }) -} - -// JSONB creates a new JSONB column definition in the blueprint. -func (b *Blueprint) JSONB(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeJSONB, - }) -} - -// UUID creates a new UUID column definition in the blueprint. -func (b *Blueprint) UUID(name string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeUUID, - }) -} - -// Geography creates a new geography column definition in the blueprint. -// The subType parameter is optional and can be used to specify the type of geography (e.g., "Point", "LineString", "Polygon"). -// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geography type. -func (b *Blueprint) Geography(name string, subType string, srid ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeGeography, - subType: subType, - srid: optional(4326, srid...), - }) -} - -// Geometry creates a new geometry column definition in the blueprint. -// The subType parameter is optional and can be used to specify the type of geometry (e.g., "Point", "LineString", "Polygon"). -// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geometry type. -func (b *Blueprint) Geometry(name string, subType string, srid ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeGeometry, - subType: subType, - srid: optional(0, srid...), - }) -} - -// Point creates a new point column definition in the blueprint. -func (b *Blueprint) Point(name string, srid ...int) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypePoint, - srid: optional(4326, srid...), - }) -} - -// Enum creates a new enum column definition in the blueprint. -// The allowedEnums parameter is a slice of strings that defines the allowed values for the enum column. -// -// Example: -// -// table.Enum("status", []string{"active", "inactive", "pending"}) -// table.Enum("role", []string{"admin", "user", "guest"}).Comment("User role in the system") -func (b *Blueprint) Enum(name string, allowedEnums []string) ColumnDefinition { - return b.addColumn(&columnDefinition{ - name: name, - columnType: columnTypeEnum, - allowedEnums: allowedEnums, - }) -} - -// Index creates a new index definition in the blueprint. -// -// Example: -// -// table.Index("email") -// table.Index("email", "username") // creates a composite index -// table.Index("email").Algorithm("btree") // creates a btree index -func (b *Blueprint) Index(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypeIndex, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("index") - - return index -} - -// Unique creates a new unique index definition in the blueprint. -// -// Example: -// -// table.Unique("email") -// table.Unique("email", "username") // creates a composite unique index -func (b *Blueprint) Unique(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypeUnique, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("unique") - - return index -} - -// Primary creates a new primary key index definition in the blueprint. -// -// Example: -// -// table.Primary("id") -// table.Primary("id", "email") // creates a composite primary key -func (b *Blueprint) Primary(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypePrimary, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("primary") - return index -} - -// FullText creates a new fulltext index definition in the blueprint. -func (b *Blueprint) FullText(column string, otherColumns ...string) IndexDefinition { - index := &indexDefinition{ - indexType: indexTypeFullText, - columns: append([]string{column}, otherColumns...), - } - b.indexes = append(b.indexes, index) - b.addCommand("fullText") - - return index -} - -// Foreign creates a new foreign key definition in the blueprint. -// -// Example: -// -// table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") -func (b *Blueprint) Foreign(column string) ForeignKeyDefinition { - fk := &foreignKeyDefinition{ - tableName: b.name, - column: column, - } - b.foreignKeys = append(b.foreignKeys, fk) - b.addCommand("foreign") - return fk -} - -// DropColumn adds a column to be dropped from the table. -// -// Example: -// -// table.DropColumn("old_column") -// table.DropColumn("old_column", "another_old_column") // drops multiple columns -func (b *Blueprint) DropColumn(column string, otherColumns ...string) { - b.dropColumns = append(b.dropColumns, append([]string{column}, otherColumns...)...) - b.addCommand("dropColumn") -} - -// RenameColumn changes the name of the table in the blueprint. -// -// Example: -// -// table.RenameColumn("old_table_name", "new_table_name") -func (b *Blueprint) RenameColumn(oldColumn string, newColumn string) { - if b.renameColumns == nil { - b.renameColumns = make(map[string]string) - } - b.renameColumns[oldColumn] = newColumn - b.addCommand("renameColumn") -} - -// DropIndex adds an index to be dropped from the table. -func (b *Blueprint) DropIndex(indexName string) { - b.dropIndexes = append(b.dropIndexes, indexName) - b.addCommand("dropIndex") -} - -// DropForeign adds a foreign key to be dropped from the table. -func (b *Blueprint) DropForeign(foreignKeyName string) { - b.dropForeignKeys = append(b.dropForeignKeys, foreignKeyName) - b.addCommand("dropForeign") -} - -// DropPrimary adds a primary key to be dropped from the table. -func (b *Blueprint) DropPrimary(primaryKeyName string) { - b.dropPrimaryKeys = append(b.dropPrimaryKeys, primaryKeyName) - b.addCommand("dropPrimary") -} - -// DropUnique adds a unique key to be dropped from the table. -func (b *Blueprint) DropUnique(uniqueKeyName string) { - b.dropUniqueKeys = append(b.dropUniqueKeys, uniqueKeyName) - b.addCommand("dropUnique") -} - -func (b *Blueprint) DropFulltext(indexName string) { - b.dropFullText = append(b.dropFullText, indexName) - b.addCommand("dropFullText") -} - -// RenameIndex changes the name of an index in the blueprint. -// Example: -// -// table.RenameIndex("old_index_name", "new_index_name") -func (b *Blueprint) RenameIndex(oldIndexName string, newIndexName string) { - if b.renameIndexes == nil { - b.renameIndexes = make(map[string]string) - } - b.renameIndexes[oldIndexName] = newIndexName - b.addCommand("renameIndex") -} - -func (b *Blueprint) getAddedColumns() []*columnDefinition { - var addedColumns []*columnDefinition - for _, col := range b.columns { - if !col.changed { - addedColumns = append(addedColumns, col) - } - } - return addedColumns -} - -func (b *Blueprint) getChangedColumns() []*columnDefinition { - var changedColumns []*columnDefinition - for _, col := range b.columns { - if col.changed { - changedColumns = append(changedColumns, col) - } - } - return changedColumns -} - -func (b *Blueprint) create() { - b.addCommand("create") -} - -func (b *Blueprint) creating() bool { - for _, command := range b.commands { - if command == "create" || command == "createIfNotExists" { - return true - } - } - return false -} - -func (b *Blueprint) createIfNotExists() { - b.addCommand("createIfNotExists") -} - -func (b *Blueprint) drop() { - b.addCommand("drop") -} - -func (b *Blueprint) dropIfExists() { - b.addCommand("dropIfNotExists") -} - -func (b *Blueprint) rename() { - b.addCommand("rename") -} - -func (b *Blueprint) addImpliedCommands() { - if len(b.getAddedColumns()) > 0 && !b.creating() { - b.commands = append([]string{"add"}, b.commands...) - } - if len(b.getChangedColumns()) > 0 && !b.creating() { - b.commands = append([]string{"change"}, b.commands...) - } -} - -func (b *Blueprint) toSql(grammar grammar) ([]string, error) { - b.addImpliedCommands() - - var statements []string - - commandMap := map[string]func(*Blueprint) (string, error){ - "create": grammar.compileCreate, - "createIfNotExists": grammar.compileCreateIfNotExists, - "add": grammar.compileAdd, - "drop": grammar.compileDrop, - "dropIfNotExists": grammar.compileDropIfExists, - "rename": grammar.compileRename, - } - for _, command := range b.commands { - if compileFunc, exists := commandMap[command]; exists { - sql, err := compileFunc(b) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - switch command { - case "create", "createIfNotExists", "add": - for _, col := range b.getAddedColumns() { - if col.index { - indexDef := &indexDefinition{ - indexType: indexTypeIndex, - columns: []string{col.name}, - name: col.indexName, - } - sql, err := grammar.compileIndex(b, indexDef) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - } - case "change": - changedStatements, err := grammar.compileChange(b) - if err != nil { - return nil, err - } - statements = append(statements, changedStatements...) - case "index": - indexStatements, err := b.getIndexStatements(grammar, indexTypeIndex) - if err != nil { - return nil, err - } - statements = append(statements, indexStatements...) - case "unique": - uniqueStatements, err := b.getIndexStatements(grammar, indexTypeUnique) - if err != nil { - return nil, err - } - statements = append(statements, uniqueStatements...) - case "primary": - primaryStatements, err := b.getIndexStatements(grammar, indexTypePrimary) - if err != nil { - return nil, err - } - statements = append(statements, primaryStatements...) - case "fullText": - fulltextStatements, err := b.getIndexStatements(grammar, indexTypeFullText) - if err != nil { - return nil, err - } - statements = append(statements, fulltextStatements...) - case "foreign": - for _, foreignKey := range b.foreignKeys { - sql, err := grammar.compileForeign(b, foreignKey) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropColumn": - sql, err := grammar.compileDropColumn(b) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - case "renameColumn": - for oldName, newName := range b.renameColumns { - sql, err := grammar.compileRenameColumn(b, oldName, newName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropIndex": - for _, indexName := range b.dropIndexes { - sql, err := grammar.compileDropIndex(b, indexName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropUnique": - for _, uniqueKeyName := range b.dropUniqueKeys { - sql, err := grammar.compileDropUnique(b, uniqueKeyName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropFullText": - for _, fullTextIndexName := range b.dropFullText { - sql, err := grammar.compileDropFulltext(b, fullTextIndexName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropPrimary": - for _, primaryKeyName := range b.dropPrimaryKeys { - sql, err := grammar.compileDropPrimary(b, primaryKeyName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "renameIndex": - for oldName, newName := range b.renameIndexes { - sql, err := grammar.compileRenameIndex(b, oldName, newName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - case "dropForeign": - for _, foreignKeyName := range b.dropForeignKeys { - sql, err := grammar.compileDropForeign(b, foreignKeyName) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - } - } - - return statements, nil -} - -func (b *Blueprint) getIndexStatements(grammar grammar, idxType indexType) ([]string, error) { - indexCommandMap := map[indexType]func(*Blueprint, *indexDefinition) (string, error){ - indexTypeIndex: grammar.compileIndex, - indexTypeUnique: grammar.compileUnique, - indexTypePrimary: grammar.compilePrimary, - indexTypeFullText: grammar.compileFullText, - } - var statements []string - for _, index := range b.indexes { - if index.indexType == idxType { - compileFunc, exists := indexCommandMap[idxType] - if !exists { - return nil, fmt.Errorf("unsupported index type: %d", idxType) - } - sql, err := compileFunc(b, index) - if err != nil { - return nil, err - } - if sql != "" { - statements = append(statements, sql) - } - } - } - return statements, nil -} - -func (b *Blueprint) addColumn(col *columnDefinition) *columnDefinition { - b.columns = append(b.columns, col) - return col -} - -func (b *Blueprint) addCommand(command string) { - if command == "" { - return - } - if !slices.Contains(b.commands, command) { - b.commands = append(b.commands, command) - } -} diff --git a/builder.go b/builder.go deleted file mode 100644 index e63c1aa..0000000 --- a/builder.go +++ /dev/null @@ -1,177 +0,0 @@ -package schema - -import ( - "context" - "database/sql" - "errors" -) - -// Builder is an interface that defines methods for creating, dropping, and managing database tables. -type Builder interface { - // Create creates a new table with the given name and applies the provided blueprint. - Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error - // CreateIfNotExists creates a new table with the given name and applies the provided blueprint if it does not already exist. - CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error - // Drop removes the table with the given name. - Drop(ctx context.Context, tx *sql.Tx, name string) error - // DropIfExists removes the table with the given name if it exists. - DropIfExists(ctx context.Context, tx *sql.Tx, name string) error - // GetColumns retrieves the columns of the specified table. - GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) - // GetIndexes retrieves the indexes of the specified table. - GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) - // GetTables retrieves all tables in the database. - GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) - // HasColumn checks if the specified table has the given column. - HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) - // HasColumns checks if the specified table has all the given columns. - HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) - // HasIndex checks if the specified table has the given index. - HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) - // HasTable checks if a table with the given name exists. - HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) - // Rename renames a table from oldName to newName. - Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error - // Table applies the provided blueprint to the specified table. - Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error -} - -// NewBuilder creates a new Builder instance based on the specified dialect. -// It returns an error if the dialect is not supported. -// -// Supported dialects are "postgres", "pgx", "mysql", and "mariadb". -func NewBuilder(dialect string) (Builder, error) { - dialectVal := dialectFromString(dialect) - switch dialectVal { - case dialectMySQL: - return newMysqlBuilder(), nil - case dialectPostgres: - return newPostgresBuilder(), nil - default: - return nil, errors.New("unsupported dialect: " + dialect) - } -} - -type baseBuilder struct { - grammar grammar -} - -func (b *baseBuilder) validateTxAndName(tx *sql.Tx, name string) error { - if name == "" { - return errors.New("table name is empty") - } - if tx == nil { - return errors.New("transaction is nil") - } - return nil -} - -func (b *baseBuilder) validateCreateAndAlter(tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if name == "" { - return errors.New("table name is empty") - } - if blueprint == nil { - return errors.New("blueprint function is nil") - } - if tx == nil { - return errors.New("transaction is nil") - } - return nil -} - -func (b *baseBuilder) Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err - } - - bp := &Blueprint{name: name} - bp.create() - blueprint(bp) - - statements, err := bp.toSql(b.grammar) - if err != nil { - return err - } - - return execContext(ctx, tx, statements...) -} - -func (b *baseBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err - } - - bp := &Blueprint{name: name} - bp.createIfNotExists() - blueprint(bp) - - statements, err := bp.toSql(b.grammar) - if err != nil { - return err - } - - return execContext(ctx, tx, statements...) -} - -func (b *baseBuilder) Drop(ctx context.Context, tx *sql.Tx, name string) error { - if err := b.validateTxAndName(tx, name); err != nil { - return err - } - - bp := &Blueprint{name: name} - bp.drop() - statements, err := bp.toSql(b.grammar) - if err != nil { - return err - } - - return execContext(ctx, tx, statements...) -} - -func (b *baseBuilder) DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { - if err := b.validateTxAndName(tx, name); err != nil { - return err - } - - bp := &Blueprint{name: name} - bp.dropIfExists() - statements, err := bp.toSql(b.grammar) - if err != nil { - return err - } - - return execContext(ctx, tx, statements...) -} - -func (b *baseBuilder) Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error { - if oldName == "" || newName == "" { - return errors.New("old or new table name is empty") - } - if tx == nil { - return errors.New("transaction is nil") - } - bp := &Blueprint{name: oldName, newName: newName} - bp.rename() - statements, err := bp.toSql(b.grammar) - if err != nil { - return err - } - - return execContext(ctx, tx, statements...) -} - -func (b *baseBuilder) Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err - } - - bp := &Blueprint{name: name} - blueprint(bp) - - statements, err := bp.toSql(b.grammar) - if err != nil { - return err - } - - return execContext(ctx, tx, statements...) -} diff --git a/column_definition.go b/column_definition.go deleted file mode 100644 index 516bea9..0000000 --- a/column_definition.go +++ /dev/null @@ -1,137 +0,0 @@ -package schema - -import "slices" - -// ColumnDefinition defines the interface for defining a column in a database table. -type ColumnDefinition interface { - // AutoIncrement sets the column to auto-increment. - // This is typically used for primary key columns. - AutoIncrement() ColumnDefinition - // Change changes the column definition. - Change() ColumnDefinition - // Comment adds a comment to the column definition. - Comment(comment string) ColumnDefinition - // Default sets a default value for the column. - Default(value any) ColumnDefinition - // Index adds an index to the column. - Index(indexName ...string) ColumnDefinition - // Nullable sets the column to be nullable or not. - Nullable(value ...bool) ColumnDefinition - // Primary sets the column as a primary key. - Primary() ColumnDefinition - // Unique sets the column to be unique. - Unique(indexName ...string) ColumnDefinition - // Unsigned sets the column to be unsigned (applicable for numeric types). - Unsigned() ColumnDefinition - // UseCurrent sets the column to use the current timestamp as default. - UseCurrent() ColumnDefinition - // UseCurrentOnUpdate sets the column to use the current timestamp on update. - UseCurrentOnUpdate() ColumnDefinition -} - -var _ ColumnDefinition = &columnDefinition{} - -type columnDefinition struct { - name string - columnType columnType - customColumnType string // for custom column types - commands []string - comment string - defaultValue any - onUpdateValue string - nullable bool - autoIncrement bool - unsigned bool - primary bool - index bool - indexName string - unique bool - uniqueName string - length int - precision int - total int - places int - changed bool - allowedEnums []string // for enum type columns - subType string // for geography and geometry types - srid int // for geography and geometry types -} - -func (c *columnDefinition) addCommand(command string) { - if command == "" { - return - } - if !slices.Contains(c.commands, command) { - c.commands = append(c.commands, command) - } -} - -func (c *columnDefinition) hasCommand(command string) bool { - return slices.Contains(c.commands, command) -} - -func (c *columnDefinition) AutoIncrement() ColumnDefinition { - c.addCommand("autoIncrement") - c.autoIncrement = true - return c -} - -func (c *columnDefinition) Comment(comment string) ColumnDefinition { - c.addCommand("comment") - c.comment = comment - return c -} - -func (c *columnDefinition) Default(value any) ColumnDefinition { - c.addCommand("default") - c.defaultValue = value - - return c -} - -func (c *columnDefinition) Index(indexName ...string) ColumnDefinition { - c.index = true - c.indexName = optional("", indexName...) - c.addCommand("index") - return c -} - -func (c *columnDefinition) Nullable(value ...bool) ColumnDefinition { - c.addCommand("nullable") - c.nullable = optional(true, value...) - return c -} - -func (c *columnDefinition) Primary() ColumnDefinition { - c.addCommand("primary") - c.primary = true - return c -} - -func (c *columnDefinition) Unique(indexName ...string) ColumnDefinition { - c.addCommand("unique") - c.unique = true - c.uniqueName = optional("", indexName...) - return c -} - -func (c *columnDefinition) Change() ColumnDefinition { - c.addCommand("change") - c.changed = true - return c -} - -func (c *columnDefinition) Unsigned() ColumnDefinition { - c.unsigned = true - return c -} - -func (c *columnDefinition) UseCurrent() ColumnDefinition { - c.Default("CURRENT_TIMESTAMP") - return c -} - -func (c *columnDefinition) UseCurrentOnUpdate() ColumnDefinition { - c.onUpdateValue = "CURRENT_TIMESTAMP" - return c -} diff --git a/create.go b/create.go new file mode 100644 index 0000000..f348cb4 --- /dev/null +++ b/create.go @@ -0,0 +1,99 @@ +package migris + +import ( + "text/template" + + "github.com/akfaiz/migris/internal/parser" + "github.com/pressly/goose/v3" +) + +// Create creates a new migration file with the given name in the specified directory. +func (m *Migrate) Create(name string) error { + tmpl := getMigrationTemplate(name) + return goose.CreateWithTemplate(nil, m.migrationDir, tmpl, name, "go") +} + +func getMigrationTemplate(name string) *template.Template { + tableName, create := parser.ParseMigrationName(name) + if create { + return migrationCreateTemplate(tableName) + } + if tableName != "" { + return migrationUpdateTemplate(tableName) + } + return migrationTemplate +} + +var migrationTemplate = template.Must(template.New("migrator.go-migration").Parse(`package migrations + +import ( + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" +) + +func init() { + migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) +} + +func up{{.CamelName}}(c *schema.Context) error { + // This code is executed when the migration is applied. + return nil +} + +func down{{.CamelName}}(c *schema.Context) error { + // This code is executed when the migration is rolled back. + return nil +} +`)) + +func migrationCreateTemplate(table string) *template.Template { + tmpl := `package migrations + +import ( + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" +) + +func init() { + migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) +} + +func up{{.CamelName}}(c *schema.Context) error { + return schema.Create(c, "` + table + `", func(table *schema.Blueprint) { + // Define your table schema here + }) +} + +func down{{.CamelName}}(c *schema.Context) error { + return schema.DropIfExists(c, "` + table + `") +} +` + return template.Must(template.New("migration-create").Parse(tmpl)) +} + +func migrationUpdateTemplate(table string) *template.Template { + tmpl := `package migrations + +import ( + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" +) + +func init() { + migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) +} + +func up{{.CamelName}}(c *schema.Context) error { + return schema.Table(c, "` + table + `", func(table *schema.Blueprint) { + // Define your table schema changes here + }) +} + +func down{{.CamelName}}(c *schema.Context) error { + return schema.Table(c, "` + table + `", func(table *schema.Blueprint) { + // Define your table schema changes here + }) +} +` + return template.Must(template.New("migration-update").Parse(tmpl)) +} diff --git a/dialect.go b/dialect.go deleted file mode 100644 index a355a9d..0000000 --- a/dialect.go +++ /dev/null @@ -1,53 +0,0 @@ -package schema - -import ( - "fmt" -) - -type dialect uint8 - -const ( - dialectUnknown dialect = iota - dialectMySQL - dialectPostgres -) - -func (d dialect) String() string { - switch d { - case dialectMySQL: - return "mysql" - case dialectPostgres: - return "postgres" - default: - return "" - } -} - -var dialectValue dialect = dialectUnknown -var debug = false - -// SetDialect sets the current dialect for the schema package. -func SetDialect(d string) error { - dialectValue = dialectFromString(d) - if dialectValue == dialectUnknown { - return fmt.Errorf("unknown dialect: %s", d) - } - - return nil -} - -func dialectFromString(d string) dialect { - switch d { - case "mysql", "mariadb": - return dialectMySQL - case "postgres", "pgx": - return dialectPostgres - default: - return dialectUnknown - } -} - -// SetDebug enables or disables debug mode for the schema package. -func SetDebug(d bool) { - debug = d -} diff --git a/dialect_test.go b/dialect_test.go deleted file mode 100644 index 5889903..0000000 --- a/dialect_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package schema - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestSetDialect(t *testing.T) { - testCases := []struct { - dialect string - expectError bool - }{ - {"postgres", false}, - {"pgx", false}, - {"mysql", false}, - {"mariadb", false}, - {"sqlite", true}, - } - for _, tc := range testCases { - t.Run(tc.dialect, func(t *testing.T) { - err := SetDialect(tc.dialect) - if tc.expectError { - assert.Error(t, err, "Expected error for unsupported dialect %s", tc.dialect) - } else { - assert.NoError(t, err, "Did not expect error for supported dialect %s", tc.dialect) - } - }) - } -} diff --git a/down.go b/down.go new file mode 100644 index 0000000..9140a19 --- /dev/null +++ b/down.go @@ -0,0 +1,78 @@ +package migris + +import ( + "context" + "errors" + + "github.com/akfaiz/migris/internal/logger" + "github.com/pressly/goose/v3" +) + +// Down rolls back the last migration. +func (m *Migrate) Down() error { + ctx := context.Background() + return m.DownContext(ctx) +} + +// DownContext rolls back the last migration. +func (m *Migrate) DownContext(ctx context.Context) error { + provider, err := m.newProvider() + if err != nil { + return err + } + currentVersion, err := provider.GetDBVersion(ctx) + if err != nil { + return err + } + if currentVersion == 0 { + logger.Info("Nothing to rollback.") + return nil + } + logger.Info("Rolling back migrations.\n") + result, err := provider.Down(ctx) + if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResult(partialErr.Failed) + } + return err + } + if result != nil { + logger.PrintResult(result) + } + return nil +} + +// DownTo rolls back the migrations to the specified version. +func (m *Migrate) DownTo(version int64) error { + ctx := context.Background() + return m.DownToContext(ctx, version) +} + +// DownToContext rolls back the migrations to the specified version. +func (m *Migrate) DownToContext(ctx context.Context, version int64) error { + provider, err := m.newProvider() + if err != nil { + return err + } + currentVersion, err := provider.GetDBVersion(ctx) + if err != nil { + return err + } + if currentVersion == 0 { + logger.Info("Nothing to rollback.") + return nil + } + logger.Info("Rolling back migrations.\n") + results, err := provider.DownTo(ctx, version) + if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResults(partialErr.Applied) + logger.PrintResult(partialErr.Failed) + } + return err + } + logger.PrintResults(results) + return nil +} diff --git a/examples/basic/cmd/migrate/cmd.go b/examples/basic/cmd/migrate/cmd.go new file mode 100644 index 0000000..9b39d89 --- /dev/null +++ b/examples/basic/cmd/migrate/cmd.go @@ -0,0 +1,60 @@ +package migrate + +import ( + "context" + + "github.com/urfave/cli/v3" +) + +func Command() *cli.Command { + cmd := &cli.Command{ + Name: "migrate", + Usage: "Migration tool", + Commands: []*cli.Command{ + { + Name: "create", + Usage: "Create a new migration file", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "name", + Aliases: []string{"n"}, + Usage: "Name of the migration", + Required: true, + }, + }, + Action: func(ctx context.Context, c *cli.Command) error { + return Create(c.String("name")) + }, + }, + { + Name: "up", + Usage: "Run all pending migrations", + Action: func(ctx context.Context, c *cli.Command) error { + return Up() + }, + }, + { + Name: "reset", + Usage: "Reset all migrations", + Action: func(ctx context.Context, c *cli.Command) error { + return Reset() + }, + }, + { + Name: "down", + Usage: "Revert the last migration", + Action: func(ctx context.Context, c *cli.Command) error { + return Down() + }, + }, + { + Name: "status", + Usage: "Show the status of migrations", + Action: func(ctx context.Context, c *cli.Command) error { + return Status() + }, + }, + }, + } + return cmd +} diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go index d55c33b..7981229 100644 --- a/examples/basic/cmd/migrate/migrate.go +++ b/examples/basic/cmd/migrate/migrate.go @@ -4,85 +4,73 @@ import ( "database/sql" "fmt" - "github.com/afkdevs/go-schema" - "github.com/afkdevs/go-schema/examples/basic/config" - _ "github.com/afkdevs/go-schema/examples/basic/migrations" + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/examples/basic/config" + _ "github.com/akfaiz/migris/examples/basic/migrations" _ "github.com/lib/pq" // PostgreSQL driver - "github.com/pressly/goose/v3" ) func Up() error { - m, err := newMigrator() + m, err := newMigrate() if err != nil { return err } - return goose.Up(m.db, m.dir) + return m.Up() } func Create(name string) error { - m, err := newMigrator() + m, err := newMigrate() if err != nil { return err } - return goose.Create(m.db, m.dir, name, m.migrationType) + return m.Create(name) +} + +func Reset() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Reset() } func Down() error { - m, err := newMigrator() + m, err := newMigrate() if err != nil { return err } - return goose.Reset(m.db, m.dir) + return m.Down() } -type migrator struct { - dir string - dialect string - tableName string - migrationType string - db *sql.DB +func Status() error { + m, err := newMigrate() + if err != nil { + return err + } + return m.Status() } -func newMigrator() (*migrator, error) { - db, err := newDatabase(config.GetDatabase()) +func newMigrate() (*migris.Migrate, error) { + cfg, err := config.Load() if err != nil { return nil, err } - m := &migrator{ - dir: "migrations", - dialect: "postgres", - tableName: "schema_migrations", - migrationType: "go", - db: db, - } - if err := m.init(); err != nil { + db, err := openDatabase(cfg.Database) + if err != nil { return nil, err } - - return m, nil + migrate, err := migris.New("postgres", migris.WithDB(db)) + if err != nil { + return nil, fmt.Errorf("failed to create migris instance: %w", err) + } + return migrate, nil } -func newDatabase(cfg config.Database) (*sql.DB, error) { +func openDatabase(cfg config.Database) (*sql.DB, error) { dsn := cfg.DSN() db, err := sql.Open("postgres", dsn) if err != nil { - return nil, fmt.Errorf("failed to connect to database: %w", err) - } - - if err := db.Ping(); err != nil { - return nil, fmt.Errorf("failed to ping database: %w", err) + return nil, fmt.Errorf("failed to open database: %w", err) } - return db, nil } - -func (m *migrator) init() error { - goose.SetTableName(m.tableName) - if err := goose.SetDialect(m.dialect); err != nil { - return err - } - if err := schema.SetDialect(m.dialect); err != nil { - return err - } - return nil -} diff --git a/examples/basic/cmd/root.go b/examples/basic/cmd/root.go new file mode 100644 index 0000000..ee336c3 --- /dev/null +++ b/examples/basic/cmd/root.go @@ -0,0 +1,20 @@ +package cmd + +import ( + "context" + + "github.com/akfaiz/migris/examples/basic/cmd/migrate" + "github.com/urfave/cli/v3" +) + +var cmd = &cli.Command{ + Name: "schema-example", + Usage: "A simple schema example", + Commands: []*cli.Command{ + migrate.Command(), + }, +} + +func Execute(args []string) error { + return cmd.Run(context.Background(), args) +} diff --git a/examples/basic/config/config.go b/examples/basic/config/config.go new file mode 100644 index 0000000..1b028be --- /dev/null +++ b/examples/basic/config/config.go @@ -0,0 +1,20 @@ +package config + +import "github.com/joho/godotenv" + +type Config struct { + Database Database +} + +func Load() (Config, error) { + var config Config + err := godotenv.Load() + if err != nil { + return config, err + } + config = Config{ + Database: getDatabaseConfig(), + } + + return config, nil +} diff --git a/examples/basic/config/database.go b/examples/basic/config/database.go index 722d0ec..8065182 100644 --- a/examples/basic/config/database.go +++ b/examples/basic/config/database.go @@ -16,13 +16,13 @@ func (db Database) DSN() string { db.User, db.Password, db.Host, db.Port, db.Database, db.SSLMode) } -func GetDatabase() Database { +func getDatabaseConfig() Database { return Database{ - Host: "localhost", - Port: 5432, - User: "root", - Password: "password", - Database: "schema_example", - SSLMode: "disable", + Host: getEnv("DB_HOST"), + Port: getEnvInt("DB_PORT"), + User: getEnv("DB_USER"), + Password: getEnv("DB_PASSWORD"), + Database: getEnv("DB_NAME"), + SSLMode: getEnv("DB_SSLMODE"), } } diff --git a/examples/basic/config/util.go b/examples/basic/config/util.go new file mode 100644 index 0000000..d2b5c43 --- /dev/null +++ b/examples/basic/config/util.go @@ -0,0 +1,23 @@ +package config + +import ( + "os" + "strconv" +) + +func getEnv(key string) string { + value := os.Getenv(key) + return value +} + +func getEnvInt(key string) int { + value := os.Getenv(key) + if value == "" { + return 0 + } + intValue, err := strconv.Atoi(value) + if err != nil { + return 0 + } + return intValue +} diff --git a/examples/basic/go.mod b/examples/basic/go.mod index a4c82dd..6d2324f 100644 --- a/examples/basic/go.mod +++ b/examples/basic/go.mod @@ -1,19 +1,25 @@ -module github.com/afkdevs/go-schema/examples/basic +module github.com/akfaiz/migris/examples/basic go 1.23.0 require ( - github.com/afkdevs/go-schema v0.0.0-000000000000-00010101000000 + github.com/akfaiz/migris v0.0.0 + github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 - github.com/pressly/goose/v3 v3.24.3 github.com/urfave/cli/v3 v3.3.8 ) -replace github.com/afkdevs/go-schema => ../.. - require ( + github.com/fatih/color v1.18.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mfridman/interpolate v0.0.2 // indirect + github.com/pressly/goose/v3 v3.25.0 // indirect github.com/sethvargo/go-retry v0.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/sync v0.14.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/term v0.34.0 // indirect ) + +replace github.com/akfaiz/migris => ../.. diff --git a/examples/basic/go.sum b/examples/basic/go.sum index 69cac7b..1b92f73 100644 --- a/examples/basic/go.sum +++ b/examples/basic/go.sum @@ -4,12 +4,19 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= @@ -18,31 +25,35 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pressly/goose/v3 v3.24.3 h1:DSWWNwwggVUsYZ0X2VitiAa9sKuqtBfe+Jr9zFGwWlM= -github.com/pressly/goose/v3 v3.24.3/go.mod h1:v9zYL4xdViLHCUUJh/mhjnm6JrK7Eul8AS93IxiZM4E= +github.com/pressly/goose/v3 v3.25.0 h1:6WeYhMWGRCzpyd89SpODFnCBCKz41KrVbRT58nVjGng= +github.com/pressly/goose/v3 v3.25.0/go.mod h1:4hC1KrritdCxtuFsqgs1R4AU5bWtTAf+cnWvfhf2DNY= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/urfave/cli/v3 v3.3.8 h1:BzolUExliMdet9NlJ/u4m5vHSotJ3PzEqSAZ1oPMa/E= github.com/urfave/cli/v3 v3.3.8/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= -golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -modernc.org/libc v1.65.0 h1:e183gLDnAp9VJh6gWKdTy0CThL9Pt7MfcR/0bgb7Y1Y= -modernc.org/libc v1.65.0/go.mod h1:7m9VzGq7APssBTydds2zBcxGREwvIGpuUBaKTXdm2Qs= +modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ= +modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8= modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= -modernc.org/memory v1.10.0 h1:fzumd51yQ1DxcOxSO+S6X7+QTuVU+n8/Aj7swYjFfC4= -modernc.org/memory v1.10.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/sqlite v1.37.0 h1:s1TMe7T3Q3ovQiK2Ouz4Jwh7dw4ZDqbebSDTlSJdfjI= -modernc.org/sqlite v1.37.0/go.mod h1:5YiWv+YviqGMuGw4V+PNplcyaJ5v+vQd7TQOgkACoJM= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek= +modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E= diff --git a/examples/basic/main.go b/examples/basic/main.go index 6a5590b..050e9ea 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -1,60 +1,14 @@ package main import ( - "context" "log" "os" - "github.com/afkdevs/go-schema/examples/basic/cmd/migrate" - "github.com/urfave/cli/v3" + "github.com/akfaiz/migris/examples/basic/cmd" ) func main() { - app := &cli.Command{ - Name: "schema-example", - Usage: "A simple schema example", - Commands: []*cli.Command{ - { - Name: "migrate", - Aliases: []string{"m"}, - Usage: "Run database migrations", - Commands: []*cli.Command{ - { - Name: "create", - Usage: "Create a new migration file", - Flags: []cli.Flag{ - &cli.StringFlag{ - Name: "name", - Aliases: []string{"n"}, - Usage: "Name of the migration", - Required: true, - }, - }, - Action: func(ctx context.Context, c *cli.Command) error { - return migrate.Create(c.String("name")) - }, - }, - { - Name: "up", - Aliases: []string{"u"}, - Usage: "Run all pending migrations", - Action: func(ctx context.Context, c *cli.Command) error { - return migrate.Up() - }, - }, - { - Name: "down", - Aliases: []string{"d"}, - Usage: "Down all migrations", - Action: func(ctx context.Context, c *cli.Command) error { - return migrate.Down() - }, - }, - }, - }, - }, - } - if err := app.Run(context.Background(), os.Args); err != nil { + if err := cmd.Execute(os.Args); err != nil { log.Printf("Error running app: %v\n", err) os.Exit(1) } diff --git a/examples/basic/migrations/20250626235117_create_users_table.go b/examples/basic/migrations/20250626235117_create_users_table.go deleted file mode 100644 index 23d318e..0000000 --- a/examples/basic/migrations/20250626235117_create_users_table.go +++ /dev/null @@ -1,29 +0,0 @@ -package migrations - -import ( - "context" - "database/sql" - - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" -) - -func init() { - goose.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) -} - -func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email") - table.Timestamp("email_verified_at").Nullable() - table.String("password") - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") - }) -} - -func downCreateUsersTable(ctx context.Context, tx *sql.Tx) error { - return schema.DropIfExists(ctx, tx, "users") -} diff --git a/examples/basic/migrations/20250628092119_create_roles_table.go b/examples/basic/migrations/20250628092119_create_roles_table.go deleted file mode 100644 index d857cc5..0000000 --- a/examples/basic/migrations/20250628092119_create_roles_table.go +++ /dev/null @@ -1,24 +0,0 @@ -package migrations - -import ( - "context" - "database/sql" - - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" -) - -func init() { - goose.AddMigrationContext(upCreateRolesTable, downCreateRolesTable) -} - -func upCreateRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "roles", func(table *schema.Blueprint) { - table.Increments("id").Primary() - table.String("name").Unique().Nullable(false) - }) -} - -func downCreateRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.DropIfExists(ctx, tx, "roles") -} diff --git a/examples/basic/migrations/20250628092223_create_user_roles_table.go b/examples/basic/migrations/20250628092223_create_user_roles_table.go deleted file mode 100644 index 713023b..0000000 --- a/examples/basic/migrations/20250628092223_create_user_roles_table.go +++ /dev/null @@ -1,27 +0,0 @@ -package migrations - -import ( - "context" - "database/sql" - - "github.com/afkdevs/go-schema" - "github.com/pressly/goose/v3" -) - -func init() { - goose.AddMigrationContext(upCreateUserRolesTable, downCreateUserRolesTable) -} - -func upCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.Create(ctx, tx, "user_roles", func(table *schema.Blueprint) { - table.BigInteger("user_id") - table.Integer("role_id") - table.Primary("user_id", "role_id") - table.Foreign("user_id").References("id").On("users") - table.Foreign("role_id").References("id").On("roles") - }) -} - -func downCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error { - return schema.DropIfExists(ctx, tx, "user_roles") -} diff --git a/examples/basic/migrations/20250830103612_create_users_table.go b/examples/basic/migrations/20250830103612_create_users_table.go new file mode 100644 index 0000000..158e385 --- /dev/null +++ b/examples/basic/migrations/20250830103612_create_users_table.go @@ -0,0 +1,25 @@ +package migrations + +import ( + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" +) + +func init() { + migris.AddMigrationContext(upCreateUsersTable, downCreateUsersTable) +} + +func upCreateUsersTable(c *schema.Context) error { + return schema.Create(c, "users", func(table *schema.Blueprint) { + table.ID() + table.String("name") + table.String("email") + table.Timestamp("email_verified_at").Nullable() + table.String("password") + table.Timestamps() + }) +} + +func downCreateUsersTable(c *schema.Context) error { + return schema.DropIfExists(c, "users") +} diff --git a/examples/basic/migrations/20250830103653_create_roles_table.go b/examples/basic/migrations/20250830103653_create_roles_table.go new file mode 100644 index 0000000..1e961ce --- /dev/null +++ b/examples/basic/migrations/20250830103653_create_roles_table.go @@ -0,0 +1,21 @@ +package migrations + +import ( + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" +) + +func init() { + migris.AddMigrationContext(upCreateRolesTable, downCreateRolesTable) +} + +func upCreateRolesTable(c *schema.Context) error { + return schema.Create(c, "roles", func(table *schema.Blueprint) { + table.Increments("id").Primary() + table.String("name").Unique().Nullable(false) + }) +} + +func downCreateRolesTable(c *schema.Context) error { + return schema.DropIfExists(c, "roles") +} diff --git a/examples/basic/migrations/20250830103714_create_user_roles_table.go b/examples/basic/migrations/20250830103714_create_user_roles_table.go new file mode 100644 index 0000000..764a76b --- /dev/null +++ b/examples/basic/migrations/20250830103714_create_user_roles_table.go @@ -0,0 +1,24 @@ +package migrations + +import ( + "github.com/akfaiz/migris" + "github.com/akfaiz/migris/schema" +) + +func init() { + migris.AddMigrationContext(upCreateUserRolesTable, downCreateUserRolesTable) +} + +func upCreateUserRolesTable(c *schema.Context) error { + return schema.Create(c, "user_roles", func(table *schema.Blueprint) { + table.BigInteger("user_id") + table.Integer("role_id") + table.Primary("user_id", "role_id") + table.Foreign("user_id").References("id").On("users") + table.Foreign("role_id").References("id").On("roles") + }) +} + +func downCreateUserRolesTable(c *schema.Context) error { + return schema.DropIfExists(c, "user_roles") +} diff --git a/foreign_key_definition.go b/foreign_key_definition.go deleted file mode 100644 index 4c6828b..0000000 --- a/foreign_key_definition.go +++ /dev/null @@ -1,127 +0,0 @@ -package schema - -// ForeignKeyDefinition defines the interface for defining a foreign key constraint in a database table. -type ForeignKeyDefinition interface { - // CascadeOnDelete sets the foreign key to cascade on delete. - CascadeOnDelete() ForeignKeyDefinition - // CascadeOnUpdate sets the foreign key to cascade on update. - CascadeOnUpdate() ForeignKeyDefinition - // Deferrable sets the foreign key as deferrable. - Deferrable(value ...bool) ForeignKeyDefinition - // InitiallyImmediate sets the foreign key to be initially immediate. - InitiallyImmediate(value ...bool) ForeignKeyDefinition - // Name sets the name of the foreign key constraint. - // This is optional and can be used to give a specific name to the foreign key. - Name(name string) ForeignKeyDefinition - // NoActionOnDelete sets the foreign key to do nothing on delete. - NoActionOnDelete() ForeignKeyDefinition - // NoActionOnUpdate sets the foreign key to do nothing on update. - NoActionOnUpdate() ForeignKeyDefinition - // NullOnDelete sets the foreign key to set the column to NULL on delete. - NullOnDelete() ForeignKeyDefinition - // NullOnUpdate sets the foreign key to set the column to NULL on update. - NullOnUpdate() ForeignKeyDefinition - // On sets the table that this foreign key references. - On(table string) ForeignKeyDefinition - // OnDelete sets the action to take when the referenced row is deleted. - OnDelete(action string) ForeignKeyDefinition - // OnUpdate sets the action to take when the referenced row is updated. - OnUpdate(action string) ForeignKeyDefinition - // References sets the column that this foreign key references in the other table. - References(column string) ForeignKeyDefinition - // RestrictOnDelete sets the foreign key to restrict deletion of the referenced row. - RestrictOnDelete() ForeignKeyDefinition - // RestrictOnUpdate sets the foreign key to restrict updating of the referenced row. - RestrictOnUpdate() ForeignKeyDefinition -} - -var _ ForeignKeyDefinition = &foreignKeyDefinition{} - -type foreignKeyDefinition struct { - tableName string - column string - constaintName string // name of the foreign key constraint - references string - on string - onDelete string - onUpdate string - deferrable *bool - initiallyImmediate *bool -} - -func (fk *foreignKeyDefinition) CascadeOnDelete() ForeignKeyDefinition { - fk.onDelete = "CASCADE" - return fk -} - -func (fk *foreignKeyDefinition) CascadeOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "CASCADE" - return fk -} - -func (fk *foreignKeyDefinition) Deferrable(value ...bool) ForeignKeyDefinition { - val := optional(true, value...) - fk.deferrable = &val - return fk -} - -func (fk *foreignKeyDefinition) InitiallyImmediate(value ...bool) ForeignKeyDefinition { - val := optional(true, value...) - fk.initiallyImmediate = &val - return fk -} - -func (fk *foreignKeyDefinition) Name(name string) ForeignKeyDefinition { - fk.constaintName = name - return fk -} - -func (fk *foreignKeyDefinition) NoActionOnDelete() ForeignKeyDefinition { - fk.onDelete = "NO ACTION" - return fk -} - -func (fk *foreignKeyDefinition) NoActionOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "NO ACTION" - return fk -} - -func (fk *foreignKeyDefinition) NullOnDelete() ForeignKeyDefinition { - fk.onDelete = "SET NULL" - return fk -} - -func (fk *foreignKeyDefinition) NullOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "SET NULL" - return fk -} - -func (fk *foreignKeyDefinition) On(table string) ForeignKeyDefinition { - fk.on = table - return fk -} - -func (fk *foreignKeyDefinition) OnDelete(action string) ForeignKeyDefinition { - fk.onDelete = action - return fk -} - -func (fk *foreignKeyDefinition) OnUpdate(action string) ForeignKeyDefinition { - fk.onUpdate = action - return fk -} - -func (fk *foreignKeyDefinition) References(column string) ForeignKeyDefinition { - fk.references = column - return fk -} - -func (fk *foreignKeyDefinition) RestrictOnDelete() ForeignKeyDefinition { - fk.onDelete = "RESTRICT" - return fk -} - -func (fk *foreignKeyDefinition) RestrictOnUpdate() ForeignKeyDefinition { - fk.onUpdate = "RESTRICT" - return fk -} diff --git a/go.mod b/go.mod index 1f0d73d..241fe24 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,26 @@ -module github.com/afkdevs/go-schema +module github.com/akfaiz/migris -go 1.23 +go 1.23.0 require ( + github.com/fatih/color v1.18.0 github.com/go-sql-driver/mysql v1.9.3 github.com/lib/pq v1.10.9 - github.com/stretchr/testify v1.10.0 + github.com/pressly/goose/v3 v3.25.0 + github.com/stretchr/testify v1.11.1 + golang.org/x/term v0.34.0 ) require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/pretty v0.3.1 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mfridman/interpolate v0.0.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + github.com/sethvargo/go-retry v0.3.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.35.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d87df1b..e53afdf 100644 --- a/go.sum +++ b/go.sum @@ -1,29 +1,57 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY= +github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/pressly/goose/v3 v3.25.0 h1:6WeYhMWGRCzpyd89SpODFnCBCKz41KrVbRT58nVjGng= +github.com/pressly/goose/v3 v3.25.0/go.mod h1:4hC1KrritdCxtuFsqgs1R4AU5bWtTAf+cnWvfhf2DNY= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE= +github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= +golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ= +modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek= +modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E= diff --git a/grammar.go b/grammar.go deleted file mode 100644 index ccdce5e..0000000 --- a/grammar.go +++ /dev/null @@ -1,138 +0,0 @@ -package schema - -import ( - "fmt" - "slices" - "strings" -) - -type grammar interface { - compileCreate(bp *Blueprint) (string, error) - compileCreateIfNotExists(bp *Blueprint) (string, error) - compileAdd(bp *Blueprint) (string, error) - compileChange(bp *Blueprint) ([]string, error) - compileDrop(bp *Blueprint) (string, error) - compileDropIfExists(bp *Blueprint) (string, error) - compileRename(bp *Blueprint) (string, error) - compileDropColumn(blueprint *Blueprint) (string, error) - compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) - compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) - compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) - compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) - compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) - compileDropIndex(blueprint *Blueprint, indexName string) (string, error) - compileDropUnique(blueprint *Blueprint, indexName string) (string, error) - compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) - compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) - compileRenameIndex(blueprint *Blueprint, oldName, newName string) (string, error) - compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) - compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) -} - -type baseGrammar struct{} - -func (g *baseGrammar) compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) { - if foreignKey.column == "" || foreignKey.on == "" || foreignKey.references == "" { - return "", fmt.Errorf("foreign key definition is incomplete: column, on, and references must be set") - } - onDelete := "" - if foreignKey.onDelete != "" { - onDelete = fmt.Sprintf(" ON DELETE %s", foreignKey.onDelete) - } - onUpdate := "" - if foreignKey.onUpdate != "" { - onUpdate = fmt.Sprintf(" ON UPDATE %s", foreignKey.onUpdate) - } - containtName := foreignKey.constaintName - if containtName == "" { - containtName = g.createForeignKeyName(blueprint, foreignKey) - } - - return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)%s%s", - blueprint.name, - containtName, - foreignKey.column, - foreignKey.on, - foreignKey.references, - onDelete, - onUpdate, - ), nil -} - -func (g *baseGrammar) quoteString(s string) string { - return "'" + s + "'" -} - -func (g *baseGrammar) prefixArray(prefix string, items []string) []string { - prefixed := make([]string, len(items)) - for i, item := range items { - prefixed[i] = fmt.Sprintf("%s%s", prefix, item) - } - return prefixed -} - -func (g *baseGrammar) columnize(columns []string) string { - if len(columns) == 0 { - return "" - } - return strings.Join(columns, ", ") -} - -func (g *baseGrammar) getDefaultValue(col *columnDefinition) string { - if col.defaultValue == nil { - return "NULL" - } - useQuote := slices.Contains([]columnType{columnTypeString, columnTypeChar, columnTypeEnum}, col.columnType) - - switch v := col.defaultValue.(type) { - case string: - if useQuote { - return g.quoteString(v) - } - return v - case int, int64, float64: - return fmt.Sprintf("%v", v) - case bool: - if v { - return "true" - } - return "false" - default: - return fmt.Sprintf("'%v'", v) // Fallback for other types - } -} - -func (g *baseGrammar) createIndexName(blueprint *Blueprint, index *indexDefinition) string { - tableName := blueprint.name - if strings.Contains(tableName, ".") { - parts := strings.Split(tableName, ".") - tableName = parts[len(parts)-1] // Use the last part as the table name - } - - switch index.indexType { - case indexTypePrimary: - return fmt.Sprintf("pk_%s", tableName) - case indexTypeUnique: - return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(index.columns, "_")) - case indexTypeIndex: - return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(index.columns, "_")) - case indexTypeFullText: - return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(index.columns, "_")) - default: - return "" - } -} - -func (g *baseGrammar) createForeignKeyName(blueprint *Blueprint, foreignKey *foreignKeyDefinition) string { - tableName := blueprint.name - if strings.Contains(tableName, ".") { - parts := strings.Split(tableName, ".") - tableName = parts[len(parts)-1] // Use the last part as the table name - } - on := foreignKey.on - if strings.Contains(on, ".") { - parts := strings.Split(on, ".") - on = parts[len(parts)-1] // Use the last part as the column name - } - return fmt.Sprintf("fk_%s_%s", tableName, on) -} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..856cfd6 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,29 @@ +package config + +import ( + "sync/atomic" + + "github.com/akfaiz/migris/internal/dialect" +) + +type Config struct { + Dialect dialect.Dialect +} + +var config = atomic.Pointer[Config]{} + +func init() { + config.Store(&Config{ + Dialect: dialect.Unknown, + }) +} + +func SetDialect(dialect dialect.Dialect) { + cfg := config.Load() + cfg.Dialect = dialect + config.Store(cfg) +} + +func GetDialect() dialect.Dialect { + return config.Load().Dialect +} diff --git a/internal/dialect/dialect.go b/internal/dialect/dialect.go new file mode 100644 index 0000000..c9f2fb4 --- /dev/null +++ b/internal/dialect/dialect.go @@ -0,0 +1,38 @@ +package dialect + +import "github.com/pressly/goose/v3/database" + +// Dialect is the type of database dialect. +type Dialect string + +const ( + MySQL Dialect = "mysql" + Postgres Dialect = "postgres" + Unknown Dialect = "" +) + +func (d Dialect) String() string { + return string(d) +} + +func (d Dialect) GooseDialect() database.Dialect { + switch d { + case MySQL: + return database.DialectMySQL + case Postgres: + return database.DialectPostgres + default: + return database.DialectCustom + } +} + +func FromString(dialect string) Dialect { + switch dialect { + case "mysql", "mariadb": + return MySQL + case "postgres", "pgx": + return Postgres + default: + return Unknown + } +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..232e71e --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,101 @@ +package logger + +import ( + "fmt" + "os" + "strings" + + "github.com/fatih/color" + "github.com/pressly/goose/v3" + "golang.org/x/term" +) + +var ( + grey = color.New(color.FgHiBlack).SprintFunc() + greenBold = color.New(color.FgGreen, color.Bold).SprintFunc() + yellowBold = color.New(color.FgYellow, color.Bold).SprintFunc() + redBold = color.New(color.FgRed, color.Bold).SprintFunc() +) + +func Info(msg string) { + Infof("%s", msg) +} + +func Infof(format string, args ...any) { + msg := fmt.Sprintf(format, args...) + + whiteBgBlue := color.New(color.FgWhite, color.BgBlue).SprintFunc() + fmt.Printf("%s %s\n", whiteBgBlue(" INFO "), msg) +} + +func PrintResults(results []*goose.MigrationResult) { + for _, result := range results { + PrintResult(result) + } +} + +func PrintResult(result *goose.MigrationResult) { + // Get terminal width + width, _, err := term.GetSize(int(os.Stdout.Fd())) + if err != nil || width <= 0 { + width = 80 // fallback + } + + durText := fmt.Sprintf(" %.2fms", result.Duration.Seconds()*1000) + statusText := " DONE" + if result.Error != nil { + statusText = " FAIL" + } + + name := result.Source.Path + + // Calculate how many dots to add + fillLen := width - len(name) - len(durText) - len(statusText) - 1 + if fillLen < 0 { + fillLen = 0 + } + dots := strings.Repeat(".", fillLen) + + // Print + fmt.Printf("%s %s%s", name, grey(dots), grey(durText)) + if result.Error != nil { + fmt.Printf("%s\n", redBold(statusText)) + } else { + fmt.Printf("%s\n", greenBold(statusText)) + } +} + +func PrintStatuses(statuses []*goose.MigrationStatus) { + for _, status := range statuses { + PrintStatus(status) + } +} + +func PrintStatus(status *goose.MigrationStatus) { + width, _, err := term.GetSize(int(os.Stdout.Fd())) + if err != nil || width <= 0 { + width = 80 // fallback + } + + var statusText string + if status.State == goose.StateApplied { + statusText = " Applied" + } else { + statusText = " Pending" + } + + name := status.Source.Path + + fillLen := width - len(name) - len(statusText) - 1 + if fillLen < 0 { + fillLen = 0 + } + dots := strings.Repeat(".", fillLen) + + fmt.Printf("%s %s", name, grey(dots)) + if status.State == goose.StateApplied { + fmt.Printf("%s\n", greenBold(statusText)) + } else { + fmt.Printf("%s\n", yellowBold(statusText)) + } +} diff --git a/internal/parser/migration.go b/internal/parser/migration.go new file mode 100644 index 0000000..5e89a25 --- /dev/null +++ b/internal/parser/migration.go @@ -0,0 +1,21 @@ +package parser + +import "regexp" + +func ParseMigrationName(filename string) (tableName string, create bool) { + // Regex patterns for common migration styles + createPattern := regexp.MustCompile(`^create_(?P[a-z0-9_]+?)(?:_table)?$`) + addColPattern := regexp.MustCompile(`^add_(?P[a-z0-9_]+?)_to_(?P
[a-z0-9_]+?)(?:_table)?$`) + removeColPattern := regexp.MustCompile(`^remove_(?P[a-z0-9_]+?)_from_(?P
[a-z0-9_]+?)(?:_table)?$`) + + switch { + case createPattern.MatchString(filename): + return createPattern.ReplaceAllString(filename, "${table}"), true + case addColPattern.MatchString(filename): + return addColPattern.ReplaceAllString(filename, "${table}"), false + case removeColPattern.MatchString(filename): + return removeColPattern.ReplaceAllString(filename, "${table}"), false + default: + return "", false // Unknown migration style + } +} diff --git a/internal/parser/migration_test.go b/internal/parser/migration_test.go new file mode 100644 index 0000000..a8c2de5 --- /dev/null +++ b/internal/parser/migration_test.go @@ -0,0 +1,132 @@ +package parser_test + +import ( + "testing" + + "github.com/akfaiz/migris/internal/parser" +) + +func TestParseMigrationName(t *testing.T) { + tests := []struct { + name string + filename string + wantTable string + wantCreate bool + }{ + // Create table patterns + { + name: "create table with table suffix", + filename: "create_users_table", + wantTable: "users", + wantCreate: true, + }, + { + name: "create table without table suffix", + filename: "create_posts", + wantTable: "posts", + wantCreate: true, + }, + { + name: "create table with underscores", + filename: "create_user_profiles_table", + wantTable: "user_profiles", + wantCreate: true, + }, + { + name: "create table with numbers", + filename: "create_table_v2", + wantTable: "table_v2", + wantCreate: true, + }, + + // Add column patterns + { + name: "add column with table suffix", + filename: "add_email_to_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "add column without table suffix", + filename: "add_name_to_posts", + wantTable: "posts", + wantCreate: false, + }, + { + name: "add multiple columns", + filename: "add_first_name_last_name_to_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "add column with underscores", + filename: "add_created_at_to_user_profiles", + wantTable: "user_profiles", + wantCreate: false, + }, + + // Remove column patterns + { + name: "remove column with table suffix", + filename: "remove_email_from_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "remove column without table suffix", + filename: "remove_name_from_posts", + wantTable: "posts", + wantCreate: false, + }, + { + name: "remove multiple columns", + filename: "remove_old_field_from_users_table", + wantTable: "users", + wantCreate: false, + }, + { + name: "remove column with underscores", + filename: "remove_updated_at_from_user_profiles", + wantTable: "user_profiles", + wantCreate: false, + }, + + // Unknown patterns + { + name: "unknown migration pattern", + filename: "update_users_table", + wantTable: "", + wantCreate: false, + }, + { + name: "empty filename", + filename: "", + wantTable: "", + wantCreate: false, + }, + { + name: "invalid format", + filename: "invalid_migration_name", + wantTable: "", + wantCreate: false, + }, + { + name: "uppercase letters", + filename: "create_Users_table", + wantTable: "", + wantCreate: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTable, gotCreate := parser.ParseMigrationName(tt.filename) + if gotTable != tt.wantTable { + t.Errorf("ParseMigrationName() gotTable = %v, want %v", gotTable, tt.wantTable) + } + if gotCreate != tt.wantCreate { + t.Errorf("ParseMigrationName() gotCreate = %v, want %v", gotCreate, tt.wantCreate) + } + }) + } +} diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 0000000..61a4ca3 --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,33 @@ +package util + +func Optional[T any](defaultValue T, values ...T) T { + if len(values) > 0 { + return values[0] + } + return defaultValue +} + +func OptionalPtr[T any](defaultValue T, values ...T) *T { + if len(values) > 0 { + return &values[0] + } + return &defaultValue +} + +func OptionalNil[T any](values ...T) *T { + if len(values) > 0 { + return &values[0] + } + return nil +} + +func PtrOf[T any](value T) *T { + return &value +} + +func Ternary[T any](condition bool, trueValue, falseValue T) T { + if condition { + return trueValue + } + return falseValue +} diff --git a/migrate.go b/migrate.go new file mode 100644 index 0000000..581564c --- /dev/null +++ b/migrate.go @@ -0,0 +1,63 @@ +package migris + +import ( + "database/sql" + "errors" + "os" + + "github.com/akfaiz/migris/internal/config" + "github.com/akfaiz/migris/internal/dialect" + "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/database" +) + +// Migrate handles database migrations +type Migrate struct { + dialect dialect.Dialect + db *sql.DB + migrationDir string + tableName string +} + +// New creates a new Migrate instance +func New(dialectValue string, opts ...Option) (*Migrate, error) { + dialectVal := dialect.FromString(dialectValue) + if dialectVal == dialect.Unknown { + return nil, errors.New("unknown database dialect") + } + config.SetDialect(dialectVal) + + m := &Migrate{ + dialect: dialectVal, + migrationDir: "migrations", + tableName: "migris_db_version", + } + for _, opt := range opts { + opt(m) + } + if m.db == nil { + return nil, errors.New("database connection is not set, please call WithDB option") + } + return m, nil +} + +func (m *Migrate) newProvider() (*goose.Provider, error) { + val := config.GetDialect() + if val == dialect.Unknown { + return nil, errors.New("unknown database dialect") + } + gooseDialect := val.GooseDialect() + store, err := database.NewStore(gooseDialect, m.tableName) + if err != nil { + return nil, err + } + provider, err := goose.NewProvider(database.DialectCustom, m.db, os.DirFS(m.migrationDir), + goose.WithStore(store), + goose.WithDisableGlobalRegistry(true), + goose.WithGoMigrations(gooseMigrations()...), + ) + if err != nil { + return nil, err + } + return provider, nil +} diff --git a/mysql_grammar.go b/mysql_grammar.go deleted file mode 100644 index 7437bca..0000000 --- a/mysql_grammar.go +++ /dev/null @@ -1,441 +0,0 @@ -package schema - -import ( - "fmt" - "slices" - "strings" -) - -type mysqlGrammar struct { - baseGrammar -} - -var _ grammar = (*mysqlGrammar)(nil) - -func newMysqlGrammar() *mysqlGrammar { - return &mysqlGrammar{} -} - -func (g *mysqlGrammar) compileCurrentDatabase() string { - return "SELECT DATABASE()" -} - -func (g *mysqlGrammar) compileTableExists(database string, table string) (string, error) { - return fmt.Sprintf( - "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'", - g.quoteString(database), - g.quoteString(table), - ), nil -} - -func (g *mysqlGrammar) compileTables(database string) (string, error) { - return fmt.Sprintf( - "select table_name as `name`, (data_length + index_length) as `size`, "+ - "table_comment as `comment`, engine as `engine`, table_collation as `collation` "+ - "from information_schema.tables where table_schema = %s and table_type in ('BASE TABLE', 'SYSTEM VERSIONED') "+ - "order by table_name", - g.quoteString(database), - ), nil -} - -func (g *mysqlGrammar) compileColumns(database, table string) (string, error) { - return fmt.Sprintf( - "select column_name as `name`, data_type as `type_name`, column_type as `type`, "+ - "collation_name as `collation`, is_nullable as `nullable`, "+ - "column_default as `default`, column_comment as `comment`, extra as `extra` "+ - "from information_schema.columns where table_schema = %s and table_name = %s "+ - "order by ordinal_position asc", - g.quoteString(database), - g.quoteString(table), - ), nil -} - -func (g *mysqlGrammar) compileIndexes(database, table string) (string, error) { - return fmt.Sprintf( - "select index_name as `name`, group_concat(column_name order by seq_in_index) as `columns`, "+ - "index_type as `type`, not non_unique as `unique` "+ - "from information_schema.statistics where table_schema = %s and table_name = %s "+ - "group by index_name, index_type, non_unique", - g.quoteString(database), - g.quoteString(table), - ), nil -} - -func (g *mysqlGrammar) compileCreate(blueprint *Blueprint) (string, error) { - sql, err := g.compileCreateTable(blueprint) - if err != nil { - return "", err - } - sql = g.compileCreateEncoding(sql, blueprint) - - return g.compileCreateEngine(sql, blueprint), nil -} - -func (g *mysqlGrammar) compileCreateTable(blueprint *Blueprint) (string, error) { - columns, err := g.getColumns(blueprint) - if err != nil { - return "", err - } - - constraints := g.getConstraints(blueprint) - columns = append(columns, constraints...) - - return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil -} - -func (g *mysqlGrammar) compileCreateEncoding(sql string, blueprint *Blueprint) string { - if blueprint.charset != "" { - sql += fmt.Sprintf(" DEFAULT CHARACTER SET %s", blueprint.charset) - } - if blueprint.collation != "" { - sql += fmt.Sprintf(" COLLATE %s", blueprint.collation) - } - - return sql -} - -func (g *mysqlGrammar) compileCreateEngine(sql string, blueprint *Blueprint) string { - if blueprint.engine != "" { - sql += fmt.Sprintf(" ENGINE = %s", blueprint.engine) - } - return sql -} - -func (g *mysqlGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, error) { - return g.compileCreate(blueprint) -} - -func (g *mysqlGrammar) compileAdd(blueprint *Blueprint) (string, error) { - if len(blueprint.getAddedColumns()) == 0 { - return "", nil - } - - columns, err := g.getColumns(blueprint) - if err != nil { - return "", err - } - columns = g.prefixArray("ADD COLUMN ", columns) - constraints := g.getConstraints(blueprint) - constraints = g.prefixArray("ADD ", constraints) - columns = append(columns, constraints...) - - return fmt.Sprintf("ALTER TABLE %s %s", - blueprint.name, - strings.Join(columns, ", "), - ), nil -} - -func (g *mysqlGrammar) compileChange(bp *Blueprint) ([]string, error) { - if len(bp.getChangedColumns()) == 0 { - return nil, nil - } - - var sqls []string - for _, col := range bp.getChangedColumns() { - if col.name == "" { - return nil, fmt.Errorf("column name cannot be empty for change operation") - } - sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s %s", bp.name, col.name, g.getType(col)) - - if col.hasCommand("nullable") { - if col.nullable { - sql += " NULL" - } else { - sql += " NOT NULL" - } - } - if col.hasCommand("default") { - if col.defaultValue != nil { - sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col)) - } else { - sql += " DEFAULT NULL" - } - } - if col.hasCommand("comment") { - if col.comment != "" { - sql += fmt.Sprintf(" COMMENT '%s'", col.comment) - } else { - sql += " COMMENT ''" - } - } - - sqls = append(sqls, sql) - } - - return sqls, nil -} - -func (g *mysqlGrammar) compileRename(blueprint *Blueprint) (string, error) { - if blueprint.newName == "" { - return "", fmt.Errorf("new table name cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, blueprint.newName), nil -} - -func (g *mysqlGrammar) compileDrop(blueprint *Blueprint) (string, error) { - if blueprint.name == "" { - return "", fmt.Errorf("table name cannot be empty") - } - return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil -} - -func (g *mysqlGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) { - if blueprint.name == "" { - return "", fmt.Errorf("table name cannot be empty") - } - return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil -} - -func (g *mysqlGrammar) compileDropColumn(blueprint *Blueprint) (string, error) { - if len(blueprint.dropColumns) == 0 { - return "", fmt.Errorf("no columns to drop") - } - columns := make([]string, len(blueprint.dropColumns)) - for i, col := range blueprint.dropColumns { - if col == "" { - return "", fmt.Errorf("column name cannot be empty") - } - columns[i] = col - } - columns = g.prefixArray("DROP COLUMN ", columns) - return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil -} - -func (g *mysqlGrammar) compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { - return "", fmt.Errorf("old and new column names cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, oldName, newName), nil -} - -func (g *mysqlGrammar) compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("index column cannot be empty") - } - - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - - sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns)) - if index.algorithm != "" { - sql += fmt.Sprintf(" USING %s", index.algorithm) - } - - return sql, nil -} - -func (g *mysqlGrammar) compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("unique column cannot be empty") - } - - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns)) - if index.algorithm != "" { - sql += fmt.Sprintf(" USING %s", index.algorithm) - } - - return sql, nil -} - -func (g *mysqlGrammar) compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("fulltext index column cannot be empty") - } - - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - - return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns)), nil -} - -func (g *mysqlGrammar) compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("primary key column cannot be empty") - } - - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - - return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(index.columns)), nil -} - -func (g *mysqlGrammar) compileDropIndex(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { - return "", fmt.Errorf("index name cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, indexName), nil -} - -func (g *mysqlGrammar) compileDropUnique(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { - return "", fmt.Errorf("unique index name cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, indexName), nil -} - -func (g *mysqlGrammar) compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) { - return g.compileDropIndex(blueprint, indexName) -} - -func (g *mysqlGrammar) compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { - return "", fmt.Errorf("primary key index name cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", blueprint.name), nil -} - -func (g *mysqlGrammar) compileRenameIndex(blueprint *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { - return "", fmt.Errorf("old and new index names cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s", blueprint.name, oldName, newName), nil -} - -func (g *mysqlGrammar) compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) { - if foreignKeyName == "" { - return "", fmt.Errorf("foreign key name cannot be empty") - } - return fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s", blueprint.name, foreignKeyName), nil -} - -func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) { - var columns []string - for _, col := range blueprint.getAddedColumns() { - if col.name == "" { - return nil, fmt.Errorf("column name cannot be empty") - } - sql := col.name + " " + g.getType(col) - if col.hasCommand("default") { - if col.defaultValue != nil { - sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col)) - } else { - sql += " DEFAULT NULL" - } - } - if col.onUpdateValue != "" { - sql += fmt.Sprintf(" ON UPDATE %s", col.onUpdateValue) - } - if col.nullable { - sql += " NULL" - } else { - sql += " NOT NULL" - } - if col.comment != "" { - sql += fmt.Sprintf(" COMMENT '%s'", col.comment) - } - columns = append(columns, sql) - } - - return columns, nil -} - -func (g *mysqlGrammar) getConstraints(blueprint *Blueprint) []string { - var constrains []string - for _, col := range blueprint.getAddedColumns() { - if col.primary { - pkConstraintName := g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary}) - sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" - constrains = append(constrains, sql) - continue - } - if col.unique { - uniqueName := col.uniqueName - if uniqueName == "" { - uniqueName = g.createIndexName(blueprint, &indexDefinition{ - indexType: indexTypeUnique, - columns: []string{col.name}, - }) - } - sql := fmt.Sprintf("CONSTRAINT %s UNIQUE (%s)", uniqueName, col.name) - constrains = append(constrains, sql) - } - } - - return constrains -} - -func (g *mysqlGrammar) getType(col *columnDefinition) string { - switch col.columnType { - case columnTypeCustom: - return col.customColumnType - case columnTypeBoolean: - return "TINYINT(1)" - case columnTypeChar: - return fmt.Sprintf("CHAR(%d)", col.length) - case columnTypeString: - return fmt.Sprintf("VARCHAR(%d)", col.length) - case columnTypeDecimal: - return fmt.Sprintf("DECIMAL(%d, %d)", col.total, col.places) - case columnTypeDouble, columnTypeFloat: - if col.total > 0 && col.places > 0 { - return fmt.Sprintf("DOUBLE(%d, %d)", col.total, col.places) - } - return "DOUBLE" - case columnTypeBigInteger: - return g.modifyUnsignedAndAutoIncrement("BIGINT", col) - case columnTypeInteger: - return g.modifyUnsignedAndAutoIncrement("INT", col) - case columnTypeSmallInteger: - return g.modifyUnsignedAndAutoIncrement("SMALLINT", col) - case columnTypeMediumInteger: - return g.modifyUnsignedAndAutoIncrement("MEDIUMINT", col) - case columnTypeTinyInteger: - return g.modifyUnsignedAndAutoIncrement("TINYINT", col) - case columnTypeTime: - return fmt.Sprintf("TIME(%d)", col.precision) - case columnTypeDateTime, columnTypeDateTimeTz: - if col.precision > 0 { - return fmt.Sprintf("DATETIME(%d)", col.precision) - } - return "DATETIME" - case columnTypeTimestamp, columnTypeTimestampTz: - if col.precision > 0 { - return fmt.Sprintf("TIMESTAMP(%d)", col.precision) - } - return "TIMESTAMP" - case columnTypeGeography: - return fmt.Sprintf("GEOGRAPHY(%s, %d)", col.subType, col.srid) - case columnTypeEnum: - return fmt.Sprintf("ENUM(%s)", g.quoteString(strings.Join(col.allowedEnums, "','"))) - default: - colType, ok := map[columnType]string{ - columnTypeBoolean: "BOOLEAN", - columnTypeLongText: "LONGTEXT", - columnTypeText: "TEXT", - columnTypeMediumText: "MEDIUMTEXT", - columnTypeTinyText: "TINYTEXT", - columnTypeDate: "DATE", - columnTypeYear: "YEAR", - columnTypeJSON: "JSON", - columnTypeJSONB: "JSON", - columnTypeUUID: "UUID", - columnTypeBinary: "BLOB", - columnTypeGeometry: "GEOMETRY", - columnTypePoint: "POINT", - }[col.columnType] - if !ok { - return "ERROR: Unknown column type " + string(col.columnType) - } - return colType - } -} - -func (g *mysqlGrammar) modifyUnsignedAndAutoIncrement(sql string, col *columnDefinition) string { - if col.unsigned { - sql += " UNSIGNED" - } - if col.autoIncrement { - sql += " AUTO_INCREMENT" - } - return sql -} diff --git a/options.go b/options.go new file mode 100644 index 0000000..8da4f71 --- /dev/null +++ b/options.go @@ -0,0 +1,26 @@ +package migris + +import "database/sql" + +type Option func(*Migrate) + +// WithTableName sets the table name for the migration. +func WithTableName(name string) Option { + return func(m *Migrate) { + m.tableName = name + } +} + +// WithMigrationDir sets the directory for the migration files. +func WithMigrationDir(dir string) Option { + return func(m *Migrate) { + m.migrationDir = dir + } +} + +// WithDB sets the database connection for the migration. +func WithDB(db *sql.DB) Option { + return func(m *Migrate) { + m.db = db + } +} diff --git a/postgres_grammar.go b/postgres_grammar.go deleted file mode 100644 index 4e02e7b..0000000 --- a/postgres_grammar.go +++ /dev/null @@ -1,450 +0,0 @@ -package schema - -import ( - "fmt" - "slices" - "strings" -) - -type pgGrammar struct { - baseGrammar -} - -var _ grammar = (*pgGrammar)(nil) - -func newPgGrammar() *pgGrammar { - return &pgGrammar{} -} - -func (g *pgGrammar) compileTableExists(schema string, table string) (string, error) { - return fmt.Sprintf( - "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'", - g.quoteString(schema), - g.quoteString(table), - ), nil -} - -func (g *pgGrammar) compileTables() (string, error) { - return "select c.relname as name, n.nspname as schema, pg_total_relation_size(c.oid) as size, " + - "obj_description(c.oid, 'pg_class') as comment from pg_class c, pg_namespace n " + - "where c.relkind in ('r', 'p') and n.oid = c.relnamespace and n.nspname not in ('pg_catalog', 'information_schema') " + - "order by c.relname", nil -} - -func (g *pgGrammar) compileColumns(schema, table string) (string, error) { - return fmt.Sprintf( - "select a.attname as name, t.typname as type_name, format_type(a.atttypid, a.atttypmod) as type, "+ - "(select tc.collcollate from pg_catalog.pg_collation tc where tc.oid = a.attcollation) as collation, "+ - "not a.attnotnull as nullable, "+ - "(select pg_get_expr(adbin, adrelid) from pg_attrdef where c.oid = pg_attrdef.adrelid and pg_attrdef.adnum = a.attnum) as default, "+ - "col_description(c.oid, a.attnum) as comment "+ - "from pg_attribute a, pg_class c, pg_type t, pg_namespace n "+ - "where c.relname = %s and n.nspname = %s and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid and n.oid = c.relnamespace "+ - "order by a.attnum", - g.quoteString(table), - g.quoteString(schema), - ), nil -} - -func (g *pgGrammar) compileIndexes(schema, table string) (string, error) { - return fmt.Sprintf( - "select ic.relname as name, string_agg(a.attname, ',' order by indseq.ord) as columns, "+ - "am.amname as \"type\", i.indisunique as \"unique\", i.indisprimary as \"primary\" "+ - "from pg_index i "+ - "join pg_class tc on tc.oid = i.indrelid "+ - "join pg_namespace tn on tn.oid = tc.relnamespace "+ - "join pg_class ic on ic.oid = i.indexrelid "+ - "join pg_am am on am.oid = ic.relam "+ - "join lateral unnest(i.indkey) with ordinality as indseq(num, ord) on true "+ - "left join pg_attribute a on a.attrelid = i.indrelid and a.attnum = indseq.num "+ - "where tc.relname = %s and tn.nspname = %s "+ - "group by ic.relname, am.amname, i.indisunique, i.indisprimary", - g.quoteString(table), - g.quoteString(schema), - ), nil -} - -func (g *pgGrammar) compileCreate(blueprint *Blueprint) (string, error) { - columns, err := g.getColumns(blueprint) - if err != nil { - return "", err - } - constraints := g.getConstraints(blueprint) - columns = append(columns, constraints...) - return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil -} - -func (g *pgGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, error) { - columns, err := g.getColumns(blueprint) - if err != nil { - return "", err - } - constraints := g.getConstraints(blueprint) - columns = append(columns, constraints...) - return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil -} - -func (g *pgGrammar) compileAdd(blueprint *Blueprint) (string, error) { - if len(blueprint.getAddedColumns()) == 0 { - return "", nil - } - - columns, err := g.getColumns(blueprint) - if err != nil { - return "", err - } - columns = g.prefixArray("ADD COLUMN ", columns) - - constraints := g.getConstraints(blueprint) - constraints = g.prefixArray("ADD ", constraints) - - columns = append(columns, constraints...) - - return fmt.Sprintf("ALTER TABLE %s %s", - blueprint.name, - strings.Join(columns, ", "), - ), nil -} - -func (g *pgGrammar) compileChange(bp *Blueprint) ([]string, error) { - if len(bp.getChangedColumns()) == 0 { - return nil, nil - } - - var queries []string - for _, col := range bp.getChangedColumns() { - if col.name == "" { - return nil, fmt.Errorf("column name cannot be empty for change operation") - } - var statements []string - - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s TYPE %s", col.name, g.getType(col))) - if col.hasCommand("default") { - if col.defaultValue != nil { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", col.name, g.getDefaultValue(col))) - } else { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s DROP DEFAULT", col.name)) - } - } - if col.hasCommand("nullable") { - if col.nullable { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s DROP NOT NULL", col.name)) - } else { - statements = append(statements, fmt.Sprintf("ALTER COLUMN %s SET NOT NULL", col.name)) - } - } - sql := "ALTER TABLE " + bp.name + " " + strings.Join(statements, ", ") - queries = append(queries, sql) - - if col.hasCommand("comment") { - if col.comment != "" { - queries = append(queries, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", bp.name, col.name, col.comment)) - } else { - queries = append(queries, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS NULL", bp.name, col.name)) - } - } - } - - return queries, nil -} - -func (g *pgGrammar) compileDrop(blueprint *Blueprint) (string, error) { - return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil -} - -func (g *pgGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) { - return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil -} - -func (g *pgGrammar) compileRename(blueprint *Blueprint) (string, error) { - return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, blueprint.newName), nil -} - -func (g *pgGrammar) compileDropColumn(blueprint *Blueprint) (string, error) { - if len(blueprint.dropColumns) == 0 { - return "", nil - } - columns := g.prefixArray("DROP COLUMN ", blueprint.dropColumns) - - return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil -} - -func (g *pgGrammar) compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { - return "", fmt.Errorf("table name, old column name, and new column name cannot be empty for rename operation") - } - return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, oldName, newName), nil -} - -func (g *pgGrammar) compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("index column cannot be empty") - } - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - - sql := fmt.Sprintf("CREATE INDEX %s ON %s", indexName, blueprint.name) - if index.algorithm != "" { - sql += fmt.Sprintf(" USING %s", index.algorithm) - } - return fmt.Sprintf("%s (%s)", sql, g.columnize(index.columns)), nil -} - -func (g *pgGrammar) compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("unique index column cannot be empty") - } - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - sql := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)", - blueprint.name, - indexName, - g.columnize(index.columns), - ) - - if index.deferrable != nil { - if *index.deferrable { - sql += " DEFERRABLE" - } else { - sql += " NOT DEFERRABLE" - } - } - if index.deferrable != nil && *index.deferrable && index.initiallyImmediate != nil { - if *index.initiallyImmediate { - sql += " INITIALLY IMMEDIATE" - } else { - sql += " INITIALLY DEFERRED" - } - } - - return sql, nil -} - -func (g *pgGrammar) compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("primary key index column cannot be empty") - } - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(index.columns)), nil -} - -func (g *pgGrammar) compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) { - if slices.Contains(index.columns, "") { - return "", fmt.Errorf("fulltext index column cannot be empty") - } - indexName := index.name - if indexName == "" { - indexName = g.createIndexName(blueprint, index) - } - language := index.language - if language == "" { - language = "english" // Default language for full-text search - } - var columns []string - for _, col := range index.columns { - columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.quoteString(language), col)) - } - - return fmt.Sprintf("CREATE INDEX %s ON %s USING GIN (%s)", indexName, blueprint.name, strings.Join(columns, " || ")), nil -} - -func (g *pgGrammar) compileDropIndex(_ *Blueprint, indexName string) (string, error) { - if indexName == "" { - return "", fmt.Errorf("index name cannot be empty for drop operation") - } - return fmt.Sprintf("DROP INDEX %s", indexName), nil -} - -func (g *pgGrammar) compileDropUnique(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { - return "", fmt.Errorf("index name cannot be empty for drop operation") - } - return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, indexName), nil -} - -func (g *pgGrammar) compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) { - return g.compileDropIndex(blueprint, indexName) -} - -func (g *pgGrammar) compileRenameIndex(_ *Blueprint, oldName, newName string) (string, error) { - if oldName == "" || newName == "" { - return "", fmt.Errorf("index names for rename operation cannot be empty: oldName=%s, newName=%s", oldName, newName) - } - return fmt.Sprintf("ALTER INDEX %s RENAME TO %s", oldName, newName), nil -} - -func (g *pgGrammar) compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) { - if indexName == "" { - indexName = g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary}) - } - return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, indexName), nil -} - -func (g *pgGrammar) compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) { - sql, err := g.baseGrammar.compileForeign(blueprint, foreignKey) - if err != nil { - return "", err - } - - if foreignKey.deferrable != nil { - if *foreignKey.deferrable { - sql += " DEFERRABLE" - } else { - sql += " NOT DEFERRABLE" - } - } - if foreignKey.deferrable != nil && *foreignKey.deferrable && foreignKey.initiallyImmediate != nil { - if *foreignKey.initiallyImmediate { - sql += " INITIALLY IMMEDIATE" - } else { - sql += " INITIALLY DEFERRED" - } - } - - return sql, nil -} - -func (g *pgGrammar) compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) { - if foreignKeyName == "" { - return "", fmt.Errorf("foreign key name cannot be empty for drop operation") - } - return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, foreignKeyName), nil -} - -func (g *pgGrammar) getColumns(blueprint *Blueprint) ([]string, error) { - var columns []string - for _, col := range blueprint.getAddedColumns() { - if col.name == "" { - return nil, fmt.Errorf("column name cannot be empty") - } - sql := col.name + " " + g.getType(col) - if col.hasCommand("default") { - if col.defaultValue != nil { - sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col)) - } else { - sql += " DEFAULT NULL" - } - } - if col.nullable { - sql += " NULL" - } else { - sql += " NOT NULL" - } - if col.comment != "" { - sql += fmt.Sprintf(" COMMENT '%s'", col.comment) - } - columns = append(columns, sql) - } - - return columns, nil -} - -func (g *pgGrammar) getConstraints(blueprint *Blueprint) []string { - var constrains []string - for _, col := range blueprint.getAddedColumns() { - if col.primary { - pkConstraintName := g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary}) - sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" - constrains = append(constrains, sql) - continue - } - if col.unique { - uniqueName := col.uniqueName - if uniqueName == "" { - uniqueName = g.createIndexName(blueprint, &indexDefinition{ - indexType: indexTypeUnique, - columns: []string{col.name}, - }) - } - sql := fmt.Sprintf("CONSTRAINT %s UNIQUE (%s)", uniqueName, col.name) - constrains = append(constrains, sql) - } - } - - return constrains -} - -func (g *pgGrammar) getType(col *columnDefinition) string { - switch col.columnType { - case columnTypeCustom: - return col.customColumnType - case columnTypeChar: - if col.length > 0 { - return fmt.Sprintf("CHAR(%d)", col.length) - } - return "CHAR" - case columnTypeString: - if col.length > 0 { - return fmt.Sprintf("VARCHAR(%d)", col.length) - } - return "VARCHAR" - case columnTypeDecimal: - return fmt.Sprintf("DECIMAL(%d, %d)", col.total, col.places) - case columnTypeBigInteger: - if col.autoIncrement { - return "BIGSERIAL" - } - return "BIGINT" - case columnTypeInteger, columnTypeMediumInteger: - if col.autoIncrement { - return "SERIAL" - } - return "INTEGER" - case columnTypeSmallInteger, columnTypeTinyInteger: - if col.autoIncrement { - return "SMALLSERIAL" - } - return "SMALLINT" - case columnTypeTime: - return fmt.Sprintf("TIME(%d)", col.precision) - case columnTypeTimestamp, columnTypeDateTime: - return fmt.Sprintf("TIMESTAMP(%d)", col.precision) - case columnTypeTimestampTz, columnTypeDateTimeTz: - return fmt.Sprintf("TIMESTAMPTZ(%d)", col.precision) - case columnTypeGeography: - return fmt.Sprintf("GEOGRAPHY(%s, %d)", col.subType, col.srid) - case columnTypeGeometry: - if col.srid > 0 { - return fmt.Sprintf("GEOMETRY(%s, %d)", col.subType, col.srid) - } - return fmt.Sprintf("GEOMETRY(%s)", col.subType) - case columnTypePoint: - return fmt.Sprintf("POINT(%d)", col.srid) - case columnTypeEnum: - if len(col.allowedEnums) == 0 { - return "VARCHAR(255)" // Fallback to VARCHAR if no enums are defined - } - enumValues := make([]string, len(col.allowedEnums)) - for i, v := range col.allowedEnums { - enumValues[i] = fmt.Sprintf("'%s'", v) - } - return "VARCHAR(255) CHECK (" + col.name + " IN (" + strings.Join(enumValues, ", ") + "))" - default: - colType, ok := map[columnType]string{ - columnTypeBoolean: "BOOLEAN", - columnTypeLongText: "TEXT", - columnTypeText: "TEXT", - columnTypeMediumText: "TEXT", - columnTypeTinyText: "TEXT", - columnTypeDouble: "DOUBLE PRECISION", - columnTypeFloat: "REAL", - columnTypeDate: "DATE", - columnTypeYear: "INTEGER", // PostgreSQL does not have a YEAR type, using INTEGER instead - columnTypeJSON: "JSON", - columnTypeJSONB: "JSONB", - columnTypeUUID: "UUID", - columnTypeBinary: "BYTEA", - }[col.columnType] - if !ok { - return "ERROR: Unknown column type " + string(col.columnType) - } - return colType - } -} diff --git a/register.go b/register.go new file mode 100644 index 0000000..96c148d --- /dev/null +++ b/register.go @@ -0,0 +1,90 @@ +package migris + +import ( + "context" + "database/sql" + "fmt" + "path" + "runtime" + + "github.com/akfaiz/migris/schema" + "github.com/pressly/goose/v3" +) + +var ( + registeredVersions = make(map[int64]string) + registeredMigrations = make([]*Migration, 0) +) + +type Migration struct { + version int64 + source string + upFnContext, downFnContext MigrationContext +} + +// MigrationContext is a Go migration func that is run within a transaction and receives a +// context. +type MigrationContext func(ctx *schema.Context) error + +func (m MigrationContext) runTxFunc(source string) func(ctx context.Context, tx *sql.Tx) error { + return func(ctx context.Context, tx *sql.Tx) error { + filename := path.Base(source) + c := schema.NewContext(ctx, tx, schema.WithFilename(filename)) + return m(c) + } +} + +// AddMigrationContext adds Go migrations. +func AddMigrationContext(up, down MigrationContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, up, down) +} + +// AddNamedMigrationContext adds named Go migrations. +func AddNamedMigrationContext(source string, up, down MigrationContext) { + if err := register( + source, + up, + down, + ); err != nil { + panic(err) + } +} + +func register(source string, up, down MigrationContext) error { + v, _ := goose.NumericComponent(source) + if existing, ok := registeredVersions[v]; ok { + return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", + source, + v, + existing, + ) + } + // Add to global as a registered migration. + m := &Migration{ + version: v, + source: source, + upFnContext: up, + downFnContext: down, + } + registeredVersions[v] = source + registeredMigrations = append(registeredMigrations, m) + return nil +} + +func gooseMigrations() []*goose.Migration { + migrations := make([]*goose.Migration, 0, len(registeredMigrations)) + for _, m := range registeredMigrations { + upFunc := &goose.GoFunc{ + RunTx: m.upFnContext.runTxFunc(m.source), + Mode: goose.TransactionEnabled, + } + downFunc := &goose.GoFunc{ + RunTx: m.downFnContext.runTxFunc(m.source), + Mode: goose.TransactionEnabled, + } + gm := goose.NewGoMigration(m.version, upFunc, downFunc) + migrations = append(migrations, gm) + } + return migrations +} diff --git a/reset.go b/reset.go new file mode 100644 index 0000000..124bcb5 --- /dev/null +++ b/reset.go @@ -0,0 +1,44 @@ +package migris + +import ( + "context" + "errors" + + "github.com/akfaiz/migris/internal/logger" + "github.com/pressly/goose/v3" +) + +// Reset rolls back all migrations. +func (m *Migrate) Reset() error { + ctx := context.Background() + return m.ResetContext(ctx) +} + +// ResetContext rolls back all migrations. +func (m *Migrate) ResetContext(ctx context.Context) error { + provider, err := m.newProvider() + if err != nil { + return err + } + currentVersion, err := provider.GetDBVersion(ctx) + if err != nil { + return err + } + if currentVersion == 0 { + logger.Info("Nothing to rollback.") + return nil + } + logger.Info("Rolling back migrations.\n") + results, err := provider.DownTo(ctx, 0) + if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResults(partialErr.Applied) + logger.PrintResult(partialErr.Failed) + } + + return err + } + logger.PrintResults(results) + return nil +} diff --git a/schema/blueprint.go b/schema/blueprint.go new file mode 100644 index 0000000..b0d6e55 --- /dev/null +++ b/schema/blueprint.go @@ -0,0 +1,712 @@ +package schema + +import ( + "fmt" + + "github.com/akfaiz/migris/internal/dialect" + "github.com/akfaiz/migris/internal/util" +) + +const ( + columnTypeBoolean string = "boolean" + columnTypeChar string = "char" + columnTypeString string = "string" + columnTypeLongText string = "longText" + columnTypeMediumText string = "mediumText" + columnTypeText string = "text" + columnTypeTinyText string = "tinyText" + columnTypeBigInteger string = "bigInteger" + columnTypeInteger string = "integer" + columnTypeMediumInteger string = "mediumInteger" + columnTypeSmallInteger string = "smallInteger" + columnTypeTinyInteger string = "tinyInteger" + columnTypeDecimal string = "decimal" + columnTypeDouble string = "double" + columnTypeFloat string = "float" + columnTypeDateTime string = "dateTime" + columnTypeDateTimeTz string = "dateTimeTz" + columnTypeDate string = "date" + columnTypeTime string = "time" + columnTypeTimeTz string = "timeTz" + columnTypeTimestamp string = "timestamp" + columnTypeTimestampTz string = "timestampTz" + columnTypeYear string = "year" + columnTypeBinary string = "binary" + columnTypeJson string = "json" + columnTypeJsonb string = "jsonb" + columnTypeGeography string = "geography" + columnTypeGeometry string = "geometry" + columnTypePoint string = "point" + columnTypeUuid string = "uuid" + columnTypeEnum string = "enum" +) + +const ( + defaultStringLength int = 255 + defaultTimePrecision int = 0 +) + +// Blueprint represents a schema blueprint for creating or altering a database table. +type Blueprint struct { + dialect dialect.Dialect + columns []*columnDefinition + commands []*command + grammar grammar + name string + charset string + collation string + engine string +} + +// Charset sets the character set for the table in the blueprint. +func (b *Blueprint) Charset(charset string) { + b.charset = charset +} + +// Collation sets the collation for the table in the blueprint. +func (b *Blueprint) Collation(collation string) { + b.collation = collation +} + +// Engine sets the storage engine for the table in the blueprint. +func (b *Blueprint) Engine(engine string) { + b.engine = engine +} + +// Column creates a new custom column definition in the blueprint with the specified name and type. +func (b *Blueprint) Column(name string, columnType string) ColumnDefinition { + return b.addColumn(columnType, name) +} + +// Boolean creates a new boolean column definition in the blueprint. +func (b *Blueprint) Boolean(name string) ColumnDefinition { + return b.addColumn(columnTypeBoolean, name) +} + +// Char creates a new char column definition in the blueprint. +func (b *Blueprint) Char(name string, length ...int) ColumnDefinition { + return b.addColumn(columnTypeChar, name, &columnDefinition{ + length: util.OptionalPtr(defaultStringLength, length...), + }) +} + +// String creates a new string column definition in the blueprint. +func (b *Blueprint) String(name string, length ...int) ColumnDefinition { + return b.addColumn(columnTypeString, name, &columnDefinition{ + length: util.OptionalPtr(defaultStringLength, length...), + }) +} + +// LongText creates a new long text column definition in the blueprint. +func (b *Blueprint) LongText(name string) ColumnDefinition { + return b.addColumn(columnTypeLongText, name) +} + +// Text creates a new text column definition in the blueprint. +func (b *Blueprint) Text(name string) ColumnDefinition { + return b.addColumn(columnTypeText, name) +} + +// MediumText creates a new medium text column definition in the blueprint. +func (b *Blueprint) MediumText(name string) ColumnDefinition { + return b.addColumn(columnTypeMediumText, name) +} + +// TinyText creates a new tiny text column definition in the blueprint. +func (b *Blueprint) TinyText(name string) ColumnDefinition { + return b.addColumn(columnTypeTinyText, name) +} + +// BigIncrements creates a new big increments column definition in the blueprint. +func (b *Blueprint) BigIncrements(name string) ColumnDefinition { + return b.UnsignedBigInteger(name).AutoIncrement() +} + +// BigInteger creates a new big integer column definition in the blueprint. +func (b *Blueprint) BigInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeBigInteger, name) +} + +// Decimal creates a new decimal column definition in the blueprint. +// +// The total and places parameters are optional. +// +// Example: +// +// table.Decimal("price", 10, 2) // creates a decimal column with total 10 and places 2 +// +// table.Decimal("price") // creates a decimal column with default total 8 and places 2 +func (b *Blueprint) Decimal(name string, params ...int) ColumnDefinition { + defaultPlaces := 2 + if len(params) > 1 { + defaultPlaces = params[1] + } + return b.addColumn(columnTypeDecimal, name, &columnDefinition{ + total: util.OptionalPtr(8, params...), + places: util.PtrOf(defaultPlaces), + }) +} + +// Double creates a new double column definition in the blueprint. +func (b *Blueprint) Double(name string) ColumnDefinition { + return b.addColumn(columnTypeDouble, name) +} + +// Float creates a new float column definition in the blueprint. +func (b *Blueprint) Float(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeFloat, name, &columnDefinition{ + precision: util.OptionalPtr(53, precision...), + }) +} + +// ID creates a new big increments column definition in the blueprint with the name "id" or a custom name. +// +// If a name is provided, it will be used as the column name; otherwise, "id" will be used. +func (b *Blueprint) ID(name ...string) ColumnDefinition { + return b.BigIncrements(util.Optional("id", name...)).Primary() +} + +// Increments create a new increment column definition in the blueprint. +func (b *Blueprint) Increments(name string) ColumnDefinition { + return b.UnsignedInteger(name).AutoIncrement() +} + +// Integer creates a new integer column definition in the blueprint. +func (b *Blueprint) Integer(name string) ColumnDefinition { + return b.addColumn(columnTypeInteger, name) +} + +// MediumIncrements creates a new medium increments column definition in the blueprint. +func (b *Blueprint) MediumIncrements(name string) ColumnDefinition { + return b.UnsignedMediumInteger(name).AutoIncrement() +} + +func (b *Blueprint) MediumInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeMediumInteger, name) +} + +// SmallIncrements creates a new small increments column definition in the blueprint. +func (b *Blueprint) SmallIncrements(name string) ColumnDefinition { + return b.UnsignedSmallInteger(name).AutoIncrement() +} + +// SmallInteger creates a new small integer column definition in the blueprint. +func (b *Blueprint) SmallInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeSmallInteger, name) +} + +// TinyIncrements creates a new tiny increments column definition in the blueprint. +func (b *Blueprint) TinyIncrements(name string) ColumnDefinition { + return b.UnsignedTinyInteger(name).AutoIncrement() +} + +// TinyInteger creates a new tiny integer column definition in the blueprint. +func (b *Blueprint) TinyInteger(name string) ColumnDefinition { + return b.addColumn(columnTypeTinyInteger, name) +} + +// UnsignedBigInteger creates a new unsigned big integer column definition in the blueprint. +func (b *Blueprint) UnsignedBigInteger(name string) ColumnDefinition { + return b.BigInteger(name).Unsigned() +} + +// UnsignedInteger creates a new unsigned integer column definition in the blueprint. +func (b *Blueprint) UnsignedInteger(name string) ColumnDefinition { + return b.Integer(name).Unsigned() +} + +// UnsignedMediumInteger creates a new unsigned medium integer column definition in the blueprint. +func (b *Blueprint) UnsignedMediumInteger(name string) ColumnDefinition { + return b.MediumInteger(name).Unsigned() +} + +// UnsignedSmallInteger creates a new unsigned small integer column definition in the blueprint. +func (b *Blueprint) UnsignedSmallInteger(name string) ColumnDefinition { + return b.SmallInteger(name).Unsigned() +} + +// UnsignedTinyInteger creates a new unsigned tiny integer column definition in the blueprint. +func (b *Blueprint) UnsignedTinyInteger(name string) ColumnDefinition { + return b.TinyInteger(name).Unsigned() +} + +// DateTime creates a new date time column definition in the blueprint. +func (b *Blueprint) DateTime(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeDateTime, name, &columnDefinition{ + precision: util.OptionalPtr(defaultTimePrecision, precision...), + }) +} + +// DateTimeTz creates a new date time with a time zone column definition in the blueprint. +func (b *Blueprint) DateTimeTz(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeDateTimeTz, name, &columnDefinition{ + precision: util.OptionalPtr(defaultTimePrecision, precision...), + }) +} + +// Date creates a new date column definition in the blueprint. +func (b *Blueprint) Date(name string) ColumnDefinition { + return b.addColumn(columnTypeDate, name) +} + +// Time creates a new time column definition in the blueprint. +func (b *Blueprint) Time(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeTime, name, &columnDefinition{ + precision: util.OptionalPtr(defaultTimePrecision, precision...), + }) +} + +// TimeTz creates a new time with time zone column definition in the blueprint. +func (b *Blueprint) TimeTz(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeTimeTz, name, &columnDefinition{ + precision: util.OptionalPtr(defaultTimePrecision, precision...), + }) +} + +// Timestamp creates a new timestamp column definition in the blueprint. +// The precision parameter is optional and defaults to 0 if not provided. +func (b *Blueprint) Timestamp(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeTimestamp, name, &columnDefinition{ + precision: util.OptionalPtr(defaultTimePrecision, precision...), + }) +} + +// TimestampTz creates a new timestamp with time zone column definition in the blueprint. +// The precision parameter is optional and defaults to 0 if not provided. +func (b *Blueprint) TimestampTz(name string, precision ...int) ColumnDefinition { + return b.addColumn(columnTypeTimestampTz, name, &columnDefinition{ + precision: util.OptionalPtr(defaultTimePrecision, precision...), + }) +} + +// Timestamps adds created_at and updated_at timestamp columns to the blueprint. +func (b *Blueprint) Timestamps(precision ...int) { + b.Timestamp("created_at", precision...).UseCurrent() + b.Timestamp("updated_at", precision...).UseCurrent().UseCurrentOnUpdate() +} + +// TimestampsTz adds created_at and updated_at timestamp with time zone columns to the blueprint. +func (b *Blueprint) TimestampsTz(precision ...int) { + b.TimestampTz("created_at", precision...).UseCurrent() + b.TimestampTz("updated_at", precision...).UseCurrent().UseCurrentOnUpdate() +} + +// Year creates a new year column definition in the blueprint. +func (b *Blueprint) Year(name string) ColumnDefinition { + return b.addColumn(columnTypeYear, name) +} + +// Binary creates a new binary column definition in the blueprint. +func (b *Blueprint) Binary(name string, length ...int) ColumnDefinition { + return b.addColumn(columnTypeBinary, name, &columnDefinition{ + length: util.OptionalNil(length...), + }) +} + +// JSON creates a new JSON column definition in the blueprint. +func (b *Blueprint) JSON(name string) ColumnDefinition { + return b.addColumn(columnTypeJson, name) +} + +// JSONB creates a new JSONB column definition in the blueprint. +func (b *Blueprint) JSONB(name string) ColumnDefinition { + return b.addColumn(columnTypeJsonb, name) +} + +// UUID creates a new UUID column definition in the blueprint. +func (b *Blueprint) UUID(name string) ColumnDefinition { + return b.addColumn(columnTypeUuid, name) +} + +// Geography creates a new geography column definition in the blueprint. +// The subType parameter is optional and can be used to specify the type of geography (e.g., "Point", "LineString", "Polygon"). +// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geography type. +func (b *Blueprint) Geography(name string, subtype string, srid ...int) ColumnDefinition { + return b.addColumn(columnTypeGeography, name, &columnDefinition{ + subtype: util.OptionalPtr("", subtype), + srid: util.OptionalPtr(4326, srid...), + }) +} + +// Geometry creates a new geometry column definition in the blueprint. +// The subType parameter is optional and can be used to specify the type of geometry (e.g., "Point", "LineString", "Polygon"). +// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geometry type. +func (b *Blueprint) Geometry(name string, subtype string, srid ...int) ColumnDefinition { + return b.addColumn(columnTypeGeometry, name, &columnDefinition{ + subtype: util.OptionalPtr("", subtype), + srid: util.OptionalNil(srid...), + }) +} + +// Point creates a new point column definition in the blueprint. +func (b *Blueprint) Point(name string, srid ...int) ColumnDefinition { + return b.addColumn(columnTypePoint, name, &columnDefinition{ + srid: util.OptionalPtr(4326, srid...), + }) +} + +// Enum creates a new enum column definition in the blueprint. +// The allowedEnums parameter is a slice of strings that defines the allowed values for the enum column. +// +// Example: +// +// table.Enum("status", []string{"active", "inactive", "pending"}) +// table.Enum("role", []string{"admin", "user", "guest"}).Comment("User role in the system") +func (b *Blueprint) Enum(name string, allowed []string) ColumnDefinition { + return b.addColumn(columnTypeEnum, name, &columnDefinition{ + allowed: allowed, + }) +} + +// DropTimestamps removes the created_at and updated_at timestamp columns from the blueprint. +func (b *Blueprint) DropTimestamps() { + b.DropColumn("created_at", "updated_at") +} + +// DropTimestampsTz removes the created_at and updated_at timestamp with time zone columns from the blueprint. +func (b *Blueprint) DropTimestampsTz() { + b.DropTimestamps() +} + +// Index creates a new index definition in the blueprint. +// +// Example: +// +// table.Index("email") +// table.Index("email", "username") // creates a composite index +// table.Index("email").Algorithm("btree") // creates a btree index +func (b *Blueprint) Index(column string, otherColumns ...string) IndexDefinition { + return b.indexCommand(commandIndex, append([]string{column}, otherColumns...)...) +} + +// Unique creates a new unique index definition in the blueprint. +// +// Example: +// +// table.Unique("email") +// table.Unique("email", "username") // creates a composite unique index +func (b *Blueprint) Unique(column string, otherColumns ...string) IndexDefinition { + return b.indexCommand(commandUnique, append([]string{column}, otherColumns...)...) +} + +// Primary creates a new primary key index definition in the blueprint. +// +// Example: +// +// table.Primary("id") +// table.Primary("id", "email") // creates a composite primary key +func (b *Blueprint) Primary(column string, otherColumns ...string) IndexDefinition { + return b.indexCommand(commandPrimary, append([]string{column}, otherColumns...)...) +} + +// FullText creates a new fulltext index definition in the blueprint. +func (b *Blueprint) FullText(column string, otherColumns ...string) IndexDefinition { + return b.indexCommand(commandFullText, append([]string{column}, otherColumns...)...) +} + +// Foreign creates a new foreign key definition in the blueprint. +// +// Example: +// +// table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") +func (b *Blueprint) Foreign(column string) ForeignKeyDefinition { + command := b.addCommand(commandForeign, &command{ + columns: []string{column}, + }) + return &foreignKeyDefinition{command: command} +} + +// DropColumn adds a column to be dropped from the table. +// +// Example: +// +// table.DropColumn("old_column") +// table.DropColumn("old_column", "another_old_column") // drops multiple columns +func (b *Blueprint) DropColumn(column string, otherColumns ...string) { + b.addCommand(commandDropColumn, &command{ + columns: append([]string{column}, otherColumns...), + }) +} + +// RenameColumn changes the name of the table in the blueprint. +// +// Example: +// +// table.RenameColumn("old_table_name", "new_table_name") +func (b *Blueprint) RenameColumn(oldColumn string, newColumn string) { + b.addCommand(commandRenameColumn, &command{ + from: oldColumn, + to: newColumn, + }) +} + +// DropIndex adds an index to be dropped from the table. +func (b *Blueprint) DropIndex(index any) { + b.dropIndexCommand(commandDropIndex, commandIndex, index) +} + +// DropForeign adds a foreign key to be dropped from the table. +func (b *Blueprint) DropForeign(index any) { + b.dropIndexCommand(commandDropForeign, commandForeign, index) +} + +// DropPrimary adds a primary key to be dropped from the table. +func (b *Blueprint) DropPrimary(index any) { + b.dropIndexCommand(commandDropPrimary, commandPrimary, index) +} + +// DropUnique adds a unique key to be dropped from the table. +func (b *Blueprint) DropUnique(index any) { + b.dropIndexCommand(commandDropUnique, commandUnique, index) +} + +func (b *Blueprint) DropFulltext(index any) { + b.dropIndexCommand(commandDropFullText, commandFullText, index) +} + +// RenameIndex changes the name of an index in the blueprint. +// Example: +// +// table.RenameIndex("old_index_name", "new_index_name") +func (b *Blueprint) RenameIndex(oldIndexName string, newIndexName string) { + b.addCommand(commandRenameIndex, &command{ + from: oldIndexName, + to: newIndexName, + }) +} + +func (b *Blueprint) getAddedColumns() []*columnDefinition { + var addedColumns []*columnDefinition + for _, col := range b.columns { + if !col.change { + addedColumns = append(addedColumns, col) + } + } + return addedColumns +} + +func (b *Blueprint) getChangedColumns() []*columnDefinition { + var changedColumns []*columnDefinition + for _, col := range b.columns { + if col.change { + changedColumns = append(changedColumns, col) + } + } + return changedColumns +} + +func (b *Blueprint) create() { + b.addCommand(commandCreate) +} + +func (b *Blueprint) creating() bool { + for _, command := range b.commands { + if command.name == commandCreate { + return true + } + } + return false +} + +func (b *Blueprint) drop() { + b.addCommand(commandDrop) +} + +func (b *Blueprint) dropIfExists() { + b.addCommand(commandDropIfExists) +} + +func (b *Blueprint) rename(to string) { + b.addCommand(commandRename, &command{ + to: to, + }) +} + +func (b *Blueprint) addImpliedCommands() { + b.addFluentIndexes() + + if !b.creating() { + if len(b.getAddedColumns()) > 0 { + b.commands = append([]*command{{name: commandAdd}}, b.commands...) + } + if len(b.getChangedColumns()) > 0 { + changedCommands := make([]*command, 0, len(b.getChangedColumns())) + for _, col := range b.getChangedColumns() { + changedCommands = append(changedCommands, &command{name: commandChange, column: col}) + } + b.commands = append(changedCommands, b.commands...) + } + } +} + +func (b *Blueprint) addFluentIndexes() { + for _, col := range b.columns { + if col.primary != nil { + if b.dialect == dialect.MySQL { + continue + } + if !*col.primary && col.change { + b.DropPrimary([]string{col.name}) + col.primary = nil + } + } + if col.index != nil { + if *col.index { + b.Index(col.name).Name(col.indexName) + col.index = nil + } else if !*col.index && col.change { + b.DropIndex([]string{col.name}) + col.index = nil + } + } + + if col.unique != nil { + if *col.unique { + b.Unique(col.name).Name(col.uniqueName) + col.unique = nil + } else if !*col.unique && col.change { + b.DropUnique([]string{col.name}) + col.unique = nil + } + } + } +} + +func (b *Blueprint) getFluentStatements() []string { + var statements []string + for _, column := range b.columns { + for _, fluentCommand := range b.grammar.GetFluentCommands() { + if statement := fluentCommand(b, &command{column: column}); statement != "" { + statements = append(statements, statement) + } + } + } + return statements +} + +func (b *Blueprint) build(ctx *Context) error { + statements, err := b.toSql() + if err != nil { + return err + } + for _, statement := range statements { + // if b.verbose { + // log.Println(statement) + // } + if _, err := ctx.Exec(statement); err != nil { + return err + } + } + return nil +} + +func (b *Blueprint) toSql() ([]string, error) { + b.addImpliedCommands() + + var statements []string + + mainCommandMap := map[string]func(blueprint *Blueprint) (string, error){ + commandCreate: b.grammar.CompileCreate, + commandAdd: b.grammar.CompileAdd, + commandDrop: b.grammar.CompileDrop, + commandDropIfExists: b.grammar.CompileDropIfExists, + } + secondaryCommandMap := map[string]func(blueprint *Blueprint, command *command) (string, error){ + commandChange: b.grammar.CompileChange, + commandDropColumn: b.grammar.CompileDropColumn, + commandDropIndex: b.grammar.CompileDropIndex, + commandDropForeign: b.grammar.CompileDropForeign, + commandDropFullText: b.grammar.CompileDropFulltext, + commandDropPrimary: b.grammar.CompileDropPrimary, + commandDropUnique: b.grammar.CompileDropUnique, + commandForeign: b.grammar.CompileForeign, + commandFullText: b.grammar.CompileFullText, + commandIndex: b.grammar.CompileIndex, + commandPrimary: b.grammar.CompilePrimary, + commandRename: b.grammar.CompileRename, + commandRenameColumn: b.grammar.CompileRenameColumn, + commandRenameIndex: b.grammar.CompileRenameIndex, + commandUnique: b.grammar.CompileUnique, + } + for _, cmd := range b.commands { + if compileFunc, exists := mainCommandMap[cmd.name]; exists { + sql, err := compileFunc(b) + if err != nil { + return nil, err + } + if sql != "" { + statements = append(statements, sql) + } + continue + } + if compileFunc, exists := secondaryCommandMap[cmd.name]; exists { + sql, err := compileFunc(b, cmd) + if err != nil { + return nil, err + } + if sql != "" { + statements = append(statements, sql) + } + continue + } + return nil, fmt.Errorf("unknown command: %s", cmd.name) + } + + statements = append(statements, b.getFluentStatements()...) + + return statements, nil +} + +func (b *Blueprint) addColumn(colType string, name string, columnDefs ...*columnDefinition) *columnDefinition { + var col *columnDefinition + if len(columnDefs) > 0 { + col = columnDefs[0] + } else { + col = &columnDefinition{} + } + col.columnType = colType + col.name = name + + return b.addColumnDefinition(col) +} + +func (b *Blueprint) addColumnDefinition(col *columnDefinition) *columnDefinition { + b.columns = append(b.columns, col) + return col +} + +func (b *Blueprint) indexCommand(name string, columns ...string) IndexDefinition { + command := b.addCommand(name, &command{ + columns: columns, + }) + return &indexDefinition{command} +} + +func (b *Blueprint) dropIndexCommand(name string, indexType string, index any) { + switch index := index.(type) { + case string: + b.addCommand(name, &command{ + index: index, + }) + case []string: + indexName := b.grammar.CreateIndexName(b, indexType, index...) + b.addCommand(name, &command{ + index: indexName, + }) + default: + panic(fmt.Sprintf("unsupported index type: %T", index)) + } +} + +func (b *Blueprint) addCommand(name string, parameters ...*command) *command { + var parameter *command + if len(parameters) > 0 { + parameter = parameters[0] + } else { + parameter = &command{} + } + parameter.name = name + b.commands = append(b.commands, parameter) + + return parameter +} diff --git a/schema/builder.go b/schema/builder.go new file mode 100644 index 0000000..35f56c0 --- /dev/null +++ b/schema/builder.go @@ -0,0 +1,135 @@ +package schema + +import ( + "errors" + + "github.com/akfaiz/migris/internal/dialect" +) + +// Builder is an interface that defines methods for creating, dropping, and managing database tables. +type Builder interface { + // Create creates a new table with the given name and applies the provided blueprint. + Create(c *Context, name string, blueprint func(table *Blueprint)) error + // Drop removes the table with the given name. + Drop(c *Context, name string) error + // DropIfExists removes the table with the given name if it exists. + DropIfExists(c *Context, name string) error + // GetColumns retrieves the columns of the specified table. + GetColumns(c *Context, tableName string) ([]*Column, error) + // GetIndexes retrieves the indexes of the specified table. + GetIndexes(c *Context, tableName string) ([]*Index, error) + // GetTables retrieves all tables in the database. + GetTables(c *Context) ([]*TableInfo, error) + // HasColumn checks if the specified table has the given column. + HasColumn(c *Context, tableName string, columnName string) (bool, error) + // HasColumns checks if the specified table has all the given columns. + HasColumns(c *Context, tableName string, columnNames []string) (bool, error) + // HasIndex checks if the specified table has the given index. + HasIndex(c *Context, tableName string, indexes []string) (bool, error) + // HasTable checks if a table with the given name exists. + HasTable(c *Context, name string) (bool, error) + // Rename renames a table from oldName to newName. + Rename(c *Context, oldName string, newName string) error + // Table applies the provided blueprint to the specified table. + Table(c *Context, name string, blueprint func(table *Blueprint)) error +} + +// NewBuilder creates a new Builder instance based on the specified dialect. +// It returns an error if the dialect is not supported. +// +// Supported dialects are "postgres", "pgx", "mysql", and "mariadb". +func NewBuilder(dialectValue string) (Builder, error) { + dialectVal := dialect.FromString(dialectValue) + switch dialectVal { + case dialect.MySQL: + return newMysqlBuilder(), nil + case dialect.Postgres: + return newPostgresBuilder(), nil + default: + return nil, errors.New("unsupported dialect: " + dialectValue) + } +} + +type baseBuilder struct { + grammar grammar +} + +func (b *baseBuilder) newBlueprint(name string) *Blueprint { + return &Blueprint{name: name, grammar: b.grammar} +} + +func (b *baseBuilder) Create(c *Context, name string, blueprint func(table *Blueprint)) error { + if c == nil || name == "" || blueprint == nil { + return errors.New("invalid arguments: context, name, or blueprint is nil/empty") + } + + bp := b.newBlueprint(name) + bp.create() + blueprint(bp) + + if err := bp.build(c); err != nil { + return err + } + + return nil +} + +func (b *baseBuilder) Drop(c *Context, name string) error { + if c == nil || name == "" { + return errors.New("invalid arguments: context is nil or name is empty") + } + + bp := b.newBlueprint(name) + bp.drop() + + if err := bp.build(c); err != nil { + return err + } + + return nil +} + +func (b *baseBuilder) DropIfExists(c *Context, name string) error { + if c == nil || name == "" { + return errors.New("invalid arguments: context is nil or name is empty") + } + + bp := b.newBlueprint(name) + bp.dropIfExists() + + if err := bp.build(c); err != nil { + return err + } + + return nil +} + +func (b *baseBuilder) Rename(c *Context, oldName string, newName string) error { + if c == nil || oldName == "" || newName == "" { + return errors.New("invalid arguments: context is nil or old/new table name is empty") + } + + bp := b.newBlueprint(oldName) + bp.rename(newName) + + if err := bp.build(c); err != nil { + return err + } + + return nil +} + +func (b *baseBuilder) Table(c *Context, name string, blueprint func(table *Blueprint)) error { + if c == nil || name == "" || blueprint == nil { + return errors.New("invalid arguments: context is nil or name/blueprint is empty") + } + + bp := b.newBlueprint(name) + blueprint(bp) + + if err := bp.build(c); err != nil { + return err + } + + return nil +} diff --git a/schema/column_definition.go b/schema/column_definition.go new file mode 100644 index 0000000..639ed1a --- /dev/null +++ b/schema/column_definition.go @@ -0,0 +1,198 @@ +package schema + +import ( + "slices" + + "github.com/akfaiz/migris/internal/util" +) + +// ColumnDefinition defines the interface for defining a column in a database table. +type ColumnDefinition interface { + // AutoIncrement sets the column to auto-increment. + // This is typically used for primary key columns. + AutoIncrement() ColumnDefinition + // Change changes the column definition. + Change() ColumnDefinition + // Charset sets the character set for the column. + Charset(charset string) ColumnDefinition + // Collation sets the collation for the column. + Collation(collation string) ColumnDefinition + // Comment adds a comment to the column definition. + Comment(comment string) ColumnDefinition + // Default sets a default value for the column. + Default(value any) ColumnDefinition + // Index adds an index to the column. + Index(params ...any) ColumnDefinition + // Nullable sets the column to be nullable or not. + Nullable(value ...bool) ColumnDefinition + // OnUpdate sets the value to be used when the column is updated. + OnUpdate(value any) ColumnDefinition + // Primary sets the column as a primary key. + Primary(value ...bool) ColumnDefinition + // Unique sets the column to be unique. + Unique(params ...any) ColumnDefinition + // Unsigned sets the column to be unsigned (applicable for numeric types). + Unsigned() ColumnDefinition + // UseCurrent sets the column to use the current timestamp as default. + UseCurrent() ColumnDefinition + // UseCurrentOnUpdate sets the column to use the current timestamp on update. + UseCurrentOnUpdate() ColumnDefinition +} + +type columnDefinition struct { + commands []string + name string + columnType string + charset *string + collation *string + comment *string + defaultValue any + onUpdateValue any + useCurrent bool + useCurrentOnUpdate bool + nullable *bool + autoIncrement *bool + unsigned *bool + primary *bool + index *bool + indexName string + unique *bool + uniqueName string + length *int + precision *int + total *int + places *int + change bool + allowed []string // for enum type columns + subtype *string // for geography and geometry types + srid *int // for geography and geometry types +} + +// Expression is a type for expressions that can be used as default values for columns. +// +// Example: +// +// schema.Timestamp("created_at").Default(schema.Expression("CURRENT_TIMESTAMP")) +type Expression string + +func (e Expression) String() string { + return string(e) +} + +var _ ColumnDefinition = &columnDefinition{} + +func (c *columnDefinition) addCommand(command string) { + c.commands = append(c.commands, command) +} + +func (c *columnDefinition) hasCommand(command string) bool { + return slices.Contains(c.commands, command) +} + +func (c *columnDefinition) SetDefault(value any) { + c.addCommand("default") + c.defaultValue = value +} + +func (c *columnDefinition) SetOnUpdate(value any) { + c.addCommand("onUpdate") + c.onUpdateValue = value +} + +func (c *columnDefinition) SetSubtype(value *string) { + c.subtype = value +} + +func (c *columnDefinition) AutoIncrement() ColumnDefinition { + c.autoIncrement = util.PtrOf(true) + return c +} + +func (c *columnDefinition) Charset(charset string) ColumnDefinition { + c.charset = &charset + return c +} + +func (c *columnDefinition) Change() ColumnDefinition { + c.change = true + return c +} + +func (c *columnDefinition) Collation(collation string) ColumnDefinition { + c.collation = &collation + return c +} + +func (c *columnDefinition) Comment(comment string) ColumnDefinition { + c.addCommand("comment") + c.comment = &comment + return c +} + +func (c *columnDefinition) Default(value any) ColumnDefinition { + c.addCommand("default") + c.defaultValue = value + + return c +} + +func (c *columnDefinition) Index(params ...any) ColumnDefinition { + index := true + for _, param := range params { + switch v := param.(type) { + case bool: + index = v + case string: + c.indexName = v + } + } + c.index = &index + return c +} + +func (c *columnDefinition) Nullable(value ...bool) ColumnDefinition { + c.addCommand("nullable") + c.nullable = util.OptionalPtr(true, value...) + return c +} + +func (c *columnDefinition) OnUpdate(value any) ColumnDefinition { + c.addCommand("onUpdate") + c.onUpdateValue = value + return c +} + +func (c *columnDefinition) Primary(value ...bool) ColumnDefinition { + val := util.Optional(true, value...) + c.primary = &val + return c +} + +func (c *columnDefinition) Unique(params ...any) ColumnDefinition { + unique := true + for _, param := range params { + switch v := param.(type) { + case bool: + unique = v + case string: + c.uniqueName = v + } + } + c.unique = &unique + return c +} + +func (c *columnDefinition) Unsigned() ColumnDefinition { + c.unsigned = util.PtrOf(true) + return c +} + +func (c *columnDefinition) UseCurrent() ColumnDefinition { + c.useCurrent = true + return c +} + +func (c *columnDefinition) UseCurrentOnUpdate() ColumnDefinition { + c.useCurrentOnUpdate = true + return c +} diff --git a/schema/command.go b/schema/command.go new file mode 100644 index 0000000..907798a --- /dev/null +++ b/schema/command.go @@ -0,0 +1,40 @@ +package schema + +const ( + commandAdd string = "add" + commandChange string = "change" + commandCreate string = "create" + commandDrop string = "drop" + commandDropIfExists string = "dropIfExists" + commandDropColumn string = "dropColumn" + commandDropForeign string = "dropForeign" + commandDropFullText string = "dropFullText" + commandDropIndex string = "dropIndex" + commandDropPrimary string = "dropPrimary" + commandDropUnique string = "dropUnique" + commandForeign string = "foreign" + commandFullText string = "fullText" + commandIndex string = "index" + commandPrimary string = "primary" + commandRename string = "rename" + commandRenameColumn string = "renameColumn" + commandRenameIndex string = "renameIndex" + commandUnique string = "unique" +) + +type command struct { + column *columnDefinition + deferrable *bool + initiallyImmediate *bool + algorithm string + from string + index string + language string + name string + on string + onDelete string + onUpdate string + to string + columns []string + references []string +} diff --git a/schema/context.go b/schema/context.go new file mode 100644 index 0000000..aa74605 --- /dev/null +++ b/schema/context.go @@ -0,0 +1,43 @@ +package schema + +import ( + "context" + "database/sql" +) + +type Context struct { + ctx context.Context + tx *sql.Tx + filename string +} + +type ContextOptions func(*Context) + +func WithFilename(filename string) ContextOptions { + return func(c *Context) { + c.filename = filename + } +} + +func NewContext(ctx context.Context, tx *sql.Tx, opts ...ContextOptions) *Context { + c := &Context{ + ctx: ctx, + tx: tx, + } + for _, opt := range opts { + opt(c) + } + return c +} + +func (c *Context) Exec(query string, args ...any) (sql.Result, error) { + return c.tx.ExecContext(c.ctx, query, args...) +} + +func (c *Context) Query(query string, args ...any) (*sql.Rows, error) { + return c.tx.QueryContext(c.ctx, query, args...) +} + +func (c *Context) QueryRow(query string, args ...any) *sql.Row { + return c.tx.QueryRowContext(c.ctx, query, args...) +} diff --git a/schema/foreign_key_definition.go b/schema/foreign_key_definition.go new file mode 100644 index 0000000..49ad922 --- /dev/null +++ b/schema/foreign_key_definition.go @@ -0,0 +1,111 @@ +package schema + +import "github.com/akfaiz/migris/internal/util" + +// ForeignKeyDefinition defines the interface for defining a foreign key constraint in a database table. +type ForeignKeyDefinition interface { + // CascadeOnDelete sets the foreign key to cascade on delete. + CascadeOnDelete() ForeignKeyDefinition + // CascadeOnUpdate sets the foreign key to cascade on update. + CascadeOnUpdate() ForeignKeyDefinition + // Deferrable sets the foreign key as deferrable. + Deferrable(value ...bool) ForeignKeyDefinition + // InitiallyImmediate sets the foreign key to be initially immediate. + InitiallyImmediate(value ...bool) ForeignKeyDefinition + // Name sets the name of the foreign key constraint. + // This is optional and can be used to give a specific name to the foreign key. + Name(name string) ForeignKeyDefinition + // NoActionOnDelete set the foreign key to do nothing on delete. + NoActionOnDelete() ForeignKeyDefinition + // NoActionOnUpdate set the foreign key to do nothing on the update. + NoActionOnUpdate() ForeignKeyDefinition + // NullOnDelete set the foreign key to set the column to NULL on delete. + NullOnDelete() ForeignKeyDefinition + // NullOnUpdate set the foreign key to set the column to NULL on update. + NullOnUpdate() ForeignKeyDefinition + // On sets the table that these foreign key references. + On(table string) ForeignKeyDefinition + // OnDelete set the action to take when the referenced row is deleted. + OnDelete(action string) ForeignKeyDefinition + // OnUpdate set the action to take when the referenced row is updated. + OnUpdate(action string) ForeignKeyDefinition + // References set the column that this foreign key references in the other table. + References(column string) ForeignKeyDefinition + // RestrictOnDelete set the foreign key to restrict deletion of the referenced row. + RestrictOnDelete() ForeignKeyDefinition + // RestrictOnUpdate set the foreign key to restrict updating of the referenced row. + RestrictOnUpdate() ForeignKeyDefinition +} + +type foreignKeyDefinition struct { + *command +} + +func (fd *foreignKeyDefinition) CascadeOnDelete() ForeignKeyDefinition { + return fd.OnDelete("CASCADE") +} + +func (fd *foreignKeyDefinition) CascadeOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("CASCADE") +} + +func (fd *foreignKeyDefinition) Deferrable(value ...bool) ForeignKeyDefinition { + val := util.Optional(true, value...) + fd.deferrable = &val + return fd +} + +func (fd *foreignKeyDefinition) InitiallyImmediate(value ...bool) ForeignKeyDefinition { + val := util.Optional(true, value...) + fd.initiallyImmediate = &val + return fd +} + +func (fd *foreignKeyDefinition) Name(name string) ForeignKeyDefinition { + fd.index = name + return fd +} + +func (fd *foreignKeyDefinition) NoActionOnDelete() ForeignKeyDefinition { + return fd.OnDelete("NO ACTION") +} + +func (fd *foreignKeyDefinition) NoActionOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("NO ACTION") +} + +func (fd *foreignKeyDefinition) NullOnDelete() ForeignKeyDefinition { + return fd.OnDelete("SET NULL") +} + +func (fd *foreignKeyDefinition) NullOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("SET NULL") +} + +func (fd *foreignKeyDefinition) On(table string) ForeignKeyDefinition { + fd.on = table + return fd +} + +func (fd *foreignKeyDefinition) OnDelete(action string) ForeignKeyDefinition { + fd.onDelete = action + return fd +} + +func (fd *foreignKeyDefinition) OnUpdate(action string) ForeignKeyDefinition { + fd.onUpdate = action + return fd +} + +func (fd *foreignKeyDefinition) References(columns string) ForeignKeyDefinition { + fd.references = []string{columns} + return fd +} + +func (fd *foreignKeyDefinition) RestrictOnDelete() ForeignKeyDefinition { + return fd.OnDelete("RESTRICT") +} + +func (fd *foreignKeyDefinition) RestrictOnUpdate() ForeignKeyDefinition { + return fd.OnUpdate("RESTRICT") +} diff --git a/schema/grammar.go b/schema/grammar.go new file mode 100644 index 0000000..78d7c3a --- /dev/null +++ b/schema/grammar.go @@ -0,0 +1,141 @@ +package schema + +import ( + "fmt" + "slices" + "strings" + + "github.com/akfaiz/migris/internal/util" +) + +type grammar interface { + CompileCreate(bp *Blueprint) (string, error) + CompileAdd(bp *Blueprint) (string, error) + CompileChange(bp *Blueprint, command *command) (string, error) + CompileDrop(bp *Blueprint) (string, error) + CompileDropIfExists(bp *Blueprint) (string, error) + CompileRename(bp *Blueprint, command *command) (string, error) + CompileDropColumn(blueprint *Blueprint, command *command) (string, error) + CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) + CompileIndex(blueprint *Blueprint, command *command) (string, error) + CompileUnique(blueprint *Blueprint, command *command) (string, error) + CompilePrimary(blueprint *Blueprint, command *command) (string, error) + CompileFullText(blueprint *Blueprint, command *command) (string, error) + CompileDropIndex(blueprint *Blueprint, command *command) (string, error) + CompileDropUnique(blueprint *Blueprint, command *command) (string, error) + CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) + CompileDropPrimary(blueprint *Blueprint, command *command) (string, error) + CompileRenameIndex(blueprint *Blueprint, command *command) (string, error) + CompileForeign(blueprint *Blueprint, command *command) (string, error) + CompileDropForeign(blueprint *Blueprint, command *command) (string, error) + GetFluentCommands() []func(blueprint *Blueprint, command *command) string + CreateIndexName(blueprint *Blueprint, idxType string, columns ...string) string +} + +type baseGrammar struct{} + +func (g *baseGrammar) CompileForeign(blueprint *Blueprint, command *command) (string, error) { + if len(command.columns) == 0 || slices.Contains(command.columns, "") || command.on == "" || + len(command.references) == 0 || slices.Contains(command.references, "") { + return "", fmt.Errorf("foreign key definition is incomplete: column, on, and references must be set") + } + onDelete := "" + if command.onDelete != "" { + onDelete = fmt.Sprintf(" ON DELETE %s", command.onDelete) + } + onUpdate := "" + if command.onUpdate != "" { + onUpdate = fmt.Sprintf(" ON UPDATE %s", command.onUpdate) + } + index := command.index + if index == "" { + index = g.CreateForeignKeyName(blueprint, command) + } + + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)%s%s", + blueprint.name, + index, + command.columns[0], + command.on, + command.references[0], + onDelete, + onUpdate, + ), nil +} + +func (g *baseGrammar) CreateIndexName(blueprint *Blueprint, idxType string, columns ...string) string { + tableName := blueprint.name + if strings.Contains(tableName, ".") { + parts := strings.Split(tableName, ".") + tableName = parts[len(parts)-1] // Use the last part as the table name + } + + switch idxType { + case "primary": + return fmt.Sprintf("pk_%s", tableName) + case "unique": + return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(columns, "_")) + case "index": + return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_")) + case "fulltext": + return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(columns, "_")) + default: + return "" + } +} + +func (g *baseGrammar) CreateForeignKeyName(blueprint *Blueprint, command *command) string { + tableName := blueprint.name + if strings.Contains(tableName, ".") { + parts := strings.Split(tableName, ".") + tableName = parts[len(parts)-1] // Use the last part as the table name + } + on := command.on + if strings.Contains(on, ".") { + parts := strings.Split(on, ".") + on = parts[len(parts)-1] // Use the last part as the column name + } + return fmt.Sprintf("fk_%s_%s", tableName, on) +} + +func (g *baseGrammar) QuoteString(s string) string { + return "'" + s + "'" +} + +func (g *baseGrammar) PrefixArray(prefix string, items []string) []string { + prefixed := make([]string, len(items)) + for i, item := range items { + prefixed[i] = fmt.Sprintf("%s%s", prefix, item) + } + return prefixed +} + +func (g *baseGrammar) Columnize(columns []string) string { + if len(columns) == 0 { + return "" + } + return strings.Join(columns, ", ") +} + +func (g *baseGrammar) GetValue(value any) string { + switch v := value.(type) { + case Expression: + return v.String() + default: + return fmt.Sprintf("'%v'", v) + } +} + +func (g *baseGrammar) GetDefaultValue(value any) string { + if value == nil { + return "NULL" + } + switch v := value.(type) { + case Expression: + return v.String() + case bool: + return util.Ternary(v, "'1'", "'0'") + default: + return fmt.Sprintf("'%v'", v) + } +} diff --git a/index_definition.go b/schema/index_definition.go similarity index 79% rename from index_definition.go rename to schema/index_definition.go index 2798584..cbcb14d 100644 --- a/index_definition.go +++ b/schema/index_definition.go @@ -1,5 +1,7 @@ package schema +import "github.com/akfaiz/migris/internal/util" + // IndexDefinition defines the interface for defining an index in a database table. type IndexDefinition interface { // Algorithm sets the algorithm for the index. @@ -15,16 +17,8 @@ type IndexDefinition interface { Name(name string) IndexDefinition } -var _ IndexDefinition = &indexDefinition{} - type indexDefinition struct { - name string - indexType indexType - algorithm string - columns []string - language string - deferrable *bool - initiallyImmediate *bool + *command } func (id *indexDefinition) Algorithm(algorithm string) IndexDefinition { @@ -33,13 +27,13 @@ func (id *indexDefinition) Algorithm(algorithm string) IndexDefinition { } func (id *indexDefinition) Deferrable(value ...bool) IndexDefinition { - val := optional(true, value...) + val := util.Optional(true, value...) id.deferrable = &val return id } func (id *indexDefinition) InitiallyImmediate(value ...bool) IndexDefinition { - val := optional(true, value...) + val := util.Optional(true, value...) id.initiallyImmediate = &val return id } @@ -50,6 +44,6 @@ func (id *indexDefinition) Language(language string) IndexDefinition { } func (id *indexDefinition) Name(name string) IndexDefinition { - id.name = name + id.index = name return id } diff --git a/mysql_builder.go b/schema/mysql_builder.go similarity index 57% rename from mysql_builder.go rename to schema/mysql_builder.go index 653b58c..a02b984 100644 --- a/mysql_builder.go +++ b/schema/mysql_builder.go @@ -1,7 +1,6 @@ package schema import ( - "context" "database/sql" "errors" "strings" @@ -12,6 +11,8 @@ type mysqlBuilder struct { grammar *mysqlGrammar } +var _ Builder = (*mysqlBuilder)(nil) + func newMysqlBuilder() Builder { grammar := newMysqlGrammar() @@ -21,13 +22,9 @@ func newMysqlBuilder() Builder { } } -func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (string, error) { - if tx == nil { - return "", errors.New("transaction is nil") - } - - query := b.grammar.compileCurrentDatabase() - row := queryRowContext(ctx, tx, query) +func (b *mysqlBuilder) getCurrentDatabase(c *Context) (string, error) { + query := b.grammar.CompileCurrentDatabase() + row := c.QueryRow(query) var dbName string if err := row.Scan(&dbName); err != nil { return "", err @@ -35,47 +32,22 @@ func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (stri return dbName, nil } -func (b *mysqlBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil { - return err - } - - exist, err := b.HasTable(ctx, tx, name) - if err != nil { - return err - } - if exist { - return nil // Table already exists, no need to create it - } - - bp := &Blueprint{name: name} - bp.createIfNotExists() - blueprint(bp) - - statements, err := bp.toSql(b.grammar) - if err != nil { - return err - } - - return execContext(ctx, tx, statements...) -} - -func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err +func (b *mysqlBuilder) GetColumns(c *Context, tableName string) ([]*Column, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } - database, err := b.getCurrentDatabase(ctx, tx) + database, err := b.getCurrentDatabase(c) if err != nil { return nil, err } - query, err := b.grammar.compileColumns(database, tableName) + query, err := b.grammar.CompileColumns(database, tableName) if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -101,22 +73,22 @@ func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName str return columns, nil } -func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err +func (b *mysqlBuilder) GetIndexes(c *Context, tableName string) ([]*Index, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } - database, err := b.getCurrentDatabase(ctx, tx) + database, err := b.getCurrentDatabase(c) if err != nil { return nil, err } - query, err := b.grammar.compileIndexes(database, tableName) + query, err := b.grammar.CompileIndexes(database, tableName) if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -135,17 +107,21 @@ func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName str return indexes, nil } -func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { - database, err := b.getCurrentDatabase(ctx, tx) +func (b *mysqlBuilder) GetTables(c *Context) ([]*TableInfo, error) { + if c == nil { + return nil, errors.New("invalid arguments: context is nil") + } + + database, err := b.getCurrentDatabase(c) if err != nil { return nil, err } - query, err := b.grammar.compileTables(database) + query, err := b.grammar.CompileTables(database) if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -162,22 +138,22 @@ func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, return tables, nil } -func (b *mysqlBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) { - if columnName == "" { - return false, errors.New("column name cannot be empty") +func (b *mysqlBuilder) HasColumn(c *Context, tableName string, columnName string) (bool, error) { + if c == nil || columnName == "" { + return false, errors.New("invalid arguments: context is nil or column name is empty") } - return b.HasColumns(ctx, tx, tableName, []string{columnName}) + return b.HasColumns(c, tableName, []string{columnName}) } -func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return false, err +func (b *mysqlBuilder) HasColumns(c *Context, tableName string, columnNames []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } if len(columnNames) == 0 { return false, errors.New("no column names provided") } - columns, err := b.GetColumns(ctx, tx, tableName) + columns, err := b.GetColumns(c, tableName) if err != nil { return false, err } @@ -193,12 +169,12 @@ func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName str return true, nil // All specified columns exist } -func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return false, err +func (b *mysqlBuilder) HasIndex(c *Context, tableName string, indexes []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } - existingIndexes, err := b.GetIndexes(ctx, tx, tableName) + existingIndexes, err := b.GetIndexes(c, tableName) if err != nil { return false, err } @@ -237,22 +213,22 @@ func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName strin return false, nil // If no specified index exists, return false } -func (b *mysqlBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { - if err := b.validateTxAndName(tx, name); err != nil { - return false, err +func (b *mysqlBuilder) HasTable(c *Context, name string) (bool, error) { + if c == nil || name == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } - database, err := b.getCurrentDatabase(ctx, tx) + database, err := b.getCurrentDatabase(c) if err != nil { return false, err } - query, err := b.grammar.compileTableExists(database, name) + query, err := b.grammar.CompileTableExists(database, name) if err != nil { return false, err } - row := queryRowContext(ctx, tx, query) + row := c.QueryRow(query) var exists bool if err := row.Scan(&exists); err != nil { if errors.Is(err, sql.ErrNoRows) { diff --git a/mysql_builder_test.go b/schema/mysql_builder_test.go similarity index 60% rename from mysql_builder_test.go rename to schema/mysql_builder_test.go index 39f50c6..0d8a643 100644 --- a/mysql_builder_test.go +++ b/schema/mysql_builder_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - "github.com/afkdevs/go-schema" + "github.com/akfaiz/migris/schema" _ "github.com/go-sql-driver/mysql" // MySQL driver "github.com/stretchr/testify/suite" ) @@ -17,8 +17,9 @@ func TestMysqlBuilderSuite(t *testing.T) { type mysqlBuilderSuite struct { suite.Suite - ctx context.Context - db *sql.DB + ctx context.Context + db *sql.DB + builder schema.Builder } func (s *mysqlBuilderSuite) SetupSuite() { @@ -42,7 +43,8 @@ func (s *mysqlBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - schema.SetDebug(false) + s.builder, err = schema.NewBuilder("mysql") + s.Require().NoError(err) } func (s *mysqlBuilderSuite) TearDownSuite() { @@ -50,13 +52,14 @@ func (s *mysqlBuilderSuite) TearDownSuite() { } func (s *mysqlBuilderSuite) AfterTest(_, _ string) { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) - tables, err := builder.GetTables(s.ctx, tx) + c := schema.NewContext(s.ctx, tx) + tables, err := builder.GetTables(c) s.Require().NoError(err) for _, table := range tables { - err := builder.DropIfExists(s.ctx, tx, table.Name) + err := builder.DropIfExists(c, table.Name) if err != nil { s.T().Logf("error dropping table %s: %v", table.Name, err) } @@ -66,29 +69,31 @@ func (s *mysqlBuilderSuite) AfterTest(_, _ string) { } func (s *mysqlBuilderSuite) TestCreate() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Create(s.ctx, nil, "test_table", func(table *schema.Blueprint) { + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Create(nil, "test_table", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Create(s.ctx, tx, "", func(table *schema.Blueprint) { + err := builder.Create(c, "", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Create(s.ctx, tx, "test_table", nil) + err := builder.Create(c, "test_table", nil) s.Error(err) }) s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -98,7 +103,7 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with valid parameters") }) s.Run("when have composite primary key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "user_roles", func(table *schema.Blueprint) { + err = builder.Create(c, "user_roles", func(table *schema.Blueprint) { table.Integer("user_id") table.Integer("role_id") @@ -107,30 +112,30 @@ func (s *mysqlBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with composite primary key") }) s.Run("when have foreign key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.UnsignedBigInteger("user_id") table.String("order_id", 255).Unique() table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when creating table with foreign key") }) s.Run("when have custom index should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders_2", func(table *schema.Blueprint) { + err = builder.Create(c, "orders_2", func(table *schema.Blueprint) { table.ID() table.String("order_id", 255).Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Index("created_at").Name("idx_orders_created_at").Algorithm("BTREE") }) s.NoError(err, "expected no error when creating table with custom index") }) s.Run("when table already exists, should return error", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() @@ -139,250 +144,212 @@ func (s *mysqlBuilderSuite) TestCreate() { }) } -func (s *mysqlBuilderSuite) TestCreateIfNotExists() { - builder, _ := schema.NewBuilder("mysql") - tx, err := s.db.BeginTx(s.ctx, nil) - s.Require().NoError(err) - defer tx.Rollback() //nolint:errcheck - - s.Run("when tx is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, nil, "test_table", func(table *schema.Blueprint) { - table.String("name") - }) - s.Error(err) - }) - s.Run("when table name is empty, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) { - table.String("name") - }) - s.Error(err) - }) - s.Run("when blueprint is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "test_table", nil) - s.Error(err) - }) - s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.CreateIfNotExists(context.Background(), tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") - }) - s.NoError(err, "expected no error when creating table with valid parameters") - }) - s.Run("when table already exists, should not return error", func() { - err = builder.CreateIfNotExists(context.Background(), tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - }) - s.NoError(err, "expected no error when creating table that already exists with CreateIfNotExists") - }) -} - func (s *mysqlBuilderSuite) TestDrop() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Drop(s.ctx, nil, "test_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Drop(nil, "test_table") s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Drop(s.ctx, tx, "") + err := builder.Drop(c, "") s.Error(err) }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.Drop(context.Background(), tx, "users") + err = builder.Drop(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should return error", func() { - err = builder.Drop(context.Background(), tx, "non_existent_table") + err = builder.Drop(c, "non_existent_table") s.Error(err, "expected error when dropping a table that does not exist") }) } func (s *mysqlBuilderSuite) TestDropIfExists() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.DropIfExists(s.ctx, nil, "test_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.DropIfExists(nil, "test_table") s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.DropIfExists(s.ctx, tx, "") + err := builder.DropIfExists(c, "") s.Error(err) }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.DropIfExists(context.Background(), tx, "users") + err = builder.DropIfExists(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should not return error", func() { - err = builder.DropIfExists(context.Background(), tx, "non_existent_table") + err = builder.DropIfExists(c, "non_existent_table") s.NoError(err, "expected no error when dropping a table that does not exist with DropIfExists") }) } func (s *mysqlBuilderSuite) TestRename() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Rename(s.ctx, nil, "old_table", "new_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Rename(nil, "old_table", "new_table") s.Error(err) }) s.Run("when old table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "", "new_table") + err := builder.Rename(c, "", "new_table") s.Error(err) }) s.Run("when new table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "old_table", "") + err := builder.Rename(c, "old_table", "") s.Error(err) }) s.Run("when all parameters are valid, should rename table successfully", func() { - err = builder.Create(context.Background(), tx, "old_table", func(table *schema.Blueprint) { + err = builder.Create(c, "old_table", func(table *schema.Blueprint) { table.ID() table.String("name", 255) }) s.NoError(err, "expected no error when creating old_table before renaming it") - err = builder.Rename(context.Background(), tx, "old_table", "new_table") + err = builder.Rename(c, "old_table", "new_table") s.NoError(err, "expected no error when renaming table with valid parameters") }) } func (s *mysqlBuilderSuite) TestTable() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Table(s.ctx, nil, "test_table", func(table *schema.Blueprint) { + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Table(nil, "test_table", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when table name is empty, should return error", func() { - err := builder.Table(s.ctx, tx, "", func(table *schema.Blueprint) { + err := builder.Table(c, "", func(table *schema.Blueprint) { table.String("name") }) s.Error(err) }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Table(s.ctx, tx, "test_table", nil) + err := builder.Table(c, "test_table", nil) s.Error(err) }) s.Run("when all parameters are valid, should modify table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique("uk_users_email") table.String("password", 255).Nullable() table.Text("bio").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.FullText("bio") }) s.NoError(err, "expected no error when creating table before modifying it") s.Run("should add new columns and modify existing ones", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("address", 255).Nullable() table.String("phone", 20).Nullable().Unique("uk_users_phone") }) s.NoError(err, "expected no error when modifying table with valid parameters") }) s.Run("should modify existing column", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("email", 255).Nullable().Change() }) s.NoError(err, "expected no error when modifying existing column") }) s.Run("should drop column and rename existing one", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropColumn("password") table.RenameColumn("name", "full_name") }) s.NoError(err, "expected no error when dropping column and renaming existing one") }) s.Run("should add index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.Index("phone").Name("idx_users_phone").Algorithm("BTREE") }) s.NoError(err, "expected no error when adding index to table") }) s.Run("should rename index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.RenameIndex("idx_users_phone", "idx_users_contact") }) s.NoError(err, "expected no error when renaming index in table") }) s.Run("should drop index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropIndex("idx_users_contact") }) s.NoError(err, "expected no error when dropping index from table") }) s.Run("should drop unique constraint", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropUnique("uk_users_email") }) s.NoError(err, "expected no error when dropping unique constraint from table") }) s.Run("should drop fulltext index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropFulltext("ft_users_bio") }) s.NoError(err, "expected no error when dropping fulltext index from table") }) s.Run("should add foreign key", func() { - err = builder.Create(s.ctx, tx, "roles", func(table *schema.Blueprint) { + err = builder.Create(c, "roles", func(table *schema.Blueprint) { table.UnsignedInteger("id").Primary() table.String("role_name", 255).Unique("uk_roles_role_name") }) s.NoError(err, "expected no error when creating roles table before adding foreign key") - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.UnsignedInteger("role_id").Nullable() table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when adding foreign key to users table") }) s.Run("should drop foreign key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropForeign("fk_users_roles") }) s.NoError(err, "expected no error when dropping foreign key from users table") }) s.Run("should drop primary key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.UnsignedBigInteger("id").Change() table.DropPrimary("users_pkey") }) @@ -392,38 +359,39 @@ func (s *mysqlBuilderSuite) TestTable() { } func (s *mysqlBuilderSuite) TestGetColumns() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - columns, err := builder.GetColumns(s.ctx, nil, "test_table") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + columns, err := builder.GetColumns(nil, "test_table") s.Error(err) s.Nil(columns) }) s.Run("when table name is empty, should return error", func() { - columns, err := builder.GetColumns(s.ctx, tx, "") + columns, err := builder.GetColumns(c, "") s.Error(err) s.Nil(columns) }) s.Run("when table does not exist, should return empty slice", func() { - columns, err := builder.GetColumns(s.ctx, tx, "non_existent_table") + columns, err := builder.GetColumns(c, "non_existent_table") s.NoError(err) s.Empty(columns) }) s.Run("when table exists, should return columns successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting columns") - columns, err := builder.GetColumns(context.Background(), tx, "users") + columns, err := builder.GetColumns(c, "users") s.NoError(err, "expected no error when getting columns from existing table") s.NotEmpty(columns) s.Len(columns, 6, "expected 6 columns in the users table") @@ -431,67 +399,69 @@ func (s *mysqlBuilderSuite) TestGetColumns() { } func (s *mysqlBuilderSuite) TestGetIndexes() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetIndexes(s.ctx, nil, "users_indexes") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetIndexes(nil, "users_indexes") + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - _, err := builder.GetIndexes(s.ctx, tx, "") + _, err := builder.GetIndexes(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.Index("name").Name("idx_users_name") }) s.NoError(err, "expected no error when creating table before getting indexes") - indexes, err := builder.GetIndexes(s.ctx, tx, "users") + indexes, err := builder.GetIndexes(c, "users") s.NoError(err, "expected no error when getting indexes with valid parameters") s.Len(indexes, 3, "expected 3 index to be returned") }) s.Run("when table does not exist, should return empty indexes", func() { - indexes, err := builder.GetIndexes(s.ctx, tx, "non_existent_table") + indexes, err := builder.GetIndexes(c, "non_existent_table") s.NoError(err, "expected no error when getting indexes of non-existent table") s.Empty(indexes, "expected empty indexes for non-existent table") }) } func (s *mysqlBuilderSuite) TestGetTables() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when tx is nil, should return error", func() { - tables, err := builder.GetTables(s.ctx, nil) + tables, err := builder.GetTables(nil) s.Error(err, "expected error when transaction is nil") s.Nil(tables) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting tables") - tables, err := builder.GetTables(context.Background(), tx) + tables, err := builder.GetTables(c) s.NoError(err, "expected no error when getting tables after creating one") s.NotEmpty(tables, "expected non-empty tables slice after creating a table") found := false @@ -506,165 +476,170 @@ func (s *mysqlBuilderSuite) TestGetTables() { } func (s *mysqlBuilderSuite) TestHasColumn() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumn(s.ctx, nil, "users", "name") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasColumn(nil, "users", "name") + s.Error(err, "expected error when context is nil") s.False(exists) }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumn(s.ctx, tx, "", "name") + exists, err := builder.HasColumn(c, "", "name") s.Error(err, "expected error when table name is empty") s.False(exists) }) s.Run("when column name is empty, should return error", func() { - exists, err := builder.HasColumn(s.ctx, tx, "users", "") + exists, err := builder.HasColumn(c, "users", "") s.Error(err, "expected error when column name is empty") s.False(exists) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking for column existence") - exists, err := builder.HasColumn(context.Background(), tx, "users", "name") + exists, err := builder.HasColumn(c, "users", "name") s.NoError(err, "expected no error when checking for existing column") s.True(exists, "expected 'name' column to exist in users table") - exists, err = builder.HasColumn(context.Background(), tx, "users", "non_existent_column") + exists, err = builder.HasColumn(c, "users", "non_existent_column") s.NoError(err, "expected no error when checking for non-existing column") s.False(exists, "expected 'non_existent_column' to not exist in users table") }) } func (s *mysqlBuilderSuite) TestHasColumns() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumns(s.ctx, nil, "users", []string{"name"}) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasColumns(nil, "users", []string{"name"}) + s.Error(err, "expected error when context is nil") s.False(exists) }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumns(s.ctx, tx, "", []string{"name"}) + exists, err := builder.HasColumns(c, "", []string{"name"}) s.Error(err, "expected error when table name is empty") s.False(exists) }) s.Run("when column names are empty, should return error", func() { - exists, err := builder.HasColumns(s.ctx, tx, "users", []string{}) + exists, err := builder.HasColumns(c, "users", []string{}) s.Error(err, "expected error when column names are empty") s.False(exists) }) s.Run("when all parameters are valid", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking for columns existence") - exists, err := builder.HasColumns(context.Background(), tx, "users", []string{"name", "email"}) + exists, err := builder.HasColumns(c, "users", []string{"name", "email"}) s.NoError(err, "expected no error when checking for existing columns") s.True(exists, "expected 'name' and 'email' columns to exist in users_has_columns table") - exists, err = builder.HasColumns(context.Background(), tx, "users", []string{"name", "non_existent_column"}) + exists, err = builder.HasColumns(c, "users", []string{"name", "non_existent_column"}) s.NoError(err, "expected no error when checking for mixed existing and non-existing columns") s.False(exists, "expected 'non_existent_column' to not exist in users_has_columns table") }) } func (s *mysqlBuilderSuite) TestHasIndex() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasIndex(s.ctx, nil, "orders", []string{"idx_users_name"}) - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasIndex(nil, "orders", []string{"idx_users_name"}) + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasIndex(s.ctx, tx, "", []string{"idx_users_name"}) + exists, err := builder.HasIndex(c, "", []string{"idx_users_name"}) s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.Integer("company_id") table.Integer("user_id") table.String("order_id", 255) table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.Index("company_id", "user_id") table.Unique("order_id").Name("uk_orders3_order_id").Algorithm("BTREE") }) s.NoError(err, "expected no error when creating table with index") - exists, err := builder.HasIndex(s.ctx, tx, "orders", []string{"uk_orders3_order_id"}) + exists, err := builder.HasIndex(c, "orders", []string{"uk_orders3_order_id"}) s.NoError(err, "expected no error when checking if index exists with valid parameters") s.True(exists, "expected exists to be true for existing index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"company_id", "user_id"}) + exists, err = builder.HasIndex(c, "orders", []string{"company_id", "user_id"}) s.NoError(err, "expected no error when checking non-existent index") s.True(exists, "expected exists to be true for existing composite index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"non_existent_index"}) + exists, err = builder.HasIndex(c, "orders", []string{"non_existent_index"}) s.NoError(err, "expected no error when checking non-existent index") s.False(exists, "expected exists to be false for non-existent index") }) } func (s *mysqlBuilderSuite) TestHasTable() { - builder, _ := schema.NewBuilder("mysql") + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasTable(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasTable(nil, "users") + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasTable(s.ctx, tx, "") + exists, err := builder.HasTable(c, "") s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking if it exists") - exists, err := builder.HasTable(s.ctx, tx, "users") + exists, err := builder.HasTable(c, "users") s.NoError(err, "expected no error when checking if table exists with valid parameters") s.True(exists, "expected exists to be true for existing table") - exists, err = builder.HasTable(s.ctx, tx, "non_existent_table") + exists, err = builder.HasTable(c, "non_existent_table") s.NoError(err, "expected no error when checking non-existent table") s.False(exists, "expected exists to be false for non-existent table") }) diff --git a/schema/mysql_grammar.go b/schema/mysql_grammar.go new file mode 100644 index 0000000..eccbf12 --- /dev/null +++ b/schema/mysql_grammar.go @@ -0,0 +1,608 @@ +package schema + +import ( + "fmt" + "slices" + "strings" + + "github.com/akfaiz/migris/internal/util" +) + +type mysqlGrammar struct { + baseGrammar + + serials []string +} + +func newMysqlGrammar() *mysqlGrammar { + return &mysqlGrammar{ + serials: []string{ + "bigInteger", "integer", "mediumInteger", "smallInteger", + "tinyInteger", + }, + } +} + +func (g *mysqlGrammar) CompileCurrentDatabase() string { + return "SELECT DATABASE()" +} + +func (g *mysqlGrammar) CompileTableExists(database string, table string) (string, error) { + return fmt.Sprintf( + "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'", + g.QuoteString(database), + g.QuoteString(table), + ), nil +} + +func (g *mysqlGrammar) CompileTables(database string) (string, error) { + return fmt.Sprintf( + "select table_name as `name`, (data_length + index_length) as `size`, "+ + "table_comment as `comment`, engine as `engine`, table_collation as `collation` "+ + "from information_schema.tables where table_schema = %s and table_type in ('BASE TABLE', 'SYSTEM VERSIONED') "+ + "order by table_name", + g.QuoteString(database), + ), nil +} + +func (g *mysqlGrammar) CompileColumns(database, table string) (string, error) { + return fmt.Sprintf( + "select column_name as `name`, data_type as `type_name`, column_type as `type`, "+ + "collation_name as `collation`, is_nullable as `nullable`, "+ + "column_default as `default`, column_comment as `comment`, extra as `extra` "+ + "from information_schema.columns where table_schema = %s and table_name = %s "+ + "order by ordinal_position asc", + g.QuoteString(database), + g.QuoteString(table), + ), nil +} + +func (g *mysqlGrammar) CompileIndexes(database, table string) (string, error) { + return fmt.Sprintf( + "select index_name as `name`, group_concat(column_name order by seq_in_index) as `columns`, "+ + "index_type as `type`, not non_unique as `unique` "+ + "from information_schema.statistics where table_schema = %s and table_name = %s "+ + "group by index_name, index_type, non_unique", + g.QuoteString(database), + g.QuoteString(table), + ), nil +} + +func (g *mysqlGrammar) CompileCreate(blueprint *Blueprint) (string, error) { + sql, err := g.compileCreateTable(blueprint) + if err != nil { + return "", err + } + sql = g.compileCreateEncoding(sql, blueprint) + + return g.compileCreateEngine(sql, blueprint), nil +} + +func (g *mysqlGrammar) compileCreateTable(blueprint *Blueprint) (string, error) { + columns, err := g.getColumns(blueprint) + if err != nil { + return "", err + } + + constraints := g.getConstraints(blueprint) + columns = append(columns, constraints...) + + return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil +} + +func (g *mysqlGrammar) compileCreateEncoding(sql string, blueprint *Blueprint) string { + if blueprint.charset != "" { + sql += fmt.Sprintf(" DEFAULT CHARACTER SET %s", blueprint.charset) + } + if blueprint.collation != "" { + sql += fmt.Sprintf(" COLLATE %s", blueprint.collation) + } + + return sql +} + +func (g *mysqlGrammar) compileCreateEngine(sql string, blueprint *Blueprint) string { + if blueprint.engine != "" { + sql += fmt.Sprintf(" ENGINE = %s", blueprint.engine) + } + return sql +} + +func (g *mysqlGrammar) CompileAdd(blueprint *Blueprint) (string, error) { + if len(blueprint.getAddedColumns()) == 0 { + return "", nil + } + + columns, err := g.getColumns(blueprint) + if err != nil { + return "", err + } + columns = g.PrefixArray("ADD COLUMN ", columns) + constraints := g.getConstraints(blueprint) + constraints = g.PrefixArray("ADD ", constraints) + columns = append(columns, constraints...) + + return fmt.Sprintf("ALTER TABLE %s %s", + blueprint.name, + strings.Join(columns, ", "), + ), nil +} + +func (g *mysqlGrammar) CompileChange(bp *Blueprint, command *command) (string, error) { + column := command.column + if column.name == "" { + return "", fmt.Errorf("column name cannot be empty for change operation") + } + + sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s %s", bp.name, column.name, g.getType(column)) + for _, modifier := range g.modifiers() { + sql += modifier(column) + } + + return sql, nil +} + +func (g *mysqlGrammar) CompileRename(blueprint *Blueprint, command *command) (string, error) { + return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil +} + +func (g *mysqlGrammar) CompileDrop(blueprint *Blueprint) (string, error) { + if blueprint.name == "" { + return "", fmt.Errorf("table name cannot be empty") + } + return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil +} + +func (g *mysqlGrammar) CompileDropIfExists(blueprint *Blueprint) (string, error) { + if blueprint.name == "" { + return "", fmt.Errorf("table name cannot be empty") + } + return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil +} + +func (g *mysqlGrammar) CompileDropColumn(blueprint *Blueprint, command *command) (string, error) { + if len(command.columns) == 0 { + return "", fmt.Errorf("no columns to drop") + } + columns := make([]string, len(command.columns)) + for i, col := range command.columns { + if col == "" { + return "", fmt.Errorf("column name cannot be empty") + } + columns[i] = col + } + columns = g.PrefixArray("DROP COLUMN ", columns) + return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil +} + +func (g *mysqlGrammar) CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { + return "", fmt.Errorf("old and new column names cannot be empty") + } + return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil +} + +func (g *mysqlGrammar) CompileIndex(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("index column cannot be empty") + } + + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "index", command.columns...) + } + + sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns)) + if command.algorithm != "" { + sql += fmt.Sprintf(" USING %s", command.algorithm) + } + + return sql, nil +} + +func (g *mysqlGrammar) CompileUnique(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("unique column cannot be empty") + } + + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "unique", command.columns...) + } + sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns)) + if command.algorithm != "" { + sql += fmt.Sprintf(" USING %s", command.algorithm) + } + + return sql, nil +} + +func (g *mysqlGrammar) CompileFullText(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("fulltext index column cannot be empty") + } + + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "fulltext", command.columns...) + } + + return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns)), nil +} + +func (g *mysqlGrammar) CompilePrimary(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("primary key column cannot be empty") + } + + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "primary", command.columns...) + } + + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.Columnize(command.columns)), nil +} + +func (g *mysqlGrammar) CompileDropIndex(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { + return "", fmt.Errorf("index name cannot be empty") + } + return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil +} + +func (g *mysqlGrammar) CompileDropUnique(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { + return "", fmt.Errorf("unique index name cannot be empty") + } + return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil +} + +func (g *mysqlGrammar) CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) { + return g.CompileDropIndex(blueprint, command) +} + +func (g *mysqlGrammar) CompileDropPrimary(blueprint *Blueprint, _ *command) (string, error) { + return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", blueprint.name), nil +} + +func (g *mysqlGrammar) CompileRenameIndex(blueprint *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { + return "", fmt.Errorf("old and new index names cannot be empty") + } + return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s", blueprint.name, command.from, command.to), nil +} + +func (g *mysqlGrammar) CompileDropForeign(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { + return "", fmt.Errorf("foreign key name cannot be empty") + } + return fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s", blueprint.name, command.index), nil +} + +func (g *mysqlGrammar) GetFluentCommands() []func(*Blueprint, *command) string { + return []func(*Blueprint, *command) string{} +} + +func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) { + var columns []string + for _, col := range blueprint.getAddedColumns() { + if col.name == "" { + return nil, fmt.Errorf("column name cannot be empty") + } + sql := col.name + " " + g.getType(col) + sql += g.modifyUnsigned(col) + sql += g.modifyIncrement(col) + sql += g.modifyDefault(col) + sql += g.modifyOnUpdate(col) + sql += g.modifyCharset(col) + sql += g.modifyCollate(col) + sql += g.modifyNullable(col) + sql += g.modifyComment(col) + + columns = append(columns, sql) + } + + return columns, nil +} + +func (g *mysqlGrammar) getConstraints(blueprint *Blueprint) []string { + var constrains []string + for _, col := range blueprint.getAddedColumns() { + if col.primary != nil && *col.primary { + pkConstraintName := g.CreateIndexName(blueprint, "primary") + sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" + constrains = append(constrains, sql) + continue + } + } + + return constrains +} + +func (g *mysqlGrammar) getType(col *columnDefinition) string { + typeFuncMap := map[string]func(*columnDefinition) string{ + columnTypeChar: g.typeChar, + columnTypeString: g.typeString, + columnTypeTinyText: g.typeTinyText, + columnTypeText: g.typeText, + columnTypeMediumText: g.typeMediumText, + columnTypeLongText: g.typeLongText, + columnTypeInteger: g.typeInteger, + columnTypeBigInteger: g.typeBigInteger, + columnTypeMediumInteger: g.typeMediumInteger, + columnTypeSmallInteger: g.typeSmallInteger, + columnTypeTinyInteger: g.typeTinyInteger, + columnTypeFloat: g.typeFloat, + columnTypeDouble: g.typeDouble, + columnTypeDecimal: g.typeDecimal, + columnTypeBoolean: g.typeBoolean, + columnTypeEnum: g.typeEnum, + columnTypeJson: g.typeJson, + columnTypeJsonb: g.typeJsonb, + columnTypeDate: g.typeDate, + columnTypeDateTime: g.typeDateTime, + columnTypeDateTimeTz: g.typeDateTimeTz, + columnTypeTime: g.typeTime, + columnTypeTimeTz: g.typeTimeTz, + columnTypeTimestamp: g.typeTimestamp, + columnTypeTimestampTz: g.typeTimestampTz, + columnTypeYear: g.typeYear, + columnTypeBinary: g.typeBinary, + columnTypeUuid: g.typeUuid, + columnTypeGeography: g.typeGeography, + columnTypeGeometry: g.typeGeometry, + columnTypePoint: g.typePoint, + } + if fn, ok := typeFuncMap[col.columnType]; ok { + return fn(col) + } + return col.columnType +} + +func (g *mysqlGrammar) typeChar(col *columnDefinition) string { + return fmt.Sprintf("CHAR(%d)", *col.length) +} + +func (g *mysqlGrammar) typeString(col *columnDefinition) string { + return fmt.Sprintf("VARCHAR(%d)", *col.length) +} + +func (g *mysqlGrammar) typeTinyText(col *columnDefinition) string { + return "TINYTEXT" +} + +func (g *mysqlGrammar) typeText(col *columnDefinition) string { + return "TEXT" +} + +func (g *mysqlGrammar) typeMediumText(col *columnDefinition) string { + return "MEDIUMTEXT" +} + +func (g *mysqlGrammar) typeLongText(col *columnDefinition) string { + return "LONGTEXT" +} + +func (g *mysqlGrammar) typeBigInteger(col *columnDefinition) string { + return "BIGINT" +} + +func (g *mysqlGrammar) typeInteger(col *columnDefinition) string { + return "INT" +} + +func (g *mysqlGrammar) typeMediumInteger(col *columnDefinition) string { + return "MEDIUMINT" +} + +func (g *mysqlGrammar) typeSmallInteger(col *columnDefinition) string { + return "SMALLINT" +} + +func (g *mysqlGrammar) typeTinyInteger(col *columnDefinition) string { + return "TINYINT" +} + +func (g *mysqlGrammar) typeFloat(col *columnDefinition) string { + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("FLOAT(%d)", *col.precision) + } + return "FLOAT" +} + +func (g *mysqlGrammar) typeDouble(col *columnDefinition) string { + return "DOUBLE" +} + +func (g *mysqlGrammar) typeDecimal(col *columnDefinition) string { + return fmt.Sprintf("DECIMAL(%d, %d)", *col.total, *col.places) +} + +func (g *mysqlGrammar) typeBoolean(col *columnDefinition) string { + return "TINYINT(1)" +} + +func (g *mysqlGrammar) typeEnum(col *columnDefinition) string { + allowedValues := make([]string, len(col.allowed)) + for i, e := range col.allowed { + allowedValues[i] = g.QuoteString(e) + } + return fmt.Sprintf("ENUM(%s)", strings.Join(allowedValues, ", ")) +} + +func (g *mysqlGrammar) typeJson(col *columnDefinition) string { + return "JSON" +} + +func (g *mysqlGrammar) typeJsonb(col *columnDefinition) string { + return "JSON" +} + +func (g *mysqlGrammar) typeDate(col *columnDefinition) string { + return "DATE" +} + +func (g *mysqlGrammar) typeDateTime(col *columnDefinition) string { + current := "CURRENT_TIMESTAMP" + if col.precision != nil && *col.precision > 0 { + current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision) + } + if col.useCurrent { + col.SetDefault(Expression(current)) + } + if col.useCurrentOnUpdate { + col.SetOnUpdate(Expression(current)) + } + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("DATETIME(%d)", *col.precision) + } + return "DATETIME" +} + +func (g *mysqlGrammar) typeDateTimeTz(col *columnDefinition) string { + return g.typeDateTime(col) +} + +func (g *mysqlGrammar) typeTime(col *columnDefinition) string { + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("TIME(%d)", *col.precision) + } + return "TIME" +} + +func (g *mysqlGrammar) typeTimeTz(col *columnDefinition) string { + return g.typeTime(col) +} + +func (g *mysqlGrammar) typeTimestamp(col *columnDefinition) string { + current := "CURRENT_TIMESTAMP" + if col.precision != nil && *col.precision > 0 { + current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision) + } + if col.useCurrent { + col.SetDefault(Expression(current)) + } + if col.useCurrentOnUpdate { + col.SetOnUpdate(Expression(current)) + } + if col.precision != nil && *col.precision > 0 { + return fmt.Sprintf("TIMESTAMP(%d)", *col.precision) + } + return "TIMESTAMP" +} + +func (g *mysqlGrammar) typeTimestampTz(col *columnDefinition) string { + return g.typeTimestamp(col) +} + +func (g *mysqlGrammar) typeYear(col *columnDefinition) string { + return "YEAR" +} + +func (g *mysqlGrammar) typeBinary(col *columnDefinition) string { + if col.length != nil && *col.length > 0 { + return fmt.Sprintf("BINARY(%d)", *col.length) + } + return "BLOB" +} + +func (g *mysqlGrammar) typeUuid(col *columnDefinition) string { + return "CHAR(36)" // Default UUID length +} + +func (g *mysqlGrammar) typeGeometry(col *columnDefinition) string { + subtype := util.Ternary(col.subtype != nil, util.PtrOf(strings.ToUpper(*col.subtype)), nil) + if subtype != nil { + if !slices.Contains([]string{"POINT", "LINESTRING", "POLYGON", "GEOMETRYCOLLECTION", "MULTIPOINT", "MULTILINESTRING"}, *subtype) { + subtype = nil + } + } + + if subtype == nil { + subtype = util.PtrOf("GEOMETRY") + } + if col.srid != nil && *col.srid > 0 { + return fmt.Sprintf("%s SRID %d", *subtype, *col.srid) + } + return *subtype +} + +func (g *mysqlGrammar) typeGeography(col *columnDefinition) string { + return g.typeGeometry(col) +} + +func (g *mysqlGrammar) typePoint(col *columnDefinition) string { + col.SetSubtype(util.PtrOf("POINT")) + return g.typeGeometry(col) +} + +func (g *mysqlGrammar) modifiers() []func(*columnDefinition) string { + return []func(*columnDefinition) string{ + g.modifyUnsigned, + g.modifyCharset, + g.modifyCollate, + g.modifyNullable, + g.modifyDefault, + g.modifyOnUpdate, + g.modifyIncrement, + g.modifyComment, + } +} + +func (g *mysqlGrammar) modifyCharset(col *columnDefinition) string { + if col.charset != nil && *col.charset != "" { + return fmt.Sprintf(" CHARACTER SET %s", *col.charset) + } + return "" +} + +func (g *mysqlGrammar) modifyCollate(col *columnDefinition) string { + if col.collation != nil && *col.collation != "" { + return fmt.Sprintf(" COLLATE %s", *col.collation) + } + return "" +} + +func (g *mysqlGrammar) modifyComment(col *columnDefinition) string { + if col.comment != nil { + return fmt.Sprintf(" COMMENT '%s'", *col.comment) + } + return "" +} + +func (g *mysqlGrammar) modifyDefault(col *columnDefinition) string { + if col.hasCommand("default") { + return fmt.Sprintf(" DEFAULT %s", g.GetDefaultValue(col.defaultValue)) + } + return "" +} + +func (g *mysqlGrammar) modifyIncrement(col *columnDefinition) string { + if slices.Contains(g.serials, col.columnType) && + col.autoIncrement != nil && *col.autoIncrement && + col.primary != nil && *col.primary { + return " AUTO_INCREMENT" + } + return "" +} + +func (g *mysqlGrammar) modifyNullable(col *columnDefinition) string { + if col.nullable != nil && *col.nullable { + return " NULL" + } + return " NOT NULL" +} + +func (g *mysqlGrammar) modifyOnUpdate(col *columnDefinition) string { + if col.hasCommand("onUpdate") { + return fmt.Sprintf(" ON UPDATE %s", g.GetValue(col.onUpdateValue)) + } + return "" +} + +func (g *mysqlGrammar) modifyUnsigned(col *columnDefinition) string { + if col.unsigned != nil && *col.unsigned { + return " UNSIGNED" + } + return "" +} diff --git a/mysql_grammar_test.go b/schema/mysql_grammar_test.go similarity index 90% rename from mysql_grammar_test.go rename to schema/mysql_grammar_test.go index 583c18e..91534f2 100644 --- a/mysql_grammar_test.go +++ b/schema/mysql_grammar_test.go @@ -3,6 +3,7 @@ package schema import ( "testing" + "github.com/akfaiz/migris/internal/dialect" "github.com/stretchr/testify/assert" ) @@ -82,7 +83,7 @@ func TestMysqlGrammar_CompileCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileCreate(bp) + got, err := g.CompileCreate(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -172,7 +173,7 @@ func TestMysqlGrammar_CompileAdd(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileAdd(bp) + got, err := g.CompileAdd(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -206,7 +207,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.Integer("age").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN age INT"}, + want: []string{"ALTER TABLE users MODIFY COLUMN age INT NOT NULL"}, wantErr: false, }, { @@ -233,16 +234,16 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.String("status", 50).Default("active").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN status VARCHAR(50) DEFAULT 'active'"}, + want: []string{"ALTER TABLE users MODIFY COLUMN status VARCHAR(50) NOT NULL DEFAULT 'active'"}, wantErr: false, }, { name: "change column with null default", table: "users", blueprint: func(table *Blueprint) { - table.Text("description").Default(nil).Change() + table.Text("description").Nullable().Default(nil).Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN description TEXT DEFAULT NULL"}, + want: []string{"ALTER TABLE users MODIFY COLUMN description TEXT NULL DEFAULT NULL"}, wantErr: false, }, { @@ -251,7 +252,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.Integer("age").Comment("User age in years").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN age INT COMMENT 'User age in years'"}, + want: []string{"ALTER TABLE users MODIFY COLUMN age INT NOT NULL COMMENT 'User age in years'"}, wantErr: false, }, { @@ -260,7 +261,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.Text("notes").Comment("").Change() }, - want: []string{"ALTER TABLE users MODIFY COLUMN notes TEXT COMMENT ''"}, + want: []string{"ALTER TABLE users MODIFY COLUMN notes TEXT NOT NULL COMMENT ''"}, wantErr: false, }, { @@ -284,7 +285,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { table.SmallInteger("age").Nullable().Change() }, want: []string{ - "ALTER TABLE users MODIFY COLUMN name VARCHAR(200)", + "ALTER TABLE users MODIFY COLUMN name VARCHAR(200) NOT NULL", "ALTER TABLE users MODIFY COLUMN age SMALLINT NULL", }, wantErr: false, @@ -301,15 +302,15 @@ func TestMysqlGrammar_CompileChange(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} + bp := &Blueprint{name: tt.table, grammar: g, dialect: dialect.MySQL} tt.blueprint(bp) - got, err := g.compileChange(bp) + statements, err := bp.toSql() if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return } assert.NoError(t, err, "Did not expect error for test case: %s", tt.name) - assert.Equal(t, tt.want, got, "Expected SQL to match for test case: %s", tt.name) + assert.Equal(t, tt.want, statements, "Expected SQL to match for test case: %s", tt.name) }) } } @@ -345,21 +346,15 @@ func TestMysqlGrammar_CompileRename(t *testing.T) { want: "ALTER TABLE table1 RENAME TO table2", wantErr: false, }, - { - name: "empty new name should return error", - table: "users", - newName: "", - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{ - name: tt.table, - newName: tt.newName, + name: tt.table, } - got, err := g.compileRename(bp) + bp.rename(tt.newName) + got, err := g.CompileRename(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -413,7 +408,7 @@ func TestMysqlGrammar_CompileDrop(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDrop(bp) + got, err := g.CompileDrop(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -467,7 +462,7 @@ func TestMysqlGrammar_CompileDropIfExists(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropIfExists(bp) + got, err := g.CompileDropIfExists(bp) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -506,12 +501,6 @@ func TestMysqlGrammar_CompileDropColumn(t *testing.T) { want: "ALTER TABLE users DROP COLUMN email, DROP COLUMN phone, DROP COLUMN address", wantErr: false, }, - { - name: "no columns to drop should return error", - table: "users", - blueprint: func(table *Blueprint) {}, - wantErr: true, - }, { name: "empty column name should return error", table: "users", @@ -536,7 +525,7 @@ func TestMysqlGrammar_CompileDropColumn(t *testing.T) { name: tt.table, } tt.blueprint(bp) - got, err := g.compileDropColumn(bp) + got, err := g.CompileDropColumn(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -592,7 +581,8 @@ func TestMysqlGrammar_CompileRenameColumn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileRenameColumn(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := g.CompileRenameColumn(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -714,7 +704,7 @@ func TestMysqlGrammar_CompileForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileForeign(bp, bp.foreignKeys[0]) + got, err := g.CompileForeign(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -753,7 +743,8 @@ func TestMysqlGrammar_CompileDropForeign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropForeign(bp, tt.fkName) + command := &command{index: tt.fkName} + got, err := g.CompileDropForeign(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -833,7 +824,7 @@ func TestMysqlGrammar_CompileIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileIndex(bp, bp.indexes[0]) + got, err := g.CompileIndex(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -921,7 +912,7 @@ func TestMysqlGrammar_CompileUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileUnique(bp, bp.indexes[0]) + got, err := g.CompileUnique(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1017,7 +1008,7 @@ func TestMysqlGrammar_CompilePrimary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compilePrimary(bp, bp.indexes[0]) + got, err := g.CompilePrimary(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1113,7 +1104,7 @@ func TestMysqlGrammar_CompileFullText(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := g.compileFullText(bp, bp.indexes[0]) + got, err := g.CompileFullText(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1152,7 +1143,8 @@ func TestMysqlGrammar_CompileDropIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropIndex(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.CompileDropIndex(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1190,7 +1182,8 @@ func TestMysqlGrammar_CompileDropUnique(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropUnique(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.CompileDropUnique(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1228,7 +1221,8 @@ func TestMysqlGrammar_CompileDropFulltext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropFulltext(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.CompileDropFulltext(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1284,18 +1278,13 @@ func TestMysqlGrammar_CompileDropPrimary(t *testing.T) { want: "ALTER TABLE orders DROP PRIMARY KEY", wantErr: false, }, - { - name: "empty primary key index name should return error", - table: "users", - indexName: "", - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileDropPrimary(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := g.CompileDropPrimary(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1383,7 +1372,8 @@ func TestMysqlGrammar_CompileRenameIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := g.compileRenameIndex(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := g.CompileRenameIndex(bp, command) if tt.wantErr { assert.Error(t, err, "Expected error for test case: %s", tt.name) return @@ -1438,32 +1428,25 @@ func TestMysqlGrammar_GetType(t *testing.T) { want: "DECIMAL(10, 2)", }, { - name: "double column type with precision", - blueprint: func(table *Blueprint) { - table.Double("value", 8, 2) - }, - want: "DOUBLE(8, 2)", - }, - { - name: "double column type without precision", + name: "double column type", blueprint: func(table *Blueprint) { - table.Double("value", 0, 0) + table.Double("value") }, want: "DOUBLE", }, { name: "float column type with precision", blueprint: func(table *Blueprint) { - table.Float("value", 6, 2) + table.Float("value", 6) }, - want: "DOUBLE(6, 2)", + want: "FLOAT(6)", }, { name: "float column type without precision", blueprint: func(table *Blueprint) { - table.Float("value", 0, 0) + table.Float("value") }, - want: "DOUBLE", + want: "FLOAT(53)", }, { name: "big integer column type", @@ -1472,27 +1455,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "BIGINT", }, - { - name: "big integer unsigned", - blueprint: func(table *Blueprint) { - table.BigInteger("id").Unsigned() - }, - want: "BIGINT UNSIGNED", - }, - { - name: "big integer auto increment", - blueprint: func(table *Blueprint) { - table.BigInteger("id").AutoIncrement() - }, - want: "BIGINT AUTO_INCREMENT", - }, - { - name: "big integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.BigInteger("id").Unsigned().AutoIncrement() - }, - want: "BIGINT UNSIGNED AUTO_INCREMENT", - }, { name: "integer column type", blueprint: func(table *Blueprint) { @@ -1500,27 +1462,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "INT", }, - { - name: "integer unsigned", - blueprint: func(table *Blueprint) { - table.Integer("count").Unsigned() - }, - want: "INT UNSIGNED", - }, - { - name: "integer auto increment", - blueprint: func(table *Blueprint) { - table.Integer("id").AutoIncrement() - }, - want: "INT AUTO_INCREMENT", - }, - { - name: "integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.Increments("id") - }, - want: "INT UNSIGNED AUTO_INCREMENT", - }, { name: "small integer column type", blueprint: func(table *Blueprint) { @@ -1528,13 +1469,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "SMALLINT", }, - { - name: "small integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.SmallInteger("id").Unsigned().AutoIncrement() - }, - want: "SMALLINT UNSIGNED AUTO_INCREMENT", - }, { name: "medium integer column type", blueprint: func(table *Blueprint) { @@ -1542,13 +1476,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "MEDIUMINT", }, - { - name: "medium integer unsigned", - blueprint: func(table *Blueprint) { - table.UnsignedMediumInteger("value") - }, - want: "MEDIUMINT UNSIGNED", - }, { name: "small integer column type", blueprint: func(table *Blueprint) { @@ -1556,13 +1483,6 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "SMALLINT", }, - { - name: "small integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.SmallIncrements("id") - }, - want: "SMALLINT UNSIGNED AUTO_INCREMENT", - }, { name: "tiny integer column type", blueprint: func(table *Blueprint) { @@ -1570,26 +1490,12 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "TINYINT", }, - { - name: "tiny integer auto increment", - blueprint: func(table *Blueprint) { - table.TinyInteger("id").AutoIncrement() - }, - want: "TINYINT AUTO_INCREMENT", - }, - { - name: "tiny integer unsigned auto increment", - blueprint: func(table *Blueprint) { - table.TinyIncrements("id") - }, - want: "TINYINT UNSIGNED AUTO_INCREMENT", - }, { name: "time column type", blueprint: func(table *Blueprint) { table.Time("created_at") }, - want: "TIME(0)", + want: "TIME", }, { name: "datetime column type with precision", @@ -1647,19 +1553,12 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "TIMESTAMP", }, - { - name: "geography column type", - blueprint: func(table *Blueprint) { - table.Geography("location", "POINT", 4326) - }, - want: "GEOGRAPHY(POINT, 4326)", - }, { name: "enum column type", blueprint: func(table *Blueprint) { table.Enum("status", []string{"active", "inactive", "pending"}) }, - want: "ENUM('active','inactive','pending')", + want: "ENUM('active', 'inactive', 'pending')", }, { name: "long text column type", @@ -1722,7 +1621,7 @@ func TestMysqlGrammar_GetType(t *testing.T) { blueprint: func(table *Blueprint) { table.UUID("uuid") }, - want: "UUID", + want: "CHAR(36)", }, { name: "binary column type", @@ -1731,19 +1630,26 @@ func TestMysqlGrammar_GetType(t *testing.T) { }, want: "BLOB", }, + { + name: "geography column type", + blueprint: func(table *Blueprint) { + table.Geography("location", "LINESTRING", 4326) + }, + want: "LINESTRING SRID 4326", + }, { name: "geometry column type", blueprint: func(table *Blueprint) { - table.Geometry("shape", "GEOMETRY", 4326) + table.Geometry("shape", "", 4326) }, - want: "GEOMETRY", + want: "GEOMETRY SRID 4326", }, { name: "point column type", blueprint: func(table *Blueprint) { table.Point("location") }, - want: "POINT", + want: "POINT SRID 4326", }, } @@ -1791,14 +1697,6 @@ func TestMysqlGrammar_GetColumns(t *testing.T) { want: []string{"status VARCHAR(50) DEFAULT 'active' NOT NULL"}, wantErr: false, }, - { - name: "column with on update value", - blueprint: func(table *Blueprint) { - table.Timestamp("updated_at", 0).UseCurrentOnUpdate() - }, - want: []string{"updated_at TIMESTAMP ON UPDATE CURRENT_TIMESTAMP NOT NULL"}, - wantErr: false, - }, { name: "nullable column", blueprint: func(table *Blueprint) { @@ -1850,21 +1748,13 @@ func TestMysqlGrammar_GetColumns(t *testing.T) { want: []string{"id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL"}, wantErr: false, }, - { - name: "timestamp with default", - blueprint: func(table *Blueprint) { - table.Timestamp("created_at", 0).UseCurrent() - }, - want: []string{"created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL"}, - wantErr: false, - }, { name: "multiple columns with different attributes", blueprint: func(table *Blueprint) { table.BigInteger("id").Unsigned().AutoIncrement().Primary() table.String("name", 255).Comment("User name") table.String("email", 255).Nullable() - table.Timestamp("created_at", 0).Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at", 0).UseCurrent() }, want: []string{ "id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL", diff --git a/postgres_builder.go b/schema/postgres_builder.go similarity index 68% rename from postgres_builder.go rename to schema/postgres_builder.go index f46b9fd..be214e3 100644 --- a/postgres_builder.go +++ b/schema/postgres_builder.go @@ -1,7 +1,6 @@ package schema import ( - "context" "database/sql" "errors" "strings" @@ -9,11 +8,11 @@ import ( type postgresBuilder struct { baseBuilder - grammar *pgGrammar + grammar *postgresGrammar } func newPostgresBuilder() Builder { - grammar := newPgGrammar() + grammar := newPostgresGrammar() return &postgresBuilder{ baseBuilder: baseBuilder{grammar: grammar}, @@ -29,21 +28,21 @@ func (b *postgresBuilder) parseSchemaAndTable(name string) (string, string) { return "", names[0] } -func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err +func (b *postgresBuilder) GetColumns(c *Context, tableName string) ([]*Column, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } schema, name := b.parseSchemaAndTable(tableName) if schema == "" { schema = "public" // Default schema for PostgreSQL } - query, err := b.grammar.compileColumns(schema, name) + query, err := b.grammar.CompileColumns(schema, name) if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -64,19 +63,19 @@ func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName return columns, nil } -func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return nil, err +func (b *postgresBuilder) GetIndexes(c *Context, tableName string) ([]*Index, error) { + if c == nil || tableName == "" { + return nil, errors.New("invalid arguments: context is nil or table name is empty") } schema, name := b.parseSchemaAndTable(tableName) if schema == "" { schema = "public" // Default schema for PostgreSQL } - query, err := b.grammar.compileIndexes(schema, name) + query, err := b.grammar.CompileIndexes(schema, name) if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -96,17 +95,17 @@ func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName return indexes, nil } -func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { - if tx == nil { - return nil, errors.New("transaction is nil") +func (b *postgresBuilder) GetTables(c *Context) ([]*TableInfo, error) { + if c == nil { + return nil, errors.New("invalid arguments: context is nil") } - query, err := b.grammar.compileTables() + query, err := b.grammar.CompileTables() if err != nil { return nil, err } - rows, err := queryContext(ctx, tx, query) + rows, err := c.Query(query) if err != nil { return nil, err } @@ -124,15 +123,18 @@ func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableIn return tables, nil } -func (b *postgresBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) { - return b.HasColumns(ctx, tx, tableName, []string{columnName}) +func (b *postgresBuilder) HasColumn(c *Context, tableName string, columnName string) (bool, error) { + return b.HasColumns(c, tableName, []string{columnName}) } -func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { +func (b *postgresBuilder) HasColumns(c *Context, tableName string, columnNames []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") + } if len(columnNames) == 0 { return false, errors.New("no column names provided") } - existingColumns, err := b.GetColumns(ctx, tx, tableName) + existingColumns, err := b.GetColumns(c, tableName) if err != nil { return false, err } @@ -157,12 +159,12 @@ func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName return true, nil // All specified columns exist } -func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { - if err := b.validateTxAndName(tx, tableName); err != nil { - return false, err +func (b *postgresBuilder) HasIndex(c *Context, tableName string, indexes []string) (bool, error) { + if c == nil || tableName == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } - existingIndexes, err := b.GetIndexes(ctx, tx, tableName) + existingIndexes, err := b.GetIndexes(c, tableName) if err != nil { return false, err } @@ -201,22 +203,22 @@ func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName st return false, nil // If no specified index exists, return false } -func (b *postgresBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { - if err := b.validateTxAndName(tx, name); err != nil { - return false, err +func (b *postgresBuilder) HasTable(c *Context, name string) (bool, error) { + if c == nil || name == "" { + return false, errors.New("invalid arguments: context is nil or table name is empty") } schema, name := b.parseSchemaAndTable(name) if schema == "" { schema = "public" // Default schema for PostgreSQL } - query, err := b.grammar.compileTableExists(schema, name) + query, err := b.grammar.CompileTableExists(schema, name) if err != nil { return false, err } var exists bool - if err := queryRowContext(ctx, tx, query).Scan(&exists); err != nil { + if err := c.QueryRow(query).Scan(&exists); err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil // Table does not exist } diff --git a/postgres_builder_test.go b/schema/postgres_builder_test.go similarity index 59% rename from postgres_builder_test.go rename to schema/postgres_builder_test.go index 334aea7..2b5b0b5 100644 --- a/postgres_builder_test.go +++ b/schema/postgres_builder_test.go @@ -7,7 +7,7 @@ import ( "os" "testing" - "github.com/afkdevs/go-schema" + "github.com/akfaiz/migris/schema" _ "github.com/lib/pq" "github.com/stretchr/testify/suite" ) @@ -16,12 +16,6 @@ func TestPostgresBuilderSuite(t *testing.T) { suite.Run(t, new(postgresBuilderSuite)) } -type postgresBuilderSuite struct { - suite.Suite - ctx context.Context - db *sql.DB -} - type dbConfig struct { Database string Username string @@ -43,6 +37,13 @@ func parseTestConfig() dbConfig { } } +type postgresBuilderSuite struct { + suite.Suite + ctx context.Context + db *sql.DB + builder schema.Builder +} + func (s *postgresBuilderSuite) SetupSuite() { s.ctx = context.Background() @@ -57,7 +58,8 @@ func (s *postgresBuilderSuite) SetupSuite() { s.Require().NoError(err) s.db = db - schema.SetDebug(true) + s.builder, err = schema.NewBuilder("postgres") + s.Require().NoError(err) } func (s *postgresBuilderSuite) TearDownSuite() { @@ -65,48 +67,49 @@ func (s *postgresBuilderSuite) TearDownSuite() { } func (s *postgresBuilderSuite) TestCreate() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Create(s.ctx, nil, "test_table", func(table *schema.Blueprint) {}) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Create(nil, "test_table", func(table *schema.Blueprint) {}) + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - err := builder.Create(s.ctx, tx, "", func(table *schema.Blueprint) {}) + err := builder.Create(c, "", func(table *schema.Blueprint) {}) s.Error(err, "expected error when table name is empty") }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Create(s.ctx, tx, "test_table", nil) + err := builder.Create(c, "test_table", nil) s.Error(err, "expected error when blueprint is nil") }) s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() table.Timestamps() }) s.NoError(err, "expected no error when creating table with valid parameters") }) s.Run("when use custom schema should create it successfully", func() { - _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_publics") + _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_public") s.NoError(err, "expected no error when creating custom schema") - err = builder.Create(context.Background(), tx, "custom_publics.users", func(table *schema.Blueprint) { + err = builder.Create(c, "custom_public.users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() table.TimestampsTz() }) s.NoError(err, "expected no error when creating table with custom schema") }) s.Run("when have composite primary key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "user_roles", func(table *schema.Blueprint) { + err = builder.Create(c, "user_roles", func(table *schema.Blueprint) { table.Integer("user_id") table.Integer("role_id") @@ -115,280 +118,249 @@ func (s *postgresBuilderSuite) TestCreate() { s.NoError(err, "expected no error when creating table with composite primary key") }) s.Run("when have foreign key should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.BigInteger("user_id") - table.String("order_id", 255).Unique() + table.String("order_id").Unique() table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when creating table with foreign key") }) s.Run("when have custom index should create it successfully", func() { - err = builder.Create(context.Background(), tx, "orders_2", func(table *schema.Blueprint) { + err = builder.Create(c, "orders_2", func(table *schema.Blueprint) { table.ID() - table.String("order_id", 255).Unique("uk_orders_2_order_id") + table.String("order_id").Unique("uk_orders_2_order_id") table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Index("created_at").Name("idx_orders_created_at").Algorithm("BTREE") }) s.NoError(err, "expected no error when creating table with custom index") }) s.Run("when table already exists, should return error", func() { - err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() + table.String("name") + table.String("email").Unique() }) s.Error(err, "expected error when creating table that already exists") }) } -func (s *postgresBuilderSuite) TestCreateIfNotExists() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) - tx, err := s.db.BeginTx(s.ctx, nil) - s.Require().NoError(err) - defer tx.Rollback() //nolint:errcheck - - s.Run("when tx is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, nil, "test_table", func(table *schema.Blueprint) {}) - s.Error(err, "expected error when transaction is nil") - }) - s.Run("when table name is empty, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) {}) - s.Error(err, "expected error when table name is empty") - }) - s.Run("when blueprint is nil, should return error", func() { - err := builder.CreateIfNotExists(s.ctx, tx, "test_table", nil) - s.Error(err, "expected error when blueprint is nil") - }) - s.Run("when all parameters are valid, should create table successfully", func() { - err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") - }) - s.NoError(err, "expected no error when creating table with valid parameters") - }) - s.Run("when table already exists, should not return error", func() { - err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name", 255) - table.String("email", 255) - }) - s.NoError(err, "expected no error when creating table that already exists") - }) -} - func (s *postgresBuilderSuite) TestDrop() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Drop(nil, "test_table") + s.Error(err, "expected error when context is nil") + }) s.Run("when table name is empty, should return error", func() { - err := builder.Drop(s.ctx, nil, "") + err := builder.Drop(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.Drop(s.ctx, tx, "users") + err = builder.Drop(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should return error", func() { - err = builder.Drop(s.ctx, tx, "non_existent_table") + err = builder.Drop(c, "non_existent_table") s.Error(err, "expected error when dropping table that does not exist") }) } func (s *postgresBuilderSuite) TestDropIfExists() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.DropIfExists(nil, "test_table") + s.Error(err, "expected error when context is nil") + }) s.Run("when table name is empty, should return error", func() { - err := builder.DropIfExists(s.ctx, nil, "") + err := builder.DropIfExists(c, "") s.Error(err, "expected error when table name is empty") }) - s.Run("when tx is nil, should return error", func() { - err = builder.DropIfExists(s.ctx, nil, "test_table") - s.Error(err, "expected error when transaction is nil") + s.Run("when context is nil, should return error", func() { + err = builder.DropIfExists(nil, "test_table") + s.Error(err, "expected error when context is nil") }) s.Run("when all parameters are valid, should drop table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255).Unique() - table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.String("name") + table.String("email").Unique() + table.String("password").Nullable() + table.Timestamps() }) s.NoError(err, "expected no error when creating table before dropping it") - err = builder.DropIfExists(s.ctx, tx, "users") + err = builder.DropIfExists(c, "users") s.NoError(err, "expected no error when dropping table with valid parameters") }) s.Run("when table does not exist, should not return error", func() { - err = builder.DropIfExists(s.ctx, tx, "non_existent_table") + err = builder.DropIfExists(c, "non_existent_table") s.NoError(err, "expected no error when dropping non-existent table with IF EXISTS clause") }) } func (s *postgresBuilderSuite) TestRename() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Rename(s.ctx, nil, "old_table", "new_table") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Rename(nil, "old_table", "new_table") + s.Error(err, "expected error when context is nil") }) s.Run("when old table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "", "new_table") + err := builder.Rename(c, "", "new_table") s.Error(err, "expected error when old table name is empty") }) s.Run("when new table name is empty, should return error", func() { - err := builder.Rename(s.ctx, tx, "old_table", "") + err := builder.Rename(c, "old_table", "") s.Error(err, "expected error when new table name is empty") }) s.Run("when all parameters are valid, should rename table successfully", func() { - err = builder.Create(s.ctx, tx, "old_table", func(table *schema.Blueprint) { + err = builder.Create(c, "old_table", func(table *schema.Blueprint) { table.ID() table.String("name", 255) }) s.NoError(err, "expected no error when creating table before renaming it") - err = builder.Rename(s.ctx, tx, "old_table", "new_table") + err = builder.Rename(c, "old_table", "new_table") s.NoError(err, "expected no error when renaming table with valid parameters") }) s.Run("when renaming non-existent table, should return error", func() { - err = builder.Rename(s.ctx, tx, "non_existent_table", "new_table") + err = builder.Rename(c, "non_existent_table", "new_table") s.Error(err, "expected error when renaming non-existent table") s.ErrorContains(err, "does not exist", "expected error message to contain 'does not exist'") }) } func (s *postgresBuilderSuite) TestTable() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - err := builder.Table(s.ctx, nil, "test_table", func(table *schema.Blueprint) {}) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + err := builder.Table(nil, "test_table", func(table *schema.Blueprint) {}) + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - err := builder.Table(s.ctx, tx, "", func(table *schema.Blueprint) {}) + err := builder.Table(c, "", func(table *schema.Blueprint) {}) s.Error(err, "expected error when table name is empty") }) s.Run("when blueprint is nil, should return error", func() { - err := builder.Table(s.ctx, tx, "test_table", nil) + err := builder.Table(c, "test_table", nil) s.Error(err, "expected error when blueprint is nil") }) s.Run("when all parameters are valid, should modify table successfully", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique("uk_users_email") table.String("password", 255).Nullable() table.Text("bio").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.FullText("bio") }) s.NoError(err, "expected no error when creating table before modifying it") s.Run("should add new columns and modify existing ones", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("address", 255).Nullable() table.String("phone", 20).Nullable().Unique("uk_users_phone") }) s.NoError(err, "expected no error when modifying table with valid parameters") }) s.Run("should modify existing column", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.String("email", 255).Nullable().Change() }) s.NoError(err, "expected no error when modifying existing column") }) s.Run("should drop column and rename existing one", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropColumn("password") table.RenameColumn("name", "full_name") }) s.NoError(err, "expected no error when dropping column and renaming existing one") }) s.Run("should add index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.Index("phone").Name("idx_users_phone").Algorithm("BTREE") }) s.NoError(err, "expected no error when adding index to table") }) s.Run("should rename index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.RenameIndex("idx_users_phone", "idx_users_contact") }) s.NoError(err, "expected no error when renaming index in table") }) s.Run("should drop index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropIndex("idx_users_contact") }) s.NoError(err, "expected no error when dropping index from table") }) s.Run("should drop unique constraint", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.DropUnique("uk_users_email") + err = builder.Table(c, "users", func(table *schema.Blueprint) { + table.DropUnique([]string{"email"}) }) s.NoError(err, "expected no error when dropping unique constraint from table") }) s.Run("should drop fulltext index", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropFulltext("ft_users_bio") }) s.NoError(err, "expected no error when dropping fulltext index from table") }) s.Run("should add foreign key", func() { - err = builder.Create(s.ctx, tx, "roles", func(table *schema.Blueprint) { + err = builder.Create(c, "roles", func(table *schema.Blueprint) { table.ID() table.String("role_name", 255).Unique("uk_roles_role_name") }) s.NoError(err, "expected no error when creating roles table before adding foreign key") - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.Integer("role_id").Nullable() table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE") }) s.NoError(err, "expected no error when adding foreign key to users table") }) s.Run("should drop foreign key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropForeign("fk_users_roles") }) s.NoError(err, "expected no error when dropping foreign key from users table") }) s.Run("should drop primary key", func() { - err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Table(c, "users", func(table *schema.Blueprint) { table.DropPrimary("pk_users") }) s.NoError(err, "expected no error when dropping primary key from users table") @@ -397,105 +369,105 @@ func (s *postgresBuilderSuite) TestTable() { } func (s *postgresBuilderSuite) TestGetColumns() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetColumns(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetColumns(nil, "users") + s.Error(err, "expected error when context is nil") }) s.Run("when table name is empty, should return error", func() { - _, err := builder.GetColumns(s.ctx, tx, "") + _, err := builder.GetColumns(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting columns") - columns, err := builder.GetColumns(s.ctx, tx, "users") + columns, err := builder.GetColumns(c, "users") s.NoError(err, "expected no error when getting columns with valid parameters") s.Len(columns, 6, "expected 6 columns to be returned") }) s.Run("when table does not exist, should return empty columns", func() { - columns, err := builder.GetColumns(s.ctx, tx, "non_existent_table") + columns, err := builder.GetColumns(c, "non_existent_table") s.NoError(err, "expected no error when getting columns of non-existent table") s.Empty(columns, "expected empty columns for non-existent table") }) } func (s *postgresBuilderSuite) TestGetIndexes() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetIndexes(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetIndexes(nil, "users") + s.Error(err, "expected error when contexts is nil") }) s.Run("when table name is empty, should return error", func() { - _, err := builder.GetIndexes(s.ctx, tx, "") + _, err := builder.GetIndexes(c, "") s.Error(err, "expected error when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() table.Index("name").Name("idx_users_name") }) s.NoError(err, "expected no error when creating table before getting indexes") - indexes, err := builder.GetIndexes(s.ctx, tx, "users") + indexes, err := builder.GetIndexes(c, "users") s.NoError(err, "expected no error when getting indexes with valid parameters") s.Len(indexes, 3, "expected 3 index to be returned") }) s.Run("when table does not exist, should return empty indexes", func() { - indexes, err := builder.GetIndexes(s.ctx, tx, "non_existent_table") + indexes, err := builder.GetIndexes(c, "non_existent_table") s.NoError(err, "expected no error when getting indexes of non-existent table") s.Empty(indexes, "expected empty indexes for non-existent table") }) } func (s *postgresBuilderSuite) TestGetTables() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - _, err := builder.GetTables(s.ctx, nil) - s.Error(err, "expected error when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + _, err := builder.GetTables(nil) + s.Error(err, "expected error when context is nil") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before getting tables") - tables, err := builder.GetTables(s.ctx, tx) + tables, err := builder.GetTables(c) s.NoError(err, "expected no error when getting tables with valid parameters") s.Len(tables, 1, "expected 1 table to be returned") userTable := tables[0] @@ -506,159 +478,160 @@ func (s *postgresBuilderSuite) TestGetTables() { } func (s *postgresBuilderSuite) TestHasColumn() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumn(s.ctx, nil, "users", "name") + exists, err := builder.HasColumn(nil, "users", "name") s.Error(err, "expected error when transaction is nil") s.False(exists, "expected exists to be false when transaction is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumn(s.ctx, tx, "", "name") + exists, err := builder.HasColumn(c, "", "name") s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking column existence") - exists, err := builder.HasColumn(s.ctx, tx, "users", "name") + exists, err := builder.HasColumn(c, "users", "name") s.NoError(err, "expected no error when checking if column exists with valid parameters") s.True(exists, "expected exists to be true for existing column") - exists, err = builder.HasColumn(s.ctx, tx, "users", "non_existent_column") + exists, err = builder.HasColumn(c, "users", "non_existent_column") s.NoError(err, "expected no error when checking non-existent column") s.False(exists, "expected exists to be false for non-existent column") }) } func (s *postgresBuilderSuite) TestHasColumns() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasColumns(s.ctx, nil, "users", []string{"name"}) + exists, err := builder.HasColumns(nil, "users", []string{"name"}) s.Error(err, "expected error when transaction is nil") s.False(exists, "expected exists to be false when transaction is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasColumns(s.ctx, tx, "", []string{"name"}) + exists, err := builder.HasColumns(c, "", []string{"name"}) s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking column existence") - exists, err := builder.HasColumns(s.ctx, tx, "users", []string{"name", "email"}) + exists, err := builder.HasColumns(c, "users", []string{"name", "email"}) s.NoError(err, "expected no error when checking if columns exist with valid parameters") s.True(exists, "expected exists to be true for existing columns") - exists, err = builder.HasColumns(s.ctx, tx, "users", []string{"name", "non_existent_column"}) + exists, err = builder.HasColumns(c, "users", []string{"name", "non_existent_column"}) s.NoError(err, "expected no error when checking mixed existing and non-existent columns") s.False(exists, "expected exists to be false for mixed existing and non-existent columns") }) } func (s *postgresBuilderSuite) TestHasIndex() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasIndex(s.ctx, nil, "users", []string{"idx_users_name"}) - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasIndex(nil, "users", []string{"idx_users_name"}) + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasIndex(s.ctx, tx, "", []string{"idx_users_name"}) + exists, err := builder.HasIndex(c, "", []string{"idx_users_name"}) s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "orders", func(table *schema.Blueprint) { + err = builder.Create(c, "orders", func(table *schema.Blueprint) { table.ID() table.Integer("company_id") table.Integer("user_id") table.String("order_id", 255) table.Decimal("amount", 10, 2) - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() table.Index("company_id", "user_id") table.Unique("order_id").Name("uk_orders_order_id") }) s.Require().NoError(err, "expected no error when creating table with index") - exists, err := builder.HasIndex(s.ctx, tx, "orders", []string{"uk_orders_order_id"}) + exists, err := builder.HasIndex(c, "orders", []string{"uk_orders_order_id"}) s.NoError(err, "expected no error when checking if index exists with valid parameters") s.True(exists, "expected exists to be true for existing index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"company_id", "user_id"}) + exists, err = builder.HasIndex(c, "orders", []string{"company_id", "user_id"}) s.NoError(err, "expected no error when checking non-existent index") s.True(exists, "expected exists to be true for existing composite index") - exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"non_existent_index"}) + exists, err = builder.HasIndex(c, "orders", []string{"non_existent_index"}) s.NoError(err, "expected no error when checking non-existent index") s.False(exists, "expected exists to be false for non-existent index") }) } func (s *postgresBuilderSuite) TestHasTable() { - builder, err := schema.NewBuilder("postgres") - s.Require().NoError(err) + builder := s.builder tx, err := s.db.BeginTx(s.ctx, nil) s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck - s.Run("when tx is nil, should return error", func() { - exists, err := builder.HasTable(s.ctx, nil, "users") - s.Error(err, "expected error when transaction is nil") - s.False(exists, "expected exists to be false when transaction is nil") + c := schema.NewContext(s.ctx, tx) + + s.Run("when context is nil, should return error", func() { + exists, err := builder.HasTable(nil, "users") + s.Error(err, "expected error when context is nil") + s.False(exists, "expected exists to be false when context is nil") }) s.Run("when table name is empty, should return error", func() { - exists, err := builder.HasTable(s.ctx, tx, "") + exists, err := builder.HasTable(c, "") s.Error(err, "expected error when table name is empty") s.False(exists, "expected exists to be false when table name is empty") }) s.Run("when all parameters are valid", func() { - err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = builder.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table before checking existence") - exists, err := builder.HasTable(s.ctx, tx, "users") + exists, err := builder.HasTable(c, "users") s.NoError(err, "expected no error when checking if table exists with valid parameters") s.True(exists, "expected exists to be true for existing table") - exists, err = builder.HasTable(s.ctx, tx, "non_existent_table") + exists, err = builder.HasTable(c, "non_existent_table") s.NoError(err, "expected no error when checking non-existent table") s.False(exists, "expected exists to be false for non-existent table") }) @@ -666,20 +639,19 @@ func (s *postgresBuilderSuite) TestHasTable() { _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_publics") s.NoError(err, "expected no error when creating custom schema") - err = builder.Create(s.ctx, tx, "custom_publics.users", func(table *schema.Blueprint) { + err = builder.Create(c, "custom_publics.users", func(table *schema.Blueprint) { table.ID() table.String("name", 255) table.String("email", 255).Unique() table.String("password", 255).Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamps() }) s.NoError(err, "expected no error when creating table with custom schema") - exists, err := builder.HasTable(s.ctx, tx, "custom_publics.users") + exists, err := builder.HasTable(c, "custom_publics.users") s.NoError(err, "expected no error when checking if table with custom schema exists") s.True(exists, "expected exists to be true for existing table with custom schema") - exists, err = builder.HasTable(s.ctx, tx, "custom_publics.non_existent_table") + exists, err = builder.HasTable(c, "custom_publics.non_existent_table") s.NoError(err, "expected no error when checking non-existent table with custom schema") s.False(exists, "expected exists to be false for non-existent table with custom schema") }) diff --git a/schema/postgres_grammar.go b/schema/postgres_grammar.go new file mode 100644 index 0000000..a5fdb25 --- /dev/null +++ b/schema/postgres_grammar.go @@ -0,0 +1,581 @@ +package schema + +import ( + "fmt" + "slices" + "strings" +) + +type postgresGrammar struct { + baseGrammar +} + +func newPostgresGrammar() *postgresGrammar { + return &postgresGrammar{} +} + +func (g *postgresGrammar) CompileTableExists(schema string, table string) (string, error) { + return fmt.Sprintf( + "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'", + g.QuoteString(schema), + g.QuoteString(table), + ), nil +} + +func (g *postgresGrammar) CompileTables() (string, error) { + return "select c.relname as name, n.nspname as schema, pg_total_relation_size(c.oid) as size, " + + "obj_description(c.oid, 'pg_class') as comment from pg_class c, pg_namespace n " + + "where c.relkind in ('r', 'p') and n.oid = c.relnamespace and n.nspname not in ('pg_catalog', 'information_schema') " + + "order by c.relname", nil +} + +func (g *postgresGrammar) CompileColumns(schema, table string) (string, error) { + return fmt.Sprintf( + "select a.attname as name, t.typname as type_name, format_type(a.atttypid, a.atttypmod) as type, "+ + "(select tc.collcollate from pg_catalog.pg_collation tc where tc.oid = a.attcollation) as collation, "+ + "not a.attnotnull as nullable, "+ + "(select pg_get_expr(adbin, adrelid) from pg_attrdef where c.oid = pg_attrdef.adrelid and pg_attrdef.adnum = a.attnum) as default, "+ + "col_description(c.oid, a.attnum) as comment "+ + "from pg_attribute a, pg_class c, pg_type t, pg_namespace n "+ + "where c.relname = %s and n.nspname = %s and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid and n.oid = c.relnamespace "+ + "order by a.attnum", + g.QuoteString(table), + g.QuoteString(schema), + ), nil +} + +func (g *postgresGrammar) CompileIndexes(schema, table string) (string, error) { + return fmt.Sprintf( + "select ic.relname as name, string_agg(a.attname, ',' order by indseq.ord) as columns, "+ + "am.amname as \"type\", i.indisunique as \"unique\", i.indisprimary as \"primary\" "+ + "from pg_index i "+ + "join pg_class tc on tc.oid = i.indrelid "+ + "join pg_namespace tn on tn.oid = tc.relnamespace "+ + "join pg_class ic on ic.oid = i.indexrelid "+ + "join pg_am am on am.oid = ic.relam "+ + "join lateral unnest(i.indkey) with ordinality as indseq(num, ord) on true "+ + "left join pg_attribute a on a.attrelid = i.indrelid and a.attnum = indseq.num "+ + "where tc.relname = %s and tn.nspname = %s "+ + "group by ic.relname, am.amname, i.indisunique, i.indisprimary", + g.QuoteString(table), + g.QuoteString(schema), + ), nil +} + +func (g *postgresGrammar) CompileCreate(blueprint *Blueprint) (string, error) { + columns, err := g.getColumns(blueprint) + if err != nil { + return "", err + } + columns = append(columns, g.getConstraints(blueprint)...) + return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil +} + +func (g *postgresGrammar) CompileAdd(blueprint *Blueprint) (string, error) { + if len(blueprint.getAddedColumns()) == 0 { + return "", nil + } + + columns, err := g.getColumns(blueprint) + if err != nil { + return "", err + } + columns = g.PrefixArray("ADD COLUMN ", columns) + constraints := g.getConstraints(blueprint) + if len(constraints) > 0 { + constraints = g.PrefixArray("ADD ", constraints) + columns = append(columns, constraints...) + } + + return fmt.Sprintf("ALTER TABLE %s %s", + blueprint.name, + strings.Join(columns, ", "), + ), nil +} + +func (g *postgresGrammar) CompileChange(bp *Blueprint, command *command) (string, error) { + column := command.column + if column.name == "" { + return "", fmt.Errorf("column name cannot be empty for change operation") + } + + var changes []string + changes = append(changes, fmt.Sprintf("TYPE %s", g.getType(command.column))) + for _, modifier := range g.modifiers() { + change := modifier(command.column) + if change != "" { + changes = append(changes, strings.TrimSpace(change)) + } + } + + return fmt.Sprintf("ALTER TABLE %s %s", + bp.name, + strings.Join(g.PrefixArray(fmt.Sprintf("ALTER COLUMN %s ", column.name), changes), ", "), + ), nil +} + +func (g *postgresGrammar) CompileDrop(blueprint *Blueprint) (string, error) { + return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil +} + +func (g *postgresGrammar) CompileDropIfExists(blueprint *Blueprint) (string, error) { + return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil +} + +func (g *postgresGrammar) CompileRename(blueprint *Blueprint, command *command) (string, error) { + return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil +} + +func (g *postgresGrammar) CompileDropColumn(blueprint *Blueprint, command *command) (string, error) { + if len(command.columns) == 0 { + return "", nil + } + columns := g.PrefixArray("DROP COLUMN ", command.columns) + + return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil +} + +func (g *postgresGrammar) CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { + return "", fmt.Errorf("table name, old column name, and new column name cannot be empty for rename operation") + } + return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil +} + +func (g *postgresGrammar) CompileFullText(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("fulltext index column cannot be empty") + } + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "fulltext", command.columns...) + } + language := command.language + if language == "" { + language = "english" // Default language for full-text search + } + var columns []string + for _, col := range command.columns { + columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.QuoteString(language), col)) + } + + return fmt.Sprintf("CREATE INDEX %s ON %s USING GIN (%s)", indexName, blueprint.name, strings.Join(columns, " || ")), nil +} + +func (g *postgresGrammar) CompileIndex(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("index column cannot be empty") + } + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "index", command.columns...) + } + + sql := fmt.Sprintf("CREATE INDEX %s ON %s", indexName, blueprint.name) + if command.algorithm != "" { + sql += fmt.Sprintf(" USING %s", command.algorithm) + } + return fmt.Sprintf("%s (%s)", sql, g.Columnize(command.columns)), nil +} + +func (g *postgresGrammar) CompileUnique(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("unique index column cannot be empty") + } + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "unique", command.columns...) + } + sql := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)", + blueprint.name, + indexName, + g.Columnize(command.columns), + ) + + if command.deferrable != nil { + if *command.deferrable { + sql += " DEFERRABLE" + } else { + sql += " NOT DEFERRABLE" + } + } + if command.deferrable != nil && *command.deferrable && command.initiallyImmediate != nil { + if *command.initiallyImmediate { + sql += " INITIALLY IMMEDIATE" + } else { + sql += " INITIALLY DEFERRED" + } + } + + return sql, nil +} + +func (g *postgresGrammar) CompilePrimary(blueprint *Blueprint, command *command) (string, error) { + if slices.Contains(command.columns, "") { + return "", fmt.Errorf("primary key index column cannot be empty") + } + indexName := command.index + if indexName == "" { + indexName = g.CreateIndexName(blueprint, "primary", command.columns...) + } + return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.Columnize(command.columns)), nil +} + +func (g *postgresGrammar) CompileDropIndex(_ *Blueprint, command *command) (string, error) { + if command.index == "" { + return "", fmt.Errorf("index name cannot be empty for drop operation") + } + return fmt.Sprintf("DROP INDEX %s", command.index), nil +} + +func (g *postgresGrammar) CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) { + return g.CompileDropIndex(blueprint, command) +} + +func (g *postgresGrammar) CompileDropUnique(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { + return "", fmt.Errorf("index name cannot be empty for drop operation") + } + return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil +} + +func (g *postgresGrammar) CompileDropPrimary(blueprint *Blueprint, command *command) (string, error) { + index := command.index + if index == "" { + index = g.CreateIndexName(blueprint, "primary", command.columns...) + } + return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, index), nil +} + +func (g *postgresGrammar) CompileRenameIndex(_ *Blueprint, command *command) (string, error) { + if command.from == "" || command.to == "" { + return "", fmt.Errorf("index names for rename operation cannot be empty: oldName=%s, newName=%s", command.from, command.to) + } + return fmt.Sprintf("ALTER INDEX %s RENAME TO %s", command.from, command.to), nil +} + +func (g *postgresGrammar) CompileForeign(blueprint *Blueprint, command *command) (string, error) { + sql, err := g.baseGrammar.CompileForeign(blueprint, command) + if err != nil { + return "", err + } + + if command.deferrable != nil { + if *command.deferrable { + sql += " DEFERRABLE" + } else { + sql += " NOT DEFERRABLE" + } + } + if command.deferrable != nil && *command.deferrable && command.initiallyImmediate != nil { + if *command.initiallyImmediate { + sql += " INITIALLY IMMEDIATE" + } else { + sql += " INITIALLY DEFERRED" + } + } + + return sql, nil +} + +func (g *postgresGrammar) CompileDropForeign(blueprint *Blueprint, command *command) (string, error) { + if command.index == "" { + return "", fmt.Errorf("foreign key name cannot be empty for drop operation") + } + return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil +} + +func (g *postgresGrammar) GetFluentCommands() []func(blueprint *Blueprint, command *command) string { + return []func(blueprint *Blueprint, command *command) string{ + g.CompileComment, + } +} + +func (g *postgresGrammar) CompileComment(blueprint *Blueprint, command *command) string { + if command.column.comment != nil { + sql := fmt.Sprintf("COMMENT ON COLUMN %s.%s IS ", blueprint.name, command.column.name) + if command.column.comment == nil { + return sql + "NULL" + } else { + return sql + fmt.Sprintf("'%s'", *command.column.comment) + } + } + return "" +} + +func (g *postgresGrammar) getColumns(blueprint *Blueprint) ([]string, error) { + var columns []string + for _, col := range blueprint.getAddedColumns() { + if col.name == "" { + return nil, fmt.Errorf("column name cannot be empty") + } + sql := col.name + " " + g.getType(col) + for _, modifier := range g.modifiers() { + sql += modifier(col) + } + columns = append(columns, sql) + } + + return columns, nil +} + +func (g *postgresGrammar) getConstraints(blueprint *Blueprint) []string { + var constrains []string + for _, col := range blueprint.getAddedColumns() { + if col.primary != nil && *col.primary { + pkConstraintName := g.CreateIndexName(blueprint, "primary") + sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")" + constrains = append(constrains, sql) + continue + } + } + + return constrains +} + +func (g *postgresGrammar) getType(col *columnDefinition) string { + typeMapFunc := map[string]func(*columnDefinition) string{ + columnTypeChar: g.typeChar, + columnTypeString: g.typeString, + columnTypeTinyText: g.typeTinyText, + columnTypeText: g.typeText, + columnTypeMediumText: g.typeMediumText, + columnTypeLongText: g.typeLongText, + columnTypeInteger: g.typeInteger, + columnTypeBigInteger: g.typeBigInteger, + columnTypeMediumInteger: g.typeMediumInteger, + columnTypeSmallInteger: g.typeSmallInteger, + columnTypeTinyInteger: g.typeTinyInteger, + columnTypeFloat: g.typeFloat, + columnTypeDouble: g.typeDouble, + columnTypeDecimal: g.typeDecimal, + columnTypeBoolean: g.typeBoolean, + columnTypeEnum: g.typeEnum, + columnTypeJson: g.typeJson, + columnTypeJsonb: g.typeJsonb, + columnTypeDate: g.typeDate, + columnTypeDateTime: g.typeDateTime, + columnTypeDateTimeTz: g.typeDateTimeTz, + columnTypeTime: g.typeTime, + columnTypeTimeTz: g.typeTimeTz, + columnTypeTimestamp: g.typeTimestamp, + columnTypeTimestampTz: g.typeTimestampTz, + columnTypeYear: g.typeYear, + columnTypeBinary: g.typeBinary, + columnTypeUuid: g.typeUuid, + columnTypeGeography: g.typeGeography, + columnTypeGeometry: g.typeGeometry, + columnTypePoint: g.typePoint, + } + if fn, ok := typeMapFunc[col.columnType]; ok { + return fn(col) + } + return col.columnType +} + +func (g *postgresGrammar) typeChar(col *columnDefinition) string { + if col.length != nil && *col.length > 0 { + return fmt.Sprintf("CHAR(%d)", *col.length) + } + return "CHAR" +} + +func (g *postgresGrammar) typeString(col *columnDefinition) string { + if col.length != nil && *col.length > 0 { + return fmt.Sprintf("VARCHAR(%d)", *col.length) + } + return "VARCHAR" +} + +func (g *postgresGrammar) typeTinyText(_ *columnDefinition) string { + return "VARCHAR(255)" +} + +func (g *postgresGrammar) typeText(_ *columnDefinition) string { + return "TEXT" +} + +func (g *postgresGrammar) typeMediumText(_ *columnDefinition) string { + return "TEXT" +} + +func (g *postgresGrammar) typeLongText(_ *columnDefinition) string { + return "TEXT" +} + +func (g *postgresGrammar) typeBigInteger(col *columnDefinition) string { + if col.autoIncrement != nil && *col.autoIncrement { + return "BIGSERIAL" + } + return "BIGINT" +} + +func (g *postgresGrammar) typeInteger(col *columnDefinition) string { + if col.autoIncrement != nil && *col.autoIncrement { + return "SERIAL" + } + return "INTEGER" +} + +func (g *postgresGrammar) typeMediumInteger(col *columnDefinition) string { + return g.typeInteger(col) +} + +func (g *postgresGrammar) typeSmallInteger(col *columnDefinition) string { + if col.autoIncrement != nil && *col.autoIncrement { + return "SMALLSERIAL" + } + return "SMALLINT" +} + +func (g *postgresGrammar) typeTinyInteger(col *columnDefinition) string { + return g.typeSmallInteger(col) +} + +func (g *postgresGrammar) typeFloat(_ *columnDefinition) string { + return "REAL" +} + +func (g *postgresGrammar) typeDouble(_ *columnDefinition) string { + return "DOUBLE PRECISION" +} + +func (g *postgresGrammar) typeDecimal(col *columnDefinition) string { + return fmt.Sprintf("DECIMAL(%d, %d)", *col.total, *col.places) +} + +func (g *postgresGrammar) typeBoolean(_ *columnDefinition) string { + return "BOOLEAN" +} + +func (g *postgresGrammar) typeEnum(col *columnDefinition) string { + enumValues := make([]string, len(col.allowed)) + for i, v := range col.allowed { + enumValues[i] = g.QuoteString(v) + } + return "VARCHAR(255) CHECK (" + col.name + " IN (" + strings.Join(enumValues, ", ") + "))" +} + +func (g *postgresGrammar) typeJson(_ *columnDefinition) string { + return "JSON" +} + +func (g *postgresGrammar) typeJsonb(_ *columnDefinition) string { + return "JSONB" +} + +func (g *postgresGrammar) typeDate(_ *columnDefinition) string { + return "DATE" +} + +func (g *postgresGrammar) typeDateTime(col *columnDefinition) string { + return g.typeTimestamp(col) +} + +func (g *postgresGrammar) typeDateTimeTz(col *columnDefinition) string { + return g.typeTimestampTz(col) +} + +func (g *postgresGrammar) typeTime(col *columnDefinition) string { + if col.precision != nil { + return fmt.Sprintf("TIME(%d)", *col.precision) + } + return "TIME" +} + +func (g *postgresGrammar) typeTimeTz(col *columnDefinition) string { + if col.precision != nil { + return fmt.Sprintf("TIMETZ(%d)", *col.precision) + } + return "TIMETZ" +} + +func (g *postgresGrammar) typeTimestamp(col *columnDefinition) string { + if col.useCurrent { + col.SetDefault(Expression("CURRENT_TIMESTAMP")) + } + if col.precision != nil { + return fmt.Sprintf("TIMESTAMP(%d)", *col.precision) + } + return "TIMESTAMP" +} + +func (g *postgresGrammar) typeTimestampTz(col *columnDefinition) string { + if col.useCurrent { + col.SetDefault(Expression("CURRENT_TIMESTAMP")) + } + if col.precision != nil { + return fmt.Sprintf("TIMESTAMPTZ(%d)", *col.precision) + } + return "TIMESTAMPTZ" +} + +func (g *postgresGrammar) typeYear(_ *columnDefinition) string { + return "INTEGER" +} + +func (g *postgresGrammar) typeBinary(_ *columnDefinition) string { + return "BYTEA" +} + +func (g *postgresGrammar) typeUuid(_ *columnDefinition) string { + return "UUID" +} + +func (g *postgresGrammar) typeGeography(col *columnDefinition) string { + if col.subtype != nil && col.srid != nil { + return fmt.Sprintf("GEOGRAPHY(%s, %d)", *col.subtype, *col.srid) + } else if col.subtype != nil { + return fmt.Sprintf("GEOGRAPHY(%s)", *col.subtype) + } + return "GEOGRAPHY" +} + +func (g *postgresGrammar) typeGeometry(col *columnDefinition) string { + if col.subtype != nil && col.srid != nil { + return fmt.Sprintf("GEOMETRY(%s, %d)", *col.subtype, *col.srid) + } else if col.subtype != nil { + return fmt.Sprintf("GEOMETRY(%s)", *col.subtype) + } + return "GEOMETRY" +} + +func (g *postgresGrammar) typePoint(col *columnDefinition) string { + if col.srid != nil { + return fmt.Sprintf("POINT(%d)", *col.srid) + } + return "POINT" +} + +func (g *postgresGrammar) modifiers() []func(*columnDefinition) string { + return []func(*columnDefinition) string{ + g.modifyDefault, + g.modifyNullable, + } +} + +func (g *postgresGrammar) modifyNullable(col *columnDefinition) string { + if col.change { + if col.nullable == nil { + return "" + } + if *col.nullable { + return " DROP NOT NULL" + } + return " SET NOT NULL" + } + if col.nullable != nil && *col.nullable { + return " NULL" + } + return " NOT NULL" +} + +func (g *postgresGrammar) modifyDefault(col *columnDefinition) string { + if col.hasCommand("default") { + if col.change { + return fmt.Sprintf(" SET DEFAULT %s", g.GetDefaultValue(col.defaultValue)) + } + return fmt.Sprintf(" DEFAULT %s", g.GetDefaultValue(col.defaultValue)) + } + return "" +} diff --git a/postgres_grammar_test.go b/schema/postgres_grammar_test.go similarity index 89% rename from postgres_grammar_test.go rename to schema/postgres_grammar_test.go index 843f791..2a805f0 100644 --- a/postgres_grammar_test.go +++ b/schema/postgres_grammar_test.go @@ -7,7 +7,7 @@ import ( ) func TestPgGrammar_CompileCreate(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -21,13 +21,13 @@ func TestPgGrammar_CompileCreate(t *testing.T) { table: "users", blueprint: func(table *Blueprint) { table.ID() - table.String("name", 255) - table.String("email", 255) + table.String("name") + table.String("email") table.String("password").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") + table.Timestamp("created_at").UseCurrent() + table.Timestamp("updated_at").UseCurrent() }, - want: "CREATE TABLE users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", + want: "CREATE TABLE users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR(255) NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", }, { name: "Create table with foreign key", @@ -35,7 +35,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) { blueprint: func(table *Blueprint) { table.ID() table.Integer("user_id") - table.String("title", 255) + table.String("title") table.Text("content").Nullable() table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE") }, @@ -45,7 +45,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) { name: "Create table with column name is empty", table: "empty_column_table", blueprint: func(table *Blueprint) { - table.String("", 255) // Intentionally empty column name + table.String("") // Intentionally empty column name }, wantErr: true, }, @@ -55,57 +55,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileCreate(bp) - if tt.wantErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - assert.Equal(t, tt.want, got, "SQL statement mismatch for %s", tt.name) - }) - } -} - -func TestPgGrammar_CompileCreateIfNotExists(t *testing.T) { - grammar := newPgGrammar() - - tests := []struct { - name string - table string - blueprint func(table *Blueprint) - want string - wantErr bool - }{ - { - name: "Create simple table if not exists", - table: "users", - blueprint: func(table *Blueprint) { - table.ID() - table.String("name", 255) - table.String("email", 255) - table.String("password").Nullable() - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP") - }, - want: "CREATE TABLE IF NOT EXISTS users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))", - wantErr: false, - }, - { - name: "Create table with column name is empty", - table: "empty_column_table", - blueprint: func(table *Blueprint) { - table.String("", 255) // Intentionally empty column name - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} - tt.blueprint(bp) - got, err := grammar.compileCreateIfNotExists(bp) + got, err := grammar.CompileCreate(bp) if tt.wantErr { assert.Error(t, err) return @@ -118,7 +68,7 @@ func TestPgGrammar_CompileCreateIfNotExists(t *testing.T) { } func TestPgGrammar_CompileAdd(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -153,7 +103,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { blueprint: func(table *Blueprint) { table.Boolean("active").Default(true) }, - want: "ALTER TABLE users ADD COLUMN active BOOLEAN DEFAULT true NOT NULL", + want: "ALTER TABLE users ADD COLUMN active BOOLEAN DEFAULT '1' NOT NULL", wantErr: false, }, { @@ -162,7 +112,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { blueprint: func(table *Blueprint) { table.String("notes", 500).Comment("User notes") }, - want: "ALTER TABLE users ADD COLUMN notes VARCHAR(500) NOT NULL COMMENT 'User notes'", + want: "ALTER TABLE users ADD COLUMN notes VARCHAR(500) NOT NULL", wantErr: false, }, { @@ -187,17 +137,17 @@ func TestPgGrammar_CompileAdd(t *testing.T) { name: "Add complex column with all attributes", table: "products", blueprint: func(table *Blueprint) { - table.Decimal("price", 10, 2).Default(0).Comment("Product price") + table.Decimal("price", 10, 2).Default(0) }, - want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT 0 NOT NULL COMMENT 'Product price'", + want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT '0' NOT NULL", wantErr: false, }, { name: "Add timestamp columns", table: "orders", blueprint: func(table *Blueprint) { - table.Timestamp("created_at").Default("CURRENT_TIMESTAMP") - table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable() + table.Timestamp("created_at").UseCurrent() + table.Timestamp("updated_at").UseCurrent().Nullable() }, want: "ALTER TABLE orders ADD COLUMN created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, ADD COLUMN updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NULL", wantErr: false, @@ -253,7 +203,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileAdd(bp) + got, err := grammar.CompileAdd(bp) if tt.wantErr { assert.Error(t, err) return @@ -266,7 +216,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) { } func TestPgGrammar_CompileChange(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -309,7 +259,7 @@ func TestPgGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.String("email", 500).Default(nil).Change() }, - want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500), ALTER COLUMN email DROP DEFAULT"}, + want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500), ALTER COLUMN email SET DEFAULT NULL"}, }, { name: "Add comment to column", @@ -328,7 +278,7 @@ func TestPgGrammar_CompileChange(t *testing.T) { blueprint: func(table *Blueprint) { table.String("email", 500).Comment("").Change() }, - want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500)", "COMMENT ON COLUMN users.email IS NULL"}, + want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500)", "COMMENT ON COLUMN users.email IS ''"}, }, { name: "Set column to not nullable", @@ -356,9 +306,9 @@ func TestPgGrammar_CompileChange(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} + bp := &Blueprint{name: tt.table, grammar: grammar} tt.blueprint(bp) - got, err := grammar.compileChange(bp) + got, err := bp.toSql() if tt.wantErr { assert.Error(t, err) return @@ -371,7 +321,7 @@ func TestPgGrammar_CompileChange(t *testing.T) { } func TestPgGrammar_CompileDrop(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -390,7 +340,7 @@ func TestPgGrammar_CompileDrop(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileDrop(bp) + got, err := grammar.CompileDrop(bp) if tt.wantErr { assert.Error(t, err) return @@ -403,7 +353,7 @@ func TestPgGrammar_CompileDrop(t *testing.T) { } func TestPgGrammar_CompileDropIfExists(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -422,7 +372,7 @@ func TestPgGrammar_CompileDropIfExists(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileDropIfExists(bp) + got, err := grammar.CompileDropIfExists(bp) if tt.wantErr { assert.Error(t, err) return @@ -435,7 +385,7 @@ func TestPgGrammar_CompileDropIfExists(t *testing.T) { } func TestPgGrammar_CompileRename(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -455,8 +405,9 @@ func TestPgGrammar_CompileRename(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.oldName, newName: tt.newName} - got, err := grammar.compileRename(bp) + bp := &Blueprint{name: tt.oldName} + bp.rename(tt.newName) + got, err := grammar.CompileRename(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -469,7 +420,7 @@ func TestPgGrammar_CompileRename(t *testing.T) { } func TestPgGrammar_GetColumns(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -503,14 +454,7 @@ func TestPgGrammar_GetColumns(t *testing.T) { blueprint: func(table *Blueprint) { table.Boolean("active").Default(true) }, - want: []string{"active BOOLEAN DEFAULT true NOT NULL"}, - }, - { - name: "Column with comment", - blueprint: func(table *Blueprint) { - table.String("description", 500).Comment("User description") - }, - want: []string{"description VARCHAR(500) NOT NULL COMMENT 'User description'"}, + want: []string{"active BOOLEAN DEFAULT '1' NOT NULL"}, }, { name: "Primary key column", @@ -543,13 +487,13 @@ func TestPgGrammar_GetColumns(t *testing.T) { } func TestPgGrammar_CompileDropColumn(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string table string blueprint func(table *Blueprint) - want string + wants []string wantErr bool }{ { @@ -558,45 +502,46 @@ func TestPgGrammar_CompileDropColumn(t *testing.T) { blueprint: func(table *Blueprint) { table.DropColumn("email") }, - want: "ALTER TABLE users DROP COLUMN email", + wants: []string{"ALTER TABLE users DROP COLUMN email"}, wantErr: false, }, { name: "Drop multiple columns", table: "users", blueprint: func(table *Blueprint) { - table.DropColumn("email", "phone", "address") + table.DropColumn("email", "phone") + table.DropColumn("address") }, - want: "ALTER TABLE users DROP COLUMN email, DROP COLUMN phone, DROP COLUMN address", + wants: []string{"ALTER TABLE users DROP COLUMN email, DROP COLUMN phone", "ALTER TABLE users DROP COLUMN address"}, wantErr: false, }, { name: "No columns to drop", table: "users", blueprint: func(table *Blueprint) {}, - want: "", + wants: nil, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bp := &Blueprint{name: tt.table} + bp := &Blueprint{name: tt.table, grammar: grammar} tt.blueprint(bp) - got, err := grammar.compileDropColumn(bp) + got, err := bp.toSql() if tt.wantErr { assert.Error(t, err) return } assert.NoError(t, err) - assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wants, got) }) } } func TestPgGrammar_CompileRenameColumn(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -640,7 +585,8 @@ func TestPgGrammar_CompileRenameColumn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileRenameColumn(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := grammar.CompileRenameColumn(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -653,7 +599,7 @@ func TestPgGrammar_CompileRenameColumn(t *testing.T) { } func TestPgGrammar_CompileDropIndex(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -684,7 +630,8 @@ func TestPgGrammar_CompileDropIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} - got, err := grammar.compileDropIndex(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.CompileDropIndex(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -698,7 +645,7 @@ func TestPgGrammar_CompileDropIndex(t *testing.T) { } func TestPgGrammar_CompileDropPrimary(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -729,7 +676,8 @@ func TestPgGrammar_CompileDropPrimary(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := grammar.compileDropPrimary(tt.blueprint, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.CompileDropPrimary(tt.blueprint, command) if tt.wantErr { assert.Error(t, err) return @@ -742,7 +690,7 @@ func TestPgGrammar_CompileDropPrimary(t *testing.T) { } func TestPgGrammar_CompileRenameIndex(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -794,7 +742,8 @@ func TestPgGrammar_CompileRenameIndex(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileRenameIndex(bp, tt.oldName, tt.newName) + command := &command{from: tt.oldName, to: tt.newName} + got, err := grammar.CompileRenameIndex(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -807,7 +756,7 @@ func TestPgGrammar_CompileRenameIndex(t *testing.T) { } func TestPgGrammar_CompileForeign(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -972,7 +921,7 @@ func TestPgGrammar_CompileForeign(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileForeign(bp, bp.foreignKeys[0]) + got, err := grammar.CompileForeign(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -985,7 +934,7 @@ func TestPgGrammar_CompileForeign(t *testing.T) { } func TestPgGrammar_CompileDropForeign(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1020,7 +969,8 @@ func TestPgGrammar_CompileDropForeign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} - got, err := grammar.compileDropForeign(bp, tt.foreignKeyName) + command := &command{index: tt.foreignKeyName} + got, err := grammar.CompileDropForeign(bp, command) if tt.wantErr { assert.Error(t, err) return @@ -1033,7 +983,7 @@ func TestPgGrammar_CompileDropForeign(t *testing.T) { } func TestPgGrammar_CompileIndex(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1100,7 +1050,7 @@ func TestPgGrammar_CompileIndex(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileIndex(bp, bp.indexes[0]) + got, err := grammar.CompileIndex(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1113,7 +1063,7 @@ func TestPgGrammar_CompileIndex(t *testing.T) { } func TestPgGrammar_CompileUnique(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1208,7 +1158,7 @@ func TestPgGrammar_CompileUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileUnique(bp, bp.indexes[0]) + got, err := grammar.CompileUnique(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1221,7 +1171,7 @@ func TestPgGrammar_CompileUnique(t *testing.T) { } func TestPgGrammar_CompileFullText(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1306,7 +1256,7 @@ func TestPgGrammar_CompileFullText(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compileFullText(bp, bp.indexes[0]) + got, err := grammar.CompileFullText(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1319,7 +1269,7 @@ func TestPgGrammar_CompileFullText(t *testing.T) { } func TestPgGrammar_CompileDropUnique(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1356,7 +1306,8 @@ func TestPgGrammar_CompileDropUnique(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} - got, err := grammar.compileDropUnique(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.CompileDropUnique(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -1370,7 +1321,7 @@ func TestPgGrammar_CompileDropUnique(t *testing.T) { } func TestPgGrammar_CompileDropFulltext(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1413,7 +1364,8 @@ func TestPgGrammar_CompileDropFulltext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{} - got, err := grammar.compileDropFulltext(bp, tt.indexName) + command := &command{index: tt.indexName} + got, err := grammar.CompileDropFulltext(bp, command) if tt.wantErr { assert.Error(t, err) assert.Empty(t, got) @@ -1427,7 +1379,7 @@ func TestPgGrammar_CompileDropFulltext(t *testing.T) { } func TestPgGrammar_CompilePrimary(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1494,7 +1446,7 @@ func TestPgGrammar_CompilePrimary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { bp := &Blueprint{name: tt.table} tt.blueprint(bp) - got, err := grammar.compilePrimary(bp, bp.indexes[0]) + got, err := grammar.CompilePrimary(bp, bp.commands[0]) if tt.wantErr { assert.Error(t, err) return @@ -1507,7 +1459,7 @@ func TestPgGrammar_CompilePrimary(t *testing.T) { } func TestPgGrammar_GetType(t *testing.T) { - grammar := newPgGrammar() + grammar := newPostgresGrammar() tests := []struct { name string @@ -1540,7 +1492,7 @@ func TestPgGrammar_GetType(t *testing.T) { blueprint: func(table *Blueprint) { table.Char("code") }, - want: "CHAR", + want: "CHAR(255)", }, { name: "string column type", @@ -1557,16 +1509,9 @@ func TestPgGrammar_GetType(t *testing.T) { want: "DECIMAL(10, 2)", }, { - name: "double column type with precision", - blueprint: func(table *Blueprint) { - table.Double("value", 8, 2) - }, - want: "DOUBLE PRECISION", - }, - { - name: "double column type without precision", + name: "double column type", blueprint: func(table *Blueprint) { - table.Double("value", 0, 0) + table.Double("value") }, want: "DOUBLE PRECISION", }, @@ -1743,7 +1688,7 @@ func TestPgGrammar_GetType(t *testing.T) { blueprint: func(table *Blueprint) { table.TinyText("notes") }, - want: "TEXT", + want: "VARCHAR(255)", }, { name: "date column type", @@ -1822,13 +1767,6 @@ func TestPgGrammar_GetType(t *testing.T) { }, want: "VARCHAR(255) CHECK (status IN ('active', 'inactive'))", }, - { - name: "Enum with empty values", - blueprint: func(table *Blueprint) { - table.Enum("status", []string{}) - }, - want: "VARCHAR(255)", - }, } for _, tt := range tests { diff --git a/schema.go b/schema/schema.go similarity index 71% rename from schema.go rename to schema/schema.go index ea067f5..edc0350 100644 --- a/schema.go +++ b/schema/schema.go @@ -1,9 +1,11 @@ package schema import ( - "context" "database/sql" "errors" + + "github.com/akfaiz/migris/internal/config" + "github.com/akfaiz/migris/internal/dialect" ) // Column represents a database column with its properties. @@ -39,11 +41,12 @@ type TableInfo struct { } func newBuilder() (Builder, error) { - if dialectValue == dialectUnknown { + dialectVal := config.GetDialect() + if dialectVal == dialect.Unknown { return nil, errors.New("schema dialect is not set, please call schema.SetDialect() before using schema functions") } - builder, err := NewBuilder(dialectValue.String()) + builder, err := NewBuilder(dialectVal.String()) if err != nil { return nil, err } @@ -65,36 +68,13 @@ func newBuilder() (Builder, error) { // table.Timestamp("created_at").Default("CURRENT_TIMESTAMP").Nullable(false) // table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable(false) // }) -func Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { - builder, err := newBuilder() - if err != nil { - return err - } - - return builder.Create(ctx, tx, name, blueprint) -} - -// CreateIfNotExists creates a new table with the given name and blueprint if it does not already exist. -// The blueprint function is used to define the structure of the table. -// It returns an error if the table creation fails. -// -// Example: -// -// err := schema.CreateIfNotExists(ctx, tx, "users", func(table *schema.Blueprint) { -// table.ID() -// table.String("name").Nullable(false) -// table.String("email").Unique().Nullable(false) -// table.String("password").Nullable() -// table.Timestamp("created_at").Default("CURRENT_TIMESTAMP").Nullable(false) -// table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable(false) -// }) -func CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { +func Create(c *Context, name string, blueprint func(table *Blueprint)) error { builder, err := newBuilder() if err != nil { return err } - return builder.CreateIfNotExists(ctx, tx, name, blueprint) + return builder.Create(c, name, blueprint) } // Drop removes the table with the given name. @@ -103,13 +83,13 @@ func CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint f // Example: // // err := schema.Drop(ctx, tx, "users") -func Drop(ctx context.Context, tx *sql.Tx, name string) error { +func Drop(c *Context, name string) error { builder, err := newBuilder() if err != nil { return err } - return builder.Drop(ctx, tx, name) + return builder.Drop(c, name) } // DropIfExists removes the table with the given name if it exists. @@ -118,13 +98,13 @@ func Drop(ctx context.Context, tx *sql.Tx, name string) error { // Example: // // err := schema.DropIfExists(ctx, tx, "users") -func DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { +func DropIfExists(c *Context, name string) error { builder, err := newBuilder() if err != nil { return err } - return builder.DropIfExists(ctx, tx, name) + return builder.DropIfExists(c, name) } // GetColumns retrieves the columns of the specified table. @@ -133,13 +113,13 @@ func DropIfExists(ctx context.Context, tx *sql.Tx, name string) error { // Example: // // columns, err := schema.GetColumns(ctx, tx, "users") -func GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) { +func GetColumns(c *Context, tableName string) ([]*Column, error) { builder, err := newBuilder() if err != nil { return nil, err } - return builder.GetColumns(ctx, tx, tableName) + return builder.GetColumns(c, tableName) } // GetIndexes retrieves the indexes of the specified table. @@ -148,13 +128,13 @@ func GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, e // Example: // // indexes, err := schema.GetIndexes(ctx, tx, "users") -func GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) { +func GetIndexes(c *Context, tableName string) ([]*Index, error) { builder, err := newBuilder() if err != nil { return nil, err } - return builder.GetIndexes(ctx, tx, tableName) + return builder.GetIndexes(c, tableName) } // GetTables retrieves all tables in the database. @@ -163,13 +143,13 @@ func GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, er // Example: // // tables, err := schema.GetTables(ctx, tx) -func GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { +func GetTables(c *Context) ([]*TableInfo, error) { builder, err := newBuilder() if err != nil { return nil, err } - return builder.GetTables(ctx, tx) + return builder.GetTables(c) } // HasColumn checks if a column with the given name exists in the specified table. @@ -178,13 +158,13 @@ func GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) { // Example: // // exists, err := schema.HasColumn(ctx, tx, "users", "email") -func HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) { +func HasColumn(c *Context, tableName string, columnName string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasColumn(ctx, tx, tableName, columnName) + return builder.HasColumn(c, tableName, columnName) } // HasColumns checks if the specified columns exist in the given table. @@ -195,13 +175,13 @@ func HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName str // exists, err := schema.HasColumns(ctx, tx, "users", []string{"email", "name"}) // // If any of the specified columns do not exist, it returns false. -func HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) { +func HasColumns(c *Context, tableName string, columnNames []string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasColumns(ctx, tx, tableName, columnNames) + return builder.HasColumns(c, tableName, columnNames) } // HasIndex checks if an index with the given name exists in the specified table. @@ -212,13 +192,13 @@ func HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames [ // exists, err := schema.HasIndex(ctx, tx, "users", []string{"uk_users_email"}) // Checks if the index with name "uk_users_email" exists in the "users" table. // // exists, err := schema.HasIndex(ctx, tx, "users", []string{"email", "name"}) // Checks if a composite index exists on the "email" and "name" columns in the "users" table. -func HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) { +func HasIndex(c *Context, tableName string, indexes []string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasIndex(ctx, tx, tableName, indexes) + return builder.HasIndex(c, tableName, indexes) } // HasTable checks if a table with the given name exists in the database. @@ -228,13 +208,13 @@ func HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []strin // Example: // // exists, err := schema.HasTable(ctx, tx, "users") -func HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { +func HasTable(c *Context, name string) (bool, error) { builder, err := newBuilder() if err != nil { return false, err } - return builder.HasTable(ctx, tx, name) + return builder.HasTable(c, name) } // Rename changes the name of the table from name to newName. @@ -243,13 +223,13 @@ func HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) { // Example: // // err := schema.Rename(ctx, tx, "users", "people") -func Rename(ctx context.Context, tx *sql.Tx, name string, newName string) error { +func Rename(c *Context, name string, newName string) error { builder, err := newBuilder() if err != nil { return err } - return builder.Rename(ctx, tx, name, newName) + return builder.Rename(c, name, newName) } // Table modifies an existing table with the given name and blueprint. @@ -263,11 +243,11 @@ func Rename(ctx context.Context, tx *sql.Tx, name string, newName string) error // table.DropColumn("password") // table.RenameColumn("email", "contact_email") // }) -func Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error { +func Table(c *Context, name string, blueprint func(table *Blueprint)) error { builder, err := newBuilder() if err != nil { return err } - return builder.Table(ctx, tx, name, blueprint) + return builder.Table(c, name, blueprint) } diff --git a/schema_test.go b/schema/schema_test.go similarity index 54% rename from schema_test.go rename to schema/schema_test.go index 5a7bfcb..399181c 100644 --- a/schema_test.go +++ b/schema/schema_test.go @@ -6,7 +6,9 @@ import ( "fmt" "testing" - "github.com/afkdevs/go-schema" + "github.com/akfaiz/migris/internal/config" + "github.com/akfaiz/migris/internal/dialect" + "github.com/akfaiz/migris/schema" "github.com/stretchr/testify/suite" ) @@ -21,6 +23,7 @@ type schemaTestSuite struct { } func (s *schemaTestSuite) SetupSuite() { + config.SetDialect(dialect.Postgres) ctx := context.Background() s.ctx = ctx @@ -33,38 +36,7 @@ func (s *schemaTestSuite) SetupSuite() { err = db.Ping() s.Require().NoError(err) - s.db = db - schema.SetDebug(false) - - s.Run("when dialect is not set should return error", func() { - err := schema.SetDialect("") - s.Error(err) - s.ErrorContains(err, "unknown dialect") - - builderFuncs := []func() error{ - func() error { return schema.Create(ctx, nil, "", nil) }, - func() error { return schema.CreateIfNotExists(ctx, nil, "", nil) }, - func() error { return schema.Drop(ctx, nil, "") }, - func() error { return schema.DropIfExists(ctx, nil, "") }, - func() error { _, err := schema.GetColumns(ctx, nil, ""); return err }, - func() error { _, err := schema.GetIndexes(ctx, nil, ""); return err }, - func() error { _, err := schema.GetTables(ctx, nil); return err }, - func() error { _, err := schema.HasColumn(ctx, nil, "", ""); return err }, - func() error { _, err := schema.HasColumns(ctx, nil, "", nil); return err }, - func() error { _, err := schema.HasIndex(ctx, nil, "", nil); return err }, - func() error { _, err := schema.HasTable(ctx, nil, ""); return err }, - func() error { return schema.Rename(ctx, nil, "", "") }, - func() error { return schema.Table(ctx, nil, "", nil) }, - } - for _, fn := range builderFuncs { - s.Error(fn(), "Expected error when dialect is not set") - } - }) - s.Run("when dialect is set to postgres should not return error", func() { - err = schema.SetDialect("postgres") - s.Require().NoError(err) - }) } func (s *schemaTestSuite) TearDownSuite() { @@ -76,8 +48,10 @@ func (s *schemaTestSuite) TestCreate() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should create table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -88,78 +62,23 @@ func (s *schemaTestSuite) TestCreate() { s.NoError(err) }) s.Run("when table already exists should return error", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() }) s.Error(err) s.ErrorContains(err, "\"users\" already exists") }) s.Run("when table name is empty should return error", func() { - err := schema.Create(s.ctx, tx, "", func(table *schema.Blueprint) { + err := schema.Create(c, "", func(table *schema.Blueprint) { table.ID() }) s.Error(err) - s.ErrorContains(err, "table name is empty") + s.ErrorContains(err, "invalid arguments") }) s.Run("when blueprint function is nil should return error", func() { - err := schema.Create(s.ctx, tx, "test", nil) + err := schema.Create(c, "test", nil) s.Error(err) - s.ErrorContains(err, "blueprint function is nil") - }) - s.Run("when transaction is nil should return error", func() { - err := schema.Create(s.ctx, nil, "test", func(table *schema.Blueprint) { - table.ID() - }) - s.Error(err) - s.ErrorContains(err, "transaction is nil") - }) -} - -func (s *schemaTestSuite) TestCreateIfNotExists() { - tx, err := s.db.BeginTx(s.ctx, nil) - s.Require().NoError(err) - defer tx.Rollback() //nolint:errcheck - - s.Run("when parameters are valid should create table", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email").Unique() - table.String("password") - table.Timestamp("created_at").UseCurrent() - table.Timestamp("updated_at").UseCurrent() - }) - s.NoError(err) - }) - s.Run("when table already exists should return no error", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) { - table.ID() - table.String("name") - table.String("email") - table.String("password") - table.Timestamp("created_at").UseCurrent() - table.Timestamp("updated_at").UseCurrent() - }) - s.NoError(err) - }) - s.Run("when table name is empty should return error", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) { - table.ID() - }) - s.Error(err) - s.ErrorContains(err, "table name is empty") - }) - s.Run("when blueprint function is nil should return error", func() { - err := schema.CreateIfNotExists(s.ctx, tx, "test", nil) - s.Error(err) - s.ErrorContains(err, "blueprint function is nil") - }) - s.Run("when transaction is nil should return error", func() { - err := schema.CreateIfNotExists(s.ctx, nil, "test", func(table *schema.Blueprint) { - table.ID() - }) - s.Error(err) - s.ErrorContains(err, "transaction is nil") + s.ErrorContains(err, "invalid arguments") }) } @@ -168,8 +87,10 @@ func (s *schemaTestSuite) TestDrop() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should drop table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -178,22 +99,22 @@ func (s *schemaTestSuite) TestDrop() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - err = schema.Drop(s.ctx, tx, "users") + err = schema.Drop(c, "users") s.NoError(err) }) s.Run("when table does not exist should return error", func() { - err := schema.Drop(s.ctx, tx, "non_existing_table") + err := schema.Drop(c, "non_existing_table") s.Error(err) }) s.Run("when table name is empty should return error", func() { - err := schema.Drop(s.ctx, tx, "") + err := schema.Drop(c, "") s.Error(err) - s.ErrorContains(err, "table name is empty") + s.ErrorContains(err, "invalid arguments") }) - s.Run("when transaction is nil should return error", func() { - err := schema.Drop(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + err := schema.Drop(nil, "test") s.Error(err) - s.ErrorContains(err, "transaction is nil") + s.ErrorContains(err, "invalid arguments") }) } @@ -202,8 +123,10 @@ func (s *schemaTestSuite) TestDropIfExists() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should drop table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -212,22 +135,22 @@ func (s *schemaTestSuite) TestDropIfExists() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - err = schema.DropIfExists(s.ctx, tx, "users") + err = schema.DropIfExists(c, "users") s.NoError(err) }) s.Run("when table does not exist should return no error", func() { - err := schema.DropIfExists(s.ctx, tx, "non_existing_table") + err := schema.DropIfExists(c, "non_existing_table") s.NoError(err) }) s.Run("when table name is empty should return error", func() { - err := schema.DropIfExists(s.ctx, tx, "") + err := schema.DropIfExists(c, "") s.Error(err) - s.ErrorContains(err, "table name is empty") + s.ErrorContains(err, "invalid arguments") }) - s.Run("when transaction is nil should return error", func() { - err := schema.DropIfExists(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + err := schema.DropIfExists(nil, "test") s.Error(err) - s.ErrorContains(err, "transaction is nil") + s.ErrorContains(err, "invalid arguments") }) } @@ -236,8 +159,10 @@ func (s *schemaTestSuite) TestGetColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should return columns", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -246,23 +171,23 @@ func (s *schemaTestSuite) TestGetColumns() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - columns, err := schema.GetColumns(s.ctx, tx, "users") + columns, err := schema.GetColumns(c, "users") s.NoError(err) s.NotEmpty(columns) s.Len(columns, 6) }) s.Run("when table does not exist should empty columns", func() { - columns, err := schema.GetColumns(s.ctx, tx, "non_existing_table") + columns, err := schema.GetColumns(c, "non_existing_table") s.NoError(err) s.Nil(columns) }) s.Run("when table name is empty should return error", func() { - columns, err := schema.GetColumns(s.ctx, tx, "") + columns, err := schema.GetColumns(c, "") s.Error(err) s.Nil(columns) }) - s.Run("when transaction is nil should return error", func() { - columns, err := schema.GetColumns(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + columns, err := schema.GetColumns(nil, "test") s.Error(err) s.Nil(columns) }) @@ -273,8 +198,10 @@ func (s *schemaTestSuite) TestGetIndexes() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should return indexes", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -283,23 +210,23 @@ func (s *schemaTestSuite) TestGetIndexes() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - indexes, err := schema.GetIndexes(s.ctx, tx, "users") + indexes, err := schema.GetIndexes(c, "users") s.NoError(err) s.NotEmpty(indexes) s.Len(indexes, 2) // Expecting the unique index on email and the primary key index on id }) s.Run("when table does not exist should return empty indexes", func() { - indexes, err := schema.GetIndexes(s.ctx, tx, "non_existing_table") + indexes, err := schema.GetIndexes(c, "non_existing_table") s.NoError(err) s.Nil(indexes) }) s.Run("when table name is empty should return error", func() { - indexes, err := schema.GetIndexes(s.ctx, tx, "") + indexes, err := schema.GetIndexes(c, "") s.Error(err) s.Nil(indexes) }) - s.Run("when transaction is nil should return error", func() { - indexes, err := schema.GetIndexes(s.ctx, nil, "test") + s.Run("when context is nil should return error", func() { + indexes, err := schema.GetIndexes(nil, "test") s.Error(err) s.Nil(indexes) }) @@ -310,13 +237,15 @@ func (s *schemaTestSuite) TestGetTables() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when no tables exist should return empty", func() { - tables, err := schema.GetTables(s.ctx, tx) + tables, err := schema.GetTables(c) s.NoError(err) s.Empty(tables) }) s.Run("when transaction is valid should return tables", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -325,13 +254,13 @@ func (s *schemaTestSuite) TestGetTables() { table.Timestamp("updated_at").UseCurrent() }) s.Require().NoError(err) - tables, err := schema.GetTables(s.ctx, tx) + tables, err := schema.GetTables(c) s.NoError(err) s.NotEmpty(tables) s.Len(tables, 1) // Expecting at least the 'users' table created in previous tests }) - s.Run("when transaction is nil should return error", func() { - tables, err := schema.GetTables(s.ctx, nil) + s.Run("when context is nil should return error", func() { + tables, err := schema.GetTables(nil) s.Error(err) s.Nil(tables) }) @@ -342,8 +271,10 @@ func (s *schemaTestSuite) TestHasColumn() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when column exists should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -353,19 +284,19 @@ func (s *schemaTestSuite) TestHasColumn() { }) s.Require().NoError(err) - exists, err := schema.HasColumn(s.ctx, tx, "users", "email") + exists, err := schema.HasColumn(c, "users", "email") s.NoError(err) s.True(exists) }) s.Run("when column does not exist should return false", func() { - exists, err := schema.HasColumn(s.ctx, tx, "users", "non_existing_column") + exists, err := schema.HasColumn(c, "users", "non_existing_column") s.NoError(err) s.False(exists) }) - s.Run("when transaction is nil should return error", func() { - exists, err := schema.HasColumn(s.ctx, nil, "users", "email") + s.Run("when context is nil should return error", func() { + exists, err := schema.HasColumn(nil, "users", "email") s.Error(err) s.False(exists) }) @@ -376,8 +307,10 @@ func (s *schemaTestSuite) TestHasColumns() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when all columns exist should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -387,17 +320,17 @@ func (s *schemaTestSuite) TestHasColumns() { }) s.Require().NoError(err) - exists, err := schema.HasColumns(s.ctx, tx, "users", []string{"email", "name"}) + exists, err := schema.HasColumns(c, "users", []string{"email", "name"}) s.NoError(err) s.True(exists) }) s.Run("when some columns do not exist should return false", func() { - exists, err := schema.HasColumns(s.ctx, tx, "users", []string{"email", "non_existing_column"}) + exists, err := schema.HasColumns(c, "users", []string{"email", "non_existing_column"}) s.NoError(err) s.False(exists) }) s.Run("when no columns are provided should return error", func() { - exists, err := schema.HasColumns(s.ctx, tx, "users", []string{}) + exists, err := schema.HasColumns(c, "users", []string{}) s.Error(err) s.False(exists) }) @@ -408,8 +341,10 @@ func (s *schemaTestSuite) TestHasIndex() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when index exists should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.Integer("company_id") table.String("name") @@ -422,20 +357,20 @@ func (s *schemaTestSuite) TestHasIndex() { }) s.Require().NoError(err) - exists, err := schema.HasIndex(s.ctx, tx, "users", []string{"email"}) + exists, err := schema.HasIndex(c, "users", []string{"email"}) s.NoError(err) s.True(exists) - exists, err = schema.HasIndex(s.ctx, tx, "users", []string{"company_id", "id"}) + exists, err = schema.HasIndex(c, "users", []string{"company_id", "id"}) s.NoError(err) s.True(exists) - exists, err = schema.HasIndex(s.ctx, tx, "users", []string{"uk_users_email"}) + exists, err = schema.HasIndex(c, "users", []string{"uk_users_email"}) s.NoError(err) s.True(exists) }) s.Run("when index does not exist should return false", func() { - exists, err := schema.HasIndex(s.ctx, tx, "users", []string{"non_existing_index"}) + exists, err := schema.HasIndex(c, "users", []string{"non_existing_index"}) s.NoError(err) s.False(exists) }) @@ -446,8 +381,10 @@ func (s *schemaTestSuite) TestHasTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when table exists should return true", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -457,22 +394,22 @@ func (s *schemaTestSuite) TestHasTable() { }) s.Require().NoError(err) - exists, err := schema.HasTable(s.ctx, tx, "users") + exists, err := schema.HasTable(c, "users") s.NoError(err) s.True(exists) }) s.Run("when table does not exist should return false", func() { - exists, err := schema.HasTable(s.ctx, tx, "non_existing_table") + exists, err := schema.HasTable(c, "non_existing_table") s.NoError(err) s.False(exists) }) s.Run("when table name is empty should return error", func() { - exists, err := schema.HasTable(s.ctx, tx, "") + exists, err := schema.HasTable(c, "") s.Error(err) s.False(exists) }) - s.Run("when transaction is nil should return error", func() { - exists, err := schema.HasTable(s.ctx, nil, "users") + s.Run("when context is nil should return error", func() { + exists, err := schema.HasTable(nil, "users") s.Error(err) s.False(exists) }) @@ -483,8 +420,10 @@ func (s *schemaTestSuite) TestRename() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should rename table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -494,27 +433,27 @@ func (s *schemaTestSuite) TestRename() { }) s.Require().NoError(err) - err = schema.Rename(s.ctx, tx, "users", "members") + err = schema.Rename(c, "users", "members") s.NoError(err) - columns, err := schema.GetColumns(s.ctx, tx, "members") + columns, err := schema.GetColumns(c, "members") s.NoError(err) s.NotEmpty(columns) s.Len(columns, 6) }) s.Run("when table does not exist should return error", func() { - err := schema.Rename(s.ctx, tx, "non_existing_table", "new_name") + err := schema.Rename(c, "non_existing_table", "new_name") s.Error(err) }) s.Run("when new name is empty should return error", func() { - err := schema.Rename(s.ctx, tx, "users", "") + err := schema.Rename(c, "users", "") s.Error(err) }) - s.Run("when transaction is nil should return error", func() { - err := schema.Rename(s.ctx, nil, "users", "new_name") + s.Run("when context is nil should return error", func() { + err := schema.Rename(nil, "users", "new_name") s.Error(err) }) } @@ -524,8 +463,10 @@ func (s *schemaTestSuite) TestTable() { s.Require().NoError(err) defer tx.Rollback() //nolint:errcheck + c := schema.NewContext(s.ctx, tx) + s.Run("when parameters are valid should alter table", func() { - err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) { + err := schema.Create(c, "users", func(table *schema.Blueprint) { table.ID() table.String("name") table.String("email").Unique() @@ -535,25 +476,25 @@ func (s *schemaTestSuite) TestTable() { }) s.Require().NoError(err) - err = schema.Table(s.ctx, tx, "users", func(table *schema.Blueprint) { + err = schema.Table(c, "users", func(table *schema.Blueprint) { table.String("phone").Nullable() }) s.NoError(err) - columns, err := schema.GetColumns(s.ctx, tx, "users") + columns, err := schema.GetColumns(c, "users") s.NoError(err) s.Len(columns, 7) // Expecting the new 'phone' column to be added }) s.Run("when table does not exist should return error", func() { - err := schema.Table(s.ctx, tx, "non_existing_table", func(table *schema.Blueprint) { + err := schema.Table(c, "non_existing_table", func(table *schema.Blueprint) { table.String("new_column") }) s.Error(err) }) - s.Run("when transaction is nil should return error", func() { - err := schema.Table(s.ctx, nil, "users", func(table *schema.Blueprint) { + s.Run("when context is nil should return error", func() { + err := schema.Table(nil, "users", func(table *schema.Blueprint) { table.String("new_column") }) s.Error(err) diff --git a/status.go b/status.go new file mode 100644 index 0000000..7fa50d9 --- /dev/null +++ b/status.go @@ -0,0 +1,27 @@ +package migris + +import ( + "context" + + "github.com/akfaiz/migris/internal/logger" +) + +// Status returns the status of the migrations. +func (m *Migrate) Status() error { + ctx := context.Background() + return m.StatusContext(ctx) +} + +// StatusContext returns the status of the migrations. +func (m *Migrate) StatusContext(ctx context.Context) error { + provider, err := m.newProvider() + if err != nil { + return err + } + migrations, err := provider.Status(ctx) + if err != nil { + return err + } + logger.PrintStatuses(migrations) + return nil +} diff --git a/up.go b/up.go new file mode 100644 index 0000000..0c230e7 --- /dev/null +++ b/up.go @@ -0,0 +1,57 @@ +package migris + +import ( + "context" + "errors" + + "github.com/akfaiz/migris/internal/logger" + "github.com/pressly/goose/v3" +) + +// Up applies the migrations in the specified directory. +func (m *Migrate) Up() error { + ctx := context.Background() + return m.UpContext(ctx) +} + +// UpContext applies the migrations in the specified directory. +func (m *Migrate) UpContext(ctx context.Context) error { + return m.UpToContext(ctx, goose.MaxVersion) +} + +// UpTo applies the migrations up to the specified version. +func (m *Migrate) UpTo(version int64) error { + ctx := context.Background() + return m.UpToContext(ctx, version) +} + +// UpToContext applies the migrations up to the specified version. +func (m *Migrate) UpToContext(ctx context.Context, version int64) error { + provider, err := m.newProvider() + if err != nil { + return err + } + hasPending, err := provider.HasPending(ctx) + if err != nil { + return err + } + if !hasPending { + logger.Info("Nothing to migrate.") + return nil + } + + logger.Infof("Running migrations.\n") + results, err := provider.UpTo(ctx, version) + if err != nil { + var partialErr *goose.PartialError + if errors.As(err, &partialErr) { + logger.PrintResults(partialErr.Applied) + logger.PrintResult(partialErr.Failed) + } + + return err + } + logger.PrintResults(results) + + return nil +} diff --git a/util.go b/util.go deleted file mode 100644 index 994156b..0000000 --- a/util.go +++ /dev/null @@ -1,44 +0,0 @@ -package schema - -import ( - "context" - "database/sql" - "log" -) - -func optional[T any](defaultValue T, values ...T) T { - return optionalAtIndex(0, defaultValue, values...) -} - -func optionalAtIndex[T any](index int, defaultValue T, values ...T) T { - if index < len(values) { - return values[index] - } - return defaultValue -} - -func execContext(ctx context.Context, tx *sql.Tx, queries ...string) error { - for _, query := range queries { - if debug { - log.Printf("Executing SQL: %s\n", query) - } - if _, err := tx.ExecContext(ctx, query); err != nil { - return err - } - } - return nil -} - -func queryRowContext(ctx context.Context, tx *sql.Tx, query string, args ...any) *sql.Row { - if debug { - log.Printf("Executing Query: %s with args: %v\n", query, args) - } - return tx.QueryRowContext(ctx, query, args...) -} - -func queryContext(ctx context.Context, tx *sql.Tx, query string, args ...any) (*sql.Rows, error) { - if debug { - log.Printf("Executing Query: %s with args: %v\n", query, args) - } - return tx.QueryContext(ctx, query, args...) -}