Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/mysql/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ func (d *Client) getProviderClient() (provider.Interface, error) {
}

// Helper function to get all data for a table.
// *sql.Rows: Rows
// []string: Columns
// error: Error
func (d *Client) selectAllDataForTable(ctx context.Context, table string, params provider.DumpParams) (*sql.Rows, []string, error) {

client, err := d.getProviderClient()
if err != nil {
return nil, nil, err
Expand Down
17 changes: 14 additions & 3 deletions internal/mysql/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"bytes"
"fmt"
"io"
"slices"

"github.com/asaskevich/govalidator"
)

func getValue(raw string) (string, error) {
func getValue(raw string, columnType string) (string, error) {
if raw == "" {
return "''", nil
}
Expand All @@ -18,8 +19,18 @@ func getValue(raw string) (string, error) {
return "", err
}

if govalidator.IsInt(raw) {
return escaped, nil
// List of types from https://dev.mysql.com/doc/refman/8.4/en/data-types.html (just the Numeric ones)
asNumber := []string{
"INTEGER", "INT", "SMALLINT", "TINYINT", "MEDIUMINT", "BIGINT",
"DECIMAL", "NUMERIC", "DEC", "FIXED",
"FLOAT", "DOUBLE", "REAL", "DOUBLE PRECISION",
"BIT", "BOOL" }

// Only want to do this for numeric field types; it can cause problems when done for strings and JSON
if slices.Contains(asNumber, columnType) {
if govalidator.IsInt(raw) {
return escaped, nil
}
}

return fmt.Sprintf("'%s'", escaped), nil
Expand Down
10 changes: 7 additions & 3 deletions internal/mysql/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@ import (
)

func TestGetValue(t *testing.T) {
val, err := getValue("")
val, err := getValue("", "TEXT")
assert.NoError(t, err)
assert.Equal(t, "''", val)

val, err = getValue("1")
val, err = getValue("1", "INTEGER")
assert.NoError(t, err)
assert.Equal(t, "1", val)

val, err = getValue("foo")
val, err = getValue("1", "TEXT")
assert.NoError(t, err)
assert.Equal(t, "'1'", val)

val, err = getValue("foo", "TEXT")
assert.NoError(t, err)
assert.Equal(t, "'foo'", val)
}
Expand Down
12 changes: 10 additions & 2 deletions internal/mysql/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ func (d *Client) WriteTableData(ctx context.Context, w io.Writer, table string,

defer rows.Close()


// Get the column types
columntypes, err := rows.ColumnTypes()
if err != nil {
return err
}

// Set up and read the data, and print the INSERT statement
values := make([]*sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))

Expand Down Expand Up @@ -165,11 +173,11 @@ func (d *Client) WriteTableData(ctx context.Context, w io.Writer, table string,

var vals []string

for _, col := range values {
for columnindex, col := range values {
val := "NULL"

if col != nil {
val, err = getValue(string(*col))
val, err = getValue(string(*col), strings.ToUpper(columntypes[columnindex].DatabaseTypeName()))
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/mysql/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ func TestMySQLDumpTableData(t *testing.T) {
assert.Nil(t, dumper.WriteTableData(context.TODO(), buffer, "table", provider.DumpParams{
ExtendedInsertRows: 2}))

assert.Equal(t, strings.Count(buffer.String(), "INSERT INTO `table` VALUES"), 3)
assert.Equal(t, buffer.String(), "INSERT INTO `table` VALUES (1,'Go'),(2,'Java');\nINSERT INTO `table` VALUES (3,'C'),(4,'C++');\nINSERT INTO `table` VALUES (5,'Rust'),(6,'Closure');\n")
assert.Equal(t, 3, strings.Count(buffer.String(), "INSERT INTO `table` VALUES"))
assert.Equal(t, "INSERT INTO `table` VALUES ('1','Go'),('2','Java');\nINSERT INTO `table` VALUES ('3','C'),('4','C++');\nINSERT INTO `table` VALUES ('5','Rust'),('6','Closure');\n", buffer.String())
}

func TestMySQLDumpTableDataHandlingErrorFromSelectAllDataFor(t *testing.T) {
Expand Down