diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index b13bcce..1e76341 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,4 +1,4 @@
-# Contributing to \[Schema]
+# Contributing to \[Migris]
Thank you for considering contributing to this Go library! We welcome contributions of all kindsβbug reports, feature requests, documentation updates, tests, and code improvements.
diff --git a/Makefile b/Makefile
index 3370ac2..2dc2199 100644
--- a/Makefile
+++ b/Makefile
@@ -1,49 +1,54 @@
COVERAGE_FILE := coverage.out
COVERAGE_HTML := coverage.html
-MIN_COVERAGE := 80
FORMAT ?= dots
-# Format code
.PHONY: fmt
-fmt:
+fmt: # Format code
@echo "Formatting code..."
@go fmt ./...
-# Lint code
.PHONY: lint
-lint:
+lint: # Lint code
@echo "Linting code..."
@golangci-lint run --timeout 5m
-# Install dependencies and tools
.PHONY: install
-install:
+install: # Install dependencies
@echo "Installing dependencies..."
@go mod download
@go mod tidy
-# Run tests
+install-tools: # Install tools
+ @echo "Installing tools..."
+ @go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
+ @go install gotest.tools/gotestsum@latest
+
.PHONY: test
-test:
+test: # Run tests
@echo "Running tests..."
- go test -v -cover -race ./... -coverprofile=$(COVERAGE_FILE) -coverpkg=./...
+ @gotestsum --format=$(FORMAT) -- ./...
+
+.PHONY: testcov
+testcov: # Run tests with coverage
+ @echo "Running tests with coverage..."
+ @gotestsum --format=$(FORMAT) -- -coverprofile=$(COVERAGE_FILE) ./...
+ @echo "Total coverage is: " $(shell go tool cover -func=$(COVERAGE_FILE) | grep total | awk '{print $$3}')
-.PHONY: coverage
-coverage:
- @echo "Generating test coverage report..."
+.PHONY: testcov-html
+testcov-html: testcov # Generate coverage HTML report
@go tool cover -html=$(COVERAGE_FILE) -o $(COVERAGE_HTML)
- @go tool cover -func=$(COVERAGE_FILE) | tee coverage.txt
@echo "Coverage HTML report generated: $(COVERAGE_HTML)"
@open $(COVERAGE_HTML)
-.PHONY: coverage-check
-coverage-check:
- @COVERAGE=$$(go tool cover -func=$(COVERAGE_FILE) | grep total | awk '{print $$3}' | sed 's/%//'); \
- RESULT=$$(echo "$$COVERAGE < $(MIN_COVERAGE)" | bc); \
- if [ "$$RESULT" -eq 1 ]; then \
- echo "Coverage is below $(MIN_COVERAGE)%: $$COVERAGE%"; \
- exit 1; \
- else \
- echo "Coverage is sufficient: $$COVERAGE%"; \
- fi
\ No newline at end of file
+.PHONY: help
+help:
+ @echo "Available commands:"
+ @echo " fmt - Format code"
+ @echo " lint - Lint code"
+ @echo " install - Install dependencies"
+ @echo " install-tools- Install development tools"
+ @echo " test - Run tests"
+ @echo " testcov - Run tests with coverage"
+ @echo " testcov-html - Generate coverage HTML report"
+ @echo " help - Show this help message"
\ No newline at end of file
diff --git a/README.md b/README.md
index b6c800f..a725956 100644
--- a/README.md
+++ b/README.md
@@ -1,74 +1,130 @@
-# Go-Schema
-[](https://github.com/afkdevs/go-schema/actions/workflows/ci.yml)
-[](https://goreportcard.com/report/github.com/afkdevs/go-schema)
-[](https://codecov.io/gh/afkdevs/go-schema)
-[](https://pkg.go.dev/github.com/afkdevs/go-schema)
-[](https://golang.org/doc/devel/release.html)
-[](LICENSE)
+# Migris
-`Go-Schema` is a simple Go library for building and running SQL schema (DDL) code in a clean, readable, and migration-friendly way. Inspired by Laravel's Schema Builder, it helps you easily create or change database tablesβand works well with tools like [`goose`](https://github.com/pressly/goose).
+**Migris** is a database migration library for Go, inspired by Laravel's migrations.
+It combines the power of [pressly/goose](https://github.com/pressly/goose) with a fluent schema builder, making migrations easy to write, run, and maintain.
-## Features
+## β¨ Features
-- π Programmatic table and column definitions
-- ποΈ Support for common data types and constraints
-- βοΈ Auto-generates `CREATE TABLE`, `ALTER TABLE`, index and foreign key SQL
-- π Designed to work with database transactions
-- π§ͺ Built-in types and functions make migration code clear and testable
-- π Provides helper functions to get list tables, columns, and indexes
+- π¦ Migration management (`up`, `down`, `reset`, `status`, `create`)
+- ποΈ Fluent schema builder (similar to Laravel migrations)
+- ποΈ Supports PostgreSQL, MySQL, and MariaDB
+- π Transaction-based migrations
+- π οΈ Integration with Go projects (no external CLI required)
-## Supported Databases
+## π Installation
-Currently, `schema` is tested and optimized for:
+```bash
+go get -u github.com/akfaiz/migris
+```
-* PostgreSQL
-* MySQL / MariaDB
-* SQLite (TODO)
+## π Usage
-## Installation
+### 1. Create a Migration
-```bash
-go get github.com/afkdevs/go-schema
-```
+Migrations are defined in Go files using the schema builder:
-## Integration Example (with goose)
```go
package migrations
import (
- "context"
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/schema"
+)
+
+func init() {
+ migris.AddMigrationContext(upCreateUsersTable, downCreateUsersTable)
+}
+
+func upCreateUsersTable(c *schema.Context) error {
+ return schema.Create(c, "users", func(table *schema.Blueprint) {
+ table.ID()
+ table.String("name")
+ table.String("email")
+ table.Timestamp("email_verified_at").Nullable()
+ table.String("password")
+ table.Timestamps()
+ })
+}
+
+func downCreateUsersTable(c *schema.Context) error {
+ return schema.DropIfExists(c, "users")
+}
+```
+
+This creates a `users` table with common fields.
+
+### 2. Run Migrations
+
+You can manage migrations directly from Go code:
+
+```go
+package migrate
+
+import (
"database/sql"
+ "fmt"
- "github.com/afkdevs/go-schema"
- "github.com/pressly/goose/v3"
+ "github.com/akfaiz/migris"
+ _ "migrations" // Import migrations
+ _ "github.com/lib/pq" // PostgreSQL driver
)
-func init() {
- goose.AddMigrationContext(upCreateUsersTable, downCreateUsersTable)
+func Up() error {
+ m, err := newMigrate()
+ if err != nil {
+ return err
+ }
+ return m.Up()
}
-func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error {
- return schema.Create(ctx, tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name")
- table.String("email")
- table.Timestamp("email_verified_at").Nullable()
- table.String("password")
- table.Timestamps()
- })
+func Create(name string) error {
+ m, err := newMigrate()
+ if err != nil {
+ return err
+ }
+ return m.Create(name)
}
-func downCreateUsersTable(ctx context.Context, tx *sql.Tx) error {
- return schema.Drop(ctx, tx, "users")
+func Reset() error {
+ m, err := newMigrate()
+ if err != nil {
+ return err
+ }
+ return m.Reset()
+}
+
+func Down() error {
+ m, err := newMigrate()
+ if err != nil {
+ return err
+ }
+ return m.Down()
+}
+
+func Status() error {
+ m, err := newMigrate()
+ if err != nil {
+ return err
+ }
+ return m.Status()
+}
+
+func newMigrate() (*migris.Migrate, error) {
+ dsn := "postgres://user:pass@localhost:5432/mydb?sslmode=disable"
+ db, err := sql.Open("postgres", dsn)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open database: %w", err)
+ }
+ return migris.New("postgres", migris.WithDB(db), migris.WithMigrationDir("migrations")), nil
}
```
-For more examples, check out the [examples](examples/basic) directory.
-## Documentation
-For detailed documentation, please refer to the [GoDoc](https://pkg.go.dev/github.com/afkdevs/go-schema) page.
+## π Roadmap
+
+- [ ] Add SQLite support
+- [ ] CLI wrapper for quick usage
-## Contributing
-Contributions are welcome! Please read the [contributing guidelines](CONTRIBUTING.md) and submit a pull request.
+## π License
-## License
-This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
\ No newline at end of file
+MIT License.
+See [LICENSE](./LICENSE) for details.
\ No newline at end of file
diff --git a/blueprint.go b/blueprint.go
deleted file mode 100644
index 0a3f7d6..0000000
--- a/blueprint.go
+++ /dev/null
@@ -1,887 +0,0 @@
-package schema
-
-import (
- "fmt"
- "slices"
-)
-
-type columnType uint8
-
-const (
- columnTypeBoolean columnType = iota
- columnTypeChar
- columnTypeString
- columnTypeLongText
- columnTypeMediumText
- columnTypeText
- columnTypeTinyText
- columnTypeBigInteger
- columnTypeInteger
- columnTypeMediumInteger
- columnTypeSmallInteger
- columnTypeTinyInteger
- columnTypeDecimal
- columnTypeDouble
- columnTypeFloat
- columnTypeDateTime
- columnTypeDateTimeTz
- columnTypeDate
- columnTypeTime
- columnTypeTimestamp
- columnTypeTimestampTz
- columnTypeYear
- columnTypeBinary
- columnTypeJSON
- columnTypeJSONB
- columnTypeGeography
- columnTypeGeometry
- columnTypePoint
- columnTypeUUID
- columnTypeEnum
- columnTypeCustom // Custom type for user-defined types
-)
-
-type indexType int
-
-const (
- indexTypeIndex indexType = iota
- indexTypeUnique
- indexTypePrimary
- indexTypeFullText
-)
-
-// Blueprint represents a schema blueprint for creating or altering a database table.
-type Blueprint struct {
- commands []string // commands to be executed
- name string
- newName string
- charset string
- collation string
- engine string
- columns []*columnDefinition
- indexes []*indexDefinition
- foreignKeys []*foreignKeyDefinition
- dropColumns []string
- renameColumns map[string]string // old column name to new column name
- dropIndexes []string // indexes to be dropped
- dropForeignKeys []string // foreign keys to be dropped
- dropPrimaryKeys []string // primary keys to be dropped
- dropUniqueKeys []string // unique keys to be dropped
- dropFullText []string // fulltext indexes to be dropped
- renameIndexes map[string]string // old index name to new index name
-}
-
-// Charset sets the character set for the table in the blueprint.
-func (b *Blueprint) Charset(charset string) {
- b.charset = charset
-}
-
-// Collation sets the collation for the table in the blueprint.
-func (b *Blueprint) Collation(collation string) {
- b.collation = collation
-}
-
-// Engine sets the storage engine for the table in the blueprint.
-func (b *Blueprint) Engine(engine string) {
- b.engine = engine
-}
-
-// Column creates a new custom column definition in the blueprint with the specified name and type.
-func (b *Blueprint) Column(name string, columnType string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeCustom,
- customColumnType: columnType,
- })
-}
-
-// Boolean creates a new boolean column definition in the blueprint.
-func (b *Blueprint) Boolean(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeBoolean,
- })
-}
-
-// Char creates a new char column definition in the blueprint.
-func (b *Blueprint) Char(name string, length ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeChar,
- length: optional(0, length...),
- })
-}
-
-// String creates a new string column definition in the blueprint.
-func (b *Blueprint) String(name string, length ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeString,
- length: optional(0, length...),
- })
-}
-
-// LongText creates a new long text column definition in the blueprint.
-func (b *Blueprint) LongText(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeLongText,
- })
-}
-
-// Text creates a new text column definition in the blueprint.
-func (b *Blueprint) Text(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeText,
- })
-}
-
-// MediumText creates a new medium text column definition in the blueprint.
-func (b *Blueprint) MediumText(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeMediumText,
- })
-}
-
-// TinyText creates a new tiny text column definition in the blueprint.
-func (b *Blueprint) TinyText(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeTinyText,
- })
-}
-
-// BigIncrements creates a new big increments column definition in the blueprint.
-func (b *Blueprint) BigIncrements(name string) ColumnDefinition {
- return b.UnsignedBigInteger(name, true)
-}
-
-// BigInteger creates a new big integer column definition in the blueprint.
-func (b *Blueprint) BigInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeBigInteger,
- autoIncrement: optional(false, autoIncrement...),
- })
-}
-
-// Decimal creates a new decimal column definition in the blueprint.
-//
-// The total and places parameters are optional.
-//
-// Example:
-//
-// table.Decimal("price", 10, 2) // creates a decimal column with total 10 and places 2
-//
-// table.Decimal("price") // creates a decimal column with default total 8 and places 2
-func (b *Blueprint) Decimal(name string, params ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeDecimal,
- total: optional(8, params...),
- places: optionalAtIndex(1, 2, params...),
- })
-}
-
-// Double creates a new double column definition in the blueprint.
-func (b *Blueprint) Double(name string, params ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeDouble,
- total: optional(0, params...),
- places: optionalAtIndex(1, 0, params...),
- })
-}
-
-// Float creates a new float column definition in the blueprint.
-//
-// The total and places parameters are optional.
-//
-// Example:
-//
-// table.Float("price", 10, 2) // creates a float column with total 10 and places 2
-//
-// table.Float("price") // creates a float column with default total 8 and places 2
-func (b *Blueprint) Float(name string, params ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeFloat,
- total: optional(8, params...),
- places: optionalAtIndex(1, 2, params...),
- })
-}
-
-// ID creates a new big increments column definition in the blueprint with the name "id" or a custom name.
-//
-// If a name is provided, it will be used as the column name; otherwise, "id" will be used.
-func (b *Blueprint) ID(name ...string) ColumnDefinition {
- return b.BigIncrements(optional("id", name...)).Primary()
-}
-
-// Increments create a new increment column definition in the blueprint.
-func (b *Blueprint) Increments(name string) ColumnDefinition {
- return b.UnsignedInteger(name, true)
-}
-
-// Integer creates a new integer column definition in the blueprint.
-func (b *Blueprint) Integer(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeInteger,
- autoIncrement: optional(false, autoIncrement...),
- })
-}
-
-// MediumIncrements creates a new medium increments column definition in the blueprint.
-func (b *Blueprint) MediumIncrements(name string) ColumnDefinition {
- return b.UnsignedMediumInteger(name, true)
-}
-
-func (b *Blueprint) MediumInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeMediumInteger,
- autoIncrement: optional(false, autoIncrement...),
- })
-}
-
-// SmallIncrements creates a new small increments column definition in the blueprint.
-func (b *Blueprint) SmallIncrements(name string) ColumnDefinition {
- return b.UnsignedSmallInteger(name, true)
-}
-
-// SmallInteger creates a new small integer column definition in the blueprint.
-func (b *Blueprint) SmallInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeSmallInteger,
- autoIncrement: optional(false, autoIncrement...),
- })
-}
-
-// TinyIncrements creates a new tiny increments column definition in the blueprint.
-func (b *Blueprint) TinyIncrements(name string) ColumnDefinition {
- return b.UnsignedTinyInteger(name, true)
-}
-
-// TinyInteger creates a new tiny integer column definition in the blueprint.
-func (b *Blueprint) TinyInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeTinyInteger,
- autoIncrement: optional(false, autoIncrement...),
- })
-}
-
-// UnsignedBigInteger creates a new unsigned big integer column definition in the blueprint.
-func (b *Blueprint) UnsignedBigInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeBigInteger,
- autoIncrement: optional(false, autoIncrement...),
- unsigned: true,
- })
-}
-
-// UnsignedInteger creates a new unsigned integer column definition in the blueprint.
-func (b *Blueprint) UnsignedInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeInteger,
- autoIncrement: optional(false, autoIncrement...),
- unsigned: true,
- })
-}
-
-// UnsignedMediumInteger creates a new unsigned medium integer column definition in the blueprint.
-func (b *Blueprint) UnsignedMediumInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeMediumInteger,
- autoIncrement: optional(false, autoIncrement...),
- unsigned: true,
- })
-}
-
-// UnsignedSmallInteger creates a new unsigned small integer column definition in the blueprint.
-func (b *Blueprint) UnsignedSmallInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeSmallInteger,
- autoIncrement: optional(false, autoIncrement...),
- unsigned: true,
- })
-}
-
-// UnsignedTinyInteger creates a new unsigned tiny integer column definition in the blueprint.
-func (b *Blueprint) UnsignedTinyInteger(name string, autoIncrement ...bool) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeTinyInteger,
- autoIncrement: optional(false, autoIncrement...),
- unsigned: true,
- })
-}
-
-// DateTime creates a new date time column definition in the blueprint.
-func (b *Blueprint) DateTime(name string, precision ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeDateTime,
- precision: optional(0, precision...),
- })
-}
-
-// DateTimeTz creates a new date time with a time zone column definition in the blueprint.
-func (b *Blueprint) DateTimeTz(name string, precision ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeDateTimeTz,
- precision: optional(0, precision...),
- })
-}
-
-// Date creates a new date column definition in the blueprint.
-func (b *Blueprint) Date(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeDate,
- })
-}
-
-// Time creates a new time column definition in the blueprint.
-func (b *Blueprint) Time(name string, precission ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeTime,
- precision: optional(0, precission...),
- })
-}
-
-// Timestamp creates a new timestamp column definition in the blueprint.
-// The precision parameter is optional and defaults to 0 if not provided.
-func (b *Blueprint) Timestamp(name string, precision ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeTimestamp,
- precision: optional(0, precision...),
- })
-}
-
-// TimestampTz creates a new timestamp with time zone column definition in the blueprint.
-// The precision parameter is optional and defaults to 0 if not provided.
-func (b *Blueprint) TimestampTz(name string, precision ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeTimestampTz,
- precision: optional(0, precision...),
- })
-}
-
-// Timestamps adds created_at and updated_at timestamp columns to the blueprint.
-func (b *Blueprint) Timestamps() {
- b.Timestamp("created_at").Nullable(false).UseCurrent()
- b.Timestamp("updated_at").Nullable(false).UseCurrent().UseCurrentOnUpdate()
-}
-
-// TimestampsTz adds created_at and updated_at timestamp with time zone columns to the blueprint.
-func (b *Blueprint) TimestampsTz() {
- b.TimestampTz("created_at").Nullable(false).UseCurrent()
- b.TimestampTz("updated_at").Nullable(false).UseCurrent().UseCurrentOnUpdate()
-}
-
-// Year creates a new year column definition in the blueprint.
-func (b *Blueprint) Year(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeYear,
- })
-}
-
-// Binary creates a new binary column definition in the blueprint.
-func (b *Blueprint) Binary(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeBinary,
- })
-}
-
-// JSON creates a new JSON column definition in the blueprint.
-func (b *Blueprint) JSON(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeJSON,
- })
-}
-
-// JSONB creates a new JSONB column definition in the blueprint.
-func (b *Blueprint) JSONB(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeJSONB,
- })
-}
-
-// UUID creates a new UUID column definition in the blueprint.
-func (b *Blueprint) UUID(name string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeUUID,
- })
-}
-
-// Geography creates a new geography column definition in the blueprint.
-// The subType parameter is optional and can be used to specify the type of geography (e.g., "Point", "LineString", "Polygon").
-// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geography type.
-func (b *Blueprint) Geography(name string, subType string, srid ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeGeography,
- subType: subType,
- srid: optional(4326, srid...),
- })
-}
-
-// Geometry creates a new geometry column definition in the blueprint.
-// The subType parameter is optional and can be used to specify the type of geometry (e.g., "Point", "LineString", "Polygon").
-// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geometry type.
-func (b *Blueprint) Geometry(name string, subType string, srid ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeGeometry,
- subType: subType,
- srid: optional(0, srid...),
- })
-}
-
-// Point creates a new point column definition in the blueprint.
-func (b *Blueprint) Point(name string, srid ...int) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypePoint,
- srid: optional(4326, srid...),
- })
-}
-
-// Enum creates a new enum column definition in the blueprint.
-// The allowedEnums parameter is a slice of strings that defines the allowed values for the enum column.
-//
-// Example:
-//
-// table.Enum("status", []string{"active", "inactive", "pending"})
-// table.Enum("role", []string{"admin", "user", "guest"}).Comment("User role in the system")
-func (b *Blueprint) Enum(name string, allowedEnums []string) ColumnDefinition {
- return b.addColumn(&columnDefinition{
- name: name,
- columnType: columnTypeEnum,
- allowedEnums: allowedEnums,
- })
-}
-
-// Index creates a new index definition in the blueprint.
-//
-// Example:
-//
-// table.Index("email")
-// table.Index("email", "username") // creates a composite index
-// table.Index("email").Algorithm("btree") // creates a btree index
-func (b *Blueprint) Index(column string, otherColumns ...string) IndexDefinition {
- index := &indexDefinition{
- indexType: indexTypeIndex,
- columns: append([]string{column}, otherColumns...),
- }
- b.indexes = append(b.indexes, index)
- b.addCommand("index")
-
- return index
-}
-
-// Unique creates a new unique index definition in the blueprint.
-//
-// Example:
-//
-// table.Unique("email")
-// table.Unique("email", "username") // creates a composite unique index
-func (b *Blueprint) Unique(column string, otherColumns ...string) IndexDefinition {
- index := &indexDefinition{
- indexType: indexTypeUnique,
- columns: append([]string{column}, otherColumns...),
- }
- b.indexes = append(b.indexes, index)
- b.addCommand("unique")
-
- return index
-}
-
-// Primary creates a new primary key index definition in the blueprint.
-//
-// Example:
-//
-// table.Primary("id")
-// table.Primary("id", "email") // creates a composite primary key
-func (b *Blueprint) Primary(column string, otherColumns ...string) IndexDefinition {
- index := &indexDefinition{
- indexType: indexTypePrimary,
- columns: append([]string{column}, otherColumns...),
- }
- b.indexes = append(b.indexes, index)
- b.addCommand("primary")
- return index
-}
-
-// FullText creates a new fulltext index definition in the blueprint.
-func (b *Blueprint) FullText(column string, otherColumns ...string) IndexDefinition {
- index := &indexDefinition{
- indexType: indexTypeFullText,
- columns: append([]string{column}, otherColumns...),
- }
- b.indexes = append(b.indexes, index)
- b.addCommand("fullText")
-
- return index
-}
-
-// Foreign creates a new foreign key definition in the blueprint.
-//
-// Example:
-//
-// table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE")
-func (b *Blueprint) Foreign(column string) ForeignKeyDefinition {
- fk := &foreignKeyDefinition{
- tableName: b.name,
- column: column,
- }
- b.foreignKeys = append(b.foreignKeys, fk)
- b.addCommand("foreign")
- return fk
-}
-
-// DropColumn adds a column to be dropped from the table.
-//
-// Example:
-//
-// table.DropColumn("old_column")
-// table.DropColumn("old_column", "another_old_column") // drops multiple columns
-func (b *Blueprint) DropColumn(column string, otherColumns ...string) {
- b.dropColumns = append(b.dropColumns, append([]string{column}, otherColumns...)...)
- b.addCommand("dropColumn")
-}
-
-// RenameColumn changes the name of the table in the blueprint.
-//
-// Example:
-//
-// table.RenameColumn("old_table_name", "new_table_name")
-func (b *Blueprint) RenameColumn(oldColumn string, newColumn string) {
- if b.renameColumns == nil {
- b.renameColumns = make(map[string]string)
- }
- b.renameColumns[oldColumn] = newColumn
- b.addCommand("renameColumn")
-}
-
-// DropIndex adds an index to be dropped from the table.
-func (b *Blueprint) DropIndex(indexName string) {
- b.dropIndexes = append(b.dropIndexes, indexName)
- b.addCommand("dropIndex")
-}
-
-// DropForeign adds a foreign key to be dropped from the table.
-func (b *Blueprint) DropForeign(foreignKeyName string) {
- b.dropForeignKeys = append(b.dropForeignKeys, foreignKeyName)
- b.addCommand("dropForeign")
-}
-
-// DropPrimary adds a primary key to be dropped from the table.
-func (b *Blueprint) DropPrimary(primaryKeyName string) {
- b.dropPrimaryKeys = append(b.dropPrimaryKeys, primaryKeyName)
- b.addCommand("dropPrimary")
-}
-
-// DropUnique adds a unique key to be dropped from the table.
-func (b *Blueprint) DropUnique(uniqueKeyName string) {
- b.dropUniqueKeys = append(b.dropUniqueKeys, uniqueKeyName)
- b.addCommand("dropUnique")
-}
-
-func (b *Blueprint) DropFulltext(indexName string) {
- b.dropFullText = append(b.dropFullText, indexName)
- b.addCommand("dropFullText")
-}
-
-// RenameIndex changes the name of an index in the blueprint.
-// Example:
-//
-// table.RenameIndex("old_index_name", "new_index_name")
-func (b *Blueprint) RenameIndex(oldIndexName string, newIndexName string) {
- if b.renameIndexes == nil {
- b.renameIndexes = make(map[string]string)
- }
- b.renameIndexes[oldIndexName] = newIndexName
- b.addCommand("renameIndex")
-}
-
-func (b *Blueprint) getAddedColumns() []*columnDefinition {
- var addedColumns []*columnDefinition
- for _, col := range b.columns {
- if !col.changed {
- addedColumns = append(addedColumns, col)
- }
- }
- return addedColumns
-}
-
-func (b *Blueprint) getChangedColumns() []*columnDefinition {
- var changedColumns []*columnDefinition
- for _, col := range b.columns {
- if col.changed {
- changedColumns = append(changedColumns, col)
- }
- }
- return changedColumns
-}
-
-func (b *Blueprint) create() {
- b.addCommand("create")
-}
-
-func (b *Blueprint) creating() bool {
- for _, command := range b.commands {
- if command == "create" || command == "createIfNotExists" {
- return true
- }
- }
- return false
-}
-
-func (b *Blueprint) createIfNotExists() {
- b.addCommand("createIfNotExists")
-}
-
-func (b *Blueprint) drop() {
- b.addCommand("drop")
-}
-
-func (b *Blueprint) dropIfExists() {
- b.addCommand("dropIfNotExists")
-}
-
-func (b *Blueprint) rename() {
- b.addCommand("rename")
-}
-
-func (b *Blueprint) addImpliedCommands() {
- if len(b.getAddedColumns()) > 0 && !b.creating() {
- b.commands = append([]string{"add"}, b.commands...)
- }
- if len(b.getChangedColumns()) > 0 && !b.creating() {
- b.commands = append([]string{"change"}, b.commands...)
- }
-}
-
-func (b *Blueprint) toSql(grammar grammar) ([]string, error) {
- b.addImpliedCommands()
-
- var statements []string
-
- commandMap := map[string]func(*Blueprint) (string, error){
- "create": grammar.compileCreate,
- "createIfNotExists": grammar.compileCreateIfNotExists,
- "add": grammar.compileAdd,
- "drop": grammar.compileDrop,
- "dropIfNotExists": grammar.compileDropIfExists,
- "rename": grammar.compileRename,
- }
- for _, command := range b.commands {
- if compileFunc, exists := commandMap[command]; exists {
- sql, err := compileFunc(b)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- switch command {
- case "create", "createIfNotExists", "add":
- for _, col := range b.getAddedColumns() {
- if col.index {
- indexDef := &indexDefinition{
- indexType: indexTypeIndex,
- columns: []string{col.name},
- name: col.indexName,
- }
- sql, err := grammar.compileIndex(b, indexDef)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- }
- case "change":
- changedStatements, err := grammar.compileChange(b)
- if err != nil {
- return nil, err
- }
- statements = append(statements, changedStatements...)
- case "index":
- indexStatements, err := b.getIndexStatements(grammar, indexTypeIndex)
- if err != nil {
- return nil, err
- }
- statements = append(statements, indexStatements...)
- case "unique":
- uniqueStatements, err := b.getIndexStatements(grammar, indexTypeUnique)
- if err != nil {
- return nil, err
- }
- statements = append(statements, uniqueStatements...)
- case "primary":
- primaryStatements, err := b.getIndexStatements(grammar, indexTypePrimary)
- if err != nil {
- return nil, err
- }
- statements = append(statements, primaryStatements...)
- case "fullText":
- fulltextStatements, err := b.getIndexStatements(grammar, indexTypeFullText)
- if err != nil {
- return nil, err
- }
- statements = append(statements, fulltextStatements...)
- case "foreign":
- for _, foreignKey := range b.foreignKeys {
- sql, err := grammar.compileForeign(b, foreignKey)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- case "dropColumn":
- sql, err := grammar.compileDropColumn(b)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- case "renameColumn":
- for oldName, newName := range b.renameColumns {
- sql, err := grammar.compileRenameColumn(b, oldName, newName)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- case "dropIndex":
- for _, indexName := range b.dropIndexes {
- sql, err := grammar.compileDropIndex(b, indexName)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- case "dropUnique":
- for _, uniqueKeyName := range b.dropUniqueKeys {
- sql, err := grammar.compileDropUnique(b, uniqueKeyName)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- case "dropFullText":
- for _, fullTextIndexName := range b.dropFullText {
- sql, err := grammar.compileDropFulltext(b, fullTextIndexName)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- case "dropPrimary":
- for _, primaryKeyName := range b.dropPrimaryKeys {
- sql, err := grammar.compileDropPrimary(b, primaryKeyName)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- case "renameIndex":
- for oldName, newName := range b.renameIndexes {
- sql, err := grammar.compileRenameIndex(b, oldName, newName)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- case "dropForeign":
- for _, foreignKeyName := range b.dropForeignKeys {
- sql, err := grammar.compileDropForeign(b, foreignKeyName)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- }
- }
-
- return statements, nil
-}
-
-func (b *Blueprint) getIndexStatements(grammar grammar, idxType indexType) ([]string, error) {
- indexCommandMap := map[indexType]func(*Blueprint, *indexDefinition) (string, error){
- indexTypeIndex: grammar.compileIndex,
- indexTypeUnique: grammar.compileUnique,
- indexTypePrimary: grammar.compilePrimary,
- indexTypeFullText: grammar.compileFullText,
- }
- var statements []string
- for _, index := range b.indexes {
- if index.indexType == idxType {
- compileFunc, exists := indexCommandMap[idxType]
- if !exists {
- return nil, fmt.Errorf("unsupported index type: %d", idxType)
- }
- sql, err := compileFunc(b, index)
- if err != nil {
- return nil, err
- }
- if sql != "" {
- statements = append(statements, sql)
- }
- }
- }
- return statements, nil
-}
-
-func (b *Blueprint) addColumn(col *columnDefinition) *columnDefinition {
- b.columns = append(b.columns, col)
- return col
-}
-
-func (b *Blueprint) addCommand(command string) {
- if command == "" {
- return
- }
- if !slices.Contains(b.commands, command) {
- b.commands = append(b.commands, command)
- }
-}
diff --git a/builder.go b/builder.go
deleted file mode 100644
index e63c1aa..0000000
--- a/builder.go
+++ /dev/null
@@ -1,177 +0,0 @@
-package schema
-
-import (
- "context"
- "database/sql"
- "errors"
-)
-
-// Builder is an interface that defines methods for creating, dropping, and managing database tables.
-type Builder interface {
- // Create creates a new table with the given name and applies the provided blueprint.
- Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error
- // CreateIfNotExists creates a new table with the given name and applies the provided blueprint if it does not already exist.
- CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error
- // Drop removes the table with the given name.
- Drop(ctx context.Context, tx *sql.Tx, name string) error
- // DropIfExists removes the table with the given name if it exists.
- DropIfExists(ctx context.Context, tx *sql.Tx, name string) error
- // GetColumns retrieves the columns of the specified table.
- GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error)
- // GetIndexes retrieves the indexes of the specified table.
- GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error)
- // GetTables retrieves all tables in the database.
- GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error)
- // HasColumn checks if the specified table has the given column.
- HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error)
- // HasColumns checks if the specified table has all the given columns.
- HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error)
- // HasIndex checks if the specified table has the given index.
- HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error)
- // HasTable checks if a table with the given name exists.
- HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error)
- // Rename renames a table from oldName to newName.
- Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error
- // Table applies the provided blueprint to the specified table.
- Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error
-}
-
-// NewBuilder creates a new Builder instance based on the specified dialect.
-// It returns an error if the dialect is not supported.
-//
-// Supported dialects are "postgres", "pgx", "mysql", and "mariadb".
-func NewBuilder(dialect string) (Builder, error) {
- dialectVal := dialectFromString(dialect)
- switch dialectVal {
- case dialectMySQL:
- return newMysqlBuilder(), nil
- case dialectPostgres:
- return newPostgresBuilder(), nil
- default:
- return nil, errors.New("unsupported dialect: " + dialect)
- }
-}
-
-type baseBuilder struct {
- grammar grammar
-}
-
-func (b *baseBuilder) validateTxAndName(tx *sql.Tx, name string) error {
- if name == "" {
- return errors.New("table name is empty")
- }
- if tx == nil {
- return errors.New("transaction is nil")
- }
- return nil
-}
-
-func (b *baseBuilder) validateCreateAndAlter(tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
- if name == "" {
- return errors.New("table name is empty")
- }
- if blueprint == nil {
- return errors.New("blueprint function is nil")
- }
- if tx == nil {
- return errors.New("transaction is nil")
- }
- return nil
-}
-
-func (b *baseBuilder) Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
- if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil {
- return err
- }
-
- bp := &Blueprint{name: name}
- bp.create()
- blueprint(bp)
-
- statements, err := bp.toSql(b.grammar)
- if err != nil {
- return err
- }
-
- return execContext(ctx, tx, statements...)
-}
-
-func (b *baseBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
- if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil {
- return err
- }
-
- bp := &Blueprint{name: name}
- bp.createIfNotExists()
- blueprint(bp)
-
- statements, err := bp.toSql(b.grammar)
- if err != nil {
- return err
- }
-
- return execContext(ctx, tx, statements...)
-}
-
-func (b *baseBuilder) Drop(ctx context.Context, tx *sql.Tx, name string) error {
- if err := b.validateTxAndName(tx, name); err != nil {
- return err
- }
-
- bp := &Blueprint{name: name}
- bp.drop()
- statements, err := bp.toSql(b.grammar)
- if err != nil {
- return err
- }
-
- return execContext(ctx, tx, statements...)
-}
-
-func (b *baseBuilder) DropIfExists(ctx context.Context, tx *sql.Tx, name string) error {
- if err := b.validateTxAndName(tx, name); err != nil {
- return err
- }
-
- bp := &Blueprint{name: name}
- bp.dropIfExists()
- statements, err := bp.toSql(b.grammar)
- if err != nil {
- return err
- }
-
- return execContext(ctx, tx, statements...)
-}
-
-func (b *baseBuilder) Rename(ctx context.Context, tx *sql.Tx, oldName string, newName string) error {
- if oldName == "" || newName == "" {
- return errors.New("old or new table name is empty")
- }
- if tx == nil {
- return errors.New("transaction is nil")
- }
- bp := &Blueprint{name: oldName, newName: newName}
- bp.rename()
- statements, err := bp.toSql(b.grammar)
- if err != nil {
- return err
- }
-
- return execContext(ctx, tx, statements...)
-}
-
-func (b *baseBuilder) Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
- if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil {
- return err
- }
-
- bp := &Blueprint{name: name}
- blueprint(bp)
-
- statements, err := bp.toSql(b.grammar)
- if err != nil {
- return err
- }
-
- return execContext(ctx, tx, statements...)
-}
diff --git a/column_definition.go b/column_definition.go
deleted file mode 100644
index 516bea9..0000000
--- a/column_definition.go
+++ /dev/null
@@ -1,137 +0,0 @@
-package schema
-
-import "slices"
-
-// ColumnDefinition defines the interface for defining a column in a database table.
-type ColumnDefinition interface {
- // AutoIncrement sets the column to auto-increment.
- // This is typically used for primary key columns.
- AutoIncrement() ColumnDefinition
- // Change changes the column definition.
- Change() ColumnDefinition
- // Comment adds a comment to the column definition.
- Comment(comment string) ColumnDefinition
- // Default sets a default value for the column.
- Default(value any) ColumnDefinition
- // Index adds an index to the column.
- Index(indexName ...string) ColumnDefinition
- // Nullable sets the column to be nullable or not.
- Nullable(value ...bool) ColumnDefinition
- // Primary sets the column as a primary key.
- Primary() ColumnDefinition
- // Unique sets the column to be unique.
- Unique(indexName ...string) ColumnDefinition
- // Unsigned sets the column to be unsigned (applicable for numeric types).
- Unsigned() ColumnDefinition
- // UseCurrent sets the column to use the current timestamp as default.
- UseCurrent() ColumnDefinition
- // UseCurrentOnUpdate sets the column to use the current timestamp on update.
- UseCurrentOnUpdate() ColumnDefinition
-}
-
-var _ ColumnDefinition = &columnDefinition{}
-
-type columnDefinition struct {
- name string
- columnType columnType
- customColumnType string // for custom column types
- commands []string
- comment string
- defaultValue any
- onUpdateValue string
- nullable bool
- autoIncrement bool
- unsigned bool
- primary bool
- index bool
- indexName string
- unique bool
- uniqueName string
- length int
- precision int
- total int
- places int
- changed bool
- allowedEnums []string // for enum type columns
- subType string // for geography and geometry types
- srid int // for geography and geometry types
-}
-
-func (c *columnDefinition) addCommand(command string) {
- if command == "" {
- return
- }
- if !slices.Contains(c.commands, command) {
- c.commands = append(c.commands, command)
- }
-}
-
-func (c *columnDefinition) hasCommand(command string) bool {
- return slices.Contains(c.commands, command)
-}
-
-func (c *columnDefinition) AutoIncrement() ColumnDefinition {
- c.addCommand("autoIncrement")
- c.autoIncrement = true
- return c
-}
-
-func (c *columnDefinition) Comment(comment string) ColumnDefinition {
- c.addCommand("comment")
- c.comment = comment
- return c
-}
-
-func (c *columnDefinition) Default(value any) ColumnDefinition {
- c.addCommand("default")
- c.defaultValue = value
-
- return c
-}
-
-func (c *columnDefinition) Index(indexName ...string) ColumnDefinition {
- c.index = true
- c.indexName = optional("", indexName...)
- c.addCommand("index")
- return c
-}
-
-func (c *columnDefinition) Nullable(value ...bool) ColumnDefinition {
- c.addCommand("nullable")
- c.nullable = optional(true, value...)
- return c
-}
-
-func (c *columnDefinition) Primary() ColumnDefinition {
- c.addCommand("primary")
- c.primary = true
- return c
-}
-
-func (c *columnDefinition) Unique(indexName ...string) ColumnDefinition {
- c.addCommand("unique")
- c.unique = true
- c.uniqueName = optional("", indexName...)
- return c
-}
-
-func (c *columnDefinition) Change() ColumnDefinition {
- c.addCommand("change")
- c.changed = true
- return c
-}
-
-func (c *columnDefinition) Unsigned() ColumnDefinition {
- c.unsigned = true
- return c
-}
-
-func (c *columnDefinition) UseCurrent() ColumnDefinition {
- c.Default("CURRENT_TIMESTAMP")
- return c
-}
-
-func (c *columnDefinition) UseCurrentOnUpdate() ColumnDefinition {
- c.onUpdateValue = "CURRENT_TIMESTAMP"
- return c
-}
diff --git a/create.go b/create.go
new file mode 100644
index 0000000..f348cb4
--- /dev/null
+++ b/create.go
@@ -0,0 +1,99 @@
+package migris
+
+import (
+ "text/template"
+
+ "github.com/akfaiz/migris/internal/parser"
+ "github.com/pressly/goose/v3"
+)
+
+// Create creates a new migration file with the given name in the specified directory.
+func (m *Migrate) Create(name string) error {
+ tmpl := getMigrationTemplate(name)
+ return goose.CreateWithTemplate(nil, m.migrationDir, tmpl, name, "go")
+}
+
+func getMigrationTemplate(name string) *template.Template {
+ tableName, create := parser.ParseMigrationName(name)
+ if create {
+ return migrationCreateTemplate(tableName)
+ }
+ if tableName != "" {
+ return migrationUpdateTemplate(tableName)
+ }
+ return migrationTemplate
+}
+
+var migrationTemplate = template.Must(template.New("migrator.go-migration").Parse(`package migrations
+
+import (
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/schema"
+)
+
+func init() {
+ migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}})
+}
+
+func up{{.CamelName}}(c *schema.Context) error {
+ // This code is executed when the migration is applied.
+ return nil
+}
+
+func down{{.CamelName}}(c *schema.Context) error {
+ // This code is executed when the migration is rolled back.
+ return nil
+}
+`))
+
+func migrationCreateTemplate(table string) *template.Template {
+ tmpl := `package migrations
+
+import (
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/schema"
+)
+
+func init() {
+ migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}})
+}
+
+func up{{.CamelName}}(c *schema.Context) error {
+ return schema.Create(c, "` + table + `", func(table *schema.Blueprint) {
+ // Define your table schema here
+ })
+}
+
+func down{{.CamelName}}(c *schema.Context) error {
+ return schema.DropIfExists(c, "` + table + `")
+}
+`
+ return template.Must(template.New("migration-create").Parse(tmpl))
+}
+
+func migrationUpdateTemplate(table string) *template.Template {
+ tmpl := `package migrations
+
+import (
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/schema"
+)
+
+func init() {
+ migris.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}})
+}
+
+func up{{.CamelName}}(c *schema.Context) error {
+ return schema.Table(c, "` + table + `", func(table *schema.Blueprint) {
+ // Define your table schema changes here
+ })
+}
+
+func down{{.CamelName}}(c *schema.Context) error {
+ return schema.Table(c, "` + table + `", func(table *schema.Blueprint) {
+ // Define your table schema changes here
+ })
+}
+`
+ return template.Must(template.New("migration-update").Parse(tmpl))
+}
diff --git a/dialect.go b/dialect.go
deleted file mode 100644
index a355a9d..0000000
--- a/dialect.go
+++ /dev/null
@@ -1,53 +0,0 @@
-package schema
-
-import (
- "fmt"
-)
-
-type dialect uint8
-
-const (
- dialectUnknown dialect = iota
- dialectMySQL
- dialectPostgres
-)
-
-func (d dialect) String() string {
- switch d {
- case dialectMySQL:
- return "mysql"
- case dialectPostgres:
- return "postgres"
- default:
- return ""
- }
-}
-
-var dialectValue dialect = dialectUnknown
-var debug = false
-
-// SetDialect sets the current dialect for the schema package.
-func SetDialect(d string) error {
- dialectValue = dialectFromString(d)
- if dialectValue == dialectUnknown {
- return fmt.Errorf("unknown dialect: %s", d)
- }
-
- return nil
-}
-
-func dialectFromString(d string) dialect {
- switch d {
- case "mysql", "mariadb":
- return dialectMySQL
- case "postgres", "pgx":
- return dialectPostgres
- default:
- return dialectUnknown
- }
-}
-
-// SetDebug enables or disables debug mode for the schema package.
-func SetDebug(d bool) {
- debug = d
-}
diff --git a/dialect_test.go b/dialect_test.go
deleted file mode 100644
index 5889903..0000000
--- a/dialect_test.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package schema
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
-)
-
-func TestSetDialect(t *testing.T) {
- testCases := []struct {
- dialect string
- expectError bool
- }{
- {"postgres", false},
- {"pgx", false},
- {"mysql", false},
- {"mariadb", false},
- {"sqlite", true},
- }
- for _, tc := range testCases {
- t.Run(tc.dialect, func(t *testing.T) {
- err := SetDialect(tc.dialect)
- if tc.expectError {
- assert.Error(t, err, "Expected error for unsupported dialect %s", tc.dialect)
- } else {
- assert.NoError(t, err, "Did not expect error for supported dialect %s", tc.dialect)
- }
- })
- }
-}
diff --git a/down.go b/down.go
new file mode 100644
index 0000000..9140a19
--- /dev/null
+++ b/down.go
@@ -0,0 +1,78 @@
+package migris
+
+import (
+ "context"
+ "errors"
+
+ "github.com/akfaiz/migris/internal/logger"
+ "github.com/pressly/goose/v3"
+)
+
+// Down rolls back the last migration.
+func (m *Migrate) Down() error {
+ ctx := context.Background()
+ return m.DownContext(ctx)
+}
+
+// DownContext rolls back the last migration.
+func (m *Migrate) DownContext(ctx context.Context) error {
+ provider, err := m.newProvider()
+ if err != nil {
+ return err
+ }
+ currentVersion, err := provider.GetDBVersion(ctx)
+ if err != nil {
+ return err
+ }
+ if currentVersion == 0 {
+ logger.Info("Nothing to rollback.")
+ return nil
+ }
+ logger.Info("Rolling back migrations.\n")
+ result, err := provider.Down(ctx)
+ if err != nil {
+ var partialErr *goose.PartialError
+ if errors.As(err, &partialErr) {
+ logger.PrintResult(partialErr.Failed)
+ }
+ return err
+ }
+ if result != nil {
+ logger.PrintResult(result)
+ }
+ return nil
+}
+
+// DownTo rolls back the migrations to the specified version.
+func (m *Migrate) DownTo(version int64) error {
+ ctx := context.Background()
+ return m.DownToContext(ctx, version)
+}
+
+// DownToContext rolls back the migrations to the specified version.
+func (m *Migrate) DownToContext(ctx context.Context, version int64) error {
+ provider, err := m.newProvider()
+ if err != nil {
+ return err
+ }
+ currentVersion, err := provider.GetDBVersion(ctx)
+ if err != nil {
+ return err
+ }
+ if currentVersion == 0 {
+ logger.Info("Nothing to rollback.")
+ return nil
+ }
+ logger.Info("Rolling back migrations.\n")
+ results, err := provider.DownTo(ctx, version)
+ if err != nil {
+ var partialErr *goose.PartialError
+ if errors.As(err, &partialErr) {
+ logger.PrintResults(partialErr.Applied)
+ logger.PrintResult(partialErr.Failed)
+ }
+ return err
+ }
+ logger.PrintResults(results)
+ return nil
+}
diff --git a/examples/basic/cmd/migrate/cmd.go b/examples/basic/cmd/migrate/cmd.go
new file mode 100644
index 0000000..9b39d89
--- /dev/null
+++ b/examples/basic/cmd/migrate/cmd.go
@@ -0,0 +1,60 @@
+package migrate
+
+import (
+ "context"
+
+ "github.com/urfave/cli/v3"
+)
+
+func Command() *cli.Command {
+ cmd := &cli.Command{
+ Name: "migrate",
+ Usage: "Migration tool",
+ Commands: []*cli.Command{
+ {
+ Name: "create",
+ Usage: "Create a new migration file",
+ Flags: []cli.Flag{
+ &cli.StringFlag{
+ Name: "name",
+ Aliases: []string{"n"},
+ Usage: "Name of the migration",
+ Required: true,
+ },
+ },
+ Action: func(ctx context.Context, c *cli.Command) error {
+ return Create(c.String("name"))
+ },
+ },
+ {
+ Name: "up",
+ Usage: "Run all pending migrations",
+ Action: func(ctx context.Context, c *cli.Command) error {
+ return Up()
+ },
+ },
+ {
+ Name: "reset",
+ Usage: "Reset all migrations",
+ Action: func(ctx context.Context, c *cli.Command) error {
+ return Reset()
+ },
+ },
+ {
+ Name: "down",
+ Usage: "Revert the last migration",
+ Action: func(ctx context.Context, c *cli.Command) error {
+ return Down()
+ },
+ },
+ {
+ Name: "status",
+ Usage: "Show the status of migrations",
+ Action: func(ctx context.Context, c *cli.Command) error {
+ return Status()
+ },
+ },
+ },
+ }
+ return cmd
+}
diff --git a/examples/basic/cmd/migrate/migrate.go b/examples/basic/cmd/migrate/migrate.go
index d55c33b..7981229 100644
--- a/examples/basic/cmd/migrate/migrate.go
+++ b/examples/basic/cmd/migrate/migrate.go
@@ -4,85 +4,73 @@ import (
"database/sql"
"fmt"
- "github.com/afkdevs/go-schema"
- "github.com/afkdevs/go-schema/examples/basic/config"
- _ "github.com/afkdevs/go-schema/examples/basic/migrations"
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/examples/basic/config"
+ _ "github.com/akfaiz/migris/examples/basic/migrations"
_ "github.com/lib/pq" // PostgreSQL driver
- "github.com/pressly/goose/v3"
)
func Up() error {
- m, err := newMigrator()
+ m, err := newMigrate()
if err != nil {
return err
}
- return goose.Up(m.db, m.dir)
+ return m.Up()
}
func Create(name string) error {
- m, err := newMigrator()
+ m, err := newMigrate()
if err != nil {
return err
}
- return goose.Create(m.db, m.dir, name, m.migrationType)
+ return m.Create(name)
+}
+
+func Reset() error {
+ m, err := newMigrate()
+ if err != nil {
+ return err
+ }
+ return m.Reset()
}
func Down() error {
- m, err := newMigrator()
+ m, err := newMigrate()
if err != nil {
return err
}
- return goose.Reset(m.db, m.dir)
+ return m.Down()
}
-type migrator struct {
- dir string
- dialect string
- tableName string
- migrationType string
- db *sql.DB
+func Status() error {
+ m, err := newMigrate()
+ if err != nil {
+ return err
+ }
+ return m.Status()
}
-func newMigrator() (*migrator, error) {
- db, err := newDatabase(config.GetDatabase())
+func newMigrate() (*migris.Migrate, error) {
+ cfg, err := config.Load()
if err != nil {
return nil, err
}
- m := &migrator{
- dir: "migrations",
- dialect: "postgres",
- tableName: "schema_migrations",
- migrationType: "go",
- db: db,
- }
- if err := m.init(); err != nil {
+ db, err := openDatabase(cfg.Database)
+ if err != nil {
return nil, err
}
-
- return m, nil
+ migrate, err := migris.New("postgres", migris.WithDB(db))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create migris instance: %w", err)
+ }
+ return migrate, nil
}
-func newDatabase(cfg config.Database) (*sql.DB, error) {
+func openDatabase(cfg config.Database) (*sql.DB, error) {
dsn := cfg.DSN()
db, err := sql.Open("postgres", dsn)
if err != nil {
- return nil, fmt.Errorf("failed to connect to database: %w", err)
- }
-
- if err := db.Ping(); err != nil {
- return nil, fmt.Errorf("failed to ping database: %w", err)
+ return nil, fmt.Errorf("failed to open database: %w", err)
}
-
return db, nil
}
-
-func (m *migrator) init() error {
- goose.SetTableName(m.tableName)
- if err := goose.SetDialect(m.dialect); err != nil {
- return err
- }
- if err := schema.SetDialect(m.dialect); err != nil {
- return err
- }
- return nil
-}
diff --git a/examples/basic/cmd/root.go b/examples/basic/cmd/root.go
new file mode 100644
index 0000000..ee336c3
--- /dev/null
+++ b/examples/basic/cmd/root.go
@@ -0,0 +1,20 @@
+package cmd
+
+import (
+ "context"
+
+ "github.com/akfaiz/migris/examples/basic/cmd/migrate"
+ "github.com/urfave/cli/v3"
+)
+
+var cmd = &cli.Command{
+ Name: "schema-example",
+ Usage: "A simple schema example",
+ Commands: []*cli.Command{
+ migrate.Command(),
+ },
+}
+
+func Execute(args []string) error {
+ return cmd.Run(context.Background(), args)
+}
diff --git a/examples/basic/config/config.go b/examples/basic/config/config.go
new file mode 100644
index 0000000..1b028be
--- /dev/null
+++ b/examples/basic/config/config.go
@@ -0,0 +1,20 @@
+package config
+
+import "github.com/joho/godotenv"
+
+type Config struct {
+ Database Database
+}
+
+func Load() (Config, error) {
+ var config Config
+ err := godotenv.Load()
+ if err != nil {
+ return config, err
+ }
+ config = Config{
+ Database: getDatabaseConfig(),
+ }
+
+ return config, nil
+}
diff --git a/examples/basic/config/database.go b/examples/basic/config/database.go
index 722d0ec..8065182 100644
--- a/examples/basic/config/database.go
+++ b/examples/basic/config/database.go
@@ -16,13 +16,13 @@ func (db Database) DSN() string {
db.User, db.Password, db.Host, db.Port, db.Database, db.SSLMode)
}
-func GetDatabase() Database {
+func getDatabaseConfig() Database {
return Database{
- Host: "localhost",
- Port: 5432,
- User: "root",
- Password: "password",
- Database: "schema_example",
- SSLMode: "disable",
+ Host: getEnv("DB_HOST"),
+ Port: getEnvInt("DB_PORT"),
+ User: getEnv("DB_USER"),
+ Password: getEnv("DB_PASSWORD"),
+ Database: getEnv("DB_NAME"),
+ SSLMode: getEnv("DB_SSLMODE"),
}
}
diff --git a/examples/basic/config/util.go b/examples/basic/config/util.go
new file mode 100644
index 0000000..d2b5c43
--- /dev/null
+++ b/examples/basic/config/util.go
@@ -0,0 +1,23 @@
+package config
+
+import (
+ "os"
+ "strconv"
+)
+
+func getEnv(key string) string {
+ value := os.Getenv(key)
+ return value
+}
+
+func getEnvInt(key string) int {
+ value := os.Getenv(key)
+ if value == "" {
+ return 0
+ }
+ intValue, err := strconv.Atoi(value)
+ if err != nil {
+ return 0
+ }
+ return intValue
+}
diff --git a/examples/basic/go.mod b/examples/basic/go.mod
index a4c82dd..6d2324f 100644
--- a/examples/basic/go.mod
+++ b/examples/basic/go.mod
@@ -1,19 +1,25 @@
-module github.com/afkdevs/go-schema/examples/basic
+module github.com/akfaiz/migris/examples/basic
go 1.23.0
require (
- github.com/afkdevs/go-schema v0.0.0-000000000000-00010101000000
+ github.com/akfaiz/migris v0.0.0
+ github.com/joho/godotenv v1.5.1
github.com/lib/pq v1.10.9
- github.com/pressly/goose/v3 v3.24.3
github.com/urfave/cli/v3 v3.3.8
)
-replace github.com/afkdevs/go-schema => ../..
-
require (
+ github.com/fatih/color v1.18.0 // indirect
+ github.com/mattn/go-colorable v0.1.13 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mfridman/interpolate v0.0.2 // indirect
+ github.com/pressly/goose/v3 v3.25.0 // indirect
github.com/sethvargo/go-retry v0.3.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
- golang.org/x/sync v0.14.0 // indirect
+ golang.org/x/sync v0.16.0 // indirect
+ golang.org/x/sys v0.35.0 // indirect
+ golang.org/x/term v0.34.0 // indirect
)
+
+replace github.com/akfaiz/migris => ../..
diff --git a/examples/basic/go.sum b/examples/basic/go.sum
index 69cac7b..1b92f73 100644
--- a/examples/basic/go.sum
+++ b/examples/basic/go.sum
@@ -4,12 +4,19 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
+github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
+github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
+github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
@@ -18,31 +25,35 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/pressly/goose/v3 v3.24.3 h1:DSWWNwwggVUsYZ0X2VitiAa9sKuqtBfe+Jr9zFGwWlM=
-github.com/pressly/goose/v3 v3.24.3/go.mod h1:v9zYL4xdViLHCUUJh/mhjnm6JrK7Eul8AS93IxiZM4E=
+github.com/pressly/goose/v3 v3.25.0 h1:6WeYhMWGRCzpyd89SpODFnCBCKz41KrVbRT58nVjGng=
+github.com/pressly/goose/v3 v3.25.0/go.mod h1:4hC1KrritdCxtuFsqgs1R4AU5bWtTAf+cnWvfhf2DNY=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
-github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
-github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/urfave/cli/v3 v3.3.8 h1:BzolUExliMdet9NlJ/u4m5vHSotJ3PzEqSAZ1oPMa/E=
github.com/urfave/cli/v3 v3.3.8/go.mod h1:FJSKtM/9AiiTOJL4fJ6TbMUkxBXn7GO9guZqoZtpYpo=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
-golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
-golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
-golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
-golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
-golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
-golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
+golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
+golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
+golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
+golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
+golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-modernc.org/libc v1.65.0 h1:e183gLDnAp9VJh6gWKdTy0CThL9Pt7MfcR/0bgb7Y1Y=
-modernc.org/libc v1.65.0/go.mod h1:7m9VzGq7APssBTydds2zBcxGREwvIGpuUBaKTXdm2Qs=
+modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
+modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
-modernc.org/memory v1.10.0 h1:fzumd51yQ1DxcOxSO+S6X7+QTuVU+n8/Aj7swYjFfC4=
-modernc.org/memory v1.10.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
-modernc.org/sqlite v1.37.0 h1:s1TMe7T3Q3ovQiK2Ouz4Jwh7dw4ZDqbebSDTlSJdfjI=
-modernc.org/sqlite v1.37.0/go.mod h1:5YiWv+YviqGMuGw4V+PNplcyaJ5v+vQd7TQOgkACoJM=
+modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
+modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
+modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek=
+modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E=
diff --git a/examples/basic/main.go b/examples/basic/main.go
index 6a5590b..050e9ea 100644
--- a/examples/basic/main.go
+++ b/examples/basic/main.go
@@ -1,60 +1,14 @@
package main
import (
- "context"
"log"
"os"
- "github.com/afkdevs/go-schema/examples/basic/cmd/migrate"
- "github.com/urfave/cli/v3"
+ "github.com/akfaiz/migris/examples/basic/cmd"
)
func main() {
- app := &cli.Command{
- Name: "schema-example",
- Usage: "A simple schema example",
- Commands: []*cli.Command{
- {
- Name: "migrate",
- Aliases: []string{"m"},
- Usage: "Run database migrations",
- Commands: []*cli.Command{
- {
- Name: "create",
- Usage: "Create a new migration file",
- Flags: []cli.Flag{
- &cli.StringFlag{
- Name: "name",
- Aliases: []string{"n"},
- Usage: "Name of the migration",
- Required: true,
- },
- },
- Action: func(ctx context.Context, c *cli.Command) error {
- return migrate.Create(c.String("name"))
- },
- },
- {
- Name: "up",
- Aliases: []string{"u"},
- Usage: "Run all pending migrations",
- Action: func(ctx context.Context, c *cli.Command) error {
- return migrate.Up()
- },
- },
- {
- Name: "down",
- Aliases: []string{"d"},
- Usage: "Down all migrations",
- Action: func(ctx context.Context, c *cli.Command) error {
- return migrate.Down()
- },
- },
- },
- },
- },
- }
- if err := app.Run(context.Background(), os.Args); err != nil {
+ if err := cmd.Execute(os.Args); err != nil {
log.Printf("Error running app: %v\n", err)
os.Exit(1)
}
diff --git a/examples/basic/migrations/20250626235117_create_users_table.go b/examples/basic/migrations/20250626235117_create_users_table.go
deleted file mode 100644
index 23d318e..0000000
--- a/examples/basic/migrations/20250626235117_create_users_table.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package migrations
-
-import (
- "context"
- "database/sql"
-
- "github.com/afkdevs/go-schema"
- "github.com/pressly/goose/v3"
-)
-
-func init() {
- goose.AddMigrationContext(upCreateUsersTable, downCreateUsersTable)
-}
-
-func upCreateUsersTable(ctx context.Context, tx *sql.Tx) error {
- return schema.Create(ctx, tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name")
- table.String("email")
- table.Timestamp("email_verified_at").Nullable()
- table.String("password")
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
- })
-}
-
-func downCreateUsersTable(ctx context.Context, tx *sql.Tx) error {
- return schema.DropIfExists(ctx, tx, "users")
-}
diff --git a/examples/basic/migrations/20250628092119_create_roles_table.go b/examples/basic/migrations/20250628092119_create_roles_table.go
deleted file mode 100644
index d857cc5..0000000
--- a/examples/basic/migrations/20250628092119_create_roles_table.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package migrations
-
-import (
- "context"
- "database/sql"
-
- "github.com/afkdevs/go-schema"
- "github.com/pressly/goose/v3"
-)
-
-func init() {
- goose.AddMigrationContext(upCreateRolesTable, downCreateRolesTable)
-}
-
-func upCreateRolesTable(ctx context.Context, tx *sql.Tx) error {
- return schema.Create(ctx, tx, "roles", func(table *schema.Blueprint) {
- table.Increments("id").Primary()
- table.String("name").Unique().Nullable(false)
- })
-}
-
-func downCreateRolesTable(ctx context.Context, tx *sql.Tx) error {
- return schema.DropIfExists(ctx, tx, "roles")
-}
diff --git a/examples/basic/migrations/20250628092223_create_user_roles_table.go b/examples/basic/migrations/20250628092223_create_user_roles_table.go
deleted file mode 100644
index 713023b..0000000
--- a/examples/basic/migrations/20250628092223_create_user_roles_table.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package migrations
-
-import (
- "context"
- "database/sql"
-
- "github.com/afkdevs/go-schema"
- "github.com/pressly/goose/v3"
-)
-
-func init() {
- goose.AddMigrationContext(upCreateUserRolesTable, downCreateUserRolesTable)
-}
-
-func upCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error {
- return schema.Create(ctx, tx, "user_roles", func(table *schema.Blueprint) {
- table.BigInteger("user_id")
- table.Integer("role_id")
- table.Primary("user_id", "role_id")
- table.Foreign("user_id").References("id").On("users")
- table.Foreign("role_id").References("id").On("roles")
- })
-}
-
-func downCreateUserRolesTable(ctx context.Context, tx *sql.Tx) error {
- return schema.DropIfExists(ctx, tx, "user_roles")
-}
diff --git a/examples/basic/migrations/20250830103612_create_users_table.go b/examples/basic/migrations/20250830103612_create_users_table.go
new file mode 100644
index 0000000..158e385
--- /dev/null
+++ b/examples/basic/migrations/20250830103612_create_users_table.go
@@ -0,0 +1,25 @@
+package migrations
+
+import (
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/schema"
+)
+
+func init() {
+ migris.AddMigrationContext(upCreateUsersTable, downCreateUsersTable)
+}
+
+func upCreateUsersTable(c *schema.Context) error {
+ return schema.Create(c, "users", func(table *schema.Blueprint) {
+ table.ID()
+ table.String("name")
+ table.String("email")
+ table.Timestamp("email_verified_at").Nullable()
+ table.String("password")
+ table.Timestamps()
+ })
+}
+
+func downCreateUsersTable(c *schema.Context) error {
+ return schema.DropIfExists(c, "users")
+}
diff --git a/examples/basic/migrations/20250830103653_create_roles_table.go b/examples/basic/migrations/20250830103653_create_roles_table.go
new file mode 100644
index 0000000..1e961ce
--- /dev/null
+++ b/examples/basic/migrations/20250830103653_create_roles_table.go
@@ -0,0 +1,21 @@
+package migrations
+
+import (
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/schema"
+)
+
+func init() {
+ migris.AddMigrationContext(upCreateRolesTable, downCreateRolesTable)
+}
+
+func upCreateRolesTable(c *schema.Context) error {
+ return schema.Create(c, "roles", func(table *schema.Blueprint) {
+ table.Increments("id").Primary()
+ table.String("name").Unique().Nullable(false)
+ })
+}
+
+func downCreateRolesTable(c *schema.Context) error {
+ return schema.DropIfExists(c, "roles")
+}
diff --git a/examples/basic/migrations/20250830103714_create_user_roles_table.go b/examples/basic/migrations/20250830103714_create_user_roles_table.go
new file mode 100644
index 0000000..764a76b
--- /dev/null
+++ b/examples/basic/migrations/20250830103714_create_user_roles_table.go
@@ -0,0 +1,24 @@
+package migrations
+
+import (
+ "github.com/akfaiz/migris"
+ "github.com/akfaiz/migris/schema"
+)
+
+func init() {
+ migris.AddMigrationContext(upCreateUserRolesTable, downCreateUserRolesTable)
+}
+
+func upCreateUserRolesTable(c *schema.Context) error {
+ return schema.Create(c, "user_roles", func(table *schema.Blueprint) {
+ table.BigInteger("user_id")
+ table.Integer("role_id")
+ table.Primary("user_id", "role_id")
+ table.Foreign("user_id").References("id").On("users")
+ table.Foreign("role_id").References("id").On("roles")
+ })
+}
+
+func downCreateUserRolesTable(c *schema.Context) error {
+ return schema.DropIfExists(c, "user_roles")
+}
diff --git a/foreign_key_definition.go b/foreign_key_definition.go
deleted file mode 100644
index 4c6828b..0000000
--- a/foreign_key_definition.go
+++ /dev/null
@@ -1,127 +0,0 @@
-package schema
-
-// ForeignKeyDefinition defines the interface for defining a foreign key constraint in a database table.
-type ForeignKeyDefinition interface {
- // CascadeOnDelete sets the foreign key to cascade on delete.
- CascadeOnDelete() ForeignKeyDefinition
- // CascadeOnUpdate sets the foreign key to cascade on update.
- CascadeOnUpdate() ForeignKeyDefinition
- // Deferrable sets the foreign key as deferrable.
- Deferrable(value ...bool) ForeignKeyDefinition
- // InitiallyImmediate sets the foreign key to be initially immediate.
- InitiallyImmediate(value ...bool) ForeignKeyDefinition
- // Name sets the name of the foreign key constraint.
- // This is optional and can be used to give a specific name to the foreign key.
- Name(name string) ForeignKeyDefinition
- // NoActionOnDelete sets the foreign key to do nothing on delete.
- NoActionOnDelete() ForeignKeyDefinition
- // NoActionOnUpdate sets the foreign key to do nothing on update.
- NoActionOnUpdate() ForeignKeyDefinition
- // NullOnDelete sets the foreign key to set the column to NULL on delete.
- NullOnDelete() ForeignKeyDefinition
- // NullOnUpdate sets the foreign key to set the column to NULL on update.
- NullOnUpdate() ForeignKeyDefinition
- // On sets the table that this foreign key references.
- On(table string) ForeignKeyDefinition
- // OnDelete sets the action to take when the referenced row is deleted.
- OnDelete(action string) ForeignKeyDefinition
- // OnUpdate sets the action to take when the referenced row is updated.
- OnUpdate(action string) ForeignKeyDefinition
- // References sets the column that this foreign key references in the other table.
- References(column string) ForeignKeyDefinition
- // RestrictOnDelete sets the foreign key to restrict deletion of the referenced row.
- RestrictOnDelete() ForeignKeyDefinition
- // RestrictOnUpdate sets the foreign key to restrict updating of the referenced row.
- RestrictOnUpdate() ForeignKeyDefinition
-}
-
-var _ ForeignKeyDefinition = &foreignKeyDefinition{}
-
-type foreignKeyDefinition struct {
- tableName string
- column string
- constaintName string // name of the foreign key constraint
- references string
- on string
- onDelete string
- onUpdate string
- deferrable *bool
- initiallyImmediate *bool
-}
-
-func (fk *foreignKeyDefinition) CascadeOnDelete() ForeignKeyDefinition {
- fk.onDelete = "CASCADE"
- return fk
-}
-
-func (fk *foreignKeyDefinition) CascadeOnUpdate() ForeignKeyDefinition {
- fk.onUpdate = "CASCADE"
- return fk
-}
-
-func (fk *foreignKeyDefinition) Deferrable(value ...bool) ForeignKeyDefinition {
- val := optional(true, value...)
- fk.deferrable = &val
- return fk
-}
-
-func (fk *foreignKeyDefinition) InitiallyImmediate(value ...bool) ForeignKeyDefinition {
- val := optional(true, value...)
- fk.initiallyImmediate = &val
- return fk
-}
-
-func (fk *foreignKeyDefinition) Name(name string) ForeignKeyDefinition {
- fk.constaintName = name
- return fk
-}
-
-func (fk *foreignKeyDefinition) NoActionOnDelete() ForeignKeyDefinition {
- fk.onDelete = "NO ACTION"
- return fk
-}
-
-func (fk *foreignKeyDefinition) NoActionOnUpdate() ForeignKeyDefinition {
- fk.onUpdate = "NO ACTION"
- return fk
-}
-
-func (fk *foreignKeyDefinition) NullOnDelete() ForeignKeyDefinition {
- fk.onDelete = "SET NULL"
- return fk
-}
-
-func (fk *foreignKeyDefinition) NullOnUpdate() ForeignKeyDefinition {
- fk.onUpdate = "SET NULL"
- return fk
-}
-
-func (fk *foreignKeyDefinition) On(table string) ForeignKeyDefinition {
- fk.on = table
- return fk
-}
-
-func (fk *foreignKeyDefinition) OnDelete(action string) ForeignKeyDefinition {
- fk.onDelete = action
- return fk
-}
-
-func (fk *foreignKeyDefinition) OnUpdate(action string) ForeignKeyDefinition {
- fk.onUpdate = action
- return fk
-}
-
-func (fk *foreignKeyDefinition) References(column string) ForeignKeyDefinition {
- fk.references = column
- return fk
-}
-
-func (fk *foreignKeyDefinition) RestrictOnDelete() ForeignKeyDefinition {
- fk.onDelete = "RESTRICT"
- return fk
-}
-
-func (fk *foreignKeyDefinition) RestrictOnUpdate() ForeignKeyDefinition {
- fk.onUpdate = "RESTRICT"
- return fk
-}
diff --git a/go.mod b/go.mod
index 1f0d73d..241fe24 100644
--- a/go.mod
+++ b/go.mod
@@ -1,19 +1,26 @@
-module github.com/afkdevs/go-schema
+module github.com/akfaiz/migris
-go 1.23
+go 1.23.0
require (
+ github.com/fatih/color v1.18.0
github.com/go-sql-driver/mysql v1.9.3
github.com/lib/pq v1.10.9
- github.com/stretchr/testify v1.10.0
+ github.com/pressly/goose/v3 v3.25.0
+ github.com/stretchr/testify v1.11.1
+ golang.org/x/term v0.34.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
- github.com/kr/pretty v0.3.1 // indirect
+ github.com/mattn/go-colorable v0.1.13 // indirect
+ github.com/mattn/go-isatty v0.0.20 // indirect
+ github.com/mfridman/interpolate v0.0.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
- github.com/rogpeppe/go-internal v1.13.1 // indirect
- gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
+ github.com/sethvargo/go-retry v0.3.0 // indirect
+ go.uber.org/multierr v1.11.0 // indirect
+ golang.org/x/sync v0.16.0 // indirect
+ golang.org/x/sys v0.35.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index d87df1b..e53afdf 100644
--- a/go.sum
+++ b/go.sum
@@ -1,29 +1,57 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
-github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
+github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
+github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
+github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
-github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
-github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
-github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
-github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
-github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
+github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
+github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
+github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
+github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
+github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mfridman/interpolate v0.0.2 h1:pnuTK7MQIxxFz1Gr+rjSIx9u7qVjf5VOoM/u6BbAxPY=
+github.com/mfridman/interpolate v0.0.2/go.mod h1:p+7uk6oE07mpE/Ik1b8EckO0O4ZXiGAfshKBWLUM9Xg=
+github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
+github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
-github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
-github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
-github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
-github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/pressly/goose/v3 v3.25.0 h1:6WeYhMWGRCzpyd89SpODFnCBCKz41KrVbRT58nVjGng=
+github.com/pressly/goose/v3 v3.25.0/go.mod h1:4hC1KrritdCxtuFsqgs1R4AU5bWtTAf+cnWvfhf2DNY=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/sethvargo/go-retry v0.3.0 h1:EEt31A35QhrcRZtrYFDTBg91cqZVnFL2navjDrah2SE=
+github.com/sethvargo/go-retry v0.3.0/go.mod h1:mNX17F0C/HguQMyMyJxcnU471gOZGxCLyYaFyAZraas=
+github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
+github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
+go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
+go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
+golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
+golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
+golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
+golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
+golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
+golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
+modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
+modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
+modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
+modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
+modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
+modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek=
+modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E=
diff --git a/grammar.go b/grammar.go
deleted file mode 100644
index ccdce5e..0000000
--- a/grammar.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package schema
-
-import (
- "fmt"
- "slices"
- "strings"
-)
-
-type grammar interface {
- compileCreate(bp *Blueprint) (string, error)
- compileCreateIfNotExists(bp *Blueprint) (string, error)
- compileAdd(bp *Blueprint) (string, error)
- compileChange(bp *Blueprint) ([]string, error)
- compileDrop(bp *Blueprint) (string, error)
- compileDropIfExists(bp *Blueprint) (string, error)
- compileRename(bp *Blueprint) (string, error)
- compileDropColumn(blueprint *Blueprint) (string, error)
- compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error)
- compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error)
- compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error)
- compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error)
- compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error)
- compileDropIndex(blueprint *Blueprint, indexName string) (string, error)
- compileDropUnique(blueprint *Blueprint, indexName string) (string, error)
- compileDropFulltext(blueprint *Blueprint, indexName string) (string, error)
- compileDropPrimary(blueprint *Blueprint, indexName string) (string, error)
- compileRenameIndex(blueprint *Blueprint, oldName, newName string) (string, error)
- compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error)
- compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error)
-}
-
-type baseGrammar struct{}
-
-func (g *baseGrammar) compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) {
- if foreignKey.column == "" || foreignKey.on == "" || foreignKey.references == "" {
- return "", fmt.Errorf("foreign key definition is incomplete: column, on, and references must be set")
- }
- onDelete := ""
- if foreignKey.onDelete != "" {
- onDelete = fmt.Sprintf(" ON DELETE %s", foreignKey.onDelete)
- }
- onUpdate := ""
- if foreignKey.onUpdate != "" {
- onUpdate = fmt.Sprintf(" ON UPDATE %s", foreignKey.onUpdate)
- }
- containtName := foreignKey.constaintName
- if containtName == "" {
- containtName = g.createForeignKeyName(blueprint, foreignKey)
- }
-
- return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)%s%s",
- blueprint.name,
- containtName,
- foreignKey.column,
- foreignKey.on,
- foreignKey.references,
- onDelete,
- onUpdate,
- ), nil
-}
-
-func (g *baseGrammar) quoteString(s string) string {
- return "'" + s + "'"
-}
-
-func (g *baseGrammar) prefixArray(prefix string, items []string) []string {
- prefixed := make([]string, len(items))
- for i, item := range items {
- prefixed[i] = fmt.Sprintf("%s%s", prefix, item)
- }
- return prefixed
-}
-
-func (g *baseGrammar) columnize(columns []string) string {
- if len(columns) == 0 {
- return ""
- }
- return strings.Join(columns, ", ")
-}
-
-func (g *baseGrammar) getDefaultValue(col *columnDefinition) string {
- if col.defaultValue == nil {
- return "NULL"
- }
- useQuote := slices.Contains([]columnType{columnTypeString, columnTypeChar, columnTypeEnum}, col.columnType)
-
- switch v := col.defaultValue.(type) {
- case string:
- if useQuote {
- return g.quoteString(v)
- }
- return v
- case int, int64, float64:
- return fmt.Sprintf("%v", v)
- case bool:
- if v {
- return "true"
- }
- return "false"
- default:
- return fmt.Sprintf("'%v'", v) // Fallback for other types
- }
-}
-
-func (g *baseGrammar) createIndexName(blueprint *Blueprint, index *indexDefinition) string {
- tableName := blueprint.name
- if strings.Contains(tableName, ".") {
- parts := strings.Split(tableName, ".")
- tableName = parts[len(parts)-1] // Use the last part as the table name
- }
-
- switch index.indexType {
- case indexTypePrimary:
- return fmt.Sprintf("pk_%s", tableName)
- case indexTypeUnique:
- return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(index.columns, "_"))
- case indexTypeIndex:
- return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(index.columns, "_"))
- case indexTypeFullText:
- return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(index.columns, "_"))
- default:
- return ""
- }
-}
-
-func (g *baseGrammar) createForeignKeyName(blueprint *Blueprint, foreignKey *foreignKeyDefinition) string {
- tableName := blueprint.name
- if strings.Contains(tableName, ".") {
- parts := strings.Split(tableName, ".")
- tableName = parts[len(parts)-1] // Use the last part as the table name
- }
- on := foreignKey.on
- if strings.Contains(on, ".") {
- parts := strings.Split(on, ".")
- on = parts[len(parts)-1] // Use the last part as the column name
- }
- return fmt.Sprintf("fk_%s_%s", tableName, on)
-}
diff --git a/internal/config/config.go b/internal/config/config.go
new file mode 100644
index 0000000..856cfd6
--- /dev/null
+++ b/internal/config/config.go
@@ -0,0 +1,29 @@
+package config
+
+import (
+ "sync/atomic"
+
+ "github.com/akfaiz/migris/internal/dialect"
+)
+
+type Config struct {
+ Dialect dialect.Dialect
+}
+
+var config = atomic.Pointer[Config]{}
+
+func init() {
+ config.Store(&Config{
+ Dialect: dialect.Unknown,
+ })
+}
+
+func SetDialect(dialect dialect.Dialect) {
+ cfg := config.Load()
+ cfg.Dialect = dialect
+ config.Store(cfg)
+}
+
+func GetDialect() dialect.Dialect {
+ return config.Load().Dialect
+}
diff --git a/internal/dialect/dialect.go b/internal/dialect/dialect.go
new file mode 100644
index 0000000..c9f2fb4
--- /dev/null
+++ b/internal/dialect/dialect.go
@@ -0,0 +1,38 @@
+package dialect
+
+import "github.com/pressly/goose/v3/database"
+
+// Dialect is the type of database dialect.
+type Dialect string
+
+const (
+ MySQL Dialect = "mysql"
+ Postgres Dialect = "postgres"
+ Unknown Dialect = ""
+)
+
+func (d Dialect) String() string {
+ return string(d)
+}
+
+func (d Dialect) GooseDialect() database.Dialect {
+ switch d {
+ case MySQL:
+ return database.DialectMySQL
+ case Postgres:
+ return database.DialectPostgres
+ default:
+ return database.DialectCustom
+ }
+}
+
+func FromString(dialect string) Dialect {
+ switch dialect {
+ case "mysql", "mariadb":
+ return MySQL
+ case "postgres", "pgx":
+ return Postgres
+ default:
+ return Unknown
+ }
+}
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
new file mode 100644
index 0000000..232e71e
--- /dev/null
+++ b/internal/logger/logger.go
@@ -0,0 +1,101 @@
+package logger
+
+import (
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/fatih/color"
+ "github.com/pressly/goose/v3"
+ "golang.org/x/term"
+)
+
+var (
+ grey = color.New(color.FgHiBlack).SprintFunc()
+ greenBold = color.New(color.FgGreen, color.Bold).SprintFunc()
+ yellowBold = color.New(color.FgYellow, color.Bold).SprintFunc()
+ redBold = color.New(color.FgRed, color.Bold).SprintFunc()
+)
+
+func Info(msg string) {
+ Infof("%s", msg)
+}
+
+func Infof(format string, args ...any) {
+ msg := fmt.Sprintf(format, args...)
+
+ whiteBgBlue := color.New(color.FgWhite, color.BgBlue).SprintFunc()
+ fmt.Printf("%s %s\n", whiteBgBlue(" INFO "), msg)
+}
+
+func PrintResults(results []*goose.MigrationResult) {
+ for _, result := range results {
+ PrintResult(result)
+ }
+}
+
+func PrintResult(result *goose.MigrationResult) {
+ // Get terminal width
+ width, _, err := term.GetSize(int(os.Stdout.Fd()))
+ if err != nil || width <= 0 {
+ width = 80 // fallback
+ }
+
+ durText := fmt.Sprintf(" %.2fms", result.Duration.Seconds()*1000)
+ statusText := " DONE"
+ if result.Error != nil {
+ statusText = " FAIL"
+ }
+
+ name := result.Source.Path
+
+ // Calculate how many dots to add
+ fillLen := width - len(name) - len(durText) - len(statusText) - 1
+ if fillLen < 0 {
+ fillLen = 0
+ }
+ dots := strings.Repeat(".", fillLen)
+
+ // Print
+ fmt.Printf("%s %s%s", name, grey(dots), grey(durText))
+ if result.Error != nil {
+ fmt.Printf("%s\n", redBold(statusText))
+ } else {
+ fmt.Printf("%s\n", greenBold(statusText))
+ }
+}
+
+func PrintStatuses(statuses []*goose.MigrationStatus) {
+ for _, status := range statuses {
+ PrintStatus(status)
+ }
+}
+
+func PrintStatus(status *goose.MigrationStatus) {
+ width, _, err := term.GetSize(int(os.Stdout.Fd()))
+ if err != nil || width <= 0 {
+ width = 80 // fallback
+ }
+
+ var statusText string
+ if status.State == goose.StateApplied {
+ statusText = " Applied"
+ } else {
+ statusText = " Pending"
+ }
+
+ name := status.Source.Path
+
+ fillLen := width - len(name) - len(statusText) - 1
+ if fillLen < 0 {
+ fillLen = 0
+ }
+ dots := strings.Repeat(".", fillLen)
+
+ fmt.Printf("%s %s", name, grey(dots))
+ if status.State == goose.StateApplied {
+ fmt.Printf("%s\n", greenBold(statusText))
+ } else {
+ fmt.Printf("%s\n", yellowBold(statusText))
+ }
+}
diff --git a/internal/parser/migration.go b/internal/parser/migration.go
new file mode 100644
index 0000000..5e89a25
--- /dev/null
+++ b/internal/parser/migration.go
@@ -0,0 +1,21 @@
+package parser
+
+import "regexp"
+
+func ParseMigrationName(filename string) (tableName string, create bool) {
+ // Regex patterns for common migration styles
+ createPattern := regexp.MustCompile(`^create_(?P
[a-z0-9_]+?)(?:_table)?$`)
+ addColPattern := regexp.MustCompile(`^add_(?P[a-z0-9_]+?)_to_(?P[a-z0-9_]+?)(?:_table)?$`)
+ removeColPattern := regexp.MustCompile(`^remove_(?P[a-z0-9_]+?)_from_(?P[a-z0-9_]+?)(?:_table)?$`)
+
+ switch {
+ case createPattern.MatchString(filename):
+ return createPattern.ReplaceAllString(filename, "${table}"), true
+ case addColPattern.MatchString(filename):
+ return addColPattern.ReplaceAllString(filename, "${table}"), false
+ case removeColPattern.MatchString(filename):
+ return removeColPattern.ReplaceAllString(filename, "${table}"), false
+ default:
+ return "", false // Unknown migration style
+ }
+}
diff --git a/internal/parser/migration_test.go b/internal/parser/migration_test.go
new file mode 100644
index 0000000..a8c2de5
--- /dev/null
+++ b/internal/parser/migration_test.go
@@ -0,0 +1,132 @@
+package parser_test
+
+import (
+ "testing"
+
+ "github.com/akfaiz/migris/internal/parser"
+)
+
+func TestParseMigrationName(t *testing.T) {
+ tests := []struct {
+ name string
+ filename string
+ wantTable string
+ wantCreate bool
+ }{
+ // Create table patterns
+ {
+ name: "create table with table suffix",
+ filename: "create_users_table",
+ wantTable: "users",
+ wantCreate: true,
+ },
+ {
+ name: "create table without table suffix",
+ filename: "create_posts",
+ wantTable: "posts",
+ wantCreate: true,
+ },
+ {
+ name: "create table with underscores",
+ filename: "create_user_profiles_table",
+ wantTable: "user_profiles",
+ wantCreate: true,
+ },
+ {
+ name: "create table with numbers",
+ filename: "create_table_v2",
+ wantTable: "table_v2",
+ wantCreate: true,
+ },
+
+ // Add column patterns
+ {
+ name: "add column with table suffix",
+ filename: "add_email_to_users_table",
+ wantTable: "users",
+ wantCreate: false,
+ },
+ {
+ name: "add column without table suffix",
+ filename: "add_name_to_posts",
+ wantTable: "posts",
+ wantCreate: false,
+ },
+ {
+ name: "add multiple columns",
+ filename: "add_first_name_last_name_to_users_table",
+ wantTable: "users",
+ wantCreate: false,
+ },
+ {
+ name: "add column with underscores",
+ filename: "add_created_at_to_user_profiles",
+ wantTable: "user_profiles",
+ wantCreate: false,
+ },
+
+ // Remove column patterns
+ {
+ name: "remove column with table suffix",
+ filename: "remove_email_from_users_table",
+ wantTable: "users",
+ wantCreate: false,
+ },
+ {
+ name: "remove column without table suffix",
+ filename: "remove_name_from_posts",
+ wantTable: "posts",
+ wantCreate: false,
+ },
+ {
+ name: "remove multiple columns",
+ filename: "remove_old_field_from_users_table",
+ wantTable: "users",
+ wantCreate: false,
+ },
+ {
+ name: "remove column with underscores",
+ filename: "remove_updated_at_from_user_profiles",
+ wantTable: "user_profiles",
+ wantCreate: false,
+ },
+
+ // Unknown patterns
+ {
+ name: "unknown migration pattern",
+ filename: "update_users_table",
+ wantTable: "",
+ wantCreate: false,
+ },
+ {
+ name: "empty filename",
+ filename: "",
+ wantTable: "",
+ wantCreate: false,
+ },
+ {
+ name: "invalid format",
+ filename: "invalid_migration_name",
+ wantTable: "",
+ wantCreate: false,
+ },
+ {
+ name: "uppercase letters",
+ filename: "create_Users_table",
+ wantTable: "",
+ wantCreate: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotTable, gotCreate := parser.ParseMigrationName(tt.filename)
+ if gotTable != tt.wantTable {
+ t.Errorf("ParseMigrationName() gotTable = %v, want %v", gotTable, tt.wantTable)
+ }
+ if gotCreate != tt.wantCreate {
+ t.Errorf("ParseMigrationName() gotCreate = %v, want %v", gotCreate, tt.wantCreate)
+ }
+ })
+ }
+}
diff --git a/internal/util/util.go b/internal/util/util.go
new file mode 100644
index 0000000..61a4ca3
--- /dev/null
+++ b/internal/util/util.go
@@ -0,0 +1,33 @@
+package util
+
+func Optional[T any](defaultValue T, values ...T) T {
+ if len(values) > 0 {
+ return values[0]
+ }
+ return defaultValue
+}
+
+func OptionalPtr[T any](defaultValue T, values ...T) *T {
+ if len(values) > 0 {
+ return &values[0]
+ }
+ return &defaultValue
+}
+
+func OptionalNil[T any](values ...T) *T {
+ if len(values) > 0 {
+ return &values[0]
+ }
+ return nil
+}
+
+func PtrOf[T any](value T) *T {
+ return &value
+}
+
+func Ternary[T any](condition bool, trueValue, falseValue T) T {
+ if condition {
+ return trueValue
+ }
+ return falseValue
+}
diff --git a/migrate.go b/migrate.go
new file mode 100644
index 0000000..581564c
--- /dev/null
+++ b/migrate.go
@@ -0,0 +1,63 @@
+package migris
+
+import (
+ "database/sql"
+ "errors"
+ "os"
+
+ "github.com/akfaiz/migris/internal/config"
+ "github.com/akfaiz/migris/internal/dialect"
+ "github.com/pressly/goose/v3"
+ "github.com/pressly/goose/v3/database"
+)
+
+// Migrate handles database migrations
+type Migrate struct {
+ dialect dialect.Dialect
+ db *sql.DB
+ migrationDir string
+ tableName string
+}
+
+// New creates a new Migrate instance
+func New(dialectValue string, opts ...Option) (*Migrate, error) {
+ dialectVal := dialect.FromString(dialectValue)
+ if dialectVal == dialect.Unknown {
+ return nil, errors.New("unknown database dialect")
+ }
+ config.SetDialect(dialectVal)
+
+ m := &Migrate{
+ dialect: dialectVal,
+ migrationDir: "migrations",
+ tableName: "migris_db_version",
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ if m.db == nil {
+ return nil, errors.New("database connection is not set, please call WithDB option")
+ }
+ return m, nil
+}
+
+func (m *Migrate) newProvider() (*goose.Provider, error) {
+ val := config.GetDialect()
+ if val == dialect.Unknown {
+ return nil, errors.New("unknown database dialect")
+ }
+ gooseDialect := val.GooseDialect()
+ store, err := database.NewStore(gooseDialect, m.tableName)
+ if err != nil {
+ return nil, err
+ }
+ provider, err := goose.NewProvider(database.DialectCustom, m.db, os.DirFS(m.migrationDir),
+ goose.WithStore(store),
+ goose.WithDisableGlobalRegistry(true),
+ goose.WithGoMigrations(gooseMigrations()...),
+ )
+ if err != nil {
+ return nil, err
+ }
+ return provider, nil
+}
diff --git a/mysql_grammar.go b/mysql_grammar.go
deleted file mode 100644
index 7437bca..0000000
--- a/mysql_grammar.go
+++ /dev/null
@@ -1,441 +0,0 @@
-package schema
-
-import (
- "fmt"
- "slices"
- "strings"
-)
-
-type mysqlGrammar struct {
- baseGrammar
-}
-
-var _ grammar = (*mysqlGrammar)(nil)
-
-func newMysqlGrammar() *mysqlGrammar {
- return &mysqlGrammar{}
-}
-
-func (g *mysqlGrammar) compileCurrentDatabase() string {
- return "SELECT DATABASE()"
-}
-
-func (g *mysqlGrammar) compileTableExists(database string, table string) (string, error) {
- return fmt.Sprintf(
- "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'",
- g.quoteString(database),
- g.quoteString(table),
- ), nil
-}
-
-func (g *mysqlGrammar) compileTables(database string) (string, error) {
- return fmt.Sprintf(
- "select table_name as `name`, (data_length + index_length) as `size`, "+
- "table_comment as `comment`, engine as `engine`, table_collation as `collation` "+
- "from information_schema.tables where table_schema = %s and table_type in ('BASE TABLE', 'SYSTEM VERSIONED') "+
- "order by table_name",
- g.quoteString(database),
- ), nil
-}
-
-func (g *mysqlGrammar) compileColumns(database, table string) (string, error) {
- return fmt.Sprintf(
- "select column_name as `name`, data_type as `type_name`, column_type as `type`, "+
- "collation_name as `collation`, is_nullable as `nullable`, "+
- "column_default as `default`, column_comment as `comment`, extra as `extra` "+
- "from information_schema.columns where table_schema = %s and table_name = %s "+
- "order by ordinal_position asc",
- g.quoteString(database),
- g.quoteString(table),
- ), nil
-}
-
-func (g *mysqlGrammar) compileIndexes(database, table string) (string, error) {
- return fmt.Sprintf(
- "select index_name as `name`, group_concat(column_name order by seq_in_index) as `columns`, "+
- "index_type as `type`, not non_unique as `unique` "+
- "from information_schema.statistics where table_schema = %s and table_name = %s "+
- "group by index_name, index_type, non_unique",
- g.quoteString(database),
- g.quoteString(table),
- ), nil
-}
-
-func (g *mysqlGrammar) compileCreate(blueprint *Blueprint) (string, error) {
- sql, err := g.compileCreateTable(blueprint)
- if err != nil {
- return "", err
- }
- sql = g.compileCreateEncoding(sql, blueprint)
-
- return g.compileCreateEngine(sql, blueprint), nil
-}
-
-func (g *mysqlGrammar) compileCreateTable(blueprint *Blueprint) (string, error) {
- columns, err := g.getColumns(blueprint)
- if err != nil {
- return "", err
- }
-
- constraints := g.getConstraints(blueprint)
- columns = append(columns, constraints...)
-
- return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil
-}
-
-func (g *mysqlGrammar) compileCreateEncoding(sql string, blueprint *Blueprint) string {
- if blueprint.charset != "" {
- sql += fmt.Sprintf(" DEFAULT CHARACTER SET %s", blueprint.charset)
- }
- if blueprint.collation != "" {
- sql += fmt.Sprintf(" COLLATE %s", blueprint.collation)
- }
-
- return sql
-}
-
-func (g *mysqlGrammar) compileCreateEngine(sql string, blueprint *Blueprint) string {
- if blueprint.engine != "" {
- sql += fmt.Sprintf(" ENGINE = %s", blueprint.engine)
- }
- return sql
-}
-
-func (g *mysqlGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, error) {
- return g.compileCreate(blueprint)
-}
-
-func (g *mysqlGrammar) compileAdd(blueprint *Blueprint) (string, error) {
- if len(blueprint.getAddedColumns()) == 0 {
- return "", nil
- }
-
- columns, err := g.getColumns(blueprint)
- if err != nil {
- return "", err
- }
- columns = g.prefixArray("ADD COLUMN ", columns)
- constraints := g.getConstraints(blueprint)
- constraints = g.prefixArray("ADD ", constraints)
- columns = append(columns, constraints...)
-
- return fmt.Sprintf("ALTER TABLE %s %s",
- blueprint.name,
- strings.Join(columns, ", "),
- ), nil
-}
-
-func (g *mysqlGrammar) compileChange(bp *Blueprint) ([]string, error) {
- if len(bp.getChangedColumns()) == 0 {
- return nil, nil
- }
-
- var sqls []string
- for _, col := range bp.getChangedColumns() {
- if col.name == "" {
- return nil, fmt.Errorf("column name cannot be empty for change operation")
- }
- sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s %s", bp.name, col.name, g.getType(col))
-
- if col.hasCommand("nullable") {
- if col.nullable {
- sql += " NULL"
- } else {
- sql += " NOT NULL"
- }
- }
- if col.hasCommand("default") {
- if col.defaultValue != nil {
- sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col))
- } else {
- sql += " DEFAULT NULL"
- }
- }
- if col.hasCommand("comment") {
- if col.comment != "" {
- sql += fmt.Sprintf(" COMMENT '%s'", col.comment)
- } else {
- sql += " COMMENT ''"
- }
- }
-
- sqls = append(sqls, sql)
- }
-
- return sqls, nil
-}
-
-func (g *mysqlGrammar) compileRename(blueprint *Blueprint) (string, error) {
- if blueprint.newName == "" {
- return "", fmt.Errorf("new table name cannot be empty")
- }
- return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, blueprint.newName), nil
-}
-
-func (g *mysqlGrammar) compileDrop(blueprint *Blueprint) (string, error) {
- if blueprint.name == "" {
- return "", fmt.Errorf("table name cannot be empty")
- }
- return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil
-}
-
-func (g *mysqlGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) {
- if blueprint.name == "" {
- return "", fmt.Errorf("table name cannot be empty")
- }
- return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil
-}
-
-func (g *mysqlGrammar) compileDropColumn(blueprint *Blueprint) (string, error) {
- if len(blueprint.dropColumns) == 0 {
- return "", fmt.Errorf("no columns to drop")
- }
- columns := make([]string, len(blueprint.dropColumns))
- for i, col := range blueprint.dropColumns {
- if col == "" {
- return "", fmt.Errorf("column name cannot be empty")
- }
- columns[i] = col
- }
- columns = g.prefixArray("DROP COLUMN ", columns)
- return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil
-}
-
-func (g *mysqlGrammar) compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) {
- if oldName == "" || newName == "" {
- return "", fmt.Errorf("old and new column names cannot be empty")
- }
- return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, oldName, newName), nil
-}
-
-func (g *mysqlGrammar) compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("index column cannot be empty")
- }
-
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
-
- sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns))
- if index.algorithm != "" {
- sql += fmt.Sprintf(" USING %s", index.algorithm)
- }
-
- return sql, nil
-}
-
-func (g *mysqlGrammar) compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("unique column cannot be empty")
- }
-
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
- sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns))
- if index.algorithm != "" {
- sql += fmt.Sprintf(" USING %s", index.algorithm)
- }
-
- return sql, nil
-}
-
-func (g *mysqlGrammar) compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("fulltext index column cannot be empty")
- }
-
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
-
- return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.columnize(index.columns)), nil
-}
-
-func (g *mysqlGrammar) compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("primary key column cannot be empty")
- }
-
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
-
- return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(index.columns)), nil
-}
-
-func (g *mysqlGrammar) compileDropIndex(blueprint *Blueprint, indexName string) (string, error) {
- if indexName == "" {
- return "", fmt.Errorf("index name cannot be empty")
- }
- return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, indexName), nil
-}
-
-func (g *mysqlGrammar) compileDropUnique(blueprint *Blueprint, indexName string) (string, error) {
- if indexName == "" {
- return "", fmt.Errorf("unique index name cannot be empty")
- }
- return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, indexName), nil
-}
-
-func (g *mysqlGrammar) compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) {
- return g.compileDropIndex(blueprint, indexName)
-}
-
-func (g *mysqlGrammar) compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) {
- if indexName == "" {
- return "", fmt.Errorf("primary key index name cannot be empty")
- }
- return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", blueprint.name), nil
-}
-
-func (g *mysqlGrammar) compileRenameIndex(blueprint *Blueprint, oldName, newName string) (string, error) {
- if oldName == "" || newName == "" {
- return "", fmt.Errorf("old and new index names cannot be empty")
- }
- return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s", blueprint.name, oldName, newName), nil
-}
-
-func (g *mysqlGrammar) compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) {
- if foreignKeyName == "" {
- return "", fmt.Errorf("foreign key name cannot be empty")
- }
- return fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s", blueprint.name, foreignKeyName), nil
-}
-
-func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) {
- var columns []string
- for _, col := range blueprint.getAddedColumns() {
- if col.name == "" {
- return nil, fmt.Errorf("column name cannot be empty")
- }
- sql := col.name + " " + g.getType(col)
- if col.hasCommand("default") {
- if col.defaultValue != nil {
- sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col))
- } else {
- sql += " DEFAULT NULL"
- }
- }
- if col.onUpdateValue != "" {
- sql += fmt.Sprintf(" ON UPDATE %s", col.onUpdateValue)
- }
- if col.nullable {
- sql += " NULL"
- } else {
- sql += " NOT NULL"
- }
- if col.comment != "" {
- sql += fmt.Sprintf(" COMMENT '%s'", col.comment)
- }
- columns = append(columns, sql)
- }
-
- return columns, nil
-}
-
-func (g *mysqlGrammar) getConstraints(blueprint *Blueprint) []string {
- var constrains []string
- for _, col := range blueprint.getAddedColumns() {
- if col.primary {
- pkConstraintName := g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary})
- sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")"
- constrains = append(constrains, sql)
- continue
- }
- if col.unique {
- uniqueName := col.uniqueName
- if uniqueName == "" {
- uniqueName = g.createIndexName(blueprint, &indexDefinition{
- indexType: indexTypeUnique,
- columns: []string{col.name},
- })
- }
- sql := fmt.Sprintf("CONSTRAINT %s UNIQUE (%s)", uniqueName, col.name)
- constrains = append(constrains, sql)
- }
- }
-
- return constrains
-}
-
-func (g *mysqlGrammar) getType(col *columnDefinition) string {
- switch col.columnType {
- case columnTypeCustom:
- return col.customColumnType
- case columnTypeBoolean:
- return "TINYINT(1)"
- case columnTypeChar:
- return fmt.Sprintf("CHAR(%d)", col.length)
- case columnTypeString:
- return fmt.Sprintf("VARCHAR(%d)", col.length)
- case columnTypeDecimal:
- return fmt.Sprintf("DECIMAL(%d, %d)", col.total, col.places)
- case columnTypeDouble, columnTypeFloat:
- if col.total > 0 && col.places > 0 {
- return fmt.Sprintf("DOUBLE(%d, %d)", col.total, col.places)
- }
- return "DOUBLE"
- case columnTypeBigInteger:
- return g.modifyUnsignedAndAutoIncrement("BIGINT", col)
- case columnTypeInteger:
- return g.modifyUnsignedAndAutoIncrement("INT", col)
- case columnTypeSmallInteger:
- return g.modifyUnsignedAndAutoIncrement("SMALLINT", col)
- case columnTypeMediumInteger:
- return g.modifyUnsignedAndAutoIncrement("MEDIUMINT", col)
- case columnTypeTinyInteger:
- return g.modifyUnsignedAndAutoIncrement("TINYINT", col)
- case columnTypeTime:
- return fmt.Sprintf("TIME(%d)", col.precision)
- case columnTypeDateTime, columnTypeDateTimeTz:
- if col.precision > 0 {
- return fmt.Sprintf("DATETIME(%d)", col.precision)
- }
- return "DATETIME"
- case columnTypeTimestamp, columnTypeTimestampTz:
- if col.precision > 0 {
- return fmt.Sprintf("TIMESTAMP(%d)", col.precision)
- }
- return "TIMESTAMP"
- case columnTypeGeography:
- return fmt.Sprintf("GEOGRAPHY(%s, %d)", col.subType, col.srid)
- case columnTypeEnum:
- return fmt.Sprintf("ENUM(%s)", g.quoteString(strings.Join(col.allowedEnums, "','")))
- default:
- colType, ok := map[columnType]string{
- columnTypeBoolean: "BOOLEAN",
- columnTypeLongText: "LONGTEXT",
- columnTypeText: "TEXT",
- columnTypeMediumText: "MEDIUMTEXT",
- columnTypeTinyText: "TINYTEXT",
- columnTypeDate: "DATE",
- columnTypeYear: "YEAR",
- columnTypeJSON: "JSON",
- columnTypeJSONB: "JSON",
- columnTypeUUID: "UUID",
- columnTypeBinary: "BLOB",
- columnTypeGeometry: "GEOMETRY",
- columnTypePoint: "POINT",
- }[col.columnType]
- if !ok {
- return "ERROR: Unknown column type " + string(col.columnType)
- }
- return colType
- }
-}
-
-func (g *mysqlGrammar) modifyUnsignedAndAutoIncrement(sql string, col *columnDefinition) string {
- if col.unsigned {
- sql += " UNSIGNED"
- }
- if col.autoIncrement {
- sql += " AUTO_INCREMENT"
- }
- return sql
-}
diff --git a/options.go b/options.go
new file mode 100644
index 0000000..8da4f71
--- /dev/null
+++ b/options.go
@@ -0,0 +1,26 @@
+package migris
+
+import "database/sql"
+
+type Option func(*Migrate)
+
+// WithTableName sets the table name for the migration.
+func WithTableName(name string) Option {
+ return func(m *Migrate) {
+ m.tableName = name
+ }
+}
+
+// WithMigrationDir sets the directory for the migration files.
+func WithMigrationDir(dir string) Option {
+ return func(m *Migrate) {
+ m.migrationDir = dir
+ }
+}
+
+// WithDB sets the database connection for the migration.
+func WithDB(db *sql.DB) Option {
+ return func(m *Migrate) {
+ m.db = db
+ }
+}
diff --git a/postgres_grammar.go b/postgres_grammar.go
deleted file mode 100644
index 4e02e7b..0000000
--- a/postgres_grammar.go
+++ /dev/null
@@ -1,450 +0,0 @@
-package schema
-
-import (
- "fmt"
- "slices"
- "strings"
-)
-
-type pgGrammar struct {
- baseGrammar
-}
-
-var _ grammar = (*pgGrammar)(nil)
-
-func newPgGrammar() *pgGrammar {
- return &pgGrammar{}
-}
-
-func (g *pgGrammar) compileTableExists(schema string, table string) (string, error) {
- return fmt.Sprintf(
- "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'",
- g.quoteString(schema),
- g.quoteString(table),
- ), nil
-}
-
-func (g *pgGrammar) compileTables() (string, error) {
- return "select c.relname as name, n.nspname as schema, pg_total_relation_size(c.oid) as size, " +
- "obj_description(c.oid, 'pg_class') as comment from pg_class c, pg_namespace n " +
- "where c.relkind in ('r', 'p') and n.oid = c.relnamespace and n.nspname not in ('pg_catalog', 'information_schema') " +
- "order by c.relname", nil
-}
-
-func (g *pgGrammar) compileColumns(schema, table string) (string, error) {
- return fmt.Sprintf(
- "select a.attname as name, t.typname as type_name, format_type(a.atttypid, a.atttypmod) as type, "+
- "(select tc.collcollate from pg_catalog.pg_collation tc where tc.oid = a.attcollation) as collation, "+
- "not a.attnotnull as nullable, "+
- "(select pg_get_expr(adbin, adrelid) from pg_attrdef where c.oid = pg_attrdef.adrelid and pg_attrdef.adnum = a.attnum) as default, "+
- "col_description(c.oid, a.attnum) as comment "+
- "from pg_attribute a, pg_class c, pg_type t, pg_namespace n "+
- "where c.relname = %s and n.nspname = %s and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid and n.oid = c.relnamespace "+
- "order by a.attnum",
- g.quoteString(table),
- g.quoteString(schema),
- ), nil
-}
-
-func (g *pgGrammar) compileIndexes(schema, table string) (string, error) {
- return fmt.Sprintf(
- "select ic.relname as name, string_agg(a.attname, ',' order by indseq.ord) as columns, "+
- "am.amname as \"type\", i.indisunique as \"unique\", i.indisprimary as \"primary\" "+
- "from pg_index i "+
- "join pg_class tc on tc.oid = i.indrelid "+
- "join pg_namespace tn on tn.oid = tc.relnamespace "+
- "join pg_class ic on ic.oid = i.indexrelid "+
- "join pg_am am on am.oid = ic.relam "+
- "join lateral unnest(i.indkey) with ordinality as indseq(num, ord) on true "+
- "left join pg_attribute a on a.attrelid = i.indrelid and a.attnum = indseq.num "+
- "where tc.relname = %s and tn.nspname = %s "+
- "group by ic.relname, am.amname, i.indisunique, i.indisprimary",
- g.quoteString(table),
- g.quoteString(schema),
- ), nil
-}
-
-func (g *pgGrammar) compileCreate(blueprint *Blueprint) (string, error) {
- columns, err := g.getColumns(blueprint)
- if err != nil {
- return "", err
- }
- constraints := g.getConstraints(blueprint)
- columns = append(columns, constraints...)
- return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil
-}
-
-func (g *pgGrammar) compileCreateIfNotExists(blueprint *Blueprint) (string, error) {
- columns, err := g.getColumns(blueprint)
- if err != nil {
- return "", err
- }
- constraints := g.getConstraints(blueprint)
- columns = append(columns, constraints...)
- return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil
-}
-
-func (g *pgGrammar) compileAdd(blueprint *Blueprint) (string, error) {
- if len(blueprint.getAddedColumns()) == 0 {
- return "", nil
- }
-
- columns, err := g.getColumns(blueprint)
- if err != nil {
- return "", err
- }
- columns = g.prefixArray("ADD COLUMN ", columns)
-
- constraints := g.getConstraints(blueprint)
- constraints = g.prefixArray("ADD ", constraints)
-
- columns = append(columns, constraints...)
-
- return fmt.Sprintf("ALTER TABLE %s %s",
- blueprint.name,
- strings.Join(columns, ", "),
- ), nil
-}
-
-func (g *pgGrammar) compileChange(bp *Blueprint) ([]string, error) {
- if len(bp.getChangedColumns()) == 0 {
- return nil, nil
- }
-
- var queries []string
- for _, col := range bp.getChangedColumns() {
- if col.name == "" {
- return nil, fmt.Errorf("column name cannot be empty for change operation")
- }
- var statements []string
-
- statements = append(statements, fmt.Sprintf("ALTER COLUMN %s TYPE %s", col.name, g.getType(col)))
- if col.hasCommand("default") {
- if col.defaultValue != nil {
- statements = append(statements, fmt.Sprintf("ALTER COLUMN %s SET DEFAULT %s", col.name, g.getDefaultValue(col)))
- } else {
- statements = append(statements, fmt.Sprintf("ALTER COLUMN %s DROP DEFAULT", col.name))
- }
- }
- if col.hasCommand("nullable") {
- if col.nullable {
- statements = append(statements, fmt.Sprintf("ALTER COLUMN %s DROP NOT NULL", col.name))
- } else {
- statements = append(statements, fmt.Sprintf("ALTER COLUMN %s SET NOT NULL", col.name))
- }
- }
- sql := "ALTER TABLE " + bp.name + " " + strings.Join(statements, ", ")
- queries = append(queries, sql)
-
- if col.hasCommand("comment") {
- if col.comment != "" {
- queries = append(queries, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", bp.name, col.name, col.comment))
- } else {
- queries = append(queries, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS NULL", bp.name, col.name))
- }
- }
- }
-
- return queries, nil
-}
-
-func (g *pgGrammar) compileDrop(blueprint *Blueprint) (string, error) {
- return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil
-}
-
-func (g *pgGrammar) compileDropIfExists(blueprint *Blueprint) (string, error) {
- return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil
-}
-
-func (g *pgGrammar) compileRename(blueprint *Blueprint) (string, error) {
- return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, blueprint.newName), nil
-}
-
-func (g *pgGrammar) compileDropColumn(blueprint *Blueprint) (string, error) {
- if len(blueprint.dropColumns) == 0 {
- return "", nil
- }
- columns := g.prefixArray("DROP COLUMN ", blueprint.dropColumns)
-
- return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil
-}
-
-func (g *pgGrammar) compileRenameColumn(blueprint *Blueprint, oldName, newName string) (string, error) {
- if oldName == "" || newName == "" {
- return "", fmt.Errorf("table name, old column name, and new column name cannot be empty for rename operation")
- }
- return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, oldName, newName), nil
-}
-
-func (g *pgGrammar) compileIndex(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("index column cannot be empty")
- }
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
-
- sql := fmt.Sprintf("CREATE INDEX %s ON %s", indexName, blueprint.name)
- if index.algorithm != "" {
- sql += fmt.Sprintf(" USING %s", index.algorithm)
- }
- return fmt.Sprintf("%s (%s)", sql, g.columnize(index.columns)), nil
-}
-
-func (g *pgGrammar) compileUnique(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("unique index column cannot be empty")
- }
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
- sql := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)",
- blueprint.name,
- indexName,
- g.columnize(index.columns),
- )
-
- if index.deferrable != nil {
- if *index.deferrable {
- sql += " DEFERRABLE"
- } else {
- sql += " NOT DEFERRABLE"
- }
- }
- if index.deferrable != nil && *index.deferrable && index.initiallyImmediate != nil {
- if *index.initiallyImmediate {
- sql += " INITIALLY IMMEDIATE"
- } else {
- sql += " INITIALLY DEFERRED"
- }
- }
-
- return sql, nil
-}
-
-func (g *pgGrammar) compilePrimary(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("primary key index column cannot be empty")
- }
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
- return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.columnize(index.columns)), nil
-}
-
-func (g *pgGrammar) compileFullText(blueprint *Blueprint, index *indexDefinition) (string, error) {
- if slices.Contains(index.columns, "") {
- return "", fmt.Errorf("fulltext index column cannot be empty")
- }
- indexName := index.name
- if indexName == "" {
- indexName = g.createIndexName(blueprint, index)
- }
- language := index.language
- if language == "" {
- language = "english" // Default language for full-text search
- }
- var columns []string
- for _, col := range index.columns {
- columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.quoteString(language), col))
- }
-
- return fmt.Sprintf("CREATE INDEX %s ON %s USING GIN (%s)", indexName, blueprint.name, strings.Join(columns, " || ")), nil
-}
-
-func (g *pgGrammar) compileDropIndex(_ *Blueprint, indexName string) (string, error) {
- if indexName == "" {
- return "", fmt.Errorf("index name cannot be empty for drop operation")
- }
- return fmt.Sprintf("DROP INDEX %s", indexName), nil
-}
-
-func (g *pgGrammar) compileDropUnique(blueprint *Blueprint, indexName string) (string, error) {
- if indexName == "" {
- return "", fmt.Errorf("index name cannot be empty for drop operation")
- }
- return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, indexName), nil
-}
-
-func (g *pgGrammar) compileDropFulltext(blueprint *Blueprint, indexName string) (string, error) {
- return g.compileDropIndex(blueprint, indexName)
-}
-
-func (g *pgGrammar) compileRenameIndex(_ *Blueprint, oldName, newName string) (string, error) {
- if oldName == "" || newName == "" {
- return "", fmt.Errorf("index names for rename operation cannot be empty: oldName=%s, newName=%s", oldName, newName)
- }
- return fmt.Sprintf("ALTER INDEX %s RENAME TO %s", oldName, newName), nil
-}
-
-func (g *pgGrammar) compileDropPrimary(blueprint *Blueprint, indexName string) (string, error) {
- if indexName == "" {
- indexName = g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary})
- }
- return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, indexName), nil
-}
-
-func (g *pgGrammar) compileForeign(blueprint *Blueprint, foreignKey *foreignKeyDefinition) (string, error) {
- sql, err := g.baseGrammar.compileForeign(blueprint, foreignKey)
- if err != nil {
- return "", err
- }
-
- if foreignKey.deferrable != nil {
- if *foreignKey.deferrable {
- sql += " DEFERRABLE"
- } else {
- sql += " NOT DEFERRABLE"
- }
- }
- if foreignKey.deferrable != nil && *foreignKey.deferrable && foreignKey.initiallyImmediate != nil {
- if *foreignKey.initiallyImmediate {
- sql += " INITIALLY IMMEDIATE"
- } else {
- sql += " INITIALLY DEFERRED"
- }
- }
-
- return sql, nil
-}
-
-func (g *pgGrammar) compileDropForeign(blueprint *Blueprint, foreignKeyName string) (string, error) {
- if foreignKeyName == "" {
- return "", fmt.Errorf("foreign key name cannot be empty for drop operation")
- }
- return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, foreignKeyName), nil
-}
-
-func (g *pgGrammar) getColumns(blueprint *Blueprint) ([]string, error) {
- var columns []string
- for _, col := range blueprint.getAddedColumns() {
- if col.name == "" {
- return nil, fmt.Errorf("column name cannot be empty")
- }
- sql := col.name + " " + g.getType(col)
- if col.hasCommand("default") {
- if col.defaultValue != nil {
- sql += fmt.Sprintf(" DEFAULT %s", g.getDefaultValue(col))
- } else {
- sql += " DEFAULT NULL"
- }
- }
- if col.nullable {
- sql += " NULL"
- } else {
- sql += " NOT NULL"
- }
- if col.comment != "" {
- sql += fmt.Sprintf(" COMMENT '%s'", col.comment)
- }
- columns = append(columns, sql)
- }
-
- return columns, nil
-}
-
-func (g *pgGrammar) getConstraints(blueprint *Blueprint) []string {
- var constrains []string
- for _, col := range blueprint.getAddedColumns() {
- if col.primary {
- pkConstraintName := g.createIndexName(blueprint, &indexDefinition{indexType: indexTypePrimary})
- sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")"
- constrains = append(constrains, sql)
- continue
- }
- if col.unique {
- uniqueName := col.uniqueName
- if uniqueName == "" {
- uniqueName = g.createIndexName(blueprint, &indexDefinition{
- indexType: indexTypeUnique,
- columns: []string{col.name},
- })
- }
- sql := fmt.Sprintf("CONSTRAINT %s UNIQUE (%s)", uniqueName, col.name)
- constrains = append(constrains, sql)
- }
- }
-
- return constrains
-}
-
-func (g *pgGrammar) getType(col *columnDefinition) string {
- switch col.columnType {
- case columnTypeCustom:
- return col.customColumnType
- case columnTypeChar:
- if col.length > 0 {
- return fmt.Sprintf("CHAR(%d)", col.length)
- }
- return "CHAR"
- case columnTypeString:
- if col.length > 0 {
- return fmt.Sprintf("VARCHAR(%d)", col.length)
- }
- return "VARCHAR"
- case columnTypeDecimal:
- return fmt.Sprintf("DECIMAL(%d, %d)", col.total, col.places)
- case columnTypeBigInteger:
- if col.autoIncrement {
- return "BIGSERIAL"
- }
- return "BIGINT"
- case columnTypeInteger, columnTypeMediumInteger:
- if col.autoIncrement {
- return "SERIAL"
- }
- return "INTEGER"
- case columnTypeSmallInteger, columnTypeTinyInteger:
- if col.autoIncrement {
- return "SMALLSERIAL"
- }
- return "SMALLINT"
- case columnTypeTime:
- return fmt.Sprintf("TIME(%d)", col.precision)
- case columnTypeTimestamp, columnTypeDateTime:
- return fmt.Sprintf("TIMESTAMP(%d)", col.precision)
- case columnTypeTimestampTz, columnTypeDateTimeTz:
- return fmt.Sprintf("TIMESTAMPTZ(%d)", col.precision)
- case columnTypeGeography:
- return fmt.Sprintf("GEOGRAPHY(%s, %d)", col.subType, col.srid)
- case columnTypeGeometry:
- if col.srid > 0 {
- return fmt.Sprintf("GEOMETRY(%s, %d)", col.subType, col.srid)
- }
- return fmt.Sprintf("GEOMETRY(%s)", col.subType)
- case columnTypePoint:
- return fmt.Sprintf("POINT(%d)", col.srid)
- case columnTypeEnum:
- if len(col.allowedEnums) == 0 {
- return "VARCHAR(255)" // Fallback to VARCHAR if no enums are defined
- }
- enumValues := make([]string, len(col.allowedEnums))
- for i, v := range col.allowedEnums {
- enumValues[i] = fmt.Sprintf("'%s'", v)
- }
- return "VARCHAR(255) CHECK (" + col.name + " IN (" + strings.Join(enumValues, ", ") + "))"
- default:
- colType, ok := map[columnType]string{
- columnTypeBoolean: "BOOLEAN",
- columnTypeLongText: "TEXT",
- columnTypeText: "TEXT",
- columnTypeMediumText: "TEXT",
- columnTypeTinyText: "TEXT",
- columnTypeDouble: "DOUBLE PRECISION",
- columnTypeFloat: "REAL",
- columnTypeDate: "DATE",
- columnTypeYear: "INTEGER", // PostgreSQL does not have a YEAR type, using INTEGER instead
- columnTypeJSON: "JSON",
- columnTypeJSONB: "JSONB",
- columnTypeUUID: "UUID",
- columnTypeBinary: "BYTEA",
- }[col.columnType]
- if !ok {
- return "ERROR: Unknown column type " + string(col.columnType)
- }
- return colType
- }
-}
diff --git a/register.go b/register.go
new file mode 100644
index 0000000..96c148d
--- /dev/null
+++ b/register.go
@@ -0,0 +1,90 @@
+package migris
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "path"
+ "runtime"
+
+ "github.com/akfaiz/migris/schema"
+ "github.com/pressly/goose/v3"
+)
+
+var (
+ registeredVersions = make(map[int64]string)
+ registeredMigrations = make([]*Migration, 0)
+)
+
+type Migration struct {
+ version int64
+ source string
+ upFnContext, downFnContext MigrationContext
+}
+
+// MigrationContext is a Go migration func that is run within a transaction and receives a
+// context.
+type MigrationContext func(ctx *schema.Context) error
+
+func (m MigrationContext) runTxFunc(source string) func(ctx context.Context, tx *sql.Tx) error {
+ return func(ctx context.Context, tx *sql.Tx) error {
+ filename := path.Base(source)
+ c := schema.NewContext(ctx, tx, schema.WithFilename(filename))
+ return m(c)
+ }
+}
+
+// AddMigrationContext adds Go migrations.
+func AddMigrationContext(up, down MigrationContext) {
+ _, filename, _, _ := runtime.Caller(1)
+ AddNamedMigrationContext(filename, up, down)
+}
+
+// AddNamedMigrationContext adds named Go migrations.
+func AddNamedMigrationContext(source string, up, down MigrationContext) {
+ if err := register(
+ source,
+ up,
+ down,
+ ); err != nil {
+ panic(err)
+ }
+}
+
+func register(source string, up, down MigrationContext) error {
+ v, _ := goose.NumericComponent(source)
+ if existing, ok := registeredVersions[v]; ok {
+ return fmt.Errorf("failed to add migration %q: version %d conflicts with %q",
+ source,
+ v,
+ existing,
+ )
+ }
+ // Add to global as a registered migration.
+ m := &Migration{
+ version: v,
+ source: source,
+ upFnContext: up,
+ downFnContext: down,
+ }
+ registeredVersions[v] = source
+ registeredMigrations = append(registeredMigrations, m)
+ return nil
+}
+
+func gooseMigrations() []*goose.Migration {
+ migrations := make([]*goose.Migration, 0, len(registeredMigrations))
+ for _, m := range registeredMigrations {
+ upFunc := &goose.GoFunc{
+ RunTx: m.upFnContext.runTxFunc(m.source),
+ Mode: goose.TransactionEnabled,
+ }
+ downFunc := &goose.GoFunc{
+ RunTx: m.downFnContext.runTxFunc(m.source),
+ Mode: goose.TransactionEnabled,
+ }
+ gm := goose.NewGoMigration(m.version, upFunc, downFunc)
+ migrations = append(migrations, gm)
+ }
+ return migrations
+}
diff --git a/reset.go b/reset.go
new file mode 100644
index 0000000..124bcb5
--- /dev/null
+++ b/reset.go
@@ -0,0 +1,44 @@
+package migris
+
+import (
+ "context"
+ "errors"
+
+ "github.com/akfaiz/migris/internal/logger"
+ "github.com/pressly/goose/v3"
+)
+
+// Reset rolls back all migrations.
+func (m *Migrate) Reset() error {
+ ctx := context.Background()
+ return m.ResetContext(ctx)
+}
+
+// ResetContext rolls back all migrations.
+func (m *Migrate) ResetContext(ctx context.Context) error {
+ provider, err := m.newProvider()
+ if err != nil {
+ return err
+ }
+ currentVersion, err := provider.GetDBVersion(ctx)
+ if err != nil {
+ return err
+ }
+ if currentVersion == 0 {
+ logger.Info("Nothing to rollback.")
+ return nil
+ }
+ logger.Info("Rolling back migrations.\n")
+ results, err := provider.DownTo(ctx, 0)
+ if err != nil {
+ var partialErr *goose.PartialError
+ if errors.As(err, &partialErr) {
+ logger.PrintResults(partialErr.Applied)
+ logger.PrintResult(partialErr.Failed)
+ }
+
+ return err
+ }
+ logger.PrintResults(results)
+ return nil
+}
diff --git a/schema/blueprint.go b/schema/blueprint.go
new file mode 100644
index 0000000..b0d6e55
--- /dev/null
+++ b/schema/blueprint.go
@@ -0,0 +1,712 @@
+package schema
+
+import (
+ "fmt"
+
+ "github.com/akfaiz/migris/internal/dialect"
+ "github.com/akfaiz/migris/internal/util"
+)
+
+const (
+ columnTypeBoolean string = "boolean"
+ columnTypeChar string = "char"
+ columnTypeString string = "string"
+ columnTypeLongText string = "longText"
+ columnTypeMediumText string = "mediumText"
+ columnTypeText string = "text"
+ columnTypeTinyText string = "tinyText"
+ columnTypeBigInteger string = "bigInteger"
+ columnTypeInteger string = "integer"
+ columnTypeMediumInteger string = "mediumInteger"
+ columnTypeSmallInteger string = "smallInteger"
+ columnTypeTinyInteger string = "tinyInteger"
+ columnTypeDecimal string = "decimal"
+ columnTypeDouble string = "double"
+ columnTypeFloat string = "float"
+ columnTypeDateTime string = "dateTime"
+ columnTypeDateTimeTz string = "dateTimeTz"
+ columnTypeDate string = "date"
+ columnTypeTime string = "time"
+ columnTypeTimeTz string = "timeTz"
+ columnTypeTimestamp string = "timestamp"
+ columnTypeTimestampTz string = "timestampTz"
+ columnTypeYear string = "year"
+ columnTypeBinary string = "binary"
+ columnTypeJson string = "json"
+ columnTypeJsonb string = "jsonb"
+ columnTypeGeography string = "geography"
+ columnTypeGeometry string = "geometry"
+ columnTypePoint string = "point"
+ columnTypeUuid string = "uuid"
+ columnTypeEnum string = "enum"
+)
+
+const (
+ defaultStringLength int = 255
+ defaultTimePrecision int = 0
+)
+
+// Blueprint represents a schema blueprint for creating or altering a database table.
+type Blueprint struct {
+ dialect dialect.Dialect
+ columns []*columnDefinition
+ commands []*command
+ grammar grammar
+ name string
+ charset string
+ collation string
+ engine string
+}
+
+// Charset sets the character set for the table in the blueprint.
+func (b *Blueprint) Charset(charset string) {
+ b.charset = charset
+}
+
+// Collation sets the collation for the table in the blueprint.
+func (b *Blueprint) Collation(collation string) {
+ b.collation = collation
+}
+
+// Engine sets the storage engine for the table in the blueprint.
+func (b *Blueprint) Engine(engine string) {
+ b.engine = engine
+}
+
+// Column creates a new custom column definition in the blueprint with the specified name and type.
+func (b *Blueprint) Column(name string, columnType string) ColumnDefinition {
+ return b.addColumn(columnType, name)
+}
+
+// Boolean creates a new boolean column definition in the blueprint.
+func (b *Blueprint) Boolean(name string) ColumnDefinition {
+ return b.addColumn(columnTypeBoolean, name)
+}
+
+// Char creates a new char column definition in the blueprint.
+func (b *Blueprint) Char(name string, length ...int) ColumnDefinition {
+ return b.addColumn(columnTypeChar, name, &columnDefinition{
+ length: util.OptionalPtr(defaultStringLength, length...),
+ })
+}
+
+// String creates a new string column definition in the blueprint.
+func (b *Blueprint) String(name string, length ...int) ColumnDefinition {
+ return b.addColumn(columnTypeString, name, &columnDefinition{
+ length: util.OptionalPtr(defaultStringLength, length...),
+ })
+}
+
+// LongText creates a new long text column definition in the blueprint.
+func (b *Blueprint) LongText(name string) ColumnDefinition {
+ return b.addColumn(columnTypeLongText, name)
+}
+
+// Text creates a new text column definition in the blueprint.
+func (b *Blueprint) Text(name string) ColumnDefinition {
+ return b.addColumn(columnTypeText, name)
+}
+
+// MediumText creates a new medium text column definition in the blueprint.
+func (b *Blueprint) MediumText(name string) ColumnDefinition {
+ return b.addColumn(columnTypeMediumText, name)
+}
+
+// TinyText creates a new tiny text column definition in the blueprint.
+func (b *Blueprint) TinyText(name string) ColumnDefinition {
+ return b.addColumn(columnTypeTinyText, name)
+}
+
+// BigIncrements creates a new big increments column definition in the blueprint.
+func (b *Blueprint) BigIncrements(name string) ColumnDefinition {
+ return b.UnsignedBigInteger(name).AutoIncrement()
+}
+
+// BigInteger creates a new big integer column definition in the blueprint.
+func (b *Blueprint) BigInteger(name string) ColumnDefinition {
+ return b.addColumn(columnTypeBigInteger, name)
+}
+
+// Decimal creates a new decimal column definition in the blueprint.
+//
+// The total and places parameters are optional.
+//
+// Example:
+//
+// table.Decimal("price", 10, 2) // creates a decimal column with total 10 and places 2
+//
+// table.Decimal("price") // creates a decimal column with default total 8 and places 2
+func (b *Blueprint) Decimal(name string, params ...int) ColumnDefinition {
+ defaultPlaces := 2
+ if len(params) > 1 {
+ defaultPlaces = params[1]
+ }
+ return b.addColumn(columnTypeDecimal, name, &columnDefinition{
+ total: util.OptionalPtr(8, params...),
+ places: util.PtrOf(defaultPlaces),
+ })
+}
+
+// Double creates a new double column definition in the blueprint.
+func (b *Blueprint) Double(name string) ColumnDefinition {
+ return b.addColumn(columnTypeDouble, name)
+}
+
+// Float creates a new float column definition in the blueprint.
+func (b *Blueprint) Float(name string, precision ...int) ColumnDefinition {
+ return b.addColumn(columnTypeFloat, name, &columnDefinition{
+ precision: util.OptionalPtr(53, precision...),
+ })
+}
+
+// ID creates a new big increments column definition in the blueprint with the name "id" or a custom name.
+//
+// If a name is provided, it will be used as the column name; otherwise, "id" will be used.
+func (b *Blueprint) ID(name ...string) ColumnDefinition {
+ return b.BigIncrements(util.Optional("id", name...)).Primary()
+}
+
+// Increments create a new increment column definition in the blueprint.
+func (b *Blueprint) Increments(name string) ColumnDefinition {
+ return b.UnsignedInteger(name).AutoIncrement()
+}
+
+// Integer creates a new integer column definition in the blueprint.
+func (b *Blueprint) Integer(name string) ColumnDefinition {
+ return b.addColumn(columnTypeInteger, name)
+}
+
+// MediumIncrements creates a new medium increments column definition in the blueprint.
+func (b *Blueprint) MediumIncrements(name string) ColumnDefinition {
+ return b.UnsignedMediumInteger(name).AutoIncrement()
+}
+
+func (b *Blueprint) MediumInteger(name string) ColumnDefinition {
+ return b.addColumn(columnTypeMediumInteger, name)
+}
+
+// SmallIncrements creates a new small increments column definition in the blueprint.
+func (b *Blueprint) SmallIncrements(name string) ColumnDefinition {
+ return b.UnsignedSmallInteger(name).AutoIncrement()
+}
+
+// SmallInteger creates a new small integer column definition in the blueprint.
+func (b *Blueprint) SmallInteger(name string) ColumnDefinition {
+ return b.addColumn(columnTypeSmallInteger, name)
+}
+
+// TinyIncrements creates a new tiny increments column definition in the blueprint.
+func (b *Blueprint) TinyIncrements(name string) ColumnDefinition {
+ return b.UnsignedTinyInteger(name).AutoIncrement()
+}
+
+// TinyInteger creates a new tiny integer column definition in the blueprint.
+func (b *Blueprint) TinyInteger(name string) ColumnDefinition {
+ return b.addColumn(columnTypeTinyInteger, name)
+}
+
+// UnsignedBigInteger creates a new unsigned big integer column definition in the blueprint.
+func (b *Blueprint) UnsignedBigInteger(name string) ColumnDefinition {
+ return b.BigInteger(name).Unsigned()
+}
+
+// UnsignedInteger creates a new unsigned integer column definition in the blueprint.
+func (b *Blueprint) UnsignedInteger(name string) ColumnDefinition {
+ return b.Integer(name).Unsigned()
+}
+
+// UnsignedMediumInteger creates a new unsigned medium integer column definition in the blueprint.
+func (b *Blueprint) UnsignedMediumInteger(name string) ColumnDefinition {
+ return b.MediumInteger(name).Unsigned()
+}
+
+// UnsignedSmallInteger creates a new unsigned small integer column definition in the blueprint.
+func (b *Blueprint) UnsignedSmallInteger(name string) ColumnDefinition {
+ return b.SmallInteger(name).Unsigned()
+}
+
+// UnsignedTinyInteger creates a new unsigned tiny integer column definition in the blueprint.
+func (b *Blueprint) UnsignedTinyInteger(name string) ColumnDefinition {
+ return b.TinyInteger(name).Unsigned()
+}
+
+// DateTime creates a new date time column definition in the blueprint.
+func (b *Blueprint) DateTime(name string, precision ...int) ColumnDefinition {
+ return b.addColumn(columnTypeDateTime, name, &columnDefinition{
+ precision: util.OptionalPtr(defaultTimePrecision, precision...),
+ })
+}
+
+// DateTimeTz creates a new date time with a time zone column definition in the blueprint.
+func (b *Blueprint) DateTimeTz(name string, precision ...int) ColumnDefinition {
+ return b.addColumn(columnTypeDateTimeTz, name, &columnDefinition{
+ precision: util.OptionalPtr(defaultTimePrecision, precision...),
+ })
+}
+
+// Date creates a new date column definition in the blueprint.
+func (b *Blueprint) Date(name string) ColumnDefinition {
+ return b.addColumn(columnTypeDate, name)
+}
+
+// Time creates a new time column definition in the blueprint.
+func (b *Blueprint) Time(name string, precision ...int) ColumnDefinition {
+ return b.addColumn(columnTypeTime, name, &columnDefinition{
+ precision: util.OptionalPtr(defaultTimePrecision, precision...),
+ })
+}
+
+// TimeTz creates a new time with time zone column definition in the blueprint.
+func (b *Blueprint) TimeTz(name string, precision ...int) ColumnDefinition {
+ return b.addColumn(columnTypeTimeTz, name, &columnDefinition{
+ precision: util.OptionalPtr(defaultTimePrecision, precision...),
+ })
+}
+
+// Timestamp creates a new timestamp column definition in the blueprint.
+// The precision parameter is optional and defaults to 0 if not provided.
+func (b *Blueprint) Timestamp(name string, precision ...int) ColumnDefinition {
+ return b.addColumn(columnTypeTimestamp, name, &columnDefinition{
+ precision: util.OptionalPtr(defaultTimePrecision, precision...),
+ })
+}
+
+// TimestampTz creates a new timestamp with time zone column definition in the blueprint.
+// The precision parameter is optional and defaults to 0 if not provided.
+func (b *Blueprint) TimestampTz(name string, precision ...int) ColumnDefinition {
+ return b.addColumn(columnTypeTimestampTz, name, &columnDefinition{
+ precision: util.OptionalPtr(defaultTimePrecision, precision...),
+ })
+}
+
+// Timestamps adds created_at and updated_at timestamp columns to the blueprint.
+func (b *Blueprint) Timestamps(precision ...int) {
+ b.Timestamp("created_at", precision...).UseCurrent()
+ b.Timestamp("updated_at", precision...).UseCurrent().UseCurrentOnUpdate()
+}
+
+// TimestampsTz adds created_at and updated_at timestamp with time zone columns to the blueprint.
+func (b *Blueprint) TimestampsTz(precision ...int) {
+ b.TimestampTz("created_at", precision...).UseCurrent()
+ b.TimestampTz("updated_at", precision...).UseCurrent().UseCurrentOnUpdate()
+}
+
+// Year creates a new year column definition in the blueprint.
+func (b *Blueprint) Year(name string) ColumnDefinition {
+ return b.addColumn(columnTypeYear, name)
+}
+
+// Binary creates a new binary column definition in the blueprint.
+func (b *Blueprint) Binary(name string, length ...int) ColumnDefinition {
+ return b.addColumn(columnTypeBinary, name, &columnDefinition{
+ length: util.OptionalNil(length...),
+ })
+}
+
+// JSON creates a new JSON column definition in the blueprint.
+func (b *Blueprint) JSON(name string) ColumnDefinition {
+ return b.addColumn(columnTypeJson, name)
+}
+
+// JSONB creates a new JSONB column definition in the blueprint.
+func (b *Blueprint) JSONB(name string) ColumnDefinition {
+ return b.addColumn(columnTypeJsonb, name)
+}
+
+// UUID creates a new UUID column definition in the blueprint.
+func (b *Blueprint) UUID(name string) ColumnDefinition {
+ return b.addColumn(columnTypeUuid, name)
+}
+
+// Geography creates a new geography column definition in the blueprint.
+// The subType parameter is optional and can be used to specify the type of geography (e.g., "Point", "LineString", "Polygon").
+// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geography type.
+func (b *Blueprint) Geography(name string, subtype string, srid ...int) ColumnDefinition {
+ return b.addColumn(columnTypeGeography, name, &columnDefinition{
+ subtype: util.OptionalPtr("", subtype),
+ srid: util.OptionalPtr(4326, srid...),
+ })
+}
+
+// Geometry creates a new geometry column definition in the blueprint.
+// The subType parameter is optional and can be used to specify the type of geometry (e.g., "Point", "LineString", "Polygon").
+// The srid parameter is optional and specifies the Spatial Reference Identifier (SRID) for the geometry type.
+func (b *Blueprint) Geometry(name string, subtype string, srid ...int) ColumnDefinition {
+ return b.addColumn(columnTypeGeometry, name, &columnDefinition{
+ subtype: util.OptionalPtr("", subtype),
+ srid: util.OptionalNil(srid...),
+ })
+}
+
+// Point creates a new point column definition in the blueprint.
+func (b *Blueprint) Point(name string, srid ...int) ColumnDefinition {
+ return b.addColumn(columnTypePoint, name, &columnDefinition{
+ srid: util.OptionalPtr(4326, srid...),
+ })
+}
+
+// Enum creates a new enum column definition in the blueprint.
+// The allowedEnums parameter is a slice of strings that defines the allowed values for the enum column.
+//
+// Example:
+//
+// table.Enum("status", []string{"active", "inactive", "pending"})
+// table.Enum("role", []string{"admin", "user", "guest"}).Comment("User role in the system")
+func (b *Blueprint) Enum(name string, allowed []string) ColumnDefinition {
+ return b.addColumn(columnTypeEnum, name, &columnDefinition{
+ allowed: allowed,
+ })
+}
+
+// DropTimestamps removes the created_at and updated_at timestamp columns from the blueprint.
+func (b *Blueprint) DropTimestamps() {
+ b.DropColumn("created_at", "updated_at")
+}
+
+// DropTimestampsTz removes the created_at and updated_at timestamp with time zone columns from the blueprint.
+func (b *Blueprint) DropTimestampsTz() {
+ b.DropTimestamps()
+}
+
+// Index creates a new index definition in the blueprint.
+//
+// Example:
+//
+// table.Index("email")
+// table.Index("email", "username") // creates a composite index
+// table.Index("email").Algorithm("btree") // creates a btree index
+func (b *Blueprint) Index(column string, otherColumns ...string) IndexDefinition {
+ return b.indexCommand(commandIndex, append([]string{column}, otherColumns...)...)
+}
+
+// Unique creates a new unique index definition in the blueprint.
+//
+// Example:
+//
+// table.Unique("email")
+// table.Unique("email", "username") // creates a composite unique index
+func (b *Blueprint) Unique(column string, otherColumns ...string) IndexDefinition {
+ return b.indexCommand(commandUnique, append([]string{column}, otherColumns...)...)
+}
+
+// Primary creates a new primary key index definition in the blueprint.
+//
+// Example:
+//
+// table.Primary("id")
+// table.Primary("id", "email") // creates a composite primary key
+func (b *Blueprint) Primary(column string, otherColumns ...string) IndexDefinition {
+ return b.indexCommand(commandPrimary, append([]string{column}, otherColumns...)...)
+}
+
+// FullText creates a new fulltext index definition in the blueprint.
+func (b *Blueprint) FullText(column string, otherColumns ...string) IndexDefinition {
+ return b.indexCommand(commandFullText, append([]string{column}, otherColumns...)...)
+}
+
+// Foreign creates a new foreign key definition in the blueprint.
+//
+// Example:
+//
+// table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE")
+func (b *Blueprint) Foreign(column string) ForeignKeyDefinition {
+ command := b.addCommand(commandForeign, &command{
+ columns: []string{column},
+ })
+ return &foreignKeyDefinition{command: command}
+}
+
+// DropColumn adds a column to be dropped from the table.
+//
+// Example:
+//
+// table.DropColumn("old_column")
+// table.DropColumn("old_column", "another_old_column") // drops multiple columns
+func (b *Blueprint) DropColumn(column string, otherColumns ...string) {
+ b.addCommand(commandDropColumn, &command{
+ columns: append([]string{column}, otherColumns...),
+ })
+}
+
+// RenameColumn changes the name of the table in the blueprint.
+//
+// Example:
+//
+// table.RenameColumn("old_table_name", "new_table_name")
+func (b *Blueprint) RenameColumn(oldColumn string, newColumn string) {
+ b.addCommand(commandRenameColumn, &command{
+ from: oldColumn,
+ to: newColumn,
+ })
+}
+
+// DropIndex adds an index to be dropped from the table.
+func (b *Blueprint) DropIndex(index any) {
+ b.dropIndexCommand(commandDropIndex, commandIndex, index)
+}
+
+// DropForeign adds a foreign key to be dropped from the table.
+func (b *Blueprint) DropForeign(index any) {
+ b.dropIndexCommand(commandDropForeign, commandForeign, index)
+}
+
+// DropPrimary adds a primary key to be dropped from the table.
+func (b *Blueprint) DropPrimary(index any) {
+ b.dropIndexCommand(commandDropPrimary, commandPrimary, index)
+}
+
+// DropUnique adds a unique key to be dropped from the table.
+func (b *Blueprint) DropUnique(index any) {
+ b.dropIndexCommand(commandDropUnique, commandUnique, index)
+}
+
+func (b *Blueprint) DropFulltext(index any) {
+ b.dropIndexCommand(commandDropFullText, commandFullText, index)
+}
+
+// RenameIndex changes the name of an index in the blueprint.
+// Example:
+//
+// table.RenameIndex("old_index_name", "new_index_name")
+func (b *Blueprint) RenameIndex(oldIndexName string, newIndexName string) {
+ b.addCommand(commandRenameIndex, &command{
+ from: oldIndexName,
+ to: newIndexName,
+ })
+}
+
+func (b *Blueprint) getAddedColumns() []*columnDefinition {
+ var addedColumns []*columnDefinition
+ for _, col := range b.columns {
+ if !col.change {
+ addedColumns = append(addedColumns, col)
+ }
+ }
+ return addedColumns
+}
+
+func (b *Blueprint) getChangedColumns() []*columnDefinition {
+ var changedColumns []*columnDefinition
+ for _, col := range b.columns {
+ if col.change {
+ changedColumns = append(changedColumns, col)
+ }
+ }
+ return changedColumns
+}
+
+func (b *Blueprint) create() {
+ b.addCommand(commandCreate)
+}
+
+func (b *Blueprint) creating() bool {
+ for _, command := range b.commands {
+ if command.name == commandCreate {
+ return true
+ }
+ }
+ return false
+}
+
+func (b *Blueprint) drop() {
+ b.addCommand(commandDrop)
+}
+
+func (b *Blueprint) dropIfExists() {
+ b.addCommand(commandDropIfExists)
+}
+
+func (b *Blueprint) rename(to string) {
+ b.addCommand(commandRename, &command{
+ to: to,
+ })
+}
+
+func (b *Blueprint) addImpliedCommands() {
+ b.addFluentIndexes()
+
+ if !b.creating() {
+ if len(b.getAddedColumns()) > 0 {
+ b.commands = append([]*command{{name: commandAdd}}, b.commands...)
+ }
+ if len(b.getChangedColumns()) > 0 {
+ changedCommands := make([]*command, 0, len(b.getChangedColumns()))
+ for _, col := range b.getChangedColumns() {
+ changedCommands = append(changedCommands, &command{name: commandChange, column: col})
+ }
+ b.commands = append(changedCommands, b.commands...)
+ }
+ }
+}
+
+func (b *Blueprint) addFluentIndexes() {
+ for _, col := range b.columns {
+ if col.primary != nil {
+ if b.dialect == dialect.MySQL {
+ continue
+ }
+ if !*col.primary && col.change {
+ b.DropPrimary([]string{col.name})
+ col.primary = nil
+ }
+ }
+ if col.index != nil {
+ if *col.index {
+ b.Index(col.name).Name(col.indexName)
+ col.index = nil
+ } else if !*col.index && col.change {
+ b.DropIndex([]string{col.name})
+ col.index = nil
+ }
+ }
+
+ if col.unique != nil {
+ if *col.unique {
+ b.Unique(col.name).Name(col.uniqueName)
+ col.unique = nil
+ } else if !*col.unique && col.change {
+ b.DropUnique([]string{col.name})
+ col.unique = nil
+ }
+ }
+ }
+}
+
+func (b *Blueprint) getFluentStatements() []string {
+ var statements []string
+ for _, column := range b.columns {
+ for _, fluentCommand := range b.grammar.GetFluentCommands() {
+ if statement := fluentCommand(b, &command{column: column}); statement != "" {
+ statements = append(statements, statement)
+ }
+ }
+ }
+ return statements
+}
+
+func (b *Blueprint) build(ctx *Context) error {
+ statements, err := b.toSql()
+ if err != nil {
+ return err
+ }
+ for _, statement := range statements {
+ // if b.verbose {
+ // log.Println(statement)
+ // }
+ if _, err := ctx.Exec(statement); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (b *Blueprint) toSql() ([]string, error) {
+ b.addImpliedCommands()
+
+ var statements []string
+
+ mainCommandMap := map[string]func(blueprint *Blueprint) (string, error){
+ commandCreate: b.grammar.CompileCreate,
+ commandAdd: b.grammar.CompileAdd,
+ commandDrop: b.grammar.CompileDrop,
+ commandDropIfExists: b.grammar.CompileDropIfExists,
+ }
+ secondaryCommandMap := map[string]func(blueprint *Blueprint, command *command) (string, error){
+ commandChange: b.grammar.CompileChange,
+ commandDropColumn: b.grammar.CompileDropColumn,
+ commandDropIndex: b.grammar.CompileDropIndex,
+ commandDropForeign: b.grammar.CompileDropForeign,
+ commandDropFullText: b.grammar.CompileDropFulltext,
+ commandDropPrimary: b.grammar.CompileDropPrimary,
+ commandDropUnique: b.grammar.CompileDropUnique,
+ commandForeign: b.grammar.CompileForeign,
+ commandFullText: b.grammar.CompileFullText,
+ commandIndex: b.grammar.CompileIndex,
+ commandPrimary: b.grammar.CompilePrimary,
+ commandRename: b.grammar.CompileRename,
+ commandRenameColumn: b.grammar.CompileRenameColumn,
+ commandRenameIndex: b.grammar.CompileRenameIndex,
+ commandUnique: b.grammar.CompileUnique,
+ }
+ for _, cmd := range b.commands {
+ if compileFunc, exists := mainCommandMap[cmd.name]; exists {
+ sql, err := compileFunc(b)
+ if err != nil {
+ return nil, err
+ }
+ if sql != "" {
+ statements = append(statements, sql)
+ }
+ continue
+ }
+ if compileFunc, exists := secondaryCommandMap[cmd.name]; exists {
+ sql, err := compileFunc(b, cmd)
+ if err != nil {
+ return nil, err
+ }
+ if sql != "" {
+ statements = append(statements, sql)
+ }
+ continue
+ }
+ return nil, fmt.Errorf("unknown command: %s", cmd.name)
+ }
+
+ statements = append(statements, b.getFluentStatements()...)
+
+ return statements, nil
+}
+
+func (b *Blueprint) addColumn(colType string, name string, columnDefs ...*columnDefinition) *columnDefinition {
+ var col *columnDefinition
+ if len(columnDefs) > 0 {
+ col = columnDefs[0]
+ } else {
+ col = &columnDefinition{}
+ }
+ col.columnType = colType
+ col.name = name
+
+ return b.addColumnDefinition(col)
+}
+
+func (b *Blueprint) addColumnDefinition(col *columnDefinition) *columnDefinition {
+ b.columns = append(b.columns, col)
+ return col
+}
+
+func (b *Blueprint) indexCommand(name string, columns ...string) IndexDefinition {
+ command := b.addCommand(name, &command{
+ columns: columns,
+ })
+ return &indexDefinition{command}
+}
+
+func (b *Blueprint) dropIndexCommand(name string, indexType string, index any) {
+ switch index := index.(type) {
+ case string:
+ b.addCommand(name, &command{
+ index: index,
+ })
+ case []string:
+ indexName := b.grammar.CreateIndexName(b, indexType, index...)
+ b.addCommand(name, &command{
+ index: indexName,
+ })
+ default:
+ panic(fmt.Sprintf("unsupported index type: %T", index))
+ }
+}
+
+func (b *Blueprint) addCommand(name string, parameters ...*command) *command {
+ var parameter *command
+ if len(parameters) > 0 {
+ parameter = parameters[0]
+ } else {
+ parameter = &command{}
+ }
+ parameter.name = name
+ b.commands = append(b.commands, parameter)
+
+ return parameter
+}
diff --git a/schema/builder.go b/schema/builder.go
new file mode 100644
index 0000000..35f56c0
--- /dev/null
+++ b/schema/builder.go
@@ -0,0 +1,135 @@
+package schema
+
+import (
+ "errors"
+
+ "github.com/akfaiz/migris/internal/dialect"
+)
+
+// Builder is an interface that defines methods for creating, dropping, and managing database tables.
+type Builder interface {
+ // Create creates a new table with the given name and applies the provided blueprint.
+ Create(c *Context, name string, blueprint func(table *Blueprint)) error
+ // Drop removes the table with the given name.
+ Drop(c *Context, name string) error
+ // DropIfExists removes the table with the given name if it exists.
+ DropIfExists(c *Context, name string) error
+ // GetColumns retrieves the columns of the specified table.
+ GetColumns(c *Context, tableName string) ([]*Column, error)
+ // GetIndexes retrieves the indexes of the specified table.
+ GetIndexes(c *Context, tableName string) ([]*Index, error)
+ // GetTables retrieves all tables in the database.
+ GetTables(c *Context) ([]*TableInfo, error)
+ // HasColumn checks if the specified table has the given column.
+ HasColumn(c *Context, tableName string, columnName string) (bool, error)
+ // HasColumns checks if the specified table has all the given columns.
+ HasColumns(c *Context, tableName string, columnNames []string) (bool, error)
+ // HasIndex checks if the specified table has the given index.
+ HasIndex(c *Context, tableName string, indexes []string) (bool, error)
+ // HasTable checks if a table with the given name exists.
+ HasTable(c *Context, name string) (bool, error)
+ // Rename renames a table from oldName to newName.
+ Rename(c *Context, oldName string, newName string) error
+ // Table applies the provided blueprint to the specified table.
+ Table(c *Context, name string, blueprint func(table *Blueprint)) error
+}
+
+// NewBuilder creates a new Builder instance based on the specified dialect.
+// It returns an error if the dialect is not supported.
+//
+// Supported dialects are "postgres", "pgx", "mysql", and "mariadb".
+func NewBuilder(dialectValue string) (Builder, error) {
+ dialectVal := dialect.FromString(dialectValue)
+ switch dialectVal {
+ case dialect.MySQL:
+ return newMysqlBuilder(), nil
+ case dialect.Postgres:
+ return newPostgresBuilder(), nil
+ default:
+ return nil, errors.New("unsupported dialect: " + dialectValue)
+ }
+}
+
+type baseBuilder struct {
+ grammar grammar
+}
+
+func (b *baseBuilder) newBlueprint(name string) *Blueprint {
+ return &Blueprint{name: name, grammar: b.grammar}
+}
+
+func (b *baseBuilder) Create(c *Context, name string, blueprint func(table *Blueprint)) error {
+ if c == nil || name == "" || blueprint == nil {
+ return errors.New("invalid arguments: context, name, or blueprint is nil/empty")
+ }
+
+ bp := b.newBlueprint(name)
+ bp.create()
+ blueprint(bp)
+
+ if err := bp.build(c); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (b *baseBuilder) Drop(c *Context, name string) error {
+ if c == nil || name == "" {
+ return errors.New("invalid arguments: context is nil or name is empty")
+ }
+
+ bp := b.newBlueprint(name)
+ bp.drop()
+
+ if err := bp.build(c); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (b *baseBuilder) DropIfExists(c *Context, name string) error {
+ if c == nil || name == "" {
+ return errors.New("invalid arguments: context is nil or name is empty")
+ }
+
+ bp := b.newBlueprint(name)
+ bp.dropIfExists()
+
+ if err := bp.build(c); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (b *baseBuilder) Rename(c *Context, oldName string, newName string) error {
+ if c == nil || oldName == "" || newName == "" {
+ return errors.New("invalid arguments: context is nil or old/new table name is empty")
+ }
+
+ bp := b.newBlueprint(oldName)
+ bp.rename(newName)
+
+ if err := bp.build(c); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (b *baseBuilder) Table(c *Context, name string, blueprint func(table *Blueprint)) error {
+ if c == nil || name == "" || blueprint == nil {
+ return errors.New("invalid arguments: context is nil or name/blueprint is empty")
+ }
+
+ bp := b.newBlueprint(name)
+ blueprint(bp)
+
+ if err := bp.build(c); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/schema/column_definition.go b/schema/column_definition.go
new file mode 100644
index 0000000..639ed1a
--- /dev/null
+++ b/schema/column_definition.go
@@ -0,0 +1,198 @@
+package schema
+
+import (
+ "slices"
+
+ "github.com/akfaiz/migris/internal/util"
+)
+
+// ColumnDefinition defines the interface for defining a column in a database table.
+type ColumnDefinition interface {
+ // AutoIncrement sets the column to auto-increment.
+ // This is typically used for primary key columns.
+ AutoIncrement() ColumnDefinition
+ // Change changes the column definition.
+ Change() ColumnDefinition
+ // Charset sets the character set for the column.
+ Charset(charset string) ColumnDefinition
+ // Collation sets the collation for the column.
+ Collation(collation string) ColumnDefinition
+ // Comment adds a comment to the column definition.
+ Comment(comment string) ColumnDefinition
+ // Default sets a default value for the column.
+ Default(value any) ColumnDefinition
+ // Index adds an index to the column.
+ Index(params ...any) ColumnDefinition
+ // Nullable sets the column to be nullable or not.
+ Nullable(value ...bool) ColumnDefinition
+ // OnUpdate sets the value to be used when the column is updated.
+ OnUpdate(value any) ColumnDefinition
+ // Primary sets the column as a primary key.
+ Primary(value ...bool) ColumnDefinition
+ // Unique sets the column to be unique.
+ Unique(params ...any) ColumnDefinition
+ // Unsigned sets the column to be unsigned (applicable for numeric types).
+ Unsigned() ColumnDefinition
+ // UseCurrent sets the column to use the current timestamp as default.
+ UseCurrent() ColumnDefinition
+ // UseCurrentOnUpdate sets the column to use the current timestamp on update.
+ UseCurrentOnUpdate() ColumnDefinition
+}
+
+type columnDefinition struct {
+ commands []string
+ name string
+ columnType string
+ charset *string
+ collation *string
+ comment *string
+ defaultValue any
+ onUpdateValue any
+ useCurrent bool
+ useCurrentOnUpdate bool
+ nullable *bool
+ autoIncrement *bool
+ unsigned *bool
+ primary *bool
+ index *bool
+ indexName string
+ unique *bool
+ uniqueName string
+ length *int
+ precision *int
+ total *int
+ places *int
+ change bool
+ allowed []string // for enum type columns
+ subtype *string // for geography and geometry types
+ srid *int // for geography and geometry types
+}
+
+// Expression is a type for expressions that can be used as default values for columns.
+//
+// Example:
+//
+// schema.Timestamp("created_at").Default(schema.Expression("CURRENT_TIMESTAMP"))
+type Expression string
+
+func (e Expression) String() string {
+ return string(e)
+}
+
+var _ ColumnDefinition = &columnDefinition{}
+
+func (c *columnDefinition) addCommand(command string) {
+ c.commands = append(c.commands, command)
+}
+
+func (c *columnDefinition) hasCommand(command string) bool {
+ return slices.Contains(c.commands, command)
+}
+
+func (c *columnDefinition) SetDefault(value any) {
+ c.addCommand("default")
+ c.defaultValue = value
+}
+
+func (c *columnDefinition) SetOnUpdate(value any) {
+ c.addCommand("onUpdate")
+ c.onUpdateValue = value
+}
+
+func (c *columnDefinition) SetSubtype(value *string) {
+ c.subtype = value
+}
+
+func (c *columnDefinition) AutoIncrement() ColumnDefinition {
+ c.autoIncrement = util.PtrOf(true)
+ return c
+}
+
+func (c *columnDefinition) Charset(charset string) ColumnDefinition {
+ c.charset = &charset
+ return c
+}
+
+func (c *columnDefinition) Change() ColumnDefinition {
+ c.change = true
+ return c
+}
+
+func (c *columnDefinition) Collation(collation string) ColumnDefinition {
+ c.collation = &collation
+ return c
+}
+
+func (c *columnDefinition) Comment(comment string) ColumnDefinition {
+ c.addCommand("comment")
+ c.comment = &comment
+ return c
+}
+
+func (c *columnDefinition) Default(value any) ColumnDefinition {
+ c.addCommand("default")
+ c.defaultValue = value
+
+ return c
+}
+
+func (c *columnDefinition) Index(params ...any) ColumnDefinition {
+ index := true
+ for _, param := range params {
+ switch v := param.(type) {
+ case bool:
+ index = v
+ case string:
+ c.indexName = v
+ }
+ }
+ c.index = &index
+ return c
+}
+
+func (c *columnDefinition) Nullable(value ...bool) ColumnDefinition {
+ c.addCommand("nullable")
+ c.nullable = util.OptionalPtr(true, value...)
+ return c
+}
+
+func (c *columnDefinition) OnUpdate(value any) ColumnDefinition {
+ c.addCommand("onUpdate")
+ c.onUpdateValue = value
+ return c
+}
+
+func (c *columnDefinition) Primary(value ...bool) ColumnDefinition {
+ val := util.Optional(true, value...)
+ c.primary = &val
+ return c
+}
+
+func (c *columnDefinition) Unique(params ...any) ColumnDefinition {
+ unique := true
+ for _, param := range params {
+ switch v := param.(type) {
+ case bool:
+ unique = v
+ case string:
+ c.uniqueName = v
+ }
+ }
+ c.unique = &unique
+ return c
+}
+
+func (c *columnDefinition) Unsigned() ColumnDefinition {
+ c.unsigned = util.PtrOf(true)
+ return c
+}
+
+func (c *columnDefinition) UseCurrent() ColumnDefinition {
+ c.useCurrent = true
+ return c
+}
+
+func (c *columnDefinition) UseCurrentOnUpdate() ColumnDefinition {
+ c.useCurrentOnUpdate = true
+ return c
+}
diff --git a/schema/command.go b/schema/command.go
new file mode 100644
index 0000000..907798a
--- /dev/null
+++ b/schema/command.go
@@ -0,0 +1,40 @@
+package schema
+
+const (
+ commandAdd string = "add"
+ commandChange string = "change"
+ commandCreate string = "create"
+ commandDrop string = "drop"
+ commandDropIfExists string = "dropIfExists"
+ commandDropColumn string = "dropColumn"
+ commandDropForeign string = "dropForeign"
+ commandDropFullText string = "dropFullText"
+ commandDropIndex string = "dropIndex"
+ commandDropPrimary string = "dropPrimary"
+ commandDropUnique string = "dropUnique"
+ commandForeign string = "foreign"
+ commandFullText string = "fullText"
+ commandIndex string = "index"
+ commandPrimary string = "primary"
+ commandRename string = "rename"
+ commandRenameColumn string = "renameColumn"
+ commandRenameIndex string = "renameIndex"
+ commandUnique string = "unique"
+)
+
+type command struct {
+ column *columnDefinition
+ deferrable *bool
+ initiallyImmediate *bool
+ algorithm string
+ from string
+ index string
+ language string
+ name string
+ on string
+ onDelete string
+ onUpdate string
+ to string
+ columns []string
+ references []string
+}
diff --git a/schema/context.go b/schema/context.go
new file mode 100644
index 0000000..aa74605
--- /dev/null
+++ b/schema/context.go
@@ -0,0 +1,43 @@
+package schema
+
+import (
+ "context"
+ "database/sql"
+)
+
+type Context struct {
+ ctx context.Context
+ tx *sql.Tx
+ filename string
+}
+
+type ContextOptions func(*Context)
+
+func WithFilename(filename string) ContextOptions {
+ return func(c *Context) {
+ c.filename = filename
+ }
+}
+
+func NewContext(ctx context.Context, tx *sql.Tx, opts ...ContextOptions) *Context {
+ c := &Context{
+ ctx: ctx,
+ tx: tx,
+ }
+ for _, opt := range opts {
+ opt(c)
+ }
+ return c
+}
+
+func (c *Context) Exec(query string, args ...any) (sql.Result, error) {
+ return c.tx.ExecContext(c.ctx, query, args...)
+}
+
+func (c *Context) Query(query string, args ...any) (*sql.Rows, error) {
+ return c.tx.QueryContext(c.ctx, query, args...)
+}
+
+func (c *Context) QueryRow(query string, args ...any) *sql.Row {
+ return c.tx.QueryRowContext(c.ctx, query, args...)
+}
diff --git a/schema/foreign_key_definition.go b/schema/foreign_key_definition.go
new file mode 100644
index 0000000..49ad922
--- /dev/null
+++ b/schema/foreign_key_definition.go
@@ -0,0 +1,111 @@
+package schema
+
+import "github.com/akfaiz/migris/internal/util"
+
+// ForeignKeyDefinition defines the interface for defining a foreign key constraint in a database table.
+type ForeignKeyDefinition interface {
+ // CascadeOnDelete sets the foreign key to cascade on delete.
+ CascadeOnDelete() ForeignKeyDefinition
+ // CascadeOnUpdate sets the foreign key to cascade on update.
+ CascadeOnUpdate() ForeignKeyDefinition
+ // Deferrable sets the foreign key as deferrable.
+ Deferrable(value ...bool) ForeignKeyDefinition
+ // InitiallyImmediate sets the foreign key to be initially immediate.
+ InitiallyImmediate(value ...bool) ForeignKeyDefinition
+ // Name sets the name of the foreign key constraint.
+ // This is optional and can be used to give a specific name to the foreign key.
+ Name(name string) ForeignKeyDefinition
+ // NoActionOnDelete set the foreign key to do nothing on delete.
+ NoActionOnDelete() ForeignKeyDefinition
+ // NoActionOnUpdate set the foreign key to do nothing on the update.
+ NoActionOnUpdate() ForeignKeyDefinition
+ // NullOnDelete set the foreign key to set the column to NULL on delete.
+ NullOnDelete() ForeignKeyDefinition
+ // NullOnUpdate set the foreign key to set the column to NULL on update.
+ NullOnUpdate() ForeignKeyDefinition
+ // On sets the table that these foreign key references.
+ On(table string) ForeignKeyDefinition
+ // OnDelete set the action to take when the referenced row is deleted.
+ OnDelete(action string) ForeignKeyDefinition
+ // OnUpdate set the action to take when the referenced row is updated.
+ OnUpdate(action string) ForeignKeyDefinition
+ // References set the column that this foreign key references in the other table.
+ References(column string) ForeignKeyDefinition
+ // RestrictOnDelete set the foreign key to restrict deletion of the referenced row.
+ RestrictOnDelete() ForeignKeyDefinition
+ // RestrictOnUpdate set the foreign key to restrict updating of the referenced row.
+ RestrictOnUpdate() ForeignKeyDefinition
+}
+
+type foreignKeyDefinition struct {
+ *command
+}
+
+func (fd *foreignKeyDefinition) CascadeOnDelete() ForeignKeyDefinition {
+ return fd.OnDelete("CASCADE")
+}
+
+func (fd *foreignKeyDefinition) CascadeOnUpdate() ForeignKeyDefinition {
+ return fd.OnUpdate("CASCADE")
+}
+
+func (fd *foreignKeyDefinition) Deferrable(value ...bool) ForeignKeyDefinition {
+ val := util.Optional(true, value...)
+ fd.deferrable = &val
+ return fd
+}
+
+func (fd *foreignKeyDefinition) InitiallyImmediate(value ...bool) ForeignKeyDefinition {
+ val := util.Optional(true, value...)
+ fd.initiallyImmediate = &val
+ return fd
+}
+
+func (fd *foreignKeyDefinition) Name(name string) ForeignKeyDefinition {
+ fd.index = name
+ return fd
+}
+
+func (fd *foreignKeyDefinition) NoActionOnDelete() ForeignKeyDefinition {
+ return fd.OnDelete("NO ACTION")
+}
+
+func (fd *foreignKeyDefinition) NoActionOnUpdate() ForeignKeyDefinition {
+ return fd.OnUpdate("NO ACTION")
+}
+
+func (fd *foreignKeyDefinition) NullOnDelete() ForeignKeyDefinition {
+ return fd.OnDelete("SET NULL")
+}
+
+func (fd *foreignKeyDefinition) NullOnUpdate() ForeignKeyDefinition {
+ return fd.OnUpdate("SET NULL")
+}
+
+func (fd *foreignKeyDefinition) On(table string) ForeignKeyDefinition {
+ fd.on = table
+ return fd
+}
+
+func (fd *foreignKeyDefinition) OnDelete(action string) ForeignKeyDefinition {
+ fd.onDelete = action
+ return fd
+}
+
+func (fd *foreignKeyDefinition) OnUpdate(action string) ForeignKeyDefinition {
+ fd.onUpdate = action
+ return fd
+}
+
+func (fd *foreignKeyDefinition) References(columns string) ForeignKeyDefinition {
+ fd.references = []string{columns}
+ return fd
+}
+
+func (fd *foreignKeyDefinition) RestrictOnDelete() ForeignKeyDefinition {
+ return fd.OnDelete("RESTRICT")
+}
+
+func (fd *foreignKeyDefinition) RestrictOnUpdate() ForeignKeyDefinition {
+ return fd.OnUpdate("RESTRICT")
+}
diff --git a/schema/grammar.go b/schema/grammar.go
new file mode 100644
index 0000000..78d7c3a
--- /dev/null
+++ b/schema/grammar.go
@@ -0,0 +1,141 @@
+package schema
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+
+ "github.com/akfaiz/migris/internal/util"
+)
+
+type grammar interface {
+ CompileCreate(bp *Blueprint) (string, error)
+ CompileAdd(bp *Blueprint) (string, error)
+ CompileChange(bp *Blueprint, command *command) (string, error)
+ CompileDrop(bp *Blueprint) (string, error)
+ CompileDropIfExists(bp *Blueprint) (string, error)
+ CompileRename(bp *Blueprint, command *command) (string, error)
+ CompileDropColumn(blueprint *Blueprint, command *command) (string, error)
+ CompileRenameColumn(blueprint *Blueprint, command *command) (string, error)
+ CompileIndex(blueprint *Blueprint, command *command) (string, error)
+ CompileUnique(blueprint *Blueprint, command *command) (string, error)
+ CompilePrimary(blueprint *Blueprint, command *command) (string, error)
+ CompileFullText(blueprint *Blueprint, command *command) (string, error)
+ CompileDropIndex(blueprint *Blueprint, command *command) (string, error)
+ CompileDropUnique(blueprint *Blueprint, command *command) (string, error)
+ CompileDropFulltext(blueprint *Blueprint, command *command) (string, error)
+ CompileDropPrimary(blueprint *Blueprint, command *command) (string, error)
+ CompileRenameIndex(blueprint *Blueprint, command *command) (string, error)
+ CompileForeign(blueprint *Blueprint, command *command) (string, error)
+ CompileDropForeign(blueprint *Blueprint, command *command) (string, error)
+ GetFluentCommands() []func(blueprint *Blueprint, command *command) string
+ CreateIndexName(blueprint *Blueprint, idxType string, columns ...string) string
+}
+
+type baseGrammar struct{}
+
+func (g *baseGrammar) CompileForeign(blueprint *Blueprint, command *command) (string, error) {
+ if len(command.columns) == 0 || slices.Contains(command.columns, "") || command.on == "" ||
+ len(command.references) == 0 || slices.Contains(command.references, "") {
+ return "", fmt.Errorf("foreign key definition is incomplete: column, on, and references must be set")
+ }
+ onDelete := ""
+ if command.onDelete != "" {
+ onDelete = fmt.Sprintf(" ON DELETE %s", command.onDelete)
+ }
+ onUpdate := ""
+ if command.onUpdate != "" {
+ onUpdate = fmt.Sprintf(" ON UPDATE %s", command.onUpdate)
+ }
+ index := command.index
+ if index == "" {
+ index = g.CreateForeignKeyName(blueprint, command)
+ }
+
+ return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s(%s)%s%s",
+ blueprint.name,
+ index,
+ command.columns[0],
+ command.on,
+ command.references[0],
+ onDelete,
+ onUpdate,
+ ), nil
+}
+
+func (g *baseGrammar) CreateIndexName(blueprint *Blueprint, idxType string, columns ...string) string {
+ tableName := blueprint.name
+ if strings.Contains(tableName, ".") {
+ parts := strings.Split(tableName, ".")
+ tableName = parts[len(parts)-1] // Use the last part as the table name
+ }
+
+ switch idxType {
+ case "primary":
+ return fmt.Sprintf("pk_%s", tableName)
+ case "unique":
+ return fmt.Sprintf("uk_%s_%s", tableName, strings.Join(columns, "_"))
+ case "index":
+ return fmt.Sprintf("idx_%s_%s", tableName, strings.Join(columns, "_"))
+ case "fulltext":
+ return fmt.Sprintf("ft_%s_%s", tableName, strings.Join(columns, "_"))
+ default:
+ return ""
+ }
+}
+
+func (g *baseGrammar) CreateForeignKeyName(blueprint *Blueprint, command *command) string {
+ tableName := blueprint.name
+ if strings.Contains(tableName, ".") {
+ parts := strings.Split(tableName, ".")
+ tableName = parts[len(parts)-1] // Use the last part as the table name
+ }
+ on := command.on
+ if strings.Contains(on, ".") {
+ parts := strings.Split(on, ".")
+ on = parts[len(parts)-1] // Use the last part as the column name
+ }
+ return fmt.Sprintf("fk_%s_%s", tableName, on)
+}
+
+func (g *baseGrammar) QuoteString(s string) string {
+ return "'" + s + "'"
+}
+
+func (g *baseGrammar) PrefixArray(prefix string, items []string) []string {
+ prefixed := make([]string, len(items))
+ for i, item := range items {
+ prefixed[i] = fmt.Sprintf("%s%s", prefix, item)
+ }
+ return prefixed
+}
+
+func (g *baseGrammar) Columnize(columns []string) string {
+ if len(columns) == 0 {
+ return ""
+ }
+ return strings.Join(columns, ", ")
+}
+
+func (g *baseGrammar) GetValue(value any) string {
+ switch v := value.(type) {
+ case Expression:
+ return v.String()
+ default:
+ return fmt.Sprintf("'%v'", v)
+ }
+}
+
+func (g *baseGrammar) GetDefaultValue(value any) string {
+ if value == nil {
+ return "NULL"
+ }
+ switch v := value.(type) {
+ case Expression:
+ return v.String()
+ case bool:
+ return util.Ternary(v, "'1'", "'0'")
+ default:
+ return fmt.Sprintf("'%v'", v)
+ }
+}
diff --git a/index_definition.go b/schema/index_definition.go
similarity index 79%
rename from index_definition.go
rename to schema/index_definition.go
index 2798584..cbcb14d 100644
--- a/index_definition.go
+++ b/schema/index_definition.go
@@ -1,5 +1,7 @@
package schema
+import "github.com/akfaiz/migris/internal/util"
+
// IndexDefinition defines the interface for defining an index in a database table.
type IndexDefinition interface {
// Algorithm sets the algorithm for the index.
@@ -15,16 +17,8 @@ type IndexDefinition interface {
Name(name string) IndexDefinition
}
-var _ IndexDefinition = &indexDefinition{}
-
type indexDefinition struct {
- name string
- indexType indexType
- algorithm string
- columns []string
- language string
- deferrable *bool
- initiallyImmediate *bool
+ *command
}
func (id *indexDefinition) Algorithm(algorithm string) IndexDefinition {
@@ -33,13 +27,13 @@ func (id *indexDefinition) Algorithm(algorithm string) IndexDefinition {
}
func (id *indexDefinition) Deferrable(value ...bool) IndexDefinition {
- val := optional(true, value...)
+ val := util.Optional(true, value...)
id.deferrable = &val
return id
}
func (id *indexDefinition) InitiallyImmediate(value ...bool) IndexDefinition {
- val := optional(true, value...)
+ val := util.Optional(true, value...)
id.initiallyImmediate = &val
return id
}
@@ -50,6 +44,6 @@ func (id *indexDefinition) Language(language string) IndexDefinition {
}
func (id *indexDefinition) Name(name string) IndexDefinition {
- id.name = name
+ id.index = name
return id
}
diff --git a/mysql_builder.go b/schema/mysql_builder.go
similarity index 57%
rename from mysql_builder.go
rename to schema/mysql_builder.go
index 653b58c..a02b984 100644
--- a/mysql_builder.go
+++ b/schema/mysql_builder.go
@@ -1,7 +1,6 @@
package schema
import (
- "context"
"database/sql"
"errors"
"strings"
@@ -12,6 +11,8 @@ type mysqlBuilder struct {
grammar *mysqlGrammar
}
+var _ Builder = (*mysqlBuilder)(nil)
+
func newMysqlBuilder() Builder {
grammar := newMysqlGrammar()
@@ -21,13 +22,9 @@ func newMysqlBuilder() Builder {
}
}
-func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (string, error) {
- if tx == nil {
- return "", errors.New("transaction is nil")
- }
-
- query := b.grammar.compileCurrentDatabase()
- row := queryRowContext(ctx, tx, query)
+func (b *mysqlBuilder) getCurrentDatabase(c *Context) (string, error) {
+ query := b.grammar.CompileCurrentDatabase()
+ row := c.QueryRow(query)
var dbName string
if err := row.Scan(&dbName); err != nil {
return "", err
@@ -35,47 +32,22 @@ func (b *mysqlBuilder) getCurrentDatabase(ctx context.Context, tx *sql.Tx) (stri
return dbName, nil
}
-func (b *mysqlBuilder) CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
- if err := b.validateCreateAndAlter(tx, name, blueprint); err != nil {
- return err
- }
-
- exist, err := b.HasTable(ctx, tx, name)
- if err != nil {
- return err
- }
- if exist {
- return nil // Table already exists, no need to create it
- }
-
- bp := &Blueprint{name: name}
- bp.createIfNotExists()
- blueprint(bp)
-
- statements, err := bp.toSql(b.grammar)
- if err != nil {
- return err
- }
-
- return execContext(ctx, tx, statements...)
-}
-
-func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) {
- if err := b.validateTxAndName(tx, tableName); err != nil {
- return nil, err
+func (b *mysqlBuilder) GetColumns(c *Context, tableName string) ([]*Column, error) {
+ if c == nil || tableName == "" {
+ return nil, errors.New("invalid arguments: context is nil or table name is empty")
}
- database, err := b.getCurrentDatabase(ctx, tx)
+ database, err := b.getCurrentDatabase(c)
if err != nil {
return nil, err
}
- query, err := b.grammar.compileColumns(database, tableName)
+ query, err := b.grammar.CompileColumns(database, tableName)
if err != nil {
return nil, err
}
- rows, err := queryContext(ctx, tx, query)
+ rows, err := c.Query(query)
if err != nil {
return nil, err
}
@@ -101,22 +73,22 @@ func (b *mysqlBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName str
return columns, nil
}
-func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) {
- if err := b.validateTxAndName(tx, tableName); err != nil {
- return nil, err
+func (b *mysqlBuilder) GetIndexes(c *Context, tableName string) ([]*Index, error) {
+ if c == nil || tableName == "" {
+ return nil, errors.New("invalid arguments: context is nil or table name is empty")
}
- database, err := b.getCurrentDatabase(ctx, tx)
+ database, err := b.getCurrentDatabase(c)
if err != nil {
return nil, err
}
- query, err := b.grammar.compileIndexes(database, tableName)
+ query, err := b.grammar.CompileIndexes(database, tableName)
if err != nil {
return nil, err
}
- rows, err := queryContext(ctx, tx, query)
+ rows, err := c.Query(query)
if err != nil {
return nil, err
}
@@ -135,17 +107,21 @@ func (b *mysqlBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName str
return indexes, nil
}
-func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) {
- database, err := b.getCurrentDatabase(ctx, tx)
+func (b *mysqlBuilder) GetTables(c *Context) ([]*TableInfo, error) {
+ if c == nil {
+ return nil, errors.New("invalid arguments: context is nil")
+ }
+
+ database, err := b.getCurrentDatabase(c)
if err != nil {
return nil, err
}
- query, err := b.grammar.compileTables(database)
+ query, err := b.grammar.CompileTables(database)
if err != nil {
return nil, err
}
- rows, err := queryContext(ctx, tx, query)
+ rows, err := c.Query(query)
if err != nil {
return nil, err
}
@@ -162,22 +138,22 @@ func (b *mysqlBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo,
return tables, nil
}
-func (b *mysqlBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) {
- if columnName == "" {
- return false, errors.New("column name cannot be empty")
+func (b *mysqlBuilder) HasColumn(c *Context, tableName string, columnName string) (bool, error) {
+ if c == nil || columnName == "" {
+ return false, errors.New("invalid arguments: context is nil or column name is empty")
}
- return b.HasColumns(ctx, tx, tableName, []string{columnName})
+ return b.HasColumns(c, tableName, []string{columnName})
}
-func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) {
- if err := b.validateTxAndName(tx, tableName); err != nil {
- return false, err
+func (b *mysqlBuilder) HasColumns(c *Context, tableName string, columnNames []string) (bool, error) {
+ if c == nil || tableName == "" {
+ return false, errors.New("invalid arguments: context is nil or table name is empty")
}
if len(columnNames) == 0 {
return false, errors.New("no column names provided")
}
- columns, err := b.GetColumns(ctx, tx, tableName)
+ columns, err := b.GetColumns(c, tableName)
if err != nil {
return false, err
}
@@ -193,12 +169,12 @@ func (b *mysqlBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName str
return true, nil // All specified columns exist
}
-func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) {
- if err := b.validateTxAndName(tx, tableName); err != nil {
- return false, err
+func (b *mysqlBuilder) HasIndex(c *Context, tableName string, indexes []string) (bool, error) {
+ if c == nil || tableName == "" {
+ return false, errors.New("invalid arguments: context is nil or table name is empty")
}
- existingIndexes, err := b.GetIndexes(ctx, tx, tableName)
+ existingIndexes, err := b.GetIndexes(c, tableName)
if err != nil {
return false, err
}
@@ -237,22 +213,22 @@ func (b *mysqlBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName strin
return false, nil // If no specified index exists, return false
}
-func (b *mysqlBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) {
- if err := b.validateTxAndName(tx, name); err != nil {
- return false, err
+func (b *mysqlBuilder) HasTable(c *Context, name string) (bool, error) {
+ if c == nil || name == "" {
+ return false, errors.New("invalid arguments: context is nil or table name is empty")
}
- database, err := b.getCurrentDatabase(ctx, tx)
+ database, err := b.getCurrentDatabase(c)
if err != nil {
return false, err
}
- query, err := b.grammar.compileTableExists(database, name)
+ query, err := b.grammar.CompileTableExists(database, name)
if err != nil {
return false, err
}
- row := queryRowContext(ctx, tx, query)
+ row := c.QueryRow(query)
var exists bool
if err := row.Scan(&exists); err != nil {
if errors.Is(err, sql.ErrNoRows) {
diff --git a/mysql_builder_test.go b/schema/mysql_builder_test.go
similarity index 60%
rename from mysql_builder_test.go
rename to schema/mysql_builder_test.go
index 39f50c6..0d8a643 100644
--- a/mysql_builder_test.go
+++ b/schema/mysql_builder_test.go
@@ -6,7 +6,7 @@ import (
"fmt"
"testing"
- "github.com/afkdevs/go-schema"
+ "github.com/akfaiz/migris/schema"
_ "github.com/go-sql-driver/mysql" // MySQL driver
"github.com/stretchr/testify/suite"
)
@@ -17,8 +17,9 @@ func TestMysqlBuilderSuite(t *testing.T) {
type mysqlBuilderSuite struct {
suite.Suite
- ctx context.Context
- db *sql.DB
+ ctx context.Context
+ db *sql.DB
+ builder schema.Builder
}
func (s *mysqlBuilderSuite) SetupSuite() {
@@ -42,7 +43,8 @@ func (s *mysqlBuilderSuite) SetupSuite() {
s.Require().NoError(err)
s.db = db
- schema.SetDebug(false)
+ s.builder, err = schema.NewBuilder("mysql")
+ s.Require().NoError(err)
}
func (s *mysqlBuilderSuite) TearDownSuite() {
@@ -50,13 +52,14 @@ func (s *mysqlBuilderSuite) TearDownSuite() {
}
func (s *mysqlBuilderSuite) AfterTest(_, _ string) {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
- tables, err := builder.GetTables(s.ctx, tx)
+ c := schema.NewContext(s.ctx, tx)
+ tables, err := builder.GetTables(c)
s.Require().NoError(err)
for _, table := range tables {
- err := builder.DropIfExists(s.ctx, tx, table.Name)
+ err := builder.DropIfExists(c, table.Name)
if err != nil {
s.T().Logf("error dropping table %s: %v", table.Name, err)
}
@@ -66,29 +69,31 @@ func (s *mysqlBuilderSuite) AfterTest(_, _ string) {
}
func (s *mysqlBuilderSuite) TestCreate() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.Create(s.ctx, nil, "test_table", func(table *schema.Blueprint) {
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Create(nil, "test_table", func(table *schema.Blueprint) {
table.String("name")
})
s.Error(err)
})
s.Run("when table name is empty, should return error", func() {
- err := builder.Create(s.ctx, tx, "", func(table *schema.Blueprint) {
+ err := builder.Create(c, "", func(table *schema.Blueprint) {
table.String("name")
})
s.Error(err)
})
s.Run("when blueprint is nil, should return error", func() {
- err := builder.Create(s.ctx, tx, "test_table", nil)
+ err := builder.Create(c, "test_table", nil)
s.Error(err)
})
s.Run("when all parameters are valid, should create table successfully", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
@@ -98,7 +103,7 @@ func (s *mysqlBuilderSuite) TestCreate() {
s.NoError(err, "expected no error when creating table with valid parameters")
})
s.Run("when have composite primary key should create it successfully", func() {
- err = builder.Create(context.Background(), tx, "user_roles", func(table *schema.Blueprint) {
+ err = builder.Create(c, "user_roles", func(table *schema.Blueprint) {
table.Integer("user_id")
table.Integer("role_id")
@@ -107,30 +112,30 @@ func (s *mysqlBuilderSuite) TestCreate() {
s.NoError(err, "expected no error when creating table with composite primary key")
})
s.Run("when have foreign key should create it successfully", func() {
- err = builder.Create(context.Background(), tx, "orders", func(table *schema.Blueprint) {
+ err = builder.Create(c, "orders", func(table *schema.Blueprint) {
table.ID()
table.UnsignedBigInteger("user_id")
table.String("order_id", 255).Unique()
table.Decimal("amount", 10, 2)
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamp("created_at").UseCurrent()
table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE")
})
s.NoError(err, "expected no error when creating table with foreign key")
})
s.Run("when have custom index should create it successfully", func() {
- err = builder.Create(context.Background(), tx, "orders_2", func(table *schema.Blueprint) {
+ err = builder.Create(c, "orders_2", func(table *schema.Blueprint) {
table.ID()
table.String("order_id", 255).Unique("uk_orders_2_order_id")
table.Decimal("amount", 10, 2)
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamp("created_at").UseCurrent()
table.Index("created_at").Name("idx_orders_created_at").Algorithm("BTREE")
})
s.NoError(err, "expected no error when creating table with custom index")
})
s.Run("when table already exists, should return error", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
@@ -139,250 +144,212 @@ func (s *mysqlBuilderSuite) TestCreate() {
})
}
-func (s *mysqlBuilderSuite) TestCreateIfNotExists() {
- builder, _ := schema.NewBuilder("mysql")
- tx, err := s.db.BeginTx(s.ctx, nil)
- s.Require().NoError(err)
- defer tx.Rollback() //nolint:errcheck
-
- s.Run("when tx is nil, should return error", func() {
- err := builder.CreateIfNotExists(s.ctx, nil, "test_table", func(table *schema.Blueprint) {
- table.String("name")
- })
- s.Error(err)
- })
- s.Run("when table name is empty, should return error", func() {
- err := builder.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) {
- table.String("name")
- })
- s.Error(err)
- })
- s.Run("when blueprint is nil, should return error", func() {
- err := builder.CreateIfNotExists(s.ctx, tx, "test_table", nil)
- s.Error(err)
- })
- s.Run("when all parameters are valid, should create table successfully", func() {
- err = builder.CreateIfNotExists(context.Background(), tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
- table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
- })
- s.NoError(err, "expected no error when creating table with valid parameters")
- })
- s.Run("when table already exists, should not return error", func() {
- err = builder.CreateIfNotExists(context.Background(), tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
- })
- s.NoError(err, "expected no error when creating table that already exists with CreateIfNotExists")
- })
-}
-
func (s *mysqlBuilderSuite) TestDrop() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.Drop(s.ctx, nil, "test_table")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Drop(nil, "test_table")
s.Error(err)
})
s.Run("when table name is empty, should return error", func() {
- err := builder.Drop(s.ctx, tx, "")
+ err := builder.Drop(c, "")
s.Error(err)
})
s.Run("when all parameters are valid, should drop table successfully", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before dropping it")
- err = builder.Drop(context.Background(), tx, "users")
+ err = builder.Drop(c, "users")
s.NoError(err, "expected no error when dropping table with valid parameters")
})
s.Run("when table does not exist, should return error", func() {
- err = builder.Drop(context.Background(), tx, "non_existent_table")
+ err = builder.Drop(c, "non_existent_table")
s.Error(err, "expected error when dropping a table that does not exist")
})
}
func (s *mysqlBuilderSuite) TestDropIfExists() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.DropIfExists(s.ctx, nil, "test_table")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.DropIfExists(nil, "test_table")
s.Error(err)
})
s.Run("when table name is empty, should return error", func() {
- err := builder.DropIfExists(s.ctx, tx, "")
+ err := builder.DropIfExists(c, "")
s.Error(err)
})
s.Run("when all parameters are valid, should drop table successfully", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before dropping it")
- err = builder.DropIfExists(context.Background(), tx, "users")
+ err = builder.DropIfExists(c, "users")
s.NoError(err, "expected no error when dropping table with valid parameters")
})
s.Run("when table does not exist, should not return error", func() {
- err = builder.DropIfExists(context.Background(), tx, "non_existent_table")
+ err = builder.DropIfExists(c, "non_existent_table")
s.NoError(err, "expected no error when dropping a table that does not exist with DropIfExists")
})
}
func (s *mysqlBuilderSuite) TestRename() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.Rename(s.ctx, nil, "old_table", "new_table")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Rename(nil, "old_table", "new_table")
s.Error(err)
})
s.Run("when old table name is empty, should return error", func() {
- err := builder.Rename(s.ctx, tx, "", "new_table")
+ err := builder.Rename(c, "", "new_table")
s.Error(err)
})
s.Run("when new table name is empty, should return error", func() {
- err := builder.Rename(s.ctx, tx, "old_table", "")
+ err := builder.Rename(c, "old_table", "")
s.Error(err)
})
s.Run("when all parameters are valid, should rename table successfully", func() {
- err = builder.Create(context.Background(), tx, "old_table", func(table *schema.Blueprint) {
+ err = builder.Create(c, "old_table", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
})
s.NoError(err, "expected no error when creating old_table before renaming it")
- err = builder.Rename(context.Background(), tx, "old_table", "new_table")
+ err = builder.Rename(c, "old_table", "new_table")
s.NoError(err, "expected no error when renaming table with valid parameters")
})
}
func (s *mysqlBuilderSuite) TestTable() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.Table(s.ctx, nil, "test_table", func(table *schema.Blueprint) {
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Table(nil, "test_table", func(table *schema.Blueprint) {
table.String("name")
})
s.Error(err)
})
s.Run("when table name is empty, should return error", func() {
- err := builder.Table(s.ctx, tx, "", func(table *schema.Blueprint) {
+ err := builder.Table(c, "", func(table *schema.Blueprint) {
table.String("name")
})
s.Error(err)
})
s.Run("when blueprint is nil, should return error", func() {
- err := builder.Table(s.ctx, tx, "test_table", nil)
+ err := builder.Table(c, "test_table", nil)
s.Error(err)
})
s.Run("when all parameters are valid, should modify table successfully", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique("uk_users_email")
table.String("password", 255).Nullable()
table.Text("bio").Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
table.FullText("bio")
})
s.NoError(err, "expected no error when creating table before modifying it")
s.Run("should add new columns and modify existing ones", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.String("address", 255).Nullable()
table.String("phone", 20).Nullable().Unique("uk_users_phone")
})
s.NoError(err, "expected no error when modifying table with valid parameters")
})
s.Run("should modify existing column", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.String("email", 255).Nullable().Change()
})
s.NoError(err, "expected no error when modifying existing column")
})
s.Run("should drop column and rename existing one", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropColumn("password")
table.RenameColumn("name", "full_name")
})
s.NoError(err, "expected no error when dropping column and renaming existing one")
})
s.Run("should add index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.Index("phone").Name("idx_users_phone").Algorithm("BTREE")
})
s.NoError(err, "expected no error when adding index to table")
})
s.Run("should rename index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.RenameIndex("idx_users_phone", "idx_users_contact")
})
s.NoError(err, "expected no error when renaming index in table")
})
s.Run("should drop index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropIndex("idx_users_contact")
})
s.NoError(err, "expected no error when dropping index from table")
})
s.Run("should drop unique constraint", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropUnique("uk_users_email")
})
s.NoError(err, "expected no error when dropping unique constraint from table")
})
s.Run("should drop fulltext index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropFulltext("ft_users_bio")
})
s.NoError(err, "expected no error when dropping fulltext index from table")
})
s.Run("should add foreign key", func() {
- err = builder.Create(s.ctx, tx, "roles", func(table *schema.Blueprint) {
+ err = builder.Create(c, "roles", func(table *schema.Blueprint) {
table.UnsignedInteger("id").Primary()
table.String("role_name", 255).Unique("uk_roles_role_name")
})
s.NoError(err, "expected no error when creating roles table before adding foreign key")
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.UnsignedInteger("role_id").Nullable()
table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE")
})
s.NoError(err, "expected no error when adding foreign key to users table")
})
s.Run("should drop foreign key", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropForeign("fk_users_roles")
})
s.NoError(err, "expected no error when dropping foreign key from users table")
})
s.Run("should drop primary key", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.UnsignedBigInteger("id").Change()
table.DropPrimary("users_pkey")
})
@@ -392,38 +359,39 @@ func (s *mysqlBuilderSuite) TestTable() {
}
func (s *mysqlBuilderSuite) TestGetColumns() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- columns, err := builder.GetColumns(s.ctx, nil, "test_table")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ columns, err := builder.GetColumns(nil, "test_table")
s.Error(err)
s.Nil(columns)
})
s.Run("when table name is empty, should return error", func() {
- columns, err := builder.GetColumns(s.ctx, tx, "")
+ columns, err := builder.GetColumns(c, "")
s.Error(err)
s.Nil(columns)
})
s.Run("when table does not exist, should return empty slice", func() {
- columns, err := builder.GetColumns(s.ctx, tx, "non_existent_table")
+ columns, err := builder.GetColumns(c, "non_existent_table")
s.NoError(err)
s.Empty(columns)
})
s.Run("when table exists, should return columns successfully", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before getting columns")
- columns, err := builder.GetColumns(context.Background(), tx, "users")
+ columns, err := builder.GetColumns(c, "users")
s.NoError(err, "expected no error when getting columns from existing table")
s.NotEmpty(columns)
s.Len(columns, 6, "expected 6 columns in the users table")
@@ -431,67 +399,69 @@ func (s *mysqlBuilderSuite) TestGetColumns() {
}
func (s *mysqlBuilderSuite) TestGetIndexes() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- _, err := builder.GetIndexes(s.ctx, nil, "users_indexes")
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ _, err := builder.GetIndexes(nil, "users_indexes")
+ s.Error(err, "expected error when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- _, err := builder.GetIndexes(s.ctx, tx, "")
+ _, err := builder.GetIndexes(c, "")
s.Error(err, "expected error when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
table.Index("name").Name("idx_users_name")
})
s.NoError(err, "expected no error when creating table before getting indexes")
- indexes, err := builder.GetIndexes(s.ctx, tx, "users")
+ indexes, err := builder.GetIndexes(c, "users")
s.NoError(err, "expected no error when getting indexes with valid parameters")
s.Len(indexes, 3, "expected 3 index to be returned")
})
s.Run("when table does not exist, should return empty indexes", func() {
- indexes, err := builder.GetIndexes(s.ctx, tx, "non_existent_table")
+ indexes, err := builder.GetIndexes(c, "non_existent_table")
s.NoError(err, "expected no error when getting indexes of non-existent table")
s.Empty(indexes, "expected empty indexes for non-existent table")
})
}
func (s *mysqlBuilderSuite) TestGetTables() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when tx is nil, should return error", func() {
- tables, err := builder.GetTables(s.ctx, nil)
+ tables, err := builder.GetTables(nil)
s.Error(err, "expected error when transaction is nil")
s.Nil(tables)
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before getting tables")
- tables, err := builder.GetTables(context.Background(), tx)
+ tables, err := builder.GetTables(c)
s.NoError(err, "expected no error when getting tables after creating one")
s.NotEmpty(tables, "expected non-empty tables slice after creating a table")
found := false
@@ -506,165 +476,170 @@ func (s *mysqlBuilderSuite) TestGetTables() {
}
func (s *mysqlBuilderSuite) TestHasColumn() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasColumn(s.ctx, nil, "users", "name")
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ exists, err := builder.HasColumn(nil, "users", "name")
+ s.Error(err, "expected error when context is nil")
s.False(exists)
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasColumn(s.ctx, tx, "", "name")
+ exists, err := builder.HasColumn(c, "", "name")
s.Error(err, "expected error when table name is empty")
s.False(exists)
})
s.Run("when column name is empty, should return error", func() {
- exists, err := builder.HasColumn(s.ctx, tx, "users", "")
+ exists, err := builder.HasColumn(c, "users", "")
s.Error(err, "expected error when column name is empty")
s.False(exists)
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before checking for column existence")
- exists, err := builder.HasColumn(context.Background(), tx, "users", "name")
+ exists, err := builder.HasColumn(c, "users", "name")
s.NoError(err, "expected no error when checking for existing column")
s.True(exists, "expected 'name' column to exist in users table")
- exists, err = builder.HasColumn(context.Background(), tx, "users", "non_existent_column")
+ exists, err = builder.HasColumn(c, "users", "non_existent_column")
s.NoError(err, "expected no error when checking for non-existing column")
s.False(exists, "expected 'non_existent_column' to not exist in users table")
})
}
func (s *mysqlBuilderSuite) TestHasColumns() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasColumns(s.ctx, nil, "users", []string{"name"})
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ exists, err := builder.HasColumns(nil, "users", []string{"name"})
+ s.Error(err, "expected error when context is nil")
s.False(exists)
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasColumns(s.ctx, tx, "", []string{"name"})
+ exists, err := builder.HasColumns(c, "", []string{"name"})
s.Error(err, "expected error when table name is empty")
s.False(exists)
})
s.Run("when column names are empty, should return error", func() {
- exists, err := builder.HasColumns(s.ctx, tx, "users", []string{})
+ exists, err := builder.HasColumns(c, "users", []string{})
s.Error(err, "expected error when column names are empty")
s.False(exists)
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before checking for columns existence")
- exists, err := builder.HasColumns(context.Background(), tx, "users", []string{"name", "email"})
+ exists, err := builder.HasColumns(c, "users", []string{"name", "email"})
s.NoError(err, "expected no error when checking for existing columns")
s.True(exists, "expected 'name' and 'email' columns to exist in users_has_columns table")
- exists, err = builder.HasColumns(context.Background(), tx, "users", []string{"name", "non_existent_column"})
+ exists, err = builder.HasColumns(c, "users", []string{"name", "non_existent_column"})
s.NoError(err, "expected no error when checking for mixed existing and non-existing columns")
s.False(exists, "expected 'non_existent_column' to not exist in users_has_columns table")
})
}
func (s *mysqlBuilderSuite) TestHasIndex() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasIndex(s.ctx, nil, "orders", []string{"idx_users_name"})
- s.Error(err, "expected error when transaction is nil")
- s.False(exists, "expected exists to be false when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ exists, err := builder.HasIndex(nil, "orders", []string{"idx_users_name"})
+ s.Error(err, "expected error when context is nil")
+ s.False(exists, "expected exists to be false when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasIndex(s.ctx, tx, "", []string{"idx_users_name"})
+ exists, err := builder.HasIndex(c, "", []string{"idx_users_name"})
s.Error(err, "expected error when table name is empty")
s.False(exists, "expected exists to be false when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "orders", func(table *schema.Blueprint) {
+ err = builder.Create(c, "orders", func(table *schema.Blueprint) {
table.ID()
table.Integer("company_id")
table.Integer("user_id")
table.String("order_id", 255)
table.Decimal("amount", 10, 2)
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
table.Index("company_id", "user_id")
table.Unique("order_id").Name("uk_orders3_order_id").Algorithm("BTREE")
})
s.NoError(err, "expected no error when creating table with index")
- exists, err := builder.HasIndex(s.ctx, tx, "orders", []string{"uk_orders3_order_id"})
+ exists, err := builder.HasIndex(c, "orders", []string{"uk_orders3_order_id"})
s.NoError(err, "expected no error when checking if index exists with valid parameters")
s.True(exists, "expected exists to be true for existing index")
- exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"company_id", "user_id"})
+ exists, err = builder.HasIndex(c, "orders", []string{"company_id", "user_id"})
s.NoError(err, "expected no error when checking non-existent index")
s.True(exists, "expected exists to be true for existing composite index")
- exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"non_existent_index"})
+ exists, err = builder.HasIndex(c, "orders", []string{"non_existent_index"})
s.NoError(err, "expected no error when checking non-existent index")
s.False(exists, "expected exists to be false for non-existent index")
})
}
func (s *mysqlBuilderSuite) TestHasTable() {
- builder, _ := schema.NewBuilder("mysql")
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasTable(s.ctx, nil, "users")
- s.Error(err, "expected error when transaction is nil")
- s.False(exists, "expected exists to be false when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ exists, err := builder.HasTable(nil, "users")
+ s.Error(err, "expected error when context is nil")
+ s.False(exists, "expected exists to be false when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasTable(s.ctx, tx, "")
+ exists, err := builder.HasTable(c, "")
s.Error(err, "expected error when table name is empty")
s.False(exists, "expected exists to be false when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before checking if it exists")
- exists, err := builder.HasTable(s.ctx, tx, "users")
+ exists, err := builder.HasTable(c, "users")
s.NoError(err, "expected no error when checking if table exists with valid parameters")
s.True(exists, "expected exists to be true for existing table")
- exists, err = builder.HasTable(s.ctx, tx, "non_existent_table")
+ exists, err = builder.HasTable(c, "non_existent_table")
s.NoError(err, "expected no error when checking non-existent table")
s.False(exists, "expected exists to be false for non-existent table")
})
diff --git a/schema/mysql_grammar.go b/schema/mysql_grammar.go
new file mode 100644
index 0000000..eccbf12
--- /dev/null
+++ b/schema/mysql_grammar.go
@@ -0,0 +1,608 @@
+package schema
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+
+ "github.com/akfaiz/migris/internal/util"
+)
+
+type mysqlGrammar struct {
+ baseGrammar
+
+ serials []string
+}
+
+func newMysqlGrammar() *mysqlGrammar {
+ return &mysqlGrammar{
+ serials: []string{
+ "bigInteger", "integer", "mediumInteger", "smallInteger",
+ "tinyInteger",
+ },
+ }
+}
+
+func (g *mysqlGrammar) CompileCurrentDatabase() string {
+ return "SELECT DATABASE()"
+}
+
+func (g *mysqlGrammar) CompileTableExists(database string, table string) (string, error) {
+ return fmt.Sprintf(
+ "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'",
+ g.QuoteString(database),
+ g.QuoteString(table),
+ ), nil
+}
+
+func (g *mysqlGrammar) CompileTables(database string) (string, error) {
+ return fmt.Sprintf(
+ "select table_name as `name`, (data_length + index_length) as `size`, "+
+ "table_comment as `comment`, engine as `engine`, table_collation as `collation` "+
+ "from information_schema.tables where table_schema = %s and table_type in ('BASE TABLE', 'SYSTEM VERSIONED') "+
+ "order by table_name",
+ g.QuoteString(database),
+ ), nil
+}
+
+func (g *mysqlGrammar) CompileColumns(database, table string) (string, error) {
+ return fmt.Sprintf(
+ "select column_name as `name`, data_type as `type_name`, column_type as `type`, "+
+ "collation_name as `collation`, is_nullable as `nullable`, "+
+ "column_default as `default`, column_comment as `comment`, extra as `extra` "+
+ "from information_schema.columns where table_schema = %s and table_name = %s "+
+ "order by ordinal_position asc",
+ g.QuoteString(database),
+ g.QuoteString(table),
+ ), nil
+}
+
+func (g *mysqlGrammar) CompileIndexes(database, table string) (string, error) {
+ return fmt.Sprintf(
+ "select index_name as `name`, group_concat(column_name order by seq_in_index) as `columns`, "+
+ "index_type as `type`, not non_unique as `unique` "+
+ "from information_schema.statistics where table_schema = %s and table_name = %s "+
+ "group by index_name, index_type, non_unique",
+ g.QuoteString(database),
+ g.QuoteString(table),
+ ), nil
+}
+
+func (g *mysqlGrammar) CompileCreate(blueprint *Blueprint) (string, error) {
+ sql, err := g.compileCreateTable(blueprint)
+ if err != nil {
+ return "", err
+ }
+ sql = g.compileCreateEncoding(sql, blueprint)
+
+ return g.compileCreateEngine(sql, blueprint), nil
+}
+
+func (g *mysqlGrammar) compileCreateTable(blueprint *Blueprint) (string, error) {
+ columns, err := g.getColumns(blueprint)
+ if err != nil {
+ return "", err
+ }
+
+ constraints := g.getConstraints(blueprint)
+ columns = append(columns, constraints...)
+
+ return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil
+}
+
+func (g *mysqlGrammar) compileCreateEncoding(sql string, blueprint *Blueprint) string {
+ if blueprint.charset != "" {
+ sql += fmt.Sprintf(" DEFAULT CHARACTER SET %s", blueprint.charset)
+ }
+ if blueprint.collation != "" {
+ sql += fmt.Sprintf(" COLLATE %s", blueprint.collation)
+ }
+
+ return sql
+}
+
+func (g *mysqlGrammar) compileCreateEngine(sql string, blueprint *Blueprint) string {
+ if blueprint.engine != "" {
+ sql += fmt.Sprintf(" ENGINE = %s", blueprint.engine)
+ }
+ return sql
+}
+
+func (g *mysqlGrammar) CompileAdd(blueprint *Blueprint) (string, error) {
+ if len(blueprint.getAddedColumns()) == 0 {
+ return "", nil
+ }
+
+ columns, err := g.getColumns(blueprint)
+ if err != nil {
+ return "", err
+ }
+ columns = g.PrefixArray("ADD COLUMN ", columns)
+ constraints := g.getConstraints(blueprint)
+ constraints = g.PrefixArray("ADD ", constraints)
+ columns = append(columns, constraints...)
+
+ return fmt.Sprintf("ALTER TABLE %s %s",
+ blueprint.name,
+ strings.Join(columns, ", "),
+ ), nil
+}
+
+func (g *mysqlGrammar) CompileChange(bp *Blueprint, command *command) (string, error) {
+ column := command.column
+ if column.name == "" {
+ return "", fmt.Errorf("column name cannot be empty for change operation")
+ }
+
+ sql := fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s %s", bp.name, column.name, g.getType(column))
+ for _, modifier := range g.modifiers() {
+ sql += modifier(column)
+ }
+
+ return sql, nil
+}
+
+func (g *mysqlGrammar) CompileRename(blueprint *Blueprint, command *command) (string, error) {
+ return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil
+}
+
+func (g *mysqlGrammar) CompileDrop(blueprint *Blueprint) (string, error) {
+ if blueprint.name == "" {
+ return "", fmt.Errorf("table name cannot be empty")
+ }
+ return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil
+}
+
+func (g *mysqlGrammar) CompileDropIfExists(blueprint *Blueprint) (string, error) {
+ if blueprint.name == "" {
+ return "", fmt.Errorf("table name cannot be empty")
+ }
+ return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil
+}
+
+func (g *mysqlGrammar) CompileDropColumn(blueprint *Blueprint, command *command) (string, error) {
+ if len(command.columns) == 0 {
+ return "", fmt.Errorf("no columns to drop")
+ }
+ columns := make([]string, len(command.columns))
+ for i, col := range command.columns {
+ if col == "" {
+ return "", fmt.Errorf("column name cannot be empty")
+ }
+ columns[i] = col
+ }
+ columns = g.PrefixArray("DROP COLUMN ", columns)
+ return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil
+}
+
+func (g *mysqlGrammar) CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) {
+ if command.from == "" || command.to == "" {
+ return "", fmt.Errorf("old and new column names cannot be empty")
+ }
+ return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil
+}
+
+func (g *mysqlGrammar) CompileIndex(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("index column cannot be empty")
+ }
+
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "index", command.columns...)
+ }
+
+ sql := fmt.Sprintf("CREATE INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns))
+ if command.algorithm != "" {
+ sql += fmt.Sprintf(" USING %s", command.algorithm)
+ }
+
+ return sql, nil
+}
+
+func (g *mysqlGrammar) CompileUnique(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("unique column cannot be empty")
+ }
+
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "unique", command.columns...)
+ }
+ sql := fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns))
+ if command.algorithm != "" {
+ sql += fmt.Sprintf(" USING %s", command.algorithm)
+ }
+
+ return sql, nil
+}
+
+func (g *mysqlGrammar) CompileFullText(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("fulltext index column cannot be empty")
+ }
+
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "fulltext", command.columns...)
+ }
+
+ return fmt.Sprintf("CREATE FULLTEXT INDEX %s ON %s (%s)", indexName, blueprint.name, g.Columnize(command.columns)), nil
+}
+
+func (g *mysqlGrammar) CompilePrimary(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("primary key column cannot be empty")
+ }
+
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "primary", command.columns...)
+ }
+
+ return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.Columnize(command.columns)), nil
+}
+
+func (g *mysqlGrammar) CompileDropIndex(blueprint *Blueprint, command *command) (string, error) {
+ if command.index == "" {
+ return "", fmt.Errorf("index name cannot be empty")
+ }
+ return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil
+}
+
+func (g *mysqlGrammar) CompileDropUnique(blueprint *Blueprint, command *command) (string, error) {
+ if command.index == "" {
+ return "", fmt.Errorf("unique index name cannot be empty")
+ }
+ return fmt.Sprintf("ALTER TABLE %s DROP INDEX %s", blueprint.name, command.index), nil
+}
+
+func (g *mysqlGrammar) CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) {
+ return g.CompileDropIndex(blueprint, command)
+}
+
+func (g *mysqlGrammar) CompileDropPrimary(blueprint *Blueprint, _ *command) (string, error) {
+ return fmt.Sprintf("ALTER TABLE %s DROP PRIMARY KEY", blueprint.name), nil
+}
+
+func (g *mysqlGrammar) CompileRenameIndex(blueprint *Blueprint, command *command) (string, error) {
+ if command.from == "" || command.to == "" {
+ return "", fmt.Errorf("old and new index names cannot be empty")
+ }
+ return fmt.Sprintf("ALTER TABLE %s RENAME INDEX %s TO %s", blueprint.name, command.from, command.to), nil
+}
+
+func (g *mysqlGrammar) CompileDropForeign(blueprint *Blueprint, command *command) (string, error) {
+ if command.index == "" {
+ return "", fmt.Errorf("foreign key name cannot be empty")
+ }
+ return fmt.Sprintf("ALTER TABLE %s DROP FOREIGN KEY %s", blueprint.name, command.index), nil
+}
+
+func (g *mysqlGrammar) GetFluentCommands() []func(*Blueprint, *command) string {
+ return []func(*Blueprint, *command) string{}
+}
+
+func (g *mysqlGrammar) getColumns(blueprint *Blueprint) ([]string, error) {
+ var columns []string
+ for _, col := range blueprint.getAddedColumns() {
+ if col.name == "" {
+ return nil, fmt.Errorf("column name cannot be empty")
+ }
+ sql := col.name + " " + g.getType(col)
+ sql += g.modifyUnsigned(col)
+ sql += g.modifyIncrement(col)
+ sql += g.modifyDefault(col)
+ sql += g.modifyOnUpdate(col)
+ sql += g.modifyCharset(col)
+ sql += g.modifyCollate(col)
+ sql += g.modifyNullable(col)
+ sql += g.modifyComment(col)
+
+ columns = append(columns, sql)
+ }
+
+ return columns, nil
+}
+
+func (g *mysqlGrammar) getConstraints(blueprint *Blueprint) []string {
+ var constrains []string
+ for _, col := range blueprint.getAddedColumns() {
+ if col.primary != nil && *col.primary {
+ pkConstraintName := g.CreateIndexName(blueprint, "primary")
+ sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")"
+ constrains = append(constrains, sql)
+ continue
+ }
+ }
+
+ return constrains
+}
+
+func (g *mysqlGrammar) getType(col *columnDefinition) string {
+ typeFuncMap := map[string]func(*columnDefinition) string{
+ columnTypeChar: g.typeChar,
+ columnTypeString: g.typeString,
+ columnTypeTinyText: g.typeTinyText,
+ columnTypeText: g.typeText,
+ columnTypeMediumText: g.typeMediumText,
+ columnTypeLongText: g.typeLongText,
+ columnTypeInteger: g.typeInteger,
+ columnTypeBigInteger: g.typeBigInteger,
+ columnTypeMediumInteger: g.typeMediumInteger,
+ columnTypeSmallInteger: g.typeSmallInteger,
+ columnTypeTinyInteger: g.typeTinyInteger,
+ columnTypeFloat: g.typeFloat,
+ columnTypeDouble: g.typeDouble,
+ columnTypeDecimal: g.typeDecimal,
+ columnTypeBoolean: g.typeBoolean,
+ columnTypeEnum: g.typeEnum,
+ columnTypeJson: g.typeJson,
+ columnTypeJsonb: g.typeJsonb,
+ columnTypeDate: g.typeDate,
+ columnTypeDateTime: g.typeDateTime,
+ columnTypeDateTimeTz: g.typeDateTimeTz,
+ columnTypeTime: g.typeTime,
+ columnTypeTimeTz: g.typeTimeTz,
+ columnTypeTimestamp: g.typeTimestamp,
+ columnTypeTimestampTz: g.typeTimestampTz,
+ columnTypeYear: g.typeYear,
+ columnTypeBinary: g.typeBinary,
+ columnTypeUuid: g.typeUuid,
+ columnTypeGeography: g.typeGeography,
+ columnTypeGeometry: g.typeGeometry,
+ columnTypePoint: g.typePoint,
+ }
+ if fn, ok := typeFuncMap[col.columnType]; ok {
+ return fn(col)
+ }
+ return col.columnType
+}
+
+func (g *mysqlGrammar) typeChar(col *columnDefinition) string {
+ return fmt.Sprintf("CHAR(%d)", *col.length)
+}
+
+func (g *mysqlGrammar) typeString(col *columnDefinition) string {
+ return fmt.Sprintf("VARCHAR(%d)", *col.length)
+}
+
+func (g *mysqlGrammar) typeTinyText(col *columnDefinition) string {
+ return "TINYTEXT"
+}
+
+func (g *mysqlGrammar) typeText(col *columnDefinition) string {
+ return "TEXT"
+}
+
+func (g *mysqlGrammar) typeMediumText(col *columnDefinition) string {
+ return "MEDIUMTEXT"
+}
+
+func (g *mysqlGrammar) typeLongText(col *columnDefinition) string {
+ return "LONGTEXT"
+}
+
+func (g *mysqlGrammar) typeBigInteger(col *columnDefinition) string {
+ return "BIGINT"
+}
+
+func (g *mysqlGrammar) typeInteger(col *columnDefinition) string {
+ return "INT"
+}
+
+func (g *mysqlGrammar) typeMediumInteger(col *columnDefinition) string {
+ return "MEDIUMINT"
+}
+
+func (g *mysqlGrammar) typeSmallInteger(col *columnDefinition) string {
+ return "SMALLINT"
+}
+
+func (g *mysqlGrammar) typeTinyInteger(col *columnDefinition) string {
+ return "TINYINT"
+}
+
+func (g *mysqlGrammar) typeFloat(col *columnDefinition) string {
+ if col.precision != nil && *col.precision > 0 {
+ return fmt.Sprintf("FLOAT(%d)", *col.precision)
+ }
+ return "FLOAT"
+}
+
+func (g *mysqlGrammar) typeDouble(col *columnDefinition) string {
+ return "DOUBLE"
+}
+
+func (g *mysqlGrammar) typeDecimal(col *columnDefinition) string {
+ return fmt.Sprintf("DECIMAL(%d, %d)", *col.total, *col.places)
+}
+
+func (g *mysqlGrammar) typeBoolean(col *columnDefinition) string {
+ return "TINYINT(1)"
+}
+
+func (g *mysqlGrammar) typeEnum(col *columnDefinition) string {
+ allowedValues := make([]string, len(col.allowed))
+ for i, e := range col.allowed {
+ allowedValues[i] = g.QuoteString(e)
+ }
+ return fmt.Sprintf("ENUM(%s)", strings.Join(allowedValues, ", "))
+}
+
+func (g *mysqlGrammar) typeJson(col *columnDefinition) string {
+ return "JSON"
+}
+
+func (g *mysqlGrammar) typeJsonb(col *columnDefinition) string {
+ return "JSON"
+}
+
+func (g *mysqlGrammar) typeDate(col *columnDefinition) string {
+ return "DATE"
+}
+
+func (g *mysqlGrammar) typeDateTime(col *columnDefinition) string {
+ current := "CURRENT_TIMESTAMP"
+ if col.precision != nil && *col.precision > 0 {
+ current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision)
+ }
+ if col.useCurrent {
+ col.SetDefault(Expression(current))
+ }
+ if col.useCurrentOnUpdate {
+ col.SetOnUpdate(Expression(current))
+ }
+ if col.precision != nil && *col.precision > 0 {
+ return fmt.Sprintf("DATETIME(%d)", *col.precision)
+ }
+ return "DATETIME"
+}
+
+func (g *mysqlGrammar) typeDateTimeTz(col *columnDefinition) string {
+ return g.typeDateTime(col)
+}
+
+func (g *mysqlGrammar) typeTime(col *columnDefinition) string {
+ if col.precision != nil && *col.precision > 0 {
+ return fmt.Sprintf("TIME(%d)", *col.precision)
+ }
+ return "TIME"
+}
+
+func (g *mysqlGrammar) typeTimeTz(col *columnDefinition) string {
+ return g.typeTime(col)
+}
+
+func (g *mysqlGrammar) typeTimestamp(col *columnDefinition) string {
+ current := "CURRENT_TIMESTAMP"
+ if col.precision != nil && *col.precision > 0 {
+ current = fmt.Sprintf("CURRENT_TIMESTAMP(%d)", *col.precision)
+ }
+ if col.useCurrent {
+ col.SetDefault(Expression(current))
+ }
+ if col.useCurrentOnUpdate {
+ col.SetOnUpdate(Expression(current))
+ }
+ if col.precision != nil && *col.precision > 0 {
+ return fmt.Sprintf("TIMESTAMP(%d)", *col.precision)
+ }
+ return "TIMESTAMP"
+}
+
+func (g *mysqlGrammar) typeTimestampTz(col *columnDefinition) string {
+ return g.typeTimestamp(col)
+}
+
+func (g *mysqlGrammar) typeYear(col *columnDefinition) string {
+ return "YEAR"
+}
+
+func (g *mysqlGrammar) typeBinary(col *columnDefinition) string {
+ if col.length != nil && *col.length > 0 {
+ return fmt.Sprintf("BINARY(%d)", *col.length)
+ }
+ return "BLOB"
+}
+
+func (g *mysqlGrammar) typeUuid(col *columnDefinition) string {
+ return "CHAR(36)" // Default UUID length
+}
+
+func (g *mysqlGrammar) typeGeometry(col *columnDefinition) string {
+ subtype := util.Ternary(col.subtype != nil, util.PtrOf(strings.ToUpper(*col.subtype)), nil)
+ if subtype != nil {
+ if !slices.Contains([]string{"POINT", "LINESTRING", "POLYGON", "GEOMETRYCOLLECTION", "MULTIPOINT", "MULTILINESTRING"}, *subtype) {
+ subtype = nil
+ }
+ }
+
+ if subtype == nil {
+ subtype = util.PtrOf("GEOMETRY")
+ }
+ if col.srid != nil && *col.srid > 0 {
+ return fmt.Sprintf("%s SRID %d", *subtype, *col.srid)
+ }
+ return *subtype
+}
+
+func (g *mysqlGrammar) typeGeography(col *columnDefinition) string {
+ return g.typeGeometry(col)
+}
+
+func (g *mysqlGrammar) typePoint(col *columnDefinition) string {
+ col.SetSubtype(util.PtrOf("POINT"))
+ return g.typeGeometry(col)
+}
+
+func (g *mysqlGrammar) modifiers() []func(*columnDefinition) string {
+ return []func(*columnDefinition) string{
+ g.modifyUnsigned,
+ g.modifyCharset,
+ g.modifyCollate,
+ g.modifyNullable,
+ g.modifyDefault,
+ g.modifyOnUpdate,
+ g.modifyIncrement,
+ g.modifyComment,
+ }
+}
+
+func (g *mysqlGrammar) modifyCharset(col *columnDefinition) string {
+ if col.charset != nil && *col.charset != "" {
+ return fmt.Sprintf(" CHARACTER SET %s", *col.charset)
+ }
+ return ""
+}
+
+func (g *mysqlGrammar) modifyCollate(col *columnDefinition) string {
+ if col.collation != nil && *col.collation != "" {
+ return fmt.Sprintf(" COLLATE %s", *col.collation)
+ }
+ return ""
+}
+
+func (g *mysqlGrammar) modifyComment(col *columnDefinition) string {
+ if col.comment != nil {
+ return fmt.Sprintf(" COMMENT '%s'", *col.comment)
+ }
+ return ""
+}
+
+func (g *mysqlGrammar) modifyDefault(col *columnDefinition) string {
+ if col.hasCommand("default") {
+ return fmt.Sprintf(" DEFAULT %s", g.GetDefaultValue(col.defaultValue))
+ }
+ return ""
+}
+
+func (g *mysqlGrammar) modifyIncrement(col *columnDefinition) string {
+ if slices.Contains(g.serials, col.columnType) &&
+ col.autoIncrement != nil && *col.autoIncrement &&
+ col.primary != nil && *col.primary {
+ return " AUTO_INCREMENT"
+ }
+ return ""
+}
+
+func (g *mysqlGrammar) modifyNullable(col *columnDefinition) string {
+ if col.nullable != nil && *col.nullable {
+ return " NULL"
+ }
+ return " NOT NULL"
+}
+
+func (g *mysqlGrammar) modifyOnUpdate(col *columnDefinition) string {
+ if col.hasCommand("onUpdate") {
+ return fmt.Sprintf(" ON UPDATE %s", g.GetValue(col.onUpdateValue))
+ }
+ return ""
+}
+
+func (g *mysqlGrammar) modifyUnsigned(col *columnDefinition) string {
+ if col.unsigned != nil && *col.unsigned {
+ return " UNSIGNED"
+ }
+ return ""
+}
diff --git a/mysql_grammar_test.go b/schema/mysql_grammar_test.go
similarity index 90%
rename from mysql_grammar_test.go
rename to schema/mysql_grammar_test.go
index 583c18e..91534f2 100644
--- a/mysql_grammar_test.go
+++ b/schema/mysql_grammar_test.go
@@ -3,6 +3,7 @@ package schema
import (
"testing"
+ "github.com/akfaiz/migris/internal/dialect"
"github.com/stretchr/testify/assert"
)
@@ -82,7 +83,7 @@ func TestMysqlGrammar_CompileCreate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := g.compileCreate(bp)
+ got, err := g.CompileCreate(bp)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -172,7 +173,7 @@ func TestMysqlGrammar_CompileAdd(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := g.compileAdd(bp)
+ got, err := g.CompileAdd(bp)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -206,7 +207,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) {
blueprint: func(table *Blueprint) {
table.Integer("age").Change()
},
- want: []string{"ALTER TABLE users MODIFY COLUMN age INT"},
+ want: []string{"ALTER TABLE users MODIFY COLUMN age INT NOT NULL"},
wantErr: false,
},
{
@@ -233,16 +234,16 @@ func TestMysqlGrammar_CompileChange(t *testing.T) {
blueprint: func(table *Blueprint) {
table.String("status", 50).Default("active").Change()
},
- want: []string{"ALTER TABLE users MODIFY COLUMN status VARCHAR(50) DEFAULT 'active'"},
+ want: []string{"ALTER TABLE users MODIFY COLUMN status VARCHAR(50) NOT NULL DEFAULT 'active'"},
wantErr: false,
},
{
name: "change column with null default",
table: "users",
blueprint: func(table *Blueprint) {
- table.Text("description").Default(nil).Change()
+ table.Text("description").Nullable().Default(nil).Change()
},
- want: []string{"ALTER TABLE users MODIFY COLUMN description TEXT DEFAULT NULL"},
+ want: []string{"ALTER TABLE users MODIFY COLUMN description TEXT NULL DEFAULT NULL"},
wantErr: false,
},
{
@@ -251,7 +252,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) {
blueprint: func(table *Blueprint) {
table.Integer("age").Comment("User age in years").Change()
},
- want: []string{"ALTER TABLE users MODIFY COLUMN age INT COMMENT 'User age in years'"},
+ want: []string{"ALTER TABLE users MODIFY COLUMN age INT NOT NULL COMMENT 'User age in years'"},
wantErr: false,
},
{
@@ -260,7 +261,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) {
blueprint: func(table *Blueprint) {
table.Text("notes").Comment("").Change()
},
- want: []string{"ALTER TABLE users MODIFY COLUMN notes TEXT COMMENT ''"},
+ want: []string{"ALTER TABLE users MODIFY COLUMN notes TEXT NOT NULL COMMENT ''"},
wantErr: false,
},
{
@@ -284,7 +285,7 @@ func TestMysqlGrammar_CompileChange(t *testing.T) {
table.SmallInteger("age").Nullable().Change()
},
want: []string{
- "ALTER TABLE users MODIFY COLUMN name VARCHAR(200)",
+ "ALTER TABLE users MODIFY COLUMN name VARCHAR(200) NOT NULL",
"ALTER TABLE users MODIFY COLUMN age SMALLINT NULL",
},
wantErr: false,
@@ -301,15 +302,15 @@ func TestMysqlGrammar_CompileChange(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- bp := &Blueprint{name: tt.table}
+ bp := &Blueprint{name: tt.table, grammar: g, dialect: dialect.MySQL}
tt.blueprint(bp)
- got, err := g.compileChange(bp)
+ statements, err := bp.toSql()
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
}
assert.NoError(t, err, "Did not expect error for test case: %s", tt.name)
- assert.Equal(t, tt.want, got, "Expected SQL to match for test case: %s", tt.name)
+ assert.Equal(t, tt.want, statements, "Expected SQL to match for test case: %s", tt.name)
})
}
}
@@ -345,21 +346,15 @@ func TestMysqlGrammar_CompileRename(t *testing.T) {
want: "ALTER TABLE table1 RENAME TO table2",
wantErr: false,
},
- {
- name: "empty new name should return error",
- table: "users",
- newName: "",
- wantErr: true,
- },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{
- name: tt.table,
- newName: tt.newName,
+ name: tt.table,
}
- got, err := g.compileRename(bp)
+ bp.rename(tt.newName)
+ got, err := g.CompileRename(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -413,7 +408,7 @@ func TestMysqlGrammar_CompileDrop(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileDrop(bp)
+ got, err := g.CompileDrop(bp)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -467,7 +462,7 @@ func TestMysqlGrammar_CompileDropIfExists(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileDropIfExists(bp)
+ got, err := g.CompileDropIfExists(bp)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -506,12 +501,6 @@ func TestMysqlGrammar_CompileDropColumn(t *testing.T) {
want: "ALTER TABLE users DROP COLUMN email, DROP COLUMN phone, DROP COLUMN address",
wantErr: false,
},
- {
- name: "no columns to drop should return error",
- table: "users",
- blueprint: func(table *Blueprint) {},
- wantErr: true,
- },
{
name: "empty column name should return error",
table: "users",
@@ -536,7 +525,7 @@ func TestMysqlGrammar_CompileDropColumn(t *testing.T) {
name: tt.table,
}
tt.blueprint(bp)
- got, err := g.compileDropColumn(bp)
+ got, err := g.CompileDropColumn(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -592,7 +581,8 @@ func TestMysqlGrammar_CompileRenameColumn(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileRenameColumn(bp, tt.oldName, tt.newName)
+ command := &command{from: tt.oldName, to: tt.newName}
+ got, err := g.CompileRenameColumn(bp, command)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -714,7 +704,7 @@ func TestMysqlGrammar_CompileForeign(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := g.compileForeign(bp, bp.foreignKeys[0])
+ got, err := g.CompileForeign(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -753,7 +743,8 @@ func TestMysqlGrammar_CompileDropForeign(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileDropForeign(bp, tt.fkName)
+ command := &command{index: tt.fkName}
+ got, err := g.CompileDropForeign(bp, command)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -833,7 +824,7 @@ func TestMysqlGrammar_CompileIndex(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := g.compileIndex(bp, bp.indexes[0])
+ got, err := g.CompileIndex(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -921,7 +912,7 @@ func TestMysqlGrammar_CompileUnique(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := g.compileUnique(bp, bp.indexes[0])
+ got, err := g.CompileUnique(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1017,7 +1008,7 @@ func TestMysqlGrammar_CompilePrimary(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := g.compilePrimary(bp, bp.indexes[0])
+ got, err := g.CompilePrimary(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1113,7 +1104,7 @@ func TestMysqlGrammar_CompileFullText(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := g.compileFullText(bp, bp.indexes[0])
+ got, err := g.CompileFullText(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1152,7 +1143,8 @@ func TestMysqlGrammar_CompileDropIndex(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileDropIndex(bp, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := g.CompileDropIndex(bp, command)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1190,7 +1182,8 @@ func TestMysqlGrammar_CompileDropUnique(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileDropUnique(bp, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := g.CompileDropUnique(bp, command)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1228,7 +1221,8 @@ func TestMysqlGrammar_CompileDropFulltext(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileDropFulltext(bp, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := g.CompileDropFulltext(bp, command)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1284,18 +1278,13 @@ func TestMysqlGrammar_CompileDropPrimary(t *testing.T) {
want: "ALTER TABLE orders DROP PRIMARY KEY",
wantErr: false,
},
- {
- name: "empty primary key index name should return error",
- table: "users",
- indexName: "",
- wantErr: true,
- },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileDropPrimary(bp, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := g.CompileDropPrimary(bp, command)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1383,7 +1372,8 @@ func TestMysqlGrammar_CompileRenameIndex(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := g.compileRenameIndex(bp, tt.oldName, tt.newName)
+ command := &command{from: tt.oldName, to: tt.newName}
+ got, err := g.CompileRenameIndex(bp, command)
if tt.wantErr {
assert.Error(t, err, "Expected error for test case: %s", tt.name)
return
@@ -1438,32 +1428,25 @@ func TestMysqlGrammar_GetType(t *testing.T) {
want: "DECIMAL(10, 2)",
},
{
- name: "double column type with precision",
- blueprint: func(table *Blueprint) {
- table.Double("value", 8, 2)
- },
- want: "DOUBLE(8, 2)",
- },
- {
- name: "double column type without precision",
+ name: "double column type",
blueprint: func(table *Blueprint) {
- table.Double("value", 0, 0)
+ table.Double("value")
},
want: "DOUBLE",
},
{
name: "float column type with precision",
blueprint: func(table *Blueprint) {
- table.Float("value", 6, 2)
+ table.Float("value", 6)
},
- want: "DOUBLE(6, 2)",
+ want: "FLOAT(6)",
},
{
name: "float column type without precision",
blueprint: func(table *Blueprint) {
- table.Float("value", 0, 0)
+ table.Float("value")
},
- want: "DOUBLE",
+ want: "FLOAT(53)",
},
{
name: "big integer column type",
@@ -1472,27 +1455,6 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "BIGINT",
},
- {
- name: "big integer unsigned",
- blueprint: func(table *Blueprint) {
- table.BigInteger("id").Unsigned()
- },
- want: "BIGINT UNSIGNED",
- },
- {
- name: "big integer auto increment",
- blueprint: func(table *Blueprint) {
- table.BigInteger("id").AutoIncrement()
- },
- want: "BIGINT AUTO_INCREMENT",
- },
- {
- name: "big integer unsigned auto increment",
- blueprint: func(table *Blueprint) {
- table.BigInteger("id").Unsigned().AutoIncrement()
- },
- want: "BIGINT UNSIGNED AUTO_INCREMENT",
- },
{
name: "integer column type",
blueprint: func(table *Blueprint) {
@@ -1500,27 +1462,6 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "INT",
},
- {
- name: "integer unsigned",
- blueprint: func(table *Blueprint) {
- table.Integer("count").Unsigned()
- },
- want: "INT UNSIGNED",
- },
- {
- name: "integer auto increment",
- blueprint: func(table *Blueprint) {
- table.Integer("id").AutoIncrement()
- },
- want: "INT AUTO_INCREMENT",
- },
- {
- name: "integer unsigned auto increment",
- blueprint: func(table *Blueprint) {
- table.Increments("id")
- },
- want: "INT UNSIGNED AUTO_INCREMENT",
- },
{
name: "small integer column type",
blueprint: func(table *Blueprint) {
@@ -1528,13 +1469,6 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "SMALLINT",
},
- {
- name: "small integer unsigned auto increment",
- blueprint: func(table *Blueprint) {
- table.SmallInteger("id").Unsigned().AutoIncrement()
- },
- want: "SMALLINT UNSIGNED AUTO_INCREMENT",
- },
{
name: "medium integer column type",
blueprint: func(table *Blueprint) {
@@ -1542,13 +1476,6 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "MEDIUMINT",
},
- {
- name: "medium integer unsigned",
- blueprint: func(table *Blueprint) {
- table.UnsignedMediumInteger("value")
- },
- want: "MEDIUMINT UNSIGNED",
- },
{
name: "small integer column type",
blueprint: func(table *Blueprint) {
@@ -1556,13 +1483,6 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "SMALLINT",
},
- {
- name: "small integer unsigned auto increment",
- blueprint: func(table *Blueprint) {
- table.SmallIncrements("id")
- },
- want: "SMALLINT UNSIGNED AUTO_INCREMENT",
- },
{
name: "tiny integer column type",
blueprint: func(table *Blueprint) {
@@ -1570,26 +1490,12 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "TINYINT",
},
- {
- name: "tiny integer auto increment",
- blueprint: func(table *Blueprint) {
- table.TinyInteger("id").AutoIncrement()
- },
- want: "TINYINT AUTO_INCREMENT",
- },
- {
- name: "tiny integer unsigned auto increment",
- blueprint: func(table *Blueprint) {
- table.TinyIncrements("id")
- },
- want: "TINYINT UNSIGNED AUTO_INCREMENT",
- },
{
name: "time column type",
blueprint: func(table *Blueprint) {
table.Time("created_at")
},
- want: "TIME(0)",
+ want: "TIME",
},
{
name: "datetime column type with precision",
@@ -1647,19 +1553,12 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "TIMESTAMP",
},
- {
- name: "geography column type",
- blueprint: func(table *Blueprint) {
- table.Geography("location", "POINT", 4326)
- },
- want: "GEOGRAPHY(POINT, 4326)",
- },
{
name: "enum column type",
blueprint: func(table *Blueprint) {
table.Enum("status", []string{"active", "inactive", "pending"})
},
- want: "ENUM('active','inactive','pending')",
+ want: "ENUM('active', 'inactive', 'pending')",
},
{
name: "long text column type",
@@ -1722,7 +1621,7 @@ func TestMysqlGrammar_GetType(t *testing.T) {
blueprint: func(table *Blueprint) {
table.UUID("uuid")
},
- want: "UUID",
+ want: "CHAR(36)",
},
{
name: "binary column type",
@@ -1731,19 +1630,26 @@ func TestMysqlGrammar_GetType(t *testing.T) {
},
want: "BLOB",
},
+ {
+ name: "geography column type",
+ blueprint: func(table *Blueprint) {
+ table.Geography("location", "LINESTRING", 4326)
+ },
+ want: "LINESTRING SRID 4326",
+ },
{
name: "geometry column type",
blueprint: func(table *Blueprint) {
- table.Geometry("shape", "GEOMETRY", 4326)
+ table.Geometry("shape", "", 4326)
},
- want: "GEOMETRY",
+ want: "GEOMETRY SRID 4326",
},
{
name: "point column type",
blueprint: func(table *Blueprint) {
table.Point("location")
},
- want: "POINT",
+ want: "POINT SRID 4326",
},
}
@@ -1791,14 +1697,6 @@ func TestMysqlGrammar_GetColumns(t *testing.T) {
want: []string{"status VARCHAR(50) DEFAULT 'active' NOT NULL"},
wantErr: false,
},
- {
- name: "column with on update value",
- blueprint: func(table *Blueprint) {
- table.Timestamp("updated_at", 0).UseCurrentOnUpdate()
- },
- want: []string{"updated_at TIMESTAMP ON UPDATE CURRENT_TIMESTAMP NOT NULL"},
- wantErr: false,
- },
{
name: "nullable column",
blueprint: func(table *Blueprint) {
@@ -1850,21 +1748,13 @@ func TestMysqlGrammar_GetColumns(t *testing.T) {
want: []string{"id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL"},
wantErr: false,
},
- {
- name: "timestamp with default",
- blueprint: func(table *Blueprint) {
- table.Timestamp("created_at", 0).UseCurrent()
- },
- want: []string{"created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL"},
- wantErr: false,
- },
{
name: "multiple columns with different attributes",
blueprint: func(table *Blueprint) {
table.BigInteger("id").Unsigned().AutoIncrement().Primary()
table.String("name", 255).Comment("User name")
table.String("email", 255).Nullable()
- table.Timestamp("created_at", 0).Default("CURRENT_TIMESTAMP")
+ table.Timestamp("created_at", 0).UseCurrent()
},
want: []string{
"id BIGINT UNSIGNED AUTO_INCREMENT NOT NULL",
diff --git a/postgres_builder.go b/schema/postgres_builder.go
similarity index 68%
rename from postgres_builder.go
rename to schema/postgres_builder.go
index f46b9fd..be214e3 100644
--- a/postgres_builder.go
+++ b/schema/postgres_builder.go
@@ -1,7 +1,6 @@
package schema
import (
- "context"
"database/sql"
"errors"
"strings"
@@ -9,11 +8,11 @@ import (
type postgresBuilder struct {
baseBuilder
- grammar *pgGrammar
+ grammar *postgresGrammar
}
func newPostgresBuilder() Builder {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
return &postgresBuilder{
baseBuilder: baseBuilder{grammar: grammar},
@@ -29,21 +28,21 @@ func (b *postgresBuilder) parseSchemaAndTable(name string) (string, string) {
return "", names[0]
}
-func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) {
- if err := b.validateTxAndName(tx, tableName); err != nil {
- return nil, err
+func (b *postgresBuilder) GetColumns(c *Context, tableName string) ([]*Column, error) {
+ if c == nil || tableName == "" {
+ return nil, errors.New("invalid arguments: context is nil or table name is empty")
}
schema, name := b.parseSchemaAndTable(tableName)
if schema == "" {
schema = "public" // Default schema for PostgreSQL
}
- query, err := b.grammar.compileColumns(schema, name)
+ query, err := b.grammar.CompileColumns(schema, name)
if err != nil {
return nil, err
}
- rows, err := queryContext(ctx, tx, query)
+ rows, err := c.Query(query)
if err != nil {
return nil, err
}
@@ -64,19 +63,19 @@ func (b *postgresBuilder) GetColumns(ctx context.Context, tx *sql.Tx, tableName
return columns, nil
}
-func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) {
- if err := b.validateTxAndName(tx, tableName); err != nil {
- return nil, err
+func (b *postgresBuilder) GetIndexes(c *Context, tableName string) ([]*Index, error) {
+ if c == nil || tableName == "" {
+ return nil, errors.New("invalid arguments: context is nil or table name is empty")
}
schema, name := b.parseSchemaAndTable(tableName)
if schema == "" {
schema = "public" // Default schema for PostgreSQL
}
- query, err := b.grammar.compileIndexes(schema, name)
+ query, err := b.grammar.CompileIndexes(schema, name)
if err != nil {
return nil, err
}
- rows, err := queryContext(ctx, tx, query)
+ rows, err := c.Query(query)
if err != nil {
return nil, err
}
@@ -96,17 +95,17 @@ func (b *postgresBuilder) GetIndexes(ctx context.Context, tx *sql.Tx, tableName
return indexes, nil
}
-func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) {
- if tx == nil {
- return nil, errors.New("transaction is nil")
+func (b *postgresBuilder) GetTables(c *Context) ([]*TableInfo, error) {
+ if c == nil {
+ return nil, errors.New("invalid arguments: context is nil")
}
- query, err := b.grammar.compileTables()
+ query, err := b.grammar.CompileTables()
if err != nil {
return nil, err
}
- rows, err := queryContext(ctx, tx, query)
+ rows, err := c.Query(query)
if err != nil {
return nil, err
}
@@ -124,15 +123,18 @@ func (b *postgresBuilder) GetTables(ctx context.Context, tx *sql.Tx) ([]*TableIn
return tables, nil
}
-func (b *postgresBuilder) HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) {
- return b.HasColumns(ctx, tx, tableName, []string{columnName})
+func (b *postgresBuilder) HasColumn(c *Context, tableName string, columnName string) (bool, error) {
+ return b.HasColumns(c, tableName, []string{columnName})
}
-func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) {
+func (b *postgresBuilder) HasColumns(c *Context, tableName string, columnNames []string) (bool, error) {
+ if c == nil || tableName == "" {
+ return false, errors.New("invalid arguments: context is nil or table name is empty")
+ }
if len(columnNames) == 0 {
return false, errors.New("no column names provided")
}
- existingColumns, err := b.GetColumns(ctx, tx, tableName)
+ existingColumns, err := b.GetColumns(c, tableName)
if err != nil {
return false, err
}
@@ -157,12 +159,12 @@ func (b *postgresBuilder) HasColumns(ctx context.Context, tx *sql.Tx, tableName
return true, nil // All specified columns exist
}
-func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) {
- if err := b.validateTxAndName(tx, tableName); err != nil {
- return false, err
+func (b *postgresBuilder) HasIndex(c *Context, tableName string, indexes []string) (bool, error) {
+ if c == nil || tableName == "" {
+ return false, errors.New("invalid arguments: context is nil or table name is empty")
}
- existingIndexes, err := b.GetIndexes(ctx, tx, tableName)
+ existingIndexes, err := b.GetIndexes(c, tableName)
if err != nil {
return false, err
}
@@ -201,22 +203,22 @@ func (b *postgresBuilder) HasIndex(ctx context.Context, tx *sql.Tx, tableName st
return false, nil // If no specified index exists, return false
}
-func (b *postgresBuilder) HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) {
- if err := b.validateTxAndName(tx, name); err != nil {
- return false, err
+func (b *postgresBuilder) HasTable(c *Context, name string) (bool, error) {
+ if c == nil || name == "" {
+ return false, errors.New("invalid arguments: context is nil or table name is empty")
}
schema, name := b.parseSchemaAndTable(name)
if schema == "" {
schema = "public" // Default schema for PostgreSQL
}
- query, err := b.grammar.compileTableExists(schema, name)
+ query, err := b.grammar.CompileTableExists(schema, name)
if err != nil {
return false, err
}
var exists bool
- if err := queryRowContext(ctx, tx, query).Scan(&exists); err != nil {
+ if err := c.QueryRow(query).Scan(&exists); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil // Table does not exist
}
diff --git a/postgres_builder_test.go b/schema/postgres_builder_test.go
similarity index 59%
rename from postgres_builder_test.go
rename to schema/postgres_builder_test.go
index 334aea7..2b5b0b5 100644
--- a/postgres_builder_test.go
+++ b/schema/postgres_builder_test.go
@@ -7,7 +7,7 @@ import (
"os"
"testing"
- "github.com/afkdevs/go-schema"
+ "github.com/akfaiz/migris/schema"
_ "github.com/lib/pq"
"github.com/stretchr/testify/suite"
)
@@ -16,12 +16,6 @@ func TestPostgresBuilderSuite(t *testing.T) {
suite.Run(t, new(postgresBuilderSuite))
}
-type postgresBuilderSuite struct {
- suite.Suite
- ctx context.Context
- db *sql.DB
-}
-
type dbConfig struct {
Database string
Username string
@@ -43,6 +37,13 @@ func parseTestConfig() dbConfig {
}
}
+type postgresBuilderSuite struct {
+ suite.Suite
+ ctx context.Context
+ db *sql.DB
+ builder schema.Builder
+}
+
func (s *postgresBuilderSuite) SetupSuite() {
s.ctx = context.Background()
@@ -57,7 +58,8 @@ func (s *postgresBuilderSuite) SetupSuite() {
s.Require().NoError(err)
s.db = db
- schema.SetDebug(true)
+ s.builder, err = schema.NewBuilder("postgres")
+ s.Require().NoError(err)
}
func (s *postgresBuilderSuite) TearDownSuite() {
@@ -65,48 +67,49 @@ func (s *postgresBuilderSuite) TearDownSuite() {
}
func (s *postgresBuilderSuite) TestCreate() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.Create(s.ctx, nil, "test_table", func(table *schema.Blueprint) {})
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Create(nil, "test_table", func(table *schema.Blueprint) {})
+ s.Error(err, "expected error when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- err := builder.Create(s.ctx, tx, "", func(table *schema.Blueprint) {})
+ err := builder.Create(c, "", func(table *schema.Blueprint) {})
s.Error(err, "expected error when table name is empty")
})
s.Run("when blueprint is nil, should return error", func() {
- err := builder.Create(s.ctx, tx, "test_table", nil)
+ err := builder.Create(c, "test_table", nil)
s.Error(err, "expected error when blueprint is nil")
})
s.Run("when all parameters are valid, should create table successfully", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
- table.String("password", 255).Nullable()
+ table.String("name")
+ table.String("email").Unique()
+ table.String("password").Nullable()
table.Timestamps()
})
s.NoError(err, "expected no error when creating table with valid parameters")
})
s.Run("when use custom schema should create it successfully", func() {
- _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_publics")
+ _, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_public")
s.NoError(err, "expected no error when creating custom schema")
- err = builder.Create(context.Background(), tx, "custom_publics.users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "custom_public.users", func(table *schema.Blueprint) {
table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
- table.String("password", 255).Nullable()
+ table.String("name")
+ table.String("email").Unique()
+ table.String("password").Nullable()
table.TimestampsTz()
})
s.NoError(err, "expected no error when creating table with custom schema")
})
s.Run("when have composite primary key should create it successfully", func() {
- err = builder.Create(context.Background(), tx, "user_roles", func(table *schema.Blueprint) {
+ err = builder.Create(c, "user_roles", func(table *schema.Blueprint) {
table.Integer("user_id")
table.Integer("role_id")
@@ -115,280 +118,249 @@ func (s *postgresBuilderSuite) TestCreate() {
s.NoError(err, "expected no error when creating table with composite primary key")
})
s.Run("when have foreign key should create it successfully", func() {
- err = builder.Create(context.Background(), tx, "orders", func(table *schema.Blueprint) {
+ err = builder.Create(c, "orders", func(table *schema.Blueprint) {
table.ID()
table.BigInteger("user_id")
- table.String("order_id", 255).Unique()
+ table.String("order_id").Unique()
table.Decimal("amount", 10, 2)
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamp("created_at").UseCurrent()
table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE")
})
s.NoError(err, "expected no error when creating table with foreign key")
})
s.Run("when have custom index should create it successfully", func() {
- err = builder.Create(context.Background(), tx, "orders_2", func(table *schema.Blueprint) {
+ err = builder.Create(c, "orders_2", func(table *schema.Blueprint) {
table.ID()
- table.String("order_id", 255).Unique("uk_orders_2_order_id")
+ table.String("order_id").Unique("uk_orders_2_order_id")
table.Decimal("amount", 10, 2)
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamp("created_at").UseCurrent()
table.Index("created_at").Name("idx_orders_created_at").Algorithm("BTREE")
})
s.NoError(err, "expected no error when creating table with custom index")
})
s.Run("when table already exists, should return error", func() {
- err = builder.Create(context.Background(), tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
+ table.String("name")
+ table.String("email").Unique()
})
s.Error(err, "expected error when creating table that already exists")
})
}
-func (s *postgresBuilderSuite) TestCreateIfNotExists() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
- tx, err := s.db.BeginTx(s.ctx, nil)
- s.Require().NoError(err)
- defer tx.Rollback() //nolint:errcheck
-
- s.Run("when tx is nil, should return error", func() {
- err := builder.CreateIfNotExists(s.ctx, nil, "test_table", func(table *schema.Blueprint) {})
- s.Error(err, "expected error when transaction is nil")
- })
- s.Run("when table name is empty, should return error", func() {
- err := builder.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) {})
- s.Error(err, "expected error when table name is empty")
- })
- s.Run("when blueprint is nil, should return error", func() {
- err := builder.CreateIfNotExists(s.ctx, tx, "test_table", nil)
- s.Error(err, "expected error when blueprint is nil")
- })
- s.Run("when all parameters are valid, should create table successfully", func() {
- err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
- table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
- })
- s.NoError(err, "expected no error when creating table with valid parameters")
- })
- s.Run("when table already exists, should not return error", func() {
- err = builder.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name", 255)
- table.String("email", 255)
- })
- s.NoError(err, "expected no error when creating table that already exists")
- })
-}
-
func (s *postgresBuilderSuite) TestDrop() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Drop(nil, "test_table")
+ s.Error(err, "expected error when context is nil")
+ })
s.Run("when table name is empty, should return error", func() {
- err := builder.Drop(s.ctx, nil, "")
+ err := builder.Drop(c, "")
s.Error(err, "expected error when table name is empty")
})
s.Run("when all parameters are valid, should drop table successfully", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
- table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.String("name")
+ table.String("email").Unique()
+ table.String("password").Nullable()
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before dropping it")
- err = builder.Drop(s.ctx, tx, "users")
+ err = builder.Drop(c, "users")
s.NoError(err, "expected no error when dropping table with valid parameters")
})
s.Run("when table does not exist, should return error", func() {
- err = builder.Drop(s.ctx, tx, "non_existent_table")
+ err = builder.Drop(c, "non_existent_table")
s.Error(err, "expected error when dropping table that does not exist")
})
}
func (s *postgresBuilderSuite) TestDropIfExists() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.DropIfExists(nil, "test_table")
+ s.Error(err, "expected error when context is nil")
+ })
s.Run("when table name is empty, should return error", func() {
- err := builder.DropIfExists(s.ctx, nil, "")
+ err := builder.DropIfExists(c, "")
s.Error(err, "expected error when table name is empty")
})
- s.Run("when tx is nil, should return error", func() {
- err = builder.DropIfExists(s.ctx, nil, "test_table")
- s.Error(err, "expected error when transaction is nil")
+ s.Run("when context is nil, should return error", func() {
+ err = builder.DropIfExists(nil, "test_table")
+ s.Error(err, "expected error when context is nil")
})
s.Run("when all parameters are valid, should drop table successfully", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
- table.String("name", 255)
- table.String("email", 255).Unique()
- table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.String("name")
+ table.String("email").Unique()
+ table.String("password").Nullable()
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before dropping it")
- err = builder.DropIfExists(s.ctx, tx, "users")
+ err = builder.DropIfExists(c, "users")
s.NoError(err, "expected no error when dropping table with valid parameters")
})
s.Run("when table does not exist, should not return error", func() {
- err = builder.DropIfExists(s.ctx, tx, "non_existent_table")
+ err = builder.DropIfExists(c, "non_existent_table")
s.NoError(err, "expected no error when dropping non-existent table with IF EXISTS clause")
})
}
func (s *postgresBuilderSuite) TestRename() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.Rename(s.ctx, nil, "old_table", "new_table")
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Rename(nil, "old_table", "new_table")
+ s.Error(err, "expected error when context is nil")
})
s.Run("when old table name is empty, should return error", func() {
- err := builder.Rename(s.ctx, tx, "", "new_table")
+ err := builder.Rename(c, "", "new_table")
s.Error(err, "expected error when old table name is empty")
})
s.Run("when new table name is empty, should return error", func() {
- err := builder.Rename(s.ctx, tx, "old_table", "")
+ err := builder.Rename(c, "old_table", "")
s.Error(err, "expected error when new table name is empty")
})
s.Run("when all parameters are valid, should rename table successfully", func() {
- err = builder.Create(s.ctx, tx, "old_table", func(table *schema.Blueprint) {
+ err = builder.Create(c, "old_table", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
})
s.NoError(err, "expected no error when creating table before renaming it")
- err = builder.Rename(s.ctx, tx, "old_table", "new_table")
+ err = builder.Rename(c, "old_table", "new_table")
s.NoError(err, "expected no error when renaming table with valid parameters")
})
s.Run("when renaming non-existent table, should return error", func() {
- err = builder.Rename(s.ctx, tx, "non_existent_table", "new_table")
+ err = builder.Rename(c, "non_existent_table", "new_table")
s.Error(err, "expected error when renaming non-existent table")
s.ErrorContains(err, "does not exist", "expected error message to contain 'does not exist'")
})
}
func (s *postgresBuilderSuite) TestTable() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- err := builder.Table(s.ctx, nil, "test_table", func(table *schema.Blueprint) {})
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ err := builder.Table(nil, "test_table", func(table *schema.Blueprint) {})
+ s.Error(err, "expected error when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- err := builder.Table(s.ctx, tx, "", func(table *schema.Blueprint) {})
+ err := builder.Table(c, "", func(table *schema.Blueprint) {})
s.Error(err, "expected error when table name is empty")
})
s.Run("when blueprint is nil, should return error", func() {
- err := builder.Table(s.ctx, tx, "test_table", nil)
+ err := builder.Table(c, "test_table", nil)
s.Error(err, "expected error when blueprint is nil")
})
s.Run("when all parameters are valid, should modify table successfully", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique("uk_users_email")
table.String("password", 255).Nullable()
table.Text("bio").Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
table.FullText("bio")
})
s.NoError(err, "expected no error when creating table before modifying it")
s.Run("should add new columns and modify existing ones", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.String("address", 255).Nullable()
table.String("phone", 20).Nullable().Unique("uk_users_phone")
})
s.NoError(err, "expected no error when modifying table with valid parameters")
})
s.Run("should modify existing column", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.String("email", 255).Nullable().Change()
})
s.NoError(err, "expected no error when modifying existing column")
})
s.Run("should drop column and rename existing one", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropColumn("password")
table.RenameColumn("name", "full_name")
})
s.NoError(err, "expected no error when dropping column and renaming existing one")
})
s.Run("should add index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.Index("phone").Name("idx_users_phone").Algorithm("BTREE")
})
s.NoError(err, "expected no error when adding index to table")
})
s.Run("should rename index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.RenameIndex("idx_users_phone", "idx_users_contact")
})
s.NoError(err, "expected no error when renaming index in table")
})
s.Run("should drop index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropIndex("idx_users_contact")
})
s.NoError(err, "expected no error when dropping index from table")
})
s.Run("should drop unique constraint", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
- table.DropUnique("uk_users_email")
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
+ table.DropUnique([]string{"email"})
})
s.NoError(err, "expected no error when dropping unique constraint from table")
})
s.Run("should drop fulltext index", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropFulltext("ft_users_bio")
})
s.NoError(err, "expected no error when dropping fulltext index from table")
})
s.Run("should add foreign key", func() {
- err = builder.Create(s.ctx, tx, "roles", func(table *schema.Blueprint) {
+ err = builder.Create(c, "roles", func(table *schema.Blueprint) {
table.ID()
table.String("role_name", 255).Unique("uk_roles_role_name")
})
s.NoError(err, "expected no error when creating roles table before adding foreign key")
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.Integer("role_id").Nullable()
table.Foreign("role_id").References("id").On("roles").OnDelete("SET NULL").OnUpdate("CASCADE")
})
s.NoError(err, "expected no error when adding foreign key to users table")
})
s.Run("should drop foreign key", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropForeign("fk_users_roles")
})
s.NoError(err, "expected no error when dropping foreign key from users table")
})
s.Run("should drop primary key", func() {
- err = builder.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Table(c, "users", func(table *schema.Blueprint) {
table.DropPrimary("pk_users")
})
s.NoError(err, "expected no error when dropping primary key from users table")
@@ -397,105 +369,105 @@ func (s *postgresBuilderSuite) TestTable() {
}
func (s *postgresBuilderSuite) TestGetColumns() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- _, err := builder.GetColumns(s.ctx, nil, "users")
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ _, err := builder.GetColumns(nil, "users")
+ s.Error(err, "expected error when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- _, err := builder.GetColumns(s.ctx, tx, "")
+ _, err := builder.GetColumns(c, "")
s.Error(err, "expected error when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before getting columns")
- columns, err := builder.GetColumns(s.ctx, tx, "users")
+ columns, err := builder.GetColumns(c, "users")
s.NoError(err, "expected no error when getting columns with valid parameters")
s.Len(columns, 6, "expected 6 columns to be returned")
})
s.Run("when table does not exist, should return empty columns", func() {
- columns, err := builder.GetColumns(s.ctx, tx, "non_existent_table")
+ columns, err := builder.GetColumns(c, "non_existent_table")
s.NoError(err, "expected no error when getting columns of non-existent table")
s.Empty(columns, "expected empty columns for non-existent table")
})
}
func (s *postgresBuilderSuite) TestGetIndexes() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- _, err := builder.GetIndexes(s.ctx, nil, "users")
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ _, err := builder.GetIndexes(nil, "users")
+ s.Error(err, "expected error when contexts is nil")
})
s.Run("when table name is empty, should return error", func() {
- _, err := builder.GetIndexes(s.ctx, tx, "")
+ _, err := builder.GetIndexes(c, "")
s.Error(err, "expected error when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
table.Index("name").Name("idx_users_name")
})
s.NoError(err, "expected no error when creating table before getting indexes")
- indexes, err := builder.GetIndexes(s.ctx, tx, "users")
+ indexes, err := builder.GetIndexes(c, "users")
s.NoError(err, "expected no error when getting indexes with valid parameters")
s.Len(indexes, 3, "expected 3 index to be returned")
})
s.Run("when table does not exist, should return empty indexes", func() {
- indexes, err := builder.GetIndexes(s.ctx, tx, "non_existent_table")
+ indexes, err := builder.GetIndexes(c, "non_existent_table")
s.NoError(err, "expected no error when getting indexes of non-existent table")
s.Empty(indexes, "expected empty indexes for non-existent table")
})
}
func (s *postgresBuilderSuite) TestGetTables() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- _, err := builder.GetTables(s.ctx, nil)
- s.Error(err, "expected error when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ _, err := builder.GetTables(nil)
+ s.Error(err, "expected error when context is nil")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before getting tables")
- tables, err := builder.GetTables(s.ctx, tx)
+ tables, err := builder.GetTables(c)
s.NoError(err, "expected no error when getting tables with valid parameters")
s.Len(tables, 1, "expected 1 table to be returned")
userTable := tables[0]
@@ -506,159 +478,160 @@ func (s *postgresBuilderSuite) TestGetTables() {
}
func (s *postgresBuilderSuite) TestHasColumn() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasColumn(s.ctx, nil, "users", "name")
+ exists, err := builder.HasColumn(nil, "users", "name")
s.Error(err, "expected error when transaction is nil")
s.False(exists, "expected exists to be false when transaction is nil")
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasColumn(s.ctx, tx, "", "name")
+ exists, err := builder.HasColumn(c, "", "name")
s.Error(err, "expected error when table name is empty")
s.False(exists, "expected exists to be false when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before checking column existence")
- exists, err := builder.HasColumn(s.ctx, tx, "users", "name")
+ exists, err := builder.HasColumn(c, "users", "name")
s.NoError(err, "expected no error when checking if column exists with valid parameters")
s.True(exists, "expected exists to be true for existing column")
- exists, err = builder.HasColumn(s.ctx, tx, "users", "non_existent_column")
+ exists, err = builder.HasColumn(c, "users", "non_existent_column")
s.NoError(err, "expected no error when checking non-existent column")
s.False(exists, "expected exists to be false for non-existent column")
})
}
func (s *postgresBuilderSuite) TestHasColumns() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasColumns(s.ctx, nil, "users", []string{"name"})
+ exists, err := builder.HasColumns(nil, "users", []string{"name"})
s.Error(err, "expected error when transaction is nil")
s.False(exists, "expected exists to be false when transaction is nil")
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasColumns(s.ctx, tx, "", []string{"name"})
+ exists, err := builder.HasColumns(c, "", []string{"name"})
s.Error(err, "expected error when table name is empty")
s.False(exists, "expected exists to be false when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before checking column existence")
- exists, err := builder.HasColumns(s.ctx, tx, "users", []string{"name", "email"})
+ exists, err := builder.HasColumns(c, "users", []string{"name", "email"})
s.NoError(err, "expected no error when checking if columns exist with valid parameters")
s.True(exists, "expected exists to be true for existing columns")
- exists, err = builder.HasColumns(s.ctx, tx, "users", []string{"name", "non_existent_column"})
+ exists, err = builder.HasColumns(c, "users", []string{"name", "non_existent_column"})
s.NoError(err, "expected no error when checking mixed existing and non-existent columns")
s.False(exists, "expected exists to be false for mixed existing and non-existent columns")
})
}
func (s *postgresBuilderSuite) TestHasIndex() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasIndex(s.ctx, nil, "users", []string{"idx_users_name"})
- s.Error(err, "expected error when transaction is nil")
- s.False(exists, "expected exists to be false when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ exists, err := builder.HasIndex(nil, "users", []string{"idx_users_name"})
+ s.Error(err, "expected error when context is nil")
+ s.False(exists, "expected exists to be false when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasIndex(s.ctx, tx, "", []string{"idx_users_name"})
+ exists, err := builder.HasIndex(c, "", []string{"idx_users_name"})
s.Error(err, "expected error when table name is empty")
s.False(exists, "expected exists to be false when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "orders", func(table *schema.Blueprint) {
+ err = builder.Create(c, "orders", func(table *schema.Blueprint) {
table.ID()
table.Integer("company_id")
table.Integer("user_id")
table.String("order_id", 255)
table.Decimal("amount", 10, 2)
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamp("created_at").UseCurrent()
table.Index("company_id", "user_id")
table.Unique("order_id").Name("uk_orders_order_id")
})
s.Require().NoError(err, "expected no error when creating table with index")
- exists, err := builder.HasIndex(s.ctx, tx, "orders", []string{"uk_orders_order_id"})
+ exists, err := builder.HasIndex(c, "orders", []string{"uk_orders_order_id"})
s.NoError(err, "expected no error when checking if index exists with valid parameters")
s.True(exists, "expected exists to be true for existing index")
- exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"company_id", "user_id"})
+ exists, err = builder.HasIndex(c, "orders", []string{"company_id", "user_id"})
s.NoError(err, "expected no error when checking non-existent index")
s.True(exists, "expected exists to be true for existing composite index")
- exists, err = builder.HasIndex(s.ctx, tx, "orders", []string{"non_existent_index"})
+ exists, err = builder.HasIndex(c, "orders", []string{"non_existent_index"})
s.NoError(err, "expected no error when checking non-existent index")
s.False(exists, "expected exists to be false for non-existent index")
})
}
func (s *postgresBuilderSuite) TestHasTable() {
- builder, err := schema.NewBuilder("postgres")
- s.Require().NoError(err)
+ builder := s.builder
tx, err := s.db.BeginTx(s.ctx, nil)
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
- s.Run("when tx is nil, should return error", func() {
- exists, err := builder.HasTable(s.ctx, nil, "users")
- s.Error(err, "expected error when transaction is nil")
- s.False(exists, "expected exists to be false when transaction is nil")
+ c := schema.NewContext(s.ctx, tx)
+
+ s.Run("when context is nil, should return error", func() {
+ exists, err := builder.HasTable(nil, "users")
+ s.Error(err, "expected error when context is nil")
+ s.False(exists, "expected exists to be false when context is nil")
})
s.Run("when table name is empty, should return error", func() {
- exists, err := builder.HasTable(s.ctx, tx, "")
+ exists, err := builder.HasTable(c, "")
s.Error(err, "expected error when table name is empty")
s.False(exists, "expected exists to be false when table name is empty")
})
s.Run("when all parameters are valid", func() {
- err = builder.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table before checking existence")
- exists, err := builder.HasTable(s.ctx, tx, "users")
+ exists, err := builder.HasTable(c, "users")
s.NoError(err, "expected no error when checking if table exists with valid parameters")
s.True(exists, "expected exists to be true for existing table")
- exists, err = builder.HasTable(s.ctx, tx, "non_existent_table")
+ exists, err = builder.HasTable(c, "non_existent_table")
s.NoError(err, "expected no error when checking non-existent table")
s.False(exists, "expected exists to be false for non-existent table")
})
@@ -666,20 +639,19 @@ func (s *postgresBuilderSuite) TestHasTable() {
_, err = tx.Exec("CREATE SCHEMA IF NOT EXISTS custom_publics")
s.NoError(err, "expected no error when creating custom schema")
- err = builder.Create(s.ctx, tx, "custom_publics.users", func(table *schema.Blueprint) {
+ err = builder.Create(c, "custom_publics.users", func(table *schema.Blueprint) {
table.ID()
table.String("name", 255)
table.String("email", 255).Unique()
table.String("password", 255).Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamps()
})
s.NoError(err, "expected no error when creating table with custom schema")
- exists, err := builder.HasTable(s.ctx, tx, "custom_publics.users")
+ exists, err := builder.HasTable(c, "custom_publics.users")
s.NoError(err, "expected no error when checking if table with custom schema exists")
s.True(exists, "expected exists to be true for existing table with custom schema")
- exists, err = builder.HasTable(s.ctx, tx, "custom_publics.non_existent_table")
+ exists, err = builder.HasTable(c, "custom_publics.non_existent_table")
s.NoError(err, "expected no error when checking non-existent table with custom schema")
s.False(exists, "expected exists to be false for non-existent table with custom schema")
})
diff --git a/schema/postgres_grammar.go b/schema/postgres_grammar.go
new file mode 100644
index 0000000..a5fdb25
--- /dev/null
+++ b/schema/postgres_grammar.go
@@ -0,0 +1,581 @@
+package schema
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+)
+
+type postgresGrammar struct {
+ baseGrammar
+}
+
+func newPostgresGrammar() *postgresGrammar {
+ return &postgresGrammar{}
+}
+
+func (g *postgresGrammar) CompileTableExists(schema string, table string) (string, error) {
+ return fmt.Sprintf(
+ "SELECT 1 FROM information_schema.tables WHERE table_schema = %s AND table_name = %s AND table_type = 'BASE TABLE'",
+ g.QuoteString(schema),
+ g.QuoteString(table),
+ ), nil
+}
+
+func (g *postgresGrammar) CompileTables() (string, error) {
+ return "select c.relname as name, n.nspname as schema, pg_total_relation_size(c.oid) as size, " +
+ "obj_description(c.oid, 'pg_class') as comment from pg_class c, pg_namespace n " +
+ "where c.relkind in ('r', 'p') and n.oid = c.relnamespace and n.nspname not in ('pg_catalog', 'information_schema') " +
+ "order by c.relname", nil
+}
+
+func (g *postgresGrammar) CompileColumns(schema, table string) (string, error) {
+ return fmt.Sprintf(
+ "select a.attname as name, t.typname as type_name, format_type(a.atttypid, a.atttypmod) as type, "+
+ "(select tc.collcollate from pg_catalog.pg_collation tc where tc.oid = a.attcollation) as collation, "+
+ "not a.attnotnull as nullable, "+
+ "(select pg_get_expr(adbin, adrelid) from pg_attrdef where c.oid = pg_attrdef.adrelid and pg_attrdef.adnum = a.attnum) as default, "+
+ "col_description(c.oid, a.attnum) as comment "+
+ "from pg_attribute a, pg_class c, pg_type t, pg_namespace n "+
+ "where c.relname = %s and n.nspname = %s and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid and n.oid = c.relnamespace "+
+ "order by a.attnum",
+ g.QuoteString(table),
+ g.QuoteString(schema),
+ ), nil
+}
+
+func (g *postgresGrammar) CompileIndexes(schema, table string) (string, error) {
+ return fmt.Sprintf(
+ "select ic.relname as name, string_agg(a.attname, ',' order by indseq.ord) as columns, "+
+ "am.amname as \"type\", i.indisunique as \"unique\", i.indisprimary as \"primary\" "+
+ "from pg_index i "+
+ "join pg_class tc on tc.oid = i.indrelid "+
+ "join pg_namespace tn on tn.oid = tc.relnamespace "+
+ "join pg_class ic on ic.oid = i.indexrelid "+
+ "join pg_am am on am.oid = ic.relam "+
+ "join lateral unnest(i.indkey) with ordinality as indseq(num, ord) on true "+
+ "left join pg_attribute a on a.attrelid = i.indrelid and a.attnum = indseq.num "+
+ "where tc.relname = %s and tn.nspname = %s "+
+ "group by ic.relname, am.amname, i.indisunique, i.indisprimary",
+ g.QuoteString(table),
+ g.QuoteString(schema),
+ ), nil
+}
+
+func (g *postgresGrammar) CompileCreate(blueprint *Blueprint) (string, error) {
+ columns, err := g.getColumns(blueprint)
+ if err != nil {
+ return "", err
+ }
+ columns = append(columns, g.getConstraints(blueprint)...)
+ return fmt.Sprintf("CREATE TABLE %s (%s)", blueprint.name, strings.Join(columns, ", ")), nil
+}
+
+func (g *postgresGrammar) CompileAdd(blueprint *Blueprint) (string, error) {
+ if len(blueprint.getAddedColumns()) == 0 {
+ return "", nil
+ }
+
+ columns, err := g.getColumns(blueprint)
+ if err != nil {
+ return "", err
+ }
+ columns = g.PrefixArray("ADD COLUMN ", columns)
+ constraints := g.getConstraints(blueprint)
+ if len(constraints) > 0 {
+ constraints = g.PrefixArray("ADD ", constraints)
+ columns = append(columns, constraints...)
+ }
+
+ return fmt.Sprintf("ALTER TABLE %s %s",
+ blueprint.name,
+ strings.Join(columns, ", "),
+ ), nil
+}
+
+func (g *postgresGrammar) CompileChange(bp *Blueprint, command *command) (string, error) {
+ column := command.column
+ if column.name == "" {
+ return "", fmt.Errorf("column name cannot be empty for change operation")
+ }
+
+ var changes []string
+ changes = append(changes, fmt.Sprintf("TYPE %s", g.getType(command.column)))
+ for _, modifier := range g.modifiers() {
+ change := modifier(command.column)
+ if change != "" {
+ changes = append(changes, strings.TrimSpace(change))
+ }
+ }
+
+ return fmt.Sprintf("ALTER TABLE %s %s",
+ bp.name,
+ strings.Join(g.PrefixArray(fmt.Sprintf("ALTER COLUMN %s ", column.name), changes), ", "),
+ ), nil
+}
+
+func (g *postgresGrammar) CompileDrop(blueprint *Blueprint) (string, error) {
+ return fmt.Sprintf("DROP TABLE %s", blueprint.name), nil
+}
+
+func (g *postgresGrammar) CompileDropIfExists(blueprint *Blueprint) (string, error) {
+ return fmt.Sprintf("DROP TABLE IF EXISTS %s", blueprint.name), nil
+}
+
+func (g *postgresGrammar) CompileRename(blueprint *Blueprint, command *command) (string, error) {
+ return fmt.Sprintf("ALTER TABLE %s RENAME TO %s", blueprint.name, command.to), nil
+}
+
+func (g *postgresGrammar) CompileDropColumn(blueprint *Blueprint, command *command) (string, error) {
+ if len(command.columns) == 0 {
+ return "", nil
+ }
+ columns := g.PrefixArray("DROP COLUMN ", command.columns)
+
+ return fmt.Sprintf("ALTER TABLE %s %s", blueprint.name, strings.Join(columns, ", ")), nil
+}
+
+func (g *postgresGrammar) CompileRenameColumn(blueprint *Blueprint, command *command) (string, error) {
+ if command.from == "" || command.to == "" {
+ return "", fmt.Errorf("table name, old column name, and new column name cannot be empty for rename operation")
+ }
+ return fmt.Sprintf("ALTER TABLE %s RENAME COLUMN %s TO %s", blueprint.name, command.from, command.to), nil
+}
+
+func (g *postgresGrammar) CompileFullText(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("fulltext index column cannot be empty")
+ }
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "fulltext", command.columns...)
+ }
+ language := command.language
+ if language == "" {
+ language = "english" // Default language for full-text search
+ }
+ var columns []string
+ for _, col := range command.columns {
+ columns = append(columns, fmt.Sprintf("to_tsvector(%s, %s)", g.QuoteString(language), col))
+ }
+
+ return fmt.Sprintf("CREATE INDEX %s ON %s USING GIN (%s)", indexName, blueprint.name, strings.Join(columns, " || ")), nil
+}
+
+func (g *postgresGrammar) CompileIndex(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("index column cannot be empty")
+ }
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "index", command.columns...)
+ }
+
+ sql := fmt.Sprintf("CREATE INDEX %s ON %s", indexName, blueprint.name)
+ if command.algorithm != "" {
+ sql += fmt.Sprintf(" USING %s", command.algorithm)
+ }
+ return fmt.Sprintf("%s (%s)", sql, g.Columnize(command.columns)), nil
+}
+
+func (g *postgresGrammar) CompileUnique(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("unique index column cannot be empty")
+ }
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "unique", command.columns...)
+ }
+ sql := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)",
+ blueprint.name,
+ indexName,
+ g.Columnize(command.columns),
+ )
+
+ if command.deferrable != nil {
+ if *command.deferrable {
+ sql += " DEFERRABLE"
+ } else {
+ sql += " NOT DEFERRABLE"
+ }
+ }
+ if command.deferrable != nil && *command.deferrable && command.initiallyImmediate != nil {
+ if *command.initiallyImmediate {
+ sql += " INITIALLY IMMEDIATE"
+ } else {
+ sql += " INITIALLY DEFERRED"
+ }
+ }
+
+ return sql, nil
+}
+
+func (g *postgresGrammar) CompilePrimary(blueprint *Blueprint, command *command) (string, error) {
+ if slices.Contains(command.columns, "") {
+ return "", fmt.Errorf("primary key index column cannot be empty")
+ }
+ indexName := command.index
+ if indexName == "" {
+ indexName = g.CreateIndexName(blueprint, "primary", command.columns...)
+ }
+ return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", blueprint.name, indexName, g.Columnize(command.columns)), nil
+}
+
+func (g *postgresGrammar) CompileDropIndex(_ *Blueprint, command *command) (string, error) {
+ if command.index == "" {
+ return "", fmt.Errorf("index name cannot be empty for drop operation")
+ }
+ return fmt.Sprintf("DROP INDEX %s", command.index), nil
+}
+
+func (g *postgresGrammar) CompileDropFulltext(blueprint *Blueprint, command *command) (string, error) {
+ return g.CompileDropIndex(blueprint, command)
+}
+
+func (g *postgresGrammar) CompileDropUnique(blueprint *Blueprint, command *command) (string, error) {
+ if command.index == "" {
+ return "", fmt.Errorf("index name cannot be empty for drop operation")
+ }
+ return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil
+}
+
+func (g *postgresGrammar) CompileDropPrimary(blueprint *Blueprint, command *command) (string, error) {
+ index := command.index
+ if index == "" {
+ index = g.CreateIndexName(blueprint, "primary", command.columns...)
+ }
+ return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, index), nil
+}
+
+func (g *postgresGrammar) CompileRenameIndex(_ *Blueprint, command *command) (string, error) {
+ if command.from == "" || command.to == "" {
+ return "", fmt.Errorf("index names for rename operation cannot be empty: oldName=%s, newName=%s", command.from, command.to)
+ }
+ return fmt.Sprintf("ALTER INDEX %s RENAME TO %s", command.from, command.to), nil
+}
+
+func (g *postgresGrammar) CompileForeign(blueprint *Blueprint, command *command) (string, error) {
+ sql, err := g.baseGrammar.CompileForeign(blueprint, command)
+ if err != nil {
+ return "", err
+ }
+
+ if command.deferrable != nil {
+ if *command.deferrable {
+ sql += " DEFERRABLE"
+ } else {
+ sql += " NOT DEFERRABLE"
+ }
+ }
+ if command.deferrable != nil && *command.deferrable && command.initiallyImmediate != nil {
+ if *command.initiallyImmediate {
+ sql += " INITIALLY IMMEDIATE"
+ } else {
+ sql += " INITIALLY DEFERRED"
+ }
+ }
+
+ return sql, nil
+}
+
+func (g *postgresGrammar) CompileDropForeign(blueprint *Blueprint, command *command) (string, error) {
+ if command.index == "" {
+ return "", fmt.Errorf("foreign key name cannot be empty for drop operation")
+ }
+ return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", blueprint.name, command.index), nil
+}
+
+func (g *postgresGrammar) GetFluentCommands() []func(blueprint *Blueprint, command *command) string {
+ return []func(blueprint *Blueprint, command *command) string{
+ g.CompileComment,
+ }
+}
+
+func (g *postgresGrammar) CompileComment(blueprint *Blueprint, command *command) string {
+ if command.column.comment != nil {
+ sql := fmt.Sprintf("COMMENT ON COLUMN %s.%s IS ", blueprint.name, command.column.name)
+ if command.column.comment == nil {
+ return sql + "NULL"
+ } else {
+ return sql + fmt.Sprintf("'%s'", *command.column.comment)
+ }
+ }
+ return ""
+}
+
+func (g *postgresGrammar) getColumns(blueprint *Blueprint) ([]string, error) {
+ var columns []string
+ for _, col := range blueprint.getAddedColumns() {
+ if col.name == "" {
+ return nil, fmt.Errorf("column name cannot be empty")
+ }
+ sql := col.name + " " + g.getType(col)
+ for _, modifier := range g.modifiers() {
+ sql += modifier(col)
+ }
+ columns = append(columns, sql)
+ }
+
+ return columns, nil
+}
+
+func (g *postgresGrammar) getConstraints(blueprint *Blueprint) []string {
+ var constrains []string
+ for _, col := range blueprint.getAddedColumns() {
+ if col.primary != nil && *col.primary {
+ pkConstraintName := g.CreateIndexName(blueprint, "primary")
+ sql := "CONSTRAINT " + pkConstraintName + " PRIMARY KEY (" + col.name + ")"
+ constrains = append(constrains, sql)
+ continue
+ }
+ }
+
+ return constrains
+}
+
+func (g *postgresGrammar) getType(col *columnDefinition) string {
+ typeMapFunc := map[string]func(*columnDefinition) string{
+ columnTypeChar: g.typeChar,
+ columnTypeString: g.typeString,
+ columnTypeTinyText: g.typeTinyText,
+ columnTypeText: g.typeText,
+ columnTypeMediumText: g.typeMediumText,
+ columnTypeLongText: g.typeLongText,
+ columnTypeInteger: g.typeInteger,
+ columnTypeBigInteger: g.typeBigInteger,
+ columnTypeMediumInteger: g.typeMediumInteger,
+ columnTypeSmallInteger: g.typeSmallInteger,
+ columnTypeTinyInteger: g.typeTinyInteger,
+ columnTypeFloat: g.typeFloat,
+ columnTypeDouble: g.typeDouble,
+ columnTypeDecimal: g.typeDecimal,
+ columnTypeBoolean: g.typeBoolean,
+ columnTypeEnum: g.typeEnum,
+ columnTypeJson: g.typeJson,
+ columnTypeJsonb: g.typeJsonb,
+ columnTypeDate: g.typeDate,
+ columnTypeDateTime: g.typeDateTime,
+ columnTypeDateTimeTz: g.typeDateTimeTz,
+ columnTypeTime: g.typeTime,
+ columnTypeTimeTz: g.typeTimeTz,
+ columnTypeTimestamp: g.typeTimestamp,
+ columnTypeTimestampTz: g.typeTimestampTz,
+ columnTypeYear: g.typeYear,
+ columnTypeBinary: g.typeBinary,
+ columnTypeUuid: g.typeUuid,
+ columnTypeGeography: g.typeGeography,
+ columnTypeGeometry: g.typeGeometry,
+ columnTypePoint: g.typePoint,
+ }
+ if fn, ok := typeMapFunc[col.columnType]; ok {
+ return fn(col)
+ }
+ return col.columnType
+}
+
+func (g *postgresGrammar) typeChar(col *columnDefinition) string {
+ if col.length != nil && *col.length > 0 {
+ return fmt.Sprintf("CHAR(%d)", *col.length)
+ }
+ return "CHAR"
+}
+
+func (g *postgresGrammar) typeString(col *columnDefinition) string {
+ if col.length != nil && *col.length > 0 {
+ return fmt.Sprintf("VARCHAR(%d)", *col.length)
+ }
+ return "VARCHAR"
+}
+
+func (g *postgresGrammar) typeTinyText(_ *columnDefinition) string {
+ return "VARCHAR(255)"
+}
+
+func (g *postgresGrammar) typeText(_ *columnDefinition) string {
+ return "TEXT"
+}
+
+func (g *postgresGrammar) typeMediumText(_ *columnDefinition) string {
+ return "TEXT"
+}
+
+func (g *postgresGrammar) typeLongText(_ *columnDefinition) string {
+ return "TEXT"
+}
+
+func (g *postgresGrammar) typeBigInteger(col *columnDefinition) string {
+ if col.autoIncrement != nil && *col.autoIncrement {
+ return "BIGSERIAL"
+ }
+ return "BIGINT"
+}
+
+func (g *postgresGrammar) typeInteger(col *columnDefinition) string {
+ if col.autoIncrement != nil && *col.autoIncrement {
+ return "SERIAL"
+ }
+ return "INTEGER"
+}
+
+func (g *postgresGrammar) typeMediumInteger(col *columnDefinition) string {
+ return g.typeInteger(col)
+}
+
+func (g *postgresGrammar) typeSmallInteger(col *columnDefinition) string {
+ if col.autoIncrement != nil && *col.autoIncrement {
+ return "SMALLSERIAL"
+ }
+ return "SMALLINT"
+}
+
+func (g *postgresGrammar) typeTinyInteger(col *columnDefinition) string {
+ return g.typeSmallInteger(col)
+}
+
+func (g *postgresGrammar) typeFloat(_ *columnDefinition) string {
+ return "REAL"
+}
+
+func (g *postgresGrammar) typeDouble(_ *columnDefinition) string {
+ return "DOUBLE PRECISION"
+}
+
+func (g *postgresGrammar) typeDecimal(col *columnDefinition) string {
+ return fmt.Sprintf("DECIMAL(%d, %d)", *col.total, *col.places)
+}
+
+func (g *postgresGrammar) typeBoolean(_ *columnDefinition) string {
+ return "BOOLEAN"
+}
+
+func (g *postgresGrammar) typeEnum(col *columnDefinition) string {
+ enumValues := make([]string, len(col.allowed))
+ for i, v := range col.allowed {
+ enumValues[i] = g.QuoteString(v)
+ }
+ return "VARCHAR(255) CHECK (" + col.name + " IN (" + strings.Join(enumValues, ", ") + "))"
+}
+
+func (g *postgresGrammar) typeJson(_ *columnDefinition) string {
+ return "JSON"
+}
+
+func (g *postgresGrammar) typeJsonb(_ *columnDefinition) string {
+ return "JSONB"
+}
+
+func (g *postgresGrammar) typeDate(_ *columnDefinition) string {
+ return "DATE"
+}
+
+func (g *postgresGrammar) typeDateTime(col *columnDefinition) string {
+ return g.typeTimestamp(col)
+}
+
+func (g *postgresGrammar) typeDateTimeTz(col *columnDefinition) string {
+ return g.typeTimestampTz(col)
+}
+
+func (g *postgresGrammar) typeTime(col *columnDefinition) string {
+ if col.precision != nil {
+ return fmt.Sprintf("TIME(%d)", *col.precision)
+ }
+ return "TIME"
+}
+
+func (g *postgresGrammar) typeTimeTz(col *columnDefinition) string {
+ if col.precision != nil {
+ return fmt.Sprintf("TIMETZ(%d)", *col.precision)
+ }
+ return "TIMETZ"
+}
+
+func (g *postgresGrammar) typeTimestamp(col *columnDefinition) string {
+ if col.useCurrent {
+ col.SetDefault(Expression("CURRENT_TIMESTAMP"))
+ }
+ if col.precision != nil {
+ return fmt.Sprintf("TIMESTAMP(%d)", *col.precision)
+ }
+ return "TIMESTAMP"
+}
+
+func (g *postgresGrammar) typeTimestampTz(col *columnDefinition) string {
+ if col.useCurrent {
+ col.SetDefault(Expression("CURRENT_TIMESTAMP"))
+ }
+ if col.precision != nil {
+ return fmt.Sprintf("TIMESTAMPTZ(%d)", *col.precision)
+ }
+ return "TIMESTAMPTZ"
+}
+
+func (g *postgresGrammar) typeYear(_ *columnDefinition) string {
+ return "INTEGER"
+}
+
+func (g *postgresGrammar) typeBinary(_ *columnDefinition) string {
+ return "BYTEA"
+}
+
+func (g *postgresGrammar) typeUuid(_ *columnDefinition) string {
+ return "UUID"
+}
+
+func (g *postgresGrammar) typeGeography(col *columnDefinition) string {
+ if col.subtype != nil && col.srid != nil {
+ return fmt.Sprintf("GEOGRAPHY(%s, %d)", *col.subtype, *col.srid)
+ } else if col.subtype != nil {
+ return fmt.Sprintf("GEOGRAPHY(%s)", *col.subtype)
+ }
+ return "GEOGRAPHY"
+}
+
+func (g *postgresGrammar) typeGeometry(col *columnDefinition) string {
+ if col.subtype != nil && col.srid != nil {
+ return fmt.Sprintf("GEOMETRY(%s, %d)", *col.subtype, *col.srid)
+ } else if col.subtype != nil {
+ return fmt.Sprintf("GEOMETRY(%s)", *col.subtype)
+ }
+ return "GEOMETRY"
+}
+
+func (g *postgresGrammar) typePoint(col *columnDefinition) string {
+ if col.srid != nil {
+ return fmt.Sprintf("POINT(%d)", *col.srid)
+ }
+ return "POINT"
+}
+
+func (g *postgresGrammar) modifiers() []func(*columnDefinition) string {
+ return []func(*columnDefinition) string{
+ g.modifyDefault,
+ g.modifyNullable,
+ }
+}
+
+func (g *postgresGrammar) modifyNullable(col *columnDefinition) string {
+ if col.change {
+ if col.nullable == nil {
+ return ""
+ }
+ if *col.nullable {
+ return " DROP NOT NULL"
+ }
+ return " SET NOT NULL"
+ }
+ if col.nullable != nil && *col.nullable {
+ return " NULL"
+ }
+ return " NOT NULL"
+}
+
+func (g *postgresGrammar) modifyDefault(col *columnDefinition) string {
+ if col.hasCommand("default") {
+ if col.change {
+ return fmt.Sprintf(" SET DEFAULT %s", g.GetDefaultValue(col.defaultValue))
+ }
+ return fmt.Sprintf(" DEFAULT %s", g.GetDefaultValue(col.defaultValue))
+ }
+ return ""
+}
diff --git a/postgres_grammar_test.go b/schema/postgres_grammar_test.go
similarity index 89%
rename from postgres_grammar_test.go
rename to schema/postgres_grammar_test.go
index 843f791..2a805f0 100644
--- a/postgres_grammar_test.go
+++ b/schema/postgres_grammar_test.go
@@ -7,7 +7,7 @@ import (
)
func TestPgGrammar_CompileCreate(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -21,13 +21,13 @@ func TestPgGrammar_CompileCreate(t *testing.T) {
table: "users",
blueprint: func(table *Blueprint) {
table.ID()
- table.String("name", 255)
- table.String("email", 255)
+ table.String("name")
+ table.String("email")
table.String("password").Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
+ table.Timestamp("created_at").UseCurrent()
+ table.Timestamp("updated_at").UseCurrent()
},
- want: "CREATE TABLE users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))",
+ want: "CREATE TABLE users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR(255) NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))",
},
{
name: "Create table with foreign key",
@@ -35,7 +35,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) {
blueprint: func(table *Blueprint) {
table.ID()
table.Integer("user_id")
- table.String("title", 255)
+ table.String("title")
table.Text("content").Nullable()
table.Foreign("user_id").References("id").On("users").OnDelete("CASCADE").OnUpdate("CASCADE")
},
@@ -45,7 +45,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) {
name: "Create table with column name is empty",
table: "empty_column_table",
blueprint: func(table *Blueprint) {
- table.String("", 255) // Intentionally empty column name
+ table.String("") // Intentionally empty column name
},
wantErr: true,
},
@@ -55,57 +55,7 @@ func TestPgGrammar_CompileCreate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := grammar.compileCreate(bp)
- if tt.wantErr {
- assert.Error(t, err)
- return
- }
-
- assert.NoError(t, err)
- assert.Equal(t, tt.want, got, "SQL statement mismatch for %s", tt.name)
- })
- }
-}
-
-func TestPgGrammar_CompileCreateIfNotExists(t *testing.T) {
- grammar := newPgGrammar()
-
- tests := []struct {
- name string
- table string
- blueprint func(table *Blueprint)
- want string
- wantErr bool
- }{
- {
- name: "Create simple table if not exists",
- table: "users",
- blueprint: func(table *Blueprint) {
- table.ID()
- table.String("name", 255)
- table.String("email", 255)
- table.String("password").Nullable()
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP")
- },
- want: "CREATE TABLE IF NOT EXISTS users (id BIGSERIAL NOT NULL, name VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, password VARCHAR NULL, created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, CONSTRAINT pk_users PRIMARY KEY (id))",
- wantErr: false,
- },
- {
- name: "Create table with column name is empty",
- table: "empty_column_table",
- blueprint: func(table *Blueprint) {
- table.String("", 255) // Intentionally empty column name
- },
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- bp := &Blueprint{name: tt.table}
- tt.blueprint(bp)
- got, err := grammar.compileCreateIfNotExists(bp)
+ got, err := grammar.CompileCreate(bp)
if tt.wantErr {
assert.Error(t, err)
return
@@ -118,7 +68,7 @@ func TestPgGrammar_CompileCreateIfNotExists(t *testing.T) {
}
func TestPgGrammar_CompileAdd(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -153,7 +103,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) {
blueprint: func(table *Blueprint) {
table.Boolean("active").Default(true)
},
- want: "ALTER TABLE users ADD COLUMN active BOOLEAN DEFAULT true NOT NULL",
+ want: "ALTER TABLE users ADD COLUMN active BOOLEAN DEFAULT '1' NOT NULL",
wantErr: false,
},
{
@@ -162,7 +112,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) {
blueprint: func(table *Blueprint) {
table.String("notes", 500).Comment("User notes")
},
- want: "ALTER TABLE users ADD COLUMN notes VARCHAR(500) NOT NULL COMMENT 'User notes'",
+ want: "ALTER TABLE users ADD COLUMN notes VARCHAR(500) NOT NULL",
wantErr: false,
},
{
@@ -187,17 +137,17 @@ func TestPgGrammar_CompileAdd(t *testing.T) {
name: "Add complex column with all attributes",
table: "products",
blueprint: func(table *Blueprint) {
- table.Decimal("price", 10, 2).Default(0).Comment("Product price")
+ table.Decimal("price", 10, 2).Default(0)
},
- want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT 0 NOT NULL COMMENT 'Product price'",
+ want: "ALTER TABLE products ADD COLUMN price DECIMAL(10, 2) DEFAULT '0' NOT NULL",
wantErr: false,
},
{
name: "Add timestamp columns",
table: "orders",
blueprint: func(table *Blueprint) {
- table.Timestamp("created_at").Default("CURRENT_TIMESTAMP")
- table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable()
+ table.Timestamp("created_at").UseCurrent()
+ table.Timestamp("updated_at").UseCurrent().Nullable()
},
want: "ALTER TABLE orders ADD COLUMN created_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NOT NULL, ADD COLUMN updated_at TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP NULL",
wantErr: false,
@@ -253,7 +203,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := grammar.compileAdd(bp)
+ got, err := grammar.CompileAdd(bp)
if tt.wantErr {
assert.Error(t, err)
return
@@ -266,7 +216,7 @@ func TestPgGrammar_CompileAdd(t *testing.T) {
}
func TestPgGrammar_CompileChange(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -309,7 +259,7 @@ func TestPgGrammar_CompileChange(t *testing.T) {
blueprint: func(table *Blueprint) {
table.String("email", 500).Default(nil).Change()
},
- want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500), ALTER COLUMN email DROP DEFAULT"},
+ want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500), ALTER COLUMN email SET DEFAULT NULL"},
},
{
name: "Add comment to column",
@@ -328,7 +278,7 @@ func TestPgGrammar_CompileChange(t *testing.T) {
blueprint: func(table *Blueprint) {
table.String("email", 500).Comment("").Change()
},
- want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500)", "COMMENT ON COLUMN users.email IS NULL"},
+ want: []string{"ALTER TABLE users ALTER COLUMN email TYPE VARCHAR(500)", "COMMENT ON COLUMN users.email IS ''"},
},
{
name: "Set column to not nullable",
@@ -356,9 +306,9 @@ func TestPgGrammar_CompileChange(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- bp := &Blueprint{name: tt.table}
+ bp := &Blueprint{name: tt.table, grammar: grammar}
tt.blueprint(bp)
- got, err := grammar.compileChange(bp)
+ got, err := bp.toSql()
if tt.wantErr {
assert.Error(t, err)
return
@@ -371,7 +321,7 @@ func TestPgGrammar_CompileChange(t *testing.T) {
}
func TestPgGrammar_CompileDrop(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -390,7 +340,7 @@ func TestPgGrammar_CompileDrop(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := grammar.compileDrop(bp)
+ got, err := grammar.CompileDrop(bp)
if tt.wantErr {
assert.Error(t, err)
return
@@ -403,7 +353,7 @@ func TestPgGrammar_CompileDrop(t *testing.T) {
}
func TestPgGrammar_CompileDropIfExists(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -422,7 +372,7 @@ func TestPgGrammar_CompileDropIfExists(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := grammar.compileDropIfExists(bp)
+ got, err := grammar.CompileDropIfExists(bp)
if tt.wantErr {
assert.Error(t, err)
return
@@ -435,7 +385,7 @@ func TestPgGrammar_CompileDropIfExists(t *testing.T) {
}
func TestPgGrammar_CompileRename(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -455,8 +405,9 @@ func TestPgGrammar_CompileRename(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- bp := &Blueprint{name: tt.oldName, newName: tt.newName}
- got, err := grammar.compileRename(bp)
+ bp := &Blueprint{name: tt.oldName}
+ bp.rename(tt.newName)
+ got, err := grammar.CompileRename(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err)
return
@@ -469,7 +420,7 @@ func TestPgGrammar_CompileRename(t *testing.T) {
}
func TestPgGrammar_GetColumns(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -503,14 +454,7 @@ func TestPgGrammar_GetColumns(t *testing.T) {
blueprint: func(table *Blueprint) {
table.Boolean("active").Default(true)
},
- want: []string{"active BOOLEAN DEFAULT true NOT NULL"},
- },
- {
- name: "Column with comment",
- blueprint: func(table *Blueprint) {
- table.String("description", 500).Comment("User description")
- },
- want: []string{"description VARCHAR(500) NOT NULL COMMENT 'User description'"},
+ want: []string{"active BOOLEAN DEFAULT '1' NOT NULL"},
},
{
name: "Primary key column",
@@ -543,13 +487,13 @@ func TestPgGrammar_GetColumns(t *testing.T) {
}
func TestPgGrammar_CompileDropColumn(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
table string
blueprint func(table *Blueprint)
- want string
+ wants []string
wantErr bool
}{
{
@@ -558,45 +502,46 @@ func TestPgGrammar_CompileDropColumn(t *testing.T) {
blueprint: func(table *Blueprint) {
table.DropColumn("email")
},
- want: "ALTER TABLE users DROP COLUMN email",
+ wants: []string{"ALTER TABLE users DROP COLUMN email"},
wantErr: false,
},
{
name: "Drop multiple columns",
table: "users",
blueprint: func(table *Blueprint) {
- table.DropColumn("email", "phone", "address")
+ table.DropColumn("email", "phone")
+ table.DropColumn("address")
},
- want: "ALTER TABLE users DROP COLUMN email, DROP COLUMN phone, DROP COLUMN address",
+ wants: []string{"ALTER TABLE users DROP COLUMN email, DROP COLUMN phone", "ALTER TABLE users DROP COLUMN address"},
wantErr: false,
},
{
name: "No columns to drop",
table: "users",
blueprint: func(table *Blueprint) {},
- want: "",
+ wants: nil,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- bp := &Blueprint{name: tt.table}
+ bp := &Blueprint{name: tt.table, grammar: grammar}
tt.blueprint(bp)
- got, err := grammar.compileDropColumn(bp)
+ got, err := bp.toSql()
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
- assert.Equal(t, tt.want, got)
+ assert.Equal(t, tt.wants, got)
})
}
}
func TestPgGrammar_CompileRenameColumn(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -640,7 +585,8 @@ func TestPgGrammar_CompileRenameColumn(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := grammar.compileRenameColumn(bp, tt.oldName, tt.newName)
+ command := &command{from: tt.oldName, to: tt.newName}
+ got, err := grammar.CompileRenameColumn(bp, command)
if tt.wantErr {
assert.Error(t, err)
return
@@ -653,7 +599,7 @@ func TestPgGrammar_CompileRenameColumn(t *testing.T) {
}
func TestPgGrammar_CompileDropIndex(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -684,7 +630,8 @@ func TestPgGrammar_CompileDropIndex(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{}
- got, err := grammar.compileDropIndex(bp, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := grammar.CompileDropIndex(bp, command)
if tt.wantErr {
assert.Error(t, err)
assert.Empty(t, got)
@@ -698,7 +645,7 @@ func TestPgGrammar_CompileDropIndex(t *testing.T) {
}
func TestPgGrammar_CompileDropPrimary(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -729,7 +676,8 @@ func TestPgGrammar_CompileDropPrimary(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got, err := grammar.compileDropPrimary(tt.blueprint, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := grammar.CompileDropPrimary(tt.blueprint, command)
if tt.wantErr {
assert.Error(t, err)
return
@@ -742,7 +690,7 @@ func TestPgGrammar_CompileDropPrimary(t *testing.T) {
}
func TestPgGrammar_CompileRenameIndex(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -794,7 +742,8 @@ func TestPgGrammar_CompileRenameIndex(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := grammar.compileRenameIndex(bp, tt.oldName, tt.newName)
+ command := &command{from: tt.oldName, to: tt.newName}
+ got, err := grammar.CompileRenameIndex(bp, command)
if tt.wantErr {
assert.Error(t, err)
return
@@ -807,7 +756,7 @@ func TestPgGrammar_CompileRenameIndex(t *testing.T) {
}
func TestPgGrammar_CompileForeign(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -972,7 +921,7 @@ func TestPgGrammar_CompileForeign(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := grammar.compileForeign(bp, bp.foreignKeys[0])
+ got, err := grammar.CompileForeign(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err)
return
@@ -985,7 +934,7 @@ func TestPgGrammar_CompileForeign(t *testing.T) {
}
func TestPgGrammar_CompileDropForeign(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1020,7 +969,8 @@ func TestPgGrammar_CompileDropForeign(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
- got, err := grammar.compileDropForeign(bp, tt.foreignKeyName)
+ command := &command{index: tt.foreignKeyName}
+ got, err := grammar.CompileDropForeign(bp, command)
if tt.wantErr {
assert.Error(t, err)
return
@@ -1033,7 +983,7 @@ func TestPgGrammar_CompileDropForeign(t *testing.T) {
}
func TestPgGrammar_CompileIndex(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1100,7 +1050,7 @@ func TestPgGrammar_CompileIndex(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := grammar.compileIndex(bp, bp.indexes[0])
+ got, err := grammar.CompileIndex(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err)
return
@@ -1113,7 +1063,7 @@ func TestPgGrammar_CompileIndex(t *testing.T) {
}
func TestPgGrammar_CompileUnique(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1208,7 +1158,7 @@ func TestPgGrammar_CompileUnique(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := grammar.compileUnique(bp, bp.indexes[0])
+ got, err := grammar.CompileUnique(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err)
return
@@ -1221,7 +1171,7 @@ func TestPgGrammar_CompileUnique(t *testing.T) {
}
func TestPgGrammar_CompileFullText(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1306,7 +1256,7 @@ func TestPgGrammar_CompileFullText(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := grammar.compileFullText(bp, bp.indexes[0])
+ got, err := grammar.CompileFullText(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err)
return
@@ -1319,7 +1269,7 @@ func TestPgGrammar_CompileFullText(t *testing.T) {
}
func TestPgGrammar_CompileDropUnique(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1356,7 +1306,8 @@ func TestPgGrammar_CompileDropUnique(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{}
- got, err := grammar.compileDropUnique(bp, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := grammar.CompileDropUnique(bp, command)
if tt.wantErr {
assert.Error(t, err)
assert.Empty(t, got)
@@ -1370,7 +1321,7 @@ func TestPgGrammar_CompileDropUnique(t *testing.T) {
}
func TestPgGrammar_CompileDropFulltext(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1413,7 +1364,8 @@ func TestPgGrammar_CompileDropFulltext(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{}
- got, err := grammar.compileDropFulltext(bp, tt.indexName)
+ command := &command{index: tt.indexName}
+ got, err := grammar.CompileDropFulltext(bp, command)
if tt.wantErr {
assert.Error(t, err)
assert.Empty(t, got)
@@ -1427,7 +1379,7 @@ func TestPgGrammar_CompileDropFulltext(t *testing.T) {
}
func TestPgGrammar_CompilePrimary(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1494,7 +1446,7 @@ func TestPgGrammar_CompilePrimary(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
bp := &Blueprint{name: tt.table}
tt.blueprint(bp)
- got, err := grammar.compilePrimary(bp, bp.indexes[0])
+ got, err := grammar.CompilePrimary(bp, bp.commands[0])
if tt.wantErr {
assert.Error(t, err)
return
@@ -1507,7 +1459,7 @@ func TestPgGrammar_CompilePrimary(t *testing.T) {
}
func TestPgGrammar_GetType(t *testing.T) {
- grammar := newPgGrammar()
+ grammar := newPostgresGrammar()
tests := []struct {
name string
@@ -1540,7 +1492,7 @@ func TestPgGrammar_GetType(t *testing.T) {
blueprint: func(table *Blueprint) {
table.Char("code")
},
- want: "CHAR",
+ want: "CHAR(255)",
},
{
name: "string column type",
@@ -1557,16 +1509,9 @@ func TestPgGrammar_GetType(t *testing.T) {
want: "DECIMAL(10, 2)",
},
{
- name: "double column type with precision",
- blueprint: func(table *Blueprint) {
- table.Double("value", 8, 2)
- },
- want: "DOUBLE PRECISION",
- },
- {
- name: "double column type without precision",
+ name: "double column type",
blueprint: func(table *Blueprint) {
- table.Double("value", 0, 0)
+ table.Double("value")
},
want: "DOUBLE PRECISION",
},
@@ -1743,7 +1688,7 @@ func TestPgGrammar_GetType(t *testing.T) {
blueprint: func(table *Blueprint) {
table.TinyText("notes")
},
- want: "TEXT",
+ want: "VARCHAR(255)",
},
{
name: "date column type",
@@ -1822,13 +1767,6 @@ func TestPgGrammar_GetType(t *testing.T) {
},
want: "VARCHAR(255) CHECK (status IN ('active', 'inactive'))",
},
- {
- name: "Enum with empty values",
- blueprint: func(table *Blueprint) {
- table.Enum("status", []string{})
- },
- want: "VARCHAR(255)",
- },
}
for _, tt := range tests {
diff --git a/schema.go b/schema/schema.go
similarity index 71%
rename from schema.go
rename to schema/schema.go
index ea067f5..edc0350 100644
--- a/schema.go
+++ b/schema/schema.go
@@ -1,9 +1,11 @@
package schema
import (
- "context"
"database/sql"
"errors"
+
+ "github.com/akfaiz/migris/internal/config"
+ "github.com/akfaiz/migris/internal/dialect"
)
// Column represents a database column with its properties.
@@ -39,11 +41,12 @@ type TableInfo struct {
}
func newBuilder() (Builder, error) {
- if dialectValue == dialectUnknown {
+ dialectVal := config.GetDialect()
+ if dialectVal == dialect.Unknown {
return nil, errors.New("schema dialect is not set, please call schema.SetDialect() before using schema functions")
}
- builder, err := NewBuilder(dialectValue.String())
+ builder, err := NewBuilder(dialectVal.String())
if err != nil {
return nil, err
}
@@ -65,36 +68,13 @@ func newBuilder() (Builder, error) {
// table.Timestamp("created_at").Default("CURRENT_TIMESTAMP").Nullable(false)
// table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable(false)
// })
-func Create(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
- builder, err := newBuilder()
- if err != nil {
- return err
- }
-
- return builder.Create(ctx, tx, name, blueprint)
-}
-
-// CreateIfNotExists creates a new table with the given name and blueprint if it does not already exist.
-// The blueprint function is used to define the structure of the table.
-// It returns an error if the table creation fails.
-//
-// Example:
-//
-// err := schema.CreateIfNotExists(ctx, tx, "users", func(table *schema.Blueprint) {
-// table.ID()
-// table.String("name").Nullable(false)
-// table.String("email").Unique().Nullable(false)
-// table.String("password").Nullable()
-// table.Timestamp("created_at").Default("CURRENT_TIMESTAMP").Nullable(false)
-// table.Timestamp("updated_at").Default("CURRENT_TIMESTAMP").Nullable(false)
-// })
-func CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
+func Create(c *Context, name string, blueprint func(table *Blueprint)) error {
builder, err := newBuilder()
if err != nil {
return err
}
- return builder.CreateIfNotExists(ctx, tx, name, blueprint)
+ return builder.Create(c, name, blueprint)
}
// Drop removes the table with the given name.
@@ -103,13 +83,13 @@ func CreateIfNotExists(ctx context.Context, tx *sql.Tx, name string, blueprint f
// Example:
//
// err := schema.Drop(ctx, tx, "users")
-func Drop(ctx context.Context, tx *sql.Tx, name string) error {
+func Drop(c *Context, name string) error {
builder, err := newBuilder()
if err != nil {
return err
}
- return builder.Drop(ctx, tx, name)
+ return builder.Drop(c, name)
}
// DropIfExists removes the table with the given name if it exists.
@@ -118,13 +98,13 @@ func Drop(ctx context.Context, tx *sql.Tx, name string) error {
// Example:
//
// err := schema.DropIfExists(ctx, tx, "users")
-func DropIfExists(ctx context.Context, tx *sql.Tx, name string) error {
+func DropIfExists(c *Context, name string) error {
builder, err := newBuilder()
if err != nil {
return err
}
- return builder.DropIfExists(ctx, tx, name)
+ return builder.DropIfExists(c, name)
}
// GetColumns retrieves the columns of the specified table.
@@ -133,13 +113,13 @@ func DropIfExists(ctx context.Context, tx *sql.Tx, name string) error {
// Example:
//
// columns, err := schema.GetColumns(ctx, tx, "users")
-func GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, error) {
+func GetColumns(c *Context, tableName string) ([]*Column, error) {
builder, err := newBuilder()
if err != nil {
return nil, err
}
- return builder.GetColumns(ctx, tx, tableName)
+ return builder.GetColumns(c, tableName)
}
// GetIndexes retrieves the indexes of the specified table.
@@ -148,13 +128,13 @@ func GetColumns(ctx context.Context, tx *sql.Tx, tableName string) ([]*Column, e
// Example:
//
// indexes, err := schema.GetIndexes(ctx, tx, "users")
-func GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, error) {
+func GetIndexes(c *Context, tableName string) ([]*Index, error) {
builder, err := newBuilder()
if err != nil {
return nil, err
}
- return builder.GetIndexes(ctx, tx, tableName)
+ return builder.GetIndexes(c, tableName)
}
// GetTables retrieves all tables in the database.
@@ -163,13 +143,13 @@ func GetIndexes(ctx context.Context, tx *sql.Tx, tableName string) ([]*Index, er
// Example:
//
// tables, err := schema.GetTables(ctx, tx)
-func GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) {
+func GetTables(c *Context) ([]*TableInfo, error) {
builder, err := newBuilder()
if err != nil {
return nil, err
}
- return builder.GetTables(ctx, tx)
+ return builder.GetTables(c)
}
// HasColumn checks if a column with the given name exists in the specified table.
@@ -178,13 +158,13 @@ func GetTables(ctx context.Context, tx *sql.Tx) ([]*TableInfo, error) {
// Example:
//
// exists, err := schema.HasColumn(ctx, tx, "users", "email")
-func HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName string) (bool, error) {
+func HasColumn(c *Context, tableName string, columnName string) (bool, error) {
builder, err := newBuilder()
if err != nil {
return false, err
}
- return builder.HasColumn(ctx, tx, tableName, columnName)
+ return builder.HasColumn(c, tableName, columnName)
}
// HasColumns checks if the specified columns exist in the given table.
@@ -195,13 +175,13 @@ func HasColumn(ctx context.Context, tx *sql.Tx, tableName string, columnName str
// exists, err := schema.HasColumns(ctx, tx, "users", []string{"email", "name"})
//
// If any of the specified columns do not exist, it returns false.
-func HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames []string) (bool, error) {
+func HasColumns(c *Context, tableName string, columnNames []string) (bool, error) {
builder, err := newBuilder()
if err != nil {
return false, err
}
- return builder.HasColumns(ctx, tx, tableName, columnNames)
+ return builder.HasColumns(c, tableName, columnNames)
}
// HasIndex checks if an index with the given name exists in the specified table.
@@ -212,13 +192,13 @@ func HasColumns(ctx context.Context, tx *sql.Tx, tableName string, columnNames [
// exists, err := schema.HasIndex(ctx, tx, "users", []string{"uk_users_email"}) // Checks if the index with name "uk_users_email" exists in the "users" table.
//
// exists, err := schema.HasIndex(ctx, tx, "users", []string{"email", "name"}) // Checks if a composite index exists on the "email" and "name" columns in the "users" table.
-func HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []string) (bool, error) {
+func HasIndex(c *Context, tableName string, indexes []string) (bool, error) {
builder, err := newBuilder()
if err != nil {
return false, err
}
- return builder.HasIndex(ctx, tx, tableName, indexes)
+ return builder.HasIndex(c, tableName, indexes)
}
// HasTable checks if a table with the given name exists in the database.
@@ -228,13 +208,13 @@ func HasIndex(ctx context.Context, tx *sql.Tx, tableName string, indexes []strin
// Example:
//
// exists, err := schema.HasTable(ctx, tx, "users")
-func HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) {
+func HasTable(c *Context, name string) (bool, error) {
builder, err := newBuilder()
if err != nil {
return false, err
}
- return builder.HasTable(ctx, tx, name)
+ return builder.HasTable(c, name)
}
// Rename changes the name of the table from name to newName.
@@ -243,13 +223,13 @@ func HasTable(ctx context.Context, tx *sql.Tx, name string) (bool, error) {
// Example:
//
// err := schema.Rename(ctx, tx, "users", "people")
-func Rename(ctx context.Context, tx *sql.Tx, name string, newName string) error {
+func Rename(c *Context, name string, newName string) error {
builder, err := newBuilder()
if err != nil {
return err
}
- return builder.Rename(ctx, tx, name, newName)
+ return builder.Rename(c, name, newName)
}
// Table modifies an existing table with the given name and blueprint.
@@ -263,11 +243,11 @@ func Rename(ctx context.Context, tx *sql.Tx, name string, newName string) error
// table.DropColumn("password")
// table.RenameColumn("email", "contact_email")
// })
-func Table(ctx context.Context, tx *sql.Tx, name string, blueprint func(table *Blueprint)) error {
+func Table(c *Context, name string, blueprint func(table *Blueprint)) error {
builder, err := newBuilder()
if err != nil {
return err
}
- return builder.Table(ctx, tx, name, blueprint)
+ return builder.Table(c, name, blueprint)
}
diff --git a/schema_test.go b/schema/schema_test.go
similarity index 54%
rename from schema_test.go
rename to schema/schema_test.go
index 5a7bfcb..399181c 100644
--- a/schema_test.go
+++ b/schema/schema_test.go
@@ -6,7 +6,9 @@ import (
"fmt"
"testing"
- "github.com/afkdevs/go-schema"
+ "github.com/akfaiz/migris/internal/config"
+ "github.com/akfaiz/migris/internal/dialect"
+ "github.com/akfaiz/migris/schema"
"github.com/stretchr/testify/suite"
)
@@ -21,6 +23,7 @@ type schemaTestSuite struct {
}
func (s *schemaTestSuite) SetupSuite() {
+ config.SetDialect(dialect.Postgres)
ctx := context.Background()
s.ctx = ctx
@@ -33,38 +36,7 @@ func (s *schemaTestSuite) SetupSuite() {
err = db.Ping()
s.Require().NoError(err)
-
s.db = db
- schema.SetDebug(false)
-
- s.Run("when dialect is not set should return error", func() {
- err := schema.SetDialect("")
- s.Error(err)
- s.ErrorContains(err, "unknown dialect")
-
- builderFuncs := []func() error{
- func() error { return schema.Create(ctx, nil, "", nil) },
- func() error { return schema.CreateIfNotExists(ctx, nil, "", nil) },
- func() error { return schema.Drop(ctx, nil, "") },
- func() error { return schema.DropIfExists(ctx, nil, "") },
- func() error { _, err := schema.GetColumns(ctx, nil, ""); return err },
- func() error { _, err := schema.GetIndexes(ctx, nil, ""); return err },
- func() error { _, err := schema.GetTables(ctx, nil); return err },
- func() error { _, err := schema.HasColumn(ctx, nil, "", ""); return err },
- func() error { _, err := schema.HasColumns(ctx, nil, "", nil); return err },
- func() error { _, err := schema.HasIndex(ctx, nil, "", nil); return err },
- func() error { _, err := schema.HasTable(ctx, nil, ""); return err },
- func() error { return schema.Rename(ctx, nil, "", "") },
- func() error { return schema.Table(ctx, nil, "", nil) },
- }
- for _, fn := range builderFuncs {
- s.Error(fn(), "Expected error when dialect is not set")
- }
- })
- s.Run("when dialect is set to postgres should not return error", func() {
- err = schema.SetDialect("postgres")
- s.Require().NoError(err)
- })
}
func (s *schemaTestSuite) TearDownSuite() {
@@ -76,8 +48,10 @@ func (s *schemaTestSuite) TestCreate() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when parameters are valid should create table", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -88,78 +62,23 @@ func (s *schemaTestSuite) TestCreate() {
s.NoError(err)
})
s.Run("when table already exists should return error", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
})
s.Error(err)
s.ErrorContains(err, "\"users\" already exists")
})
s.Run("when table name is empty should return error", func() {
- err := schema.Create(s.ctx, tx, "", func(table *schema.Blueprint) {
+ err := schema.Create(c, "", func(table *schema.Blueprint) {
table.ID()
})
s.Error(err)
- s.ErrorContains(err, "table name is empty")
+ s.ErrorContains(err, "invalid arguments")
})
s.Run("when blueprint function is nil should return error", func() {
- err := schema.Create(s.ctx, tx, "test", nil)
+ err := schema.Create(c, "test", nil)
s.Error(err)
- s.ErrorContains(err, "blueprint function is nil")
- })
- s.Run("when transaction is nil should return error", func() {
- err := schema.Create(s.ctx, nil, "test", func(table *schema.Blueprint) {
- table.ID()
- })
- s.Error(err)
- s.ErrorContains(err, "transaction is nil")
- })
-}
-
-func (s *schemaTestSuite) TestCreateIfNotExists() {
- tx, err := s.db.BeginTx(s.ctx, nil)
- s.Require().NoError(err)
- defer tx.Rollback() //nolint:errcheck
-
- s.Run("when parameters are valid should create table", func() {
- err := schema.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name")
- table.String("email").Unique()
- table.String("password")
- table.Timestamp("created_at").UseCurrent()
- table.Timestamp("updated_at").UseCurrent()
- })
- s.NoError(err)
- })
- s.Run("when table already exists should return no error", func() {
- err := schema.CreateIfNotExists(s.ctx, tx, "users", func(table *schema.Blueprint) {
- table.ID()
- table.String("name")
- table.String("email")
- table.String("password")
- table.Timestamp("created_at").UseCurrent()
- table.Timestamp("updated_at").UseCurrent()
- })
- s.NoError(err)
- })
- s.Run("when table name is empty should return error", func() {
- err := schema.CreateIfNotExists(s.ctx, tx, "", func(table *schema.Blueprint) {
- table.ID()
- })
- s.Error(err)
- s.ErrorContains(err, "table name is empty")
- })
- s.Run("when blueprint function is nil should return error", func() {
- err := schema.CreateIfNotExists(s.ctx, tx, "test", nil)
- s.Error(err)
- s.ErrorContains(err, "blueprint function is nil")
- })
- s.Run("when transaction is nil should return error", func() {
- err := schema.CreateIfNotExists(s.ctx, nil, "test", func(table *schema.Blueprint) {
- table.ID()
- })
- s.Error(err)
- s.ErrorContains(err, "transaction is nil")
+ s.ErrorContains(err, "invalid arguments")
})
}
@@ -168,8 +87,10 @@ func (s *schemaTestSuite) TestDrop() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when parameters are valid should drop table", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -178,22 +99,22 @@ func (s *schemaTestSuite) TestDrop() {
table.Timestamp("updated_at").UseCurrent()
})
s.Require().NoError(err)
- err = schema.Drop(s.ctx, tx, "users")
+ err = schema.Drop(c, "users")
s.NoError(err)
})
s.Run("when table does not exist should return error", func() {
- err := schema.Drop(s.ctx, tx, "non_existing_table")
+ err := schema.Drop(c, "non_existing_table")
s.Error(err)
})
s.Run("when table name is empty should return error", func() {
- err := schema.Drop(s.ctx, tx, "")
+ err := schema.Drop(c, "")
s.Error(err)
- s.ErrorContains(err, "table name is empty")
+ s.ErrorContains(err, "invalid arguments")
})
- s.Run("when transaction is nil should return error", func() {
- err := schema.Drop(s.ctx, nil, "test")
+ s.Run("when context is nil should return error", func() {
+ err := schema.Drop(nil, "test")
s.Error(err)
- s.ErrorContains(err, "transaction is nil")
+ s.ErrorContains(err, "invalid arguments")
})
}
@@ -202,8 +123,10 @@ func (s *schemaTestSuite) TestDropIfExists() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when parameters are valid should drop table", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -212,22 +135,22 @@ func (s *schemaTestSuite) TestDropIfExists() {
table.Timestamp("updated_at").UseCurrent()
})
s.Require().NoError(err)
- err = schema.DropIfExists(s.ctx, tx, "users")
+ err = schema.DropIfExists(c, "users")
s.NoError(err)
})
s.Run("when table does not exist should return no error", func() {
- err := schema.DropIfExists(s.ctx, tx, "non_existing_table")
+ err := schema.DropIfExists(c, "non_existing_table")
s.NoError(err)
})
s.Run("when table name is empty should return error", func() {
- err := schema.DropIfExists(s.ctx, tx, "")
+ err := schema.DropIfExists(c, "")
s.Error(err)
- s.ErrorContains(err, "table name is empty")
+ s.ErrorContains(err, "invalid arguments")
})
- s.Run("when transaction is nil should return error", func() {
- err := schema.DropIfExists(s.ctx, nil, "test")
+ s.Run("when context is nil should return error", func() {
+ err := schema.DropIfExists(nil, "test")
s.Error(err)
- s.ErrorContains(err, "transaction is nil")
+ s.ErrorContains(err, "invalid arguments")
})
}
@@ -236,8 +159,10 @@ func (s *schemaTestSuite) TestGetColumns() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when parameters are valid should return columns", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -246,23 +171,23 @@ func (s *schemaTestSuite) TestGetColumns() {
table.Timestamp("updated_at").UseCurrent()
})
s.Require().NoError(err)
- columns, err := schema.GetColumns(s.ctx, tx, "users")
+ columns, err := schema.GetColumns(c, "users")
s.NoError(err)
s.NotEmpty(columns)
s.Len(columns, 6)
})
s.Run("when table does not exist should empty columns", func() {
- columns, err := schema.GetColumns(s.ctx, tx, "non_existing_table")
+ columns, err := schema.GetColumns(c, "non_existing_table")
s.NoError(err)
s.Nil(columns)
})
s.Run("when table name is empty should return error", func() {
- columns, err := schema.GetColumns(s.ctx, tx, "")
+ columns, err := schema.GetColumns(c, "")
s.Error(err)
s.Nil(columns)
})
- s.Run("when transaction is nil should return error", func() {
- columns, err := schema.GetColumns(s.ctx, nil, "test")
+ s.Run("when context is nil should return error", func() {
+ columns, err := schema.GetColumns(nil, "test")
s.Error(err)
s.Nil(columns)
})
@@ -273,8 +198,10 @@ func (s *schemaTestSuite) TestGetIndexes() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when parameters are valid should return indexes", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -283,23 +210,23 @@ func (s *schemaTestSuite) TestGetIndexes() {
table.Timestamp("updated_at").UseCurrent()
})
s.Require().NoError(err)
- indexes, err := schema.GetIndexes(s.ctx, tx, "users")
+ indexes, err := schema.GetIndexes(c, "users")
s.NoError(err)
s.NotEmpty(indexes)
s.Len(indexes, 2) // Expecting the unique index on email and the primary key index on id
})
s.Run("when table does not exist should return empty indexes", func() {
- indexes, err := schema.GetIndexes(s.ctx, tx, "non_existing_table")
+ indexes, err := schema.GetIndexes(c, "non_existing_table")
s.NoError(err)
s.Nil(indexes)
})
s.Run("when table name is empty should return error", func() {
- indexes, err := schema.GetIndexes(s.ctx, tx, "")
+ indexes, err := schema.GetIndexes(c, "")
s.Error(err)
s.Nil(indexes)
})
- s.Run("when transaction is nil should return error", func() {
- indexes, err := schema.GetIndexes(s.ctx, nil, "test")
+ s.Run("when context is nil should return error", func() {
+ indexes, err := schema.GetIndexes(nil, "test")
s.Error(err)
s.Nil(indexes)
})
@@ -310,13 +237,15 @@ func (s *schemaTestSuite) TestGetTables() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when no tables exist should return empty", func() {
- tables, err := schema.GetTables(s.ctx, tx)
+ tables, err := schema.GetTables(c)
s.NoError(err)
s.Empty(tables)
})
s.Run("when transaction is valid should return tables", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -325,13 +254,13 @@ func (s *schemaTestSuite) TestGetTables() {
table.Timestamp("updated_at").UseCurrent()
})
s.Require().NoError(err)
- tables, err := schema.GetTables(s.ctx, tx)
+ tables, err := schema.GetTables(c)
s.NoError(err)
s.NotEmpty(tables)
s.Len(tables, 1) // Expecting at least the 'users' table created in previous tests
})
- s.Run("when transaction is nil should return error", func() {
- tables, err := schema.GetTables(s.ctx, nil)
+ s.Run("when context is nil should return error", func() {
+ tables, err := schema.GetTables(nil)
s.Error(err)
s.Nil(tables)
})
@@ -342,8 +271,10 @@ func (s *schemaTestSuite) TestHasColumn() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when column exists should return true", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -353,19 +284,19 @@ func (s *schemaTestSuite) TestHasColumn() {
})
s.Require().NoError(err)
- exists, err := schema.HasColumn(s.ctx, tx, "users", "email")
+ exists, err := schema.HasColumn(c, "users", "email")
s.NoError(err)
s.True(exists)
})
s.Run("when column does not exist should return false", func() {
- exists, err := schema.HasColumn(s.ctx, tx, "users", "non_existing_column")
+ exists, err := schema.HasColumn(c, "users", "non_existing_column")
s.NoError(err)
s.False(exists)
})
- s.Run("when transaction is nil should return error", func() {
- exists, err := schema.HasColumn(s.ctx, nil, "users", "email")
+ s.Run("when context is nil should return error", func() {
+ exists, err := schema.HasColumn(nil, "users", "email")
s.Error(err)
s.False(exists)
})
@@ -376,8 +307,10 @@ func (s *schemaTestSuite) TestHasColumns() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when all columns exist should return true", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -387,17 +320,17 @@ func (s *schemaTestSuite) TestHasColumns() {
})
s.Require().NoError(err)
- exists, err := schema.HasColumns(s.ctx, tx, "users", []string{"email", "name"})
+ exists, err := schema.HasColumns(c, "users", []string{"email", "name"})
s.NoError(err)
s.True(exists)
})
s.Run("when some columns do not exist should return false", func() {
- exists, err := schema.HasColumns(s.ctx, tx, "users", []string{"email", "non_existing_column"})
+ exists, err := schema.HasColumns(c, "users", []string{"email", "non_existing_column"})
s.NoError(err)
s.False(exists)
})
s.Run("when no columns are provided should return error", func() {
- exists, err := schema.HasColumns(s.ctx, tx, "users", []string{})
+ exists, err := schema.HasColumns(c, "users", []string{})
s.Error(err)
s.False(exists)
})
@@ -408,8 +341,10 @@ func (s *schemaTestSuite) TestHasIndex() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when index exists should return true", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.Integer("company_id")
table.String("name")
@@ -422,20 +357,20 @@ func (s *schemaTestSuite) TestHasIndex() {
})
s.Require().NoError(err)
- exists, err := schema.HasIndex(s.ctx, tx, "users", []string{"email"})
+ exists, err := schema.HasIndex(c, "users", []string{"email"})
s.NoError(err)
s.True(exists)
- exists, err = schema.HasIndex(s.ctx, tx, "users", []string{"company_id", "id"})
+ exists, err = schema.HasIndex(c, "users", []string{"company_id", "id"})
s.NoError(err)
s.True(exists)
- exists, err = schema.HasIndex(s.ctx, tx, "users", []string{"uk_users_email"})
+ exists, err = schema.HasIndex(c, "users", []string{"uk_users_email"})
s.NoError(err)
s.True(exists)
})
s.Run("when index does not exist should return false", func() {
- exists, err := schema.HasIndex(s.ctx, tx, "users", []string{"non_existing_index"})
+ exists, err := schema.HasIndex(c, "users", []string{"non_existing_index"})
s.NoError(err)
s.False(exists)
})
@@ -446,8 +381,10 @@ func (s *schemaTestSuite) TestHasTable() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when table exists should return true", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -457,22 +394,22 @@ func (s *schemaTestSuite) TestHasTable() {
})
s.Require().NoError(err)
- exists, err := schema.HasTable(s.ctx, tx, "users")
+ exists, err := schema.HasTable(c, "users")
s.NoError(err)
s.True(exists)
})
s.Run("when table does not exist should return false", func() {
- exists, err := schema.HasTable(s.ctx, tx, "non_existing_table")
+ exists, err := schema.HasTable(c, "non_existing_table")
s.NoError(err)
s.False(exists)
})
s.Run("when table name is empty should return error", func() {
- exists, err := schema.HasTable(s.ctx, tx, "")
+ exists, err := schema.HasTable(c, "")
s.Error(err)
s.False(exists)
})
- s.Run("when transaction is nil should return error", func() {
- exists, err := schema.HasTable(s.ctx, nil, "users")
+ s.Run("when context is nil should return error", func() {
+ exists, err := schema.HasTable(nil, "users")
s.Error(err)
s.False(exists)
})
@@ -483,8 +420,10 @@ func (s *schemaTestSuite) TestRename() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when parameters are valid should rename table", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -494,27 +433,27 @@ func (s *schemaTestSuite) TestRename() {
})
s.Require().NoError(err)
- err = schema.Rename(s.ctx, tx, "users", "members")
+ err = schema.Rename(c, "users", "members")
s.NoError(err)
- columns, err := schema.GetColumns(s.ctx, tx, "members")
+ columns, err := schema.GetColumns(c, "members")
s.NoError(err)
s.NotEmpty(columns)
s.Len(columns, 6)
})
s.Run("when table does not exist should return error", func() {
- err := schema.Rename(s.ctx, tx, "non_existing_table", "new_name")
+ err := schema.Rename(c, "non_existing_table", "new_name")
s.Error(err)
})
s.Run("when new name is empty should return error", func() {
- err := schema.Rename(s.ctx, tx, "users", "")
+ err := schema.Rename(c, "users", "")
s.Error(err)
})
- s.Run("when transaction is nil should return error", func() {
- err := schema.Rename(s.ctx, nil, "users", "new_name")
+ s.Run("when context is nil should return error", func() {
+ err := schema.Rename(nil, "users", "new_name")
s.Error(err)
})
}
@@ -524,8 +463,10 @@ func (s *schemaTestSuite) TestTable() {
s.Require().NoError(err)
defer tx.Rollback() //nolint:errcheck
+ c := schema.NewContext(s.ctx, tx)
+
s.Run("when parameters are valid should alter table", func() {
- err := schema.Create(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err := schema.Create(c, "users", func(table *schema.Blueprint) {
table.ID()
table.String("name")
table.String("email").Unique()
@@ -535,25 +476,25 @@ func (s *schemaTestSuite) TestTable() {
})
s.Require().NoError(err)
- err = schema.Table(s.ctx, tx, "users", func(table *schema.Blueprint) {
+ err = schema.Table(c, "users", func(table *schema.Blueprint) {
table.String("phone").Nullable()
})
s.NoError(err)
- columns, err := schema.GetColumns(s.ctx, tx, "users")
+ columns, err := schema.GetColumns(c, "users")
s.NoError(err)
s.Len(columns, 7) // Expecting the new 'phone' column to be added
})
s.Run("when table does not exist should return error", func() {
- err := schema.Table(s.ctx, tx, "non_existing_table", func(table *schema.Blueprint) {
+ err := schema.Table(c, "non_existing_table", func(table *schema.Blueprint) {
table.String("new_column")
})
s.Error(err)
})
- s.Run("when transaction is nil should return error", func() {
- err := schema.Table(s.ctx, nil, "users", func(table *schema.Blueprint) {
+ s.Run("when context is nil should return error", func() {
+ err := schema.Table(nil, "users", func(table *schema.Blueprint) {
table.String("new_column")
})
s.Error(err)
diff --git a/status.go b/status.go
new file mode 100644
index 0000000..7fa50d9
--- /dev/null
+++ b/status.go
@@ -0,0 +1,27 @@
+package migris
+
+import (
+ "context"
+
+ "github.com/akfaiz/migris/internal/logger"
+)
+
+// Status returns the status of the migrations.
+func (m *Migrate) Status() error {
+ ctx := context.Background()
+ return m.StatusContext(ctx)
+}
+
+// StatusContext returns the status of the migrations.
+func (m *Migrate) StatusContext(ctx context.Context) error {
+ provider, err := m.newProvider()
+ if err != nil {
+ return err
+ }
+ migrations, err := provider.Status(ctx)
+ if err != nil {
+ return err
+ }
+ logger.PrintStatuses(migrations)
+ return nil
+}
diff --git a/up.go b/up.go
new file mode 100644
index 0000000..0c230e7
--- /dev/null
+++ b/up.go
@@ -0,0 +1,57 @@
+package migris
+
+import (
+ "context"
+ "errors"
+
+ "github.com/akfaiz/migris/internal/logger"
+ "github.com/pressly/goose/v3"
+)
+
+// Up applies the migrations in the specified directory.
+func (m *Migrate) Up() error {
+ ctx := context.Background()
+ return m.UpContext(ctx)
+}
+
+// UpContext applies the migrations in the specified directory.
+func (m *Migrate) UpContext(ctx context.Context) error {
+ return m.UpToContext(ctx, goose.MaxVersion)
+}
+
+// UpTo applies the migrations up to the specified version.
+func (m *Migrate) UpTo(version int64) error {
+ ctx := context.Background()
+ return m.UpToContext(ctx, version)
+}
+
+// UpToContext applies the migrations up to the specified version.
+func (m *Migrate) UpToContext(ctx context.Context, version int64) error {
+ provider, err := m.newProvider()
+ if err != nil {
+ return err
+ }
+ hasPending, err := provider.HasPending(ctx)
+ if err != nil {
+ return err
+ }
+ if !hasPending {
+ logger.Info("Nothing to migrate.")
+ return nil
+ }
+
+ logger.Infof("Running migrations.\n")
+ results, err := provider.UpTo(ctx, version)
+ if err != nil {
+ var partialErr *goose.PartialError
+ if errors.As(err, &partialErr) {
+ logger.PrintResults(partialErr.Applied)
+ logger.PrintResult(partialErr.Failed)
+ }
+
+ return err
+ }
+ logger.PrintResults(results)
+
+ return nil
+}
diff --git a/util.go b/util.go
deleted file mode 100644
index 994156b..0000000
--- a/util.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package schema
-
-import (
- "context"
- "database/sql"
- "log"
-)
-
-func optional[T any](defaultValue T, values ...T) T {
- return optionalAtIndex(0, defaultValue, values...)
-}
-
-func optionalAtIndex[T any](index int, defaultValue T, values ...T) T {
- if index < len(values) {
- return values[index]
- }
- return defaultValue
-}
-
-func execContext(ctx context.Context, tx *sql.Tx, queries ...string) error {
- for _, query := range queries {
- if debug {
- log.Printf("Executing SQL: %s\n", query)
- }
- if _, err := tx.ExecContext(ctx, query); err != nil {
- return err
- }
- }
- return nil
-}
-
-func queryRowContext(ctx context.Context, tx *sql.Tx, query string, args ...any) *sql.Row {
- if debug {
- log.Printf("Executing Query: %s with args: %v\n", query, args)
- }
- return tx.QueryRowContext(ctx, query, args...)
-}
-
-func queryContext(ctx context.Context, tx *sql.Tx, query string, args ...any) (*sql.Rows, error) {
- if debug {
- log.Printf("Executing Query: %s with args: %v\n", query, args)
- }
- return tx.QueryContext(ctx, query, args...)
-}