From 92b7df0d40b88654910f2062fd009a78b97bd57b Mon Sep 17 00:00:00 2001 From: Tim Nelson Date: Mon, 24 Mar 2025 10:07:20 +1100 Subject: [PATCH 1/6] First attempt --- internal/mysql/utils.go | 17 ++++++++++++++--- internal/mysql/write.go | 9 +++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/internal/mysql/utils.go b/internal/mysql/utils.go index 9b87603..8047848 100644 --- a/internal/mysql/utils.go +++ b/internal/mysql/utils.go @@ -8,7 +8,7 @@ import ( "github.com/asaskevich/govalidator" ) -func getValue(raw string) (string, error) { +func getValue(raw string, column_type ColumnType) (string, error) { if raw == "" { return "''", nil } @@ -18,8 +18,19 @@ func getValue(raw string) (string, error) { return "", err } - if govalidator.IsInt(raw) { - return escaped, nil + // List of types from https://dev.mysql.com/doc/refman/8.4/en/data-types.html (just the Numeric ones) + asNumber := []string{ + "INTEGER", "INT", "SMALLINT", "TINYINT", "MEDIUMINT", "BIGINT", + "DECIMAL", "NUMERIC", "DEC", "FIXED", + "FLOAT", "DOUBLE", "REAL", "DOUBLE PRECISION", + "BIT", "BOOL" + } + + // Only want to do this for numeric field types; it can cause problems when done for strings and JSON + if asNumber.Contains(column_type.DatabaseTypeName) { + if govalidator.IsInt(raw) { + return escaped, nil + } } return fmt.Sprintf("'%s'", escaped), nil diff --git a/internal/mysql/write.go b/internal/mysql/write.go index fdd9279..ee74e7c 100644 --- a/internal/mysql/write.go +++ b/internal/mysql/write.go @@ -127,6 +127,11 @@ func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpP defer rows.Close() + + // Get the column types + columntypes := rows.ColumnTypes() + + // Set up and read the data, and print the INSERT statement values := make([]*sql.RawBytes, len(columns)) scanArgs := make([]interface{}, len(values)) @@ -164,11 +169,11 @@ func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpP var vals []string - for _, col := range values { + for columnindex, col := range values { val := "NULL" if col != nil { - val, err = getValue(string(*col)) + val, err = getValue(string(*col), columntypes[columnindex]) if err != nil { return err } From 9c23dc6cf21f3fb4ae089d5b0e13fbb30f998015 Mon Sep 17 00:00:00 2001 From: Tim Nelson Date: Mon, 24 Mar 2025 11:18:11 +1100 Subject: [PATCH 2/6] Fix for: syntax error: unexpected newline in composite literal; possibly missing comma or } --- internal/mysql/utils.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/mysql/utils.go b/internal/mysql/utils.go index 8047848..b43527c 100644 --- a/internal/mysql/utils.go +++ b/internal/mysql/utils.go @@ -23,8 +23,7 @@ func getValue(raw string, column_type ColumnType) (string, error) { "INTEGER", "INT", "SMALLINT", "TINYINT", "MEDIUMINT", "BIGINT", "DECIMAL", "NUMERIC", "DEC", "FIXED", "FLOAT", "DOUBLE", "REAL", "DOUBLE PRECISION", - "BIT", "BOOL" - } + "BIT", "BOOL" } // Only want to do this for numeric field types; it can cause problems when done for strings and JSON if asNumber.Contains(column_type.DatabaseTypeName) { From 6ffbee7ee0d030d6cfe5ab943f2589c62fd16ba6 Mon Sep 17 00:00:00 2001 From: Tim Nelson Date: Tue, 15 Apr 2025 13:58:59 +1000 Subject: [PATCH 3/6] Tweaks to fix errors - Renamed variables as requested by linter - Used slices.Contains better - Did error-checking on rows.ColumnTypes --- internal/mysql/tables.go | 3 +++ internal/mysql/utils.go | 6 ++++-- internal/mysql/write.go | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/internal/mysql/tables.go b/internal/mysql/tables.go index 443fc44..603f010 100644 --- a/internal/mysql/tables.go +++ b/internal/mysql/tables.go @@ -84,6 +84,9 @@ func (d *Client) getProviderClient() (provider.Interface, error) { } // Helper function to get all data for a table. +// *sql.Rows: Rows +// []string: Columns +// error: Error func (d *Client) selectAllDataForTable(table string, params provider.DumpParams) (*sql.Rows, []string, error) { client, err := d.getProviderClient() diff --git a/internal/mysql/utils.go b/internal/mysql/utils.go index b43527c..fffcab9 100644 --- a/internal/mysql/utils.go +++ b/internal/mysql/utils.go @@ -4,11 +4,13 @@ import ( "bytes" "fmt" "io" + "slices" + "database/sql" "github.com/asaskevich/govalidator" ) -func getValue(raw string, column_type ColumnType) (string, error) { +func getValue(raw string, theColumnType ColumnType) (string, error) { if raw == "" { return "''", nil } @@ -26,7 +28,7 @@ func getValue(raw string, column_type ColumnType) (string, error) { "BIT", "BOOL" } // Only want to do this for numeric field types; it can cause problems when done for strings and JSON - if asNumber.Contains(column_type.DatabaseTypeName) { + if slices.Contains(asNumber, theColumnType.DatabaseTypeName) { if govalidator.IsInt(raw) { return escaped, nil } diff --git a/internal/mysql/write.go b/internal/mysql/write.go index ee74e7c..995d17c 100644 --- a/internal/mysql/write.go +++ b/internal/mysql/write.go @@ -129,7 +129,7 @@ func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpP // Get the column types - columntypes := rows.ColumnTypes() + columntypes, err := rows.ColumnTypes() // Set up and read the data, and print the INSERT statement values := make([]*sql.RawBytes, len(columns)) From 3c05d6024456d3765efe6118a6a7ad48940d6854 Mon Sep 17 00:00:00 2001 From: Tim Nelson Date: Tue, 15 Apr 2025 14:08:37 +1000 Subject: [PATCH 4/6] Passed column type to getValue as a string instead, to keep down the number of imports --- internal/mysql/utils.go | 5 ++--- internal/mysql/write.go | 5 ++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/internal/mysql/utils.go b/internal/mysql/utils.go index fffcab9..575d962 100644 --- a/internal/mysql/utils.go +++ b/internal/mysql/utils.go @@ -5,12 +5,11 @@ import ( "fmt" "io" "slices" - "database/sql" "github.com/asaskevich/govalidator" ) -func getValue(raw string, theColumnType ColumnType) (string, error) { +func getValue(raw string, columnType string) (string, error) { if raw == "" { return "''", nil } @@ -28,7 +27,7 @@ func getValue(raw string, theColumnType ColumnType) (string, error) { "BIT", "BOOL" } // Only want to do this for numeric field types; it can cause problems when done for strings and JSON - if slices.Contains(asNumber, theColumnType.DatabaseTypeName) { + if slices.Contains(asNumber, columnType) { if govalidator.IsInt(raw) { return escaped, nil } diff --git a/internal/mysql/write.go b/internal/mysql/write.go index 995d17c..cbe01e3 100644 --- a/internal/mysql/write.go +++ b/internal/mysql/write.go @@ -130,6 +130,9 @@ func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpP // Get the column types columntypes, err := rows.ColumnTypes() + if err != nil { + return err + } // Set up and read the data, and print the INSERT statement values := make([]*sql.RawBytes, len(columns)) @@ -173,7 +176,7 @@ func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpP val := "NULL" if col != nil { - val, err = getValue(string(*col), columntypes[columnindex]) + val, err = getValue(string(*col), columntypes[columnindex].databaseType) if err != nil { return err } From 9eb58d960e1d3c769dd8db1d345c46c56f32bea2 Mon Sep 17 00:00:00 2001 From: Tim Nelson Date: Tue, 15 Apr 2025 14:12:49 +1000 Subject: [PATCH 5/6] Use accessor to get database type name --- internal/mysql/write.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/mysql/write.go b/internal/mysql/write.go index cbe01e3..12c82bd 100644 --- a/internal/mysql/write.go +++ b/internal/mysql/write.go @@ -176,7 +176,7 @@ func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpP val := "NULL" if col != nil { - val, err = getValue(string(*col), columntypes[columnindex].databaseType) + val, err = getValue(string(*col), columntypes[columnindex].DatabaseTypeName()) if err != nil { return err } From c764ebb287603b77f954d83d766912b6ac6b3364 Mon Sep 17 00:00:00 2001 From: Tim Nelson Date: Tue, 15 Apr 2025 14:39:21 +1000 Subject: [PATCH 6/6] Updated Tests internal/mysql/utils_test.go: - Passed in second parameter for all tests - Added a new test so that we can test both numbers in a numeric type and numbers in a non-numeric type internal/mysql/write.go: - Converted type to uppercase before passing to getValue internal/mysql/write_test.go: - Switched around parameters to assert.Equal, because they were in the wrong order - Added quotes to the insert statements, because apparently that's what they do now (unless we can figure out a way to have sql_mock do column.DatabaseTypeName()) --- internal/mysql/utils_test.go | 10 +++++++--- internal/mysql/write.go | 2 +- internal/mysql/write_test.go | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/internal/mysql/utils_test.go b/internal/mysql/utils_test.go index 57ccea4..a316b1c 100644 --- a/internal/mysql/utils_test.go +++ b/internal/mysql/utils_test.go @@ -7,15 +7,19 @@ import ( ) func TestGetValue(t *testing.T) { - val, err := getValue("") + val, err := getValue("", "TEXT") assert.NoError(t, err) assert.Equal(t, "''", val) - val, err = getValue("1") + val, err = getValue("1", "INTEGER") assert.NoError(t, err) assert.Equal(t, "1", val) - val, err = getValue("foo") + val, err = getValue("1", "TEXT") + assert.NoError(t, err) + assert.Equal(t, "'1'", val) + + val, err = getValue("foo", "TEXT") assert.NoError(t, err) assert.Equal(t, "'foo'", val) } diff --git a/internal/mysql/write.go b/internal/mysql/write.go index 12c82bd..ecd232e 100644 --- a/internal/mysql/write.go +++ b/internal/mysql/write.go @@ -176,7 +176,7 @@ func (d *Client) WriteTableData(w io.Writer, table string, params provider.DumpP val := "NULL" if col != nil { - val, err = getValue(string(*col), columntypes[columnindex].DatabaseTypeName()) + val, err = getValue(string(*col), strings.ToUpper(columntypes[columnindex].DatabaseTypeName())) if err != nil { return err } diff --git a/internal/mysql/write_test.go b/internal/mysql/write_test.go index 62c014d..b0703eb 100644 --- a/internal/mysql/write_test.go +++ b/internal/mysql/write_test.go @@ -74,8 +74,8 @@ func TestMySQLDumpTableData(t *testing.T) { assert.Nil(t, dumper.WriteTableData(buffer, "table", provider.DumpParams{ ExtendedInsertRows: 2})) - assert.Equal(t, strings.Count(buffer.String(), "INSERT INTO `table` VALUES"), 3) - assert.Equal(t, buffer.String(), "INSERT INTO `table` VALUES (1,'Go'),(2,'Java');\nINSERT INTO `table` VALUES (3,'C'),(4,'C++');\nINSERT INTO `table` VALUES (5,'Rust'),(6,'Closure');\n") + assert.Equal(t, 3, strings.Count(buffer.String(), "INSERT INTO `table` VALUES")) + assert.Equal(t, "INSERT INTO `table` VALUES ('1','Go'),('2','Java');\nINSERT INTO `table` VALUES ('3','C'),('4','C++');\nINSERT INTO `table` VALUES ('5','Rust'),('6','Closure');\n", buffer.String()) } func TestMySQLDumpTableDataHandlingErrorFromSelectAllDataFor(t *testing.T) {