diff --git a/internal/mysql/tables.go b/internal/mysql/tables.go index f502b06..cfa9af3 100644 --- a/internal/mysql/tables.go +++ b/internal/mysql/tables.go @@ -85,8 +85,10 @@ func (d *Client) getProviderClient() (provider.Interface, error) { } // Helper function to get all data for a table. +// *sql.Rows: Rows +// []string: Columns +// error: Error func (d *Client) selectAllDataForTable(ctx context.Context, table string, params provider.DumpParams) (*sql.Rows, []string, error) { - client, err := d.getProviderClient() if err != nil { return nil, nil, err diff --git a/internal/mysql/utils.go b/internal/mysql/utils.go index 9b87603..575d962 100644 --- a/internal/mysql/utils.go +++ b/internal/mysql/utils.go @@ -4,11 +4,12 @@ import ( "bytes" "fmt" "io" + "slices" "github.com/asaskevich/govalidator" ) -func getValue(raw string) (string, error) { +func getValue(raw string, columnType string) (string, error) { if raw == "" { return "''", nil } @@ -18,8 +19,18 @@ func getValue(raw string) (string, error) { return "", err } - if govalidator.IsInt(raw) { - return escaped, nil + // List of types from https://dev.mysql.com/doc/refman/8.4/en/data-types.html (just the Numeric ones) + asNumber := []string{ + "INTEGER", "INT", "SMALLINT", "TINYINT", "MEDIUMINT", "BIGINT", + "DECIMAL", "NUMERIC", "DEC", "FIXED", + "FLOAT", "DOUBLE", "REAL", "DOUBLE PRECISION", + "BIT", "BOOL" } + + // Only want to do this for numeric field types; it can cause problems when done for strings and JSON + if slices.Contains(asNumber, columnType) { + if govalidator.IsInt(raw) { + return escaped, nil + } } return fmt.Sprintf("'%s'", escaped), nil diff --git a/internal/mysql/utils_test.go b/internal/mysql/utils_test.go index 57ccea4..a316b1c 100644 --- a/internal/mysql/utils_test.go +++ b/internal/mysql/utils_test.go @@ -7,15 +7,19 @@ import ( ) func TestGetValue(t *testing.T) { - val, err := getValue("") + val, err := getValue("", "TEXT") assert.NoError(t, err) assert.Equal(t, "''", val) - val, err = getValue("1") + val, err = getValue("1", "INTEGER") assert.NoError(t, err) assert.Equal(t, "1", val) - val, err = getValue("foo") + val, err = getValue("1", "TEXT") + assert.NoError(t, err) + assert.Equal(t, "'1'", val) + + val, err = getValue("foo", "TEXT") assert.NoError(t, err) assert.Equal(t, "'foo'", val) } diff --git a/internal/mysql/write.go b/internal/mysql/write.go index cebd008..aa4f4cc 100644 --- a/internal/mysql/write.go +++ b/internal/mysql/write.go @@ -128,6 +128,14 @@ func (d *Client) WriteTableData(ctx context.Context, w io.Writer, table string, defer rows.Close() + + // Get the column types + columntypes, err := rows.ColumnTypes() + if err != nil { + return err + } + + // Set up and read the data, and print the INSERT statement values := make([]*sql.RawBytes, len(columns)) scanArgs := make([]interface{}, len(values)) @@ -165,11 +173,11 @@ func (d *Client) WriteTableData(ctx context.Context, w io.Writer, table string, var vals []string - for _, col := range values { + for columnindex, col := range values { val := "NULL" if col != nil { - val, err = getValue(string(*col)) + val, err = getValue(string(*col), strings.ToUpper(columntypes[columnindex].DatabaseTypeName())) if err != nil { return err } diff --git a/internal/mysql/write_test.go b/internal/mysql/write_test.go index 37e296f..40df7dc 100644 --- a/internal/mysql/write_test.go +++ b/internal/mysql/write_test.go @@ -75,8 +75,8 @@ func TestMySQLDumpTableData(t *testing.T) { assert.Nil(t, dumper.WriteTableData(context.TODO(), buffer, "table", provider.DumpParams{ ExtendedInsertRows: 2})) - assert.Equal(t, strings.Count(buffer.String(), "INSERT INTO `table` VALUES"), 3) - assert.Equal(t, buffer.String(), "INSERT INTO `table` VALUES (1,'Go'),(2,'Java');\nINSERT INTO `table` VALUES (3,'C'),(4,'C++');\nINSERT INTO `table` VALUES (5,'Rust'),(6,'Closure');\n") + assert.Equal(t, 3, strings.Count(buffer.String(), "INSERT INTO `table` VALUES")) + assert.Equal(t, "INSERT INTO `table` VALUES ('1','Go'),('2','Java');\nINSERT INTO `table` VALUES ('3','C'),('4','C++');\nINSERT INTO `table` VALUES ('5','Rust'),('6','Closure');\n", buffer.String()) } func TestMySQLDumpTableDataHandlingErrorFromSelectAllDataFor(t *testing.T) {