From ca5a868b473949872c0818c33de4cfc88ab57a01 Mon Sep 17 00:00:00 2001 From: Deep1998 Date: Tue, 18 Jan 2022 03:06:36 +0530 Subject: [PATCH 1/3] CSV pt 2 --- cmd/cmd.go | 6 +- cmd/data.go | 7 +- cmd/schema.go | 3 + cmd/schema_and_data.go | 7 +- common/constants/constants.go | 3 + common/utils/utils.go | 49 +++- conversion/conversion.go | 58 ++-- profiles/source_profile.go | 7 +- sources/csv/data.go | 336 +++++++++++++++++++--- sources/csv/data_test.go | 80 +++--- sources/spanner/infoschema.go | 339 +++++++++++++++++++++++ sources/spanner/toddl.go | 79 ++++++ spanner/{ => writer}/batchwriter.go | 2 +- spanner/{ => writer}/batchwriter_test.go | 2 +- 14 files changed, 865 insertions(+), 113 deletions(-) create mode 100644 sources/spanner/infoschema.go create mode 100644 sources/spanner/toddl.go rename spanner/{ => writer}/batchwriter.go (99%) rename spanner/{ => writer}/batchwriter_test.go (99%) diff --git a/cmd/cmd.go b/cmd/cmd.go index a3509c10ed..588bd560af 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -18,9 +18,11 @@ package cmd import ( "context" "fmt" + "os" "strings" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/constants" "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" @@ -97,7 +99,7 @@ func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, // We pass an empty string to the sqlConnectionStr parameter as this is the legacy codepath, // which reads the environment variables and constructs the string later on. - bw, err := conversion.DataConv(sourceProfile, targetProfile, ioHelper, client, conv, dataOnly) + bw, err := conversion.DataConv(ctx, sourceProfile, targetProfile, ioHelper, client, conv, dataOnly) if err != nil { return fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) } @@ -109,5 +111,7 @@ func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, banner := utils.GetBanner(now, dbURI) conversion.Report(driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, outputFilePrefix+reportFile, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, outputFilePrefix+badDataFile, ioHelper.Out) + // Cleanup hb tmp data directory. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) return nil } diff --git a/cmd/data.go b/cmd/data.go index bd9883a8ea..cc82957a2a 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -8,6 +8,7 @@ import ( "path" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/constants" "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" @@ -54,7 +55,7 @@ func (cmd *DataCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.sessionJSON, "session", "", "Specifies the file we restore session state from") f.StringVar(&cmd.target, "target", "Spanner", "Specifies the target DB, defaults to Spanner (accepted values: `Spanner`)") f.StringVar(&cmd.targetProfile, "target-profile", "", "Flag for specifying connection profile for target database e.g., \"dialect=postgresql\"") - flag.BoolVar(&cmd.skipForeignKeys, "skip-foreign-keys", false, "Skip creating foreign keys after data migration is complete (ddl statements for foreign keys can still be found in the downloaded schema.ddl.txt file and the same can be applied separately)") + f.BoolVar(&cmd.skipForeignKeys, "skip-foreign-keys", false, "Skip creating foreign keys after data migration is complete (ddl statements for foreign keys can still be found in the downloaded schema.ddl.txt file and the same can be applied separately)") f.StringVar(&cmd.filePrefix, "prefix", "", "File prefix for generated files") } @@ -136,7 +137,7 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface } } - bw, err := conversion.DataConv(sourceProfile, targetProfile, &ioHelper, client, conv, true) + bw, err := conversion.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, client, conv, true) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return subcommands.ExitFailure @@ -150,5 +151,7 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface banner := utils.GetBanner(now, dbURI) conversion.Report(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix+reportFile, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, cmd.filePrefix+badDataFile, ioHelper.Out) + // Cleanup hb tmp data directory. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) return subcommands.ExitSuccess } diff --git a/cmd/schema.go b/cmd/schema.go index 1d5f9cc08c..69fa89729c 100644 --- a/cmd/schema.go +++ b/cmd/schema.go @@ -8,6 +8,7 @@ import ( "path" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/constants" "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" @@ -106,5 +107,7 @@ func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfa conversion.WriteSchemaFile(conv, now, cmd.filePrefix+schemaFile, ioHelper.Out) conversion.WriteSessionFile(conv, cmd.filePrefix+sessionFile, ioHelper.Out) conversion.Report(sourceProfile.Driver, nil, ioHelper.BytesRead, "", conv, cmd.filePrefix+reportFile, ioHelper.Out) + // Cleanup hb tmp data directory. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) return subcommands.ExitSuccess } diff --git a/cmd/schema_and_data.go b/cmd/schema_and_data.go index b6bdd17cd0..6f7021891f 100644 --- a/cmd/schema_and_data.go +++ b/cmd/schema_and_data.go @@ -8,6 +8,7 @@ import ( "path" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/constants" "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" @@ -52,7 +53,7 @@ func (cmd *SchemaAndDataCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.sourceProfile, "source-profile", "", "Flag for specifying connection profile for source database e.g., \"file=,format=dump\"") f.StringVar(&cmd.target, "target", "Spanner", "Specifies the target DB, defaults to Spanner (accepted values: `Spanner`)") f.StringVar(&cmd.targetProfile, "target-profile", "", "Flag for specifying connection profile for target database e.g., \"dialect=postgresql\"") - flag.BoolVar(&cmd.skipForeignKeys, "skip-foreign-keys", false, "Skip creating foreign keys after data migration is complete (ddl statements for foreign keys can still be found in the downloaded schema.ddl.txt file and the same can be applied separately)") + f.BoolVar(&cmd.skipForeignKeys, "skip-foreign-keys", false, "Skip creating foreign keys after data migration is complete (ddl statements for foreign keys can still be found in the downloaded schema.ddl.txt file and the same can be applied separately)") f.StringVar(&cmd.filePrefix, "prefix", "", "File prefix for generated files") } @@ -134,7 +135,7 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... return subcommands.ExitFailure } - bw, err := conversion.DataConv(sourceProfile, targetProfile, &ioHelper, client, conv, true) + bw, err := conversion.DataConv(ctx, sourceProfile, targetProfile, &ioHelper, client, conv, true) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return subcommands.ExitFailure @@ -148,5 +149,7 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... banner := utils.GetBanner(now, dbURI) conversion.Report(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix+reportFile, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, cmd.filePrefix+badDataFile, ioHelper.Out) + // Cleanup hb tmp data directory. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) return subcommands.ExitSuccess } diff --git a/common/constants/constants.go b/common/constants/constants.go index 83d5c61e83..170dfa52b0 100644 --- a/common/constants/constants.go +++ b/common/constants/constants.go @@ -33,4 +33,7 @@ const ( // Supported dialects for Cloud Spanner database. DIALECT_POSTGRESQL string = "postgresql" DIALECT_GOOGLESQL string = "google_standard_sql" + + // Temp directory name to write data which we cleanup at the end. + HB_TMP_DIR string = "harbourbridge_tmp_data" ) diff --git a/common/utils/utils.go b/common/utils/utils.go index 4ba6456da8..ca4ae0c03a 100644 --- a/common/utils/utils.go +++ b/common/utils/utils.go @@ -13,6 +13,8 @@ import ( "net/url" "os" "os/exec" + "reflect" + "sort" "strings" "syscall" "time" @@ -22,6 +24,9 @@ import ( instance "cloud.google.com/go/spanner/admin/instance/apiv1" "cloud.google.com/go/storage" "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/sources/common" + "github.com/cloudspannerecosystem/harbourbridge/sources/spanner" "golang.org/x/crypto/ssh/terminal" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -51,7 +56,7 @@ func NewIOStreams(driver string, dumpFile string) IOStreams { if u.Scheme == "gs" { bucketName := u.Host filePath := u.Path[1:] // removes "/" from beginning of path - f, err = downloadFromGCS(bucketName, filePath) + f, err = DownloadFromGCS(bucketName, filePath, "harbourbridge.gcs.data") } else { f, err = os.Open(dumpFile) } @@ -64,8 +69,8 @@ func NewIOStreams(driver string, dumpFile string) IOStreams { return io } -// downloadFromGCS returns the dump file that is downloaded from GCS -func downloadFromGCS(bucketName string, filePath string) (*os.File, error) { +// DownloadFromGCS returns the dump file that is downloaded from GCS. +func DownloadFromGCS(bucketName, filePath, tmpFile string) (*os.File, error) { ctx := context.Background() client, err := storage.NewClient(ctx) @@ -85,22 +90,23 @@ func downloadFromGCS(bucketName string, filePath string) (*os.File, error) { defer rc.Close() r := bufio.NewReader(rc) - tmpfile, err := ioutil.TempFile("", "harbourbridge.gcs.data") + tmpDir := os.TempDir() + constants.HB_TMP_DIR + os.MkdirAll(tmpDir, os.ModePerm) + tmpfile, err := os.Create(tmpDir + "/" + tmpFile) if err != nil { fmt.Printf("saveFile: unable to open temporary file to save dump file from GCS bucket %v", err) log.Fatal(err) return nil, err } - syscall.Unlink(tmpfile.Name()) // File will be deleted when this process exits. - fmt.Printf("\nDownloading dump file from GCS bucket %s, path %s\n", bucketName, filePath) + fmt.Printf("\nDownloading file from GCS bucket %s, path %s\n", bucketName, filePath) buffer := make([]byte, 1024) for { // read a chunk n, err := r.Read(buffer[:cap(buffer)]) if err != nil && err != io.EOF { - fmt.Printf("readFile: unable to read entire dump file from bucket %s, file %s: %v", bucketName, filePath, err) + fmt.Printf("readFile: unable to read entire file from bucket %s, file %s: %v", bucketName, filePath, err) log.Fatal(err) return nil, err } @@ -289,6 +295,15 @@ func ContainsAny(s string, l []string) bool { return false } +// CheckEqualSets checks if the set of values in a and b are equal. +func CheckEqualSets(a, b []string) bool { + tmp_a := append(make([]string, len(a)), a...) + tmp_b := append(make([]string, len(b)), b...) + sort.Strings(tmp_a) + sort.Strings(tmp_b) + return reflect.DeepEqual(tmp_a, tmp_b) +} + func GetFileSize(f *os.File) (int64, error) { info, err := f.Stat() if err != nil { @@ -415,3 +430,23 @@ func IsLegacyModeSupportedDriver(driver string) bool { func GetLegacyModeSupportedDrivers() []string { return GetValidDrivers()[:5] } + +// ReadSpannerSchema fills conv by querying Spanner infoschema treating Spanner as both the source and dest. +func ReadSpannerSchema(ctx context.Context, conv *internal.Conv, client *sp.Client) error { + infoSchema := spanner.InfoSchemaImpl{Client: client, Ctx: ctx, TargetDb: conv.TargetDb} + err := common.ProcessSchema(conv, infoSchema) + if err != nil { + return fmt.Errorf("error trying to read and convert spanner schema: %v", err) + } + parentTables, err := infoSchema.GetInterleaveTables() + if err != nil { + return fmt.Errorf("error trying to fetch interleave table info from schema: %v", err) + } + // Assign parents if any. + for table, parent := range parentTables { + spTable := conv.SpSchema[table] + spTable.Parent = parent + conv.SpSchema[table] = spTable + } + return nil +} diff --git a/conversion/conversion.go b/conversion/conversion.go index 9e2f3d92cb..d8cd85958f 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -54,8 +54,8 @@ import ( "github.com/cloudspannerecosystem/harbourbridge/sources/mysql" "github.com/cloudspannerecosystem/harbourbridge/sources/postgres" "github.com/cloudspannerecosystem/harbourbridge/sources/sqlserver" - "github.com/cloudspannerecosystem/harbourbridge/spanner" "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" + "github.com/cloudspannerecosystem/harbourbridge/spanner/writer" ) var ( @@ -92,8 +92,8 @@ func SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.Tar // - Driver is DynamoDB or a dump file mode. // - This function is called as part of the legacy global CLI flag mode. (This string is constructed from env variables later on) // When using source-profile, the sqlConnectionStr and schemaSampleSize are constructed from the input params. -func DataConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*spanner.BatchWriter, error) { - config := spanner.BatchWriterConfig{ +func DataConv(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*writer.BatchWriter, error) { + config := writer.BatchWriterConfig{ BytesLimit: 100 * 1000 * 1000, WriteLimit: 40, RetryLimit: 1000, @@ -108,7 +108,7 @@ func DataConv(sourceProfile profiles.SourceProfile, targetProfile profiles.Targe } return dataFromDump(sourceProfile.Driver, config, ioHelper, client, conv, dataOnly) case constants.CSV: - return dataFromCSV(sourceProfile, targetProfile, config, conv, client) + return dataFromCSV(ctx, sourceProfile, targetProfile, config, conv, client) default: return nil, fmt.Errorf("data conversion for driver %s not supported", sourceProfile.Driver) } @@ -170,7 +170,7 @@ func schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile prof return conv, common.ProcessSchema(conv, infoSchema) } -func dataFromDatabase(sourceProfile profiles.SourceProfile, config spanner.BatchWriterConfig, client *sp.Client, conv *internal.Conv) (*spanner.BatchWriter, error) { +func dataFromDatabase(sourceProfile profiles.SourceProfile, config writer.BatchWriterConfig, client *sp.Client, conv *internal.Conv) (*writer.BatchWriter, error) { infoSchema, err := GetInfoSchema(sourceProfile) if err != nil { return nil, err @@ -188,18 +188,18 @@ func dataFromDatabase(sourceProfile profiles.SourceProfile, config spanner.Batch p.MaybeReport(atomic.LoadInt64(&rows)) return nil } - writer := spanner.NewBatchWriter(config) + batchWriter := writer.NewBatchWriter(config) conv.SetDataMode() conv.SetDataSink( func(table string, cols []string, vals []interface{}) { - writer.AddRow(table, cols, vals) + batchWriter.AddRow(table, cols, vals) }) conv.DataFlush = func() { - writer.Flush() + batchWriter.Flush() } common.ProcessData(conv, infoSchema) - writer.Flush() - return writer, nil + batchWriter.Flush() + return batchWriter, nil } func getDynamoDBClientConfig() (*aws.Config, error) { @@ -234,7 +234,7 @@ func schemaFromDump(driver string, targetDb string, ioHelper *utils.IOStreams) ( return conv, nil } -func dataFromDump(driver string, config spanner.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*spanner.BatchWriter, error) { +func dataFromDump(driver string, config writer.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*writer.BatchWriter, error) { // TODO: refactor of the way we handle getSeekable // to avoid the code duplication here if !dataOnly { @@ -268,34 +268,41 @@ func dataFromDump(driver string, config spanner.BatchWriterConfig, ioHelper *uti p.MaybeReport(atomic.LoadInt64(&rows)) return nil } - writer := spanner.NewBatchWriter(config) + batchWriter := writer.NewBatchWriter(config) conv.SetDataMode() // Process data in dump; schema is unchanged. conv.SetDataSink( func(table string, cols []string, vals []interface{}) { - writer.AddRow(table, cols, vals) + batchWriter.AddRow(table, cols, vals) }) conv.DataFlush = func() { - writer.Flush() + batchWriter.Flush() } ProcessDump(driver, conv, r) - writer.Flush() + batchWriter.Flush() p.Done() - return writer, nil + return batchWriter, nil } -func dataFromCSV(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config spanner.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*spanner.BatchWriter, error) { +func dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, config writer.BatchWriterConfig, conv *internal.Conv, client *sp.Client) (*writer.BatchWriter, error) { if targetProfile.Conn.Sp.Dbname == "" { return nil, fmt.Errorf("dbname is mandatory in target-profile for csv source") } + conv.TargetDb = targetProfile.ToLegacyTargetDb() delimiterStr := sourceProfile.Csv.Delimiter if len(delimiterStr) != 1 { return nil, fmt.Errorf("delimiter should only be a single character long, found '%s'", delimiterStr) } delimiter := rune(delimiterStr[0]) - tables, err := csv.LoadManifest(conv, sourceProfile.Csv.Manifest) + + err := utils.ReadSpannerSchema(ctx, conv, client) if err != nil { - return nil, err + return nil, fmt.Errorf("error trying to read and convert spanner schema: %v", err) + } + + tables, err := csv.GetCSVFiles(conv, sourceProfile) + if err != nil { + return nil, fmt.Errorf("error finding csv files: %v", err) } // Find the number of rows in each csv file for generating stats. @@ -312,19 +319,22 @@ func dataFromCSV(sourceProfile profiles.SourceProfile, targetProfile profiles.Ta p.MaybeReport(atomic.LoadInt64(&rows)) return nil } - writer := spanner.NewBatchWriter(config) + batchWriter := writer.NewBatchWriter(config) conv.SetDataMode() conv.SetDataSink( func(table string, cols []string, vals []interface{}) { - writer.AddRow(table, cols, vals) + batchWriter.AddRow(table, cols, vals) }) + conv.DataFlush = func() { + batchWriter.Flush() + } err = csv.ProcessCSV(conv, tables, sourceProfile.Csv.NullStr, delimiter) if err != nil { return nil, fmt.Errorf("can't process csv: %v", err) } - writer.Flush() + batchWriter.Flush() p.Done() - return writer, nil + return batchWriter, nil } // Report generates a report of schema and data conversion. @@ -697,7 +707,7 @@ func ReadSessionFile(conv *internal.Conv, sessionJSON string) error { // WriteBadData prints summary stats about bad rows and writes detailed info // to file 'name'. -func WriteBadData(bw *spanner.BatchWriter, conv *internal.Conv, banner, name string, out *os.File) { +func WriteBadData(bw *writer.BatchWriter, conv *internal.Conv, banner, name string, out *os.File) { badConversions := conv.BadRows() badWrites := utils.SumMapValues(bw.DroppedRowsByTable()) if badConversions == 0 && badWrites == 0 { diff --git a/profiles/source_profile.go b/profiles/source_profile.go index cdb55829fb..23b474fbd8 100644 --- a/profiles/source_profile.go +++ b/profiles/source_profile.go @@ -431,12 +431,7 @@ func NewSourceProfile(s string, source string) (SourceProfile, error) { return SourceProfile{}, fmt.Errorf("could not parse source-profile, error = %v", err) } if strings.ToLower(source) == constants.CSV { - if _, ok := params["manifest"]; ok { - profile := NewSourceProfileCsv(params) - return SourceProfile{Ty: SourceProfileTypeCsv, Csv: profile}, nil - } else { - return SourceProfile{}, fmt.Errorf("csv source requires a manifest file, please specify manifest file in the source profile e.g., -source-profile=\"manifest=file_path\"") - } + return SourceProfile{Ty: SourceProfileTypeCsv, Csv: NewSourceProfileCsv(params)}, nil } if _, ok := params["file"]; ok || filePipedToStdin() { diff --git a/sources/csv/data.go b/sources/csv/data.go index 9b1fad5fd0..6c5bbc2704 100644 --- a/sources/csv/data.go +++ b/sources/csv/data.go @@ -21,13 +21,18 @@ import ( "io" "io/ioutil" "math/big" + "net/url" "os" "strconv" + "strings" "time" "cloud.google.com/go/civil" + "cloud.google.com/go/spanner" + "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/internal" - "github.com/cloudspannerecosystem/harbourbridge/schema" + "github.com/cloudspannerecosystem/harbourbridge/profiles" "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" ) @@ -40,12 +45,37 @@ type Column struct { type Table struct { Table_name string `json:"table_name"` File_patterns []string `json:"file_patterns"` - Columns []Column `json:"columns"` } -// LoadManifest reads the manifest file and unmarshalls it into a list of Table struct. +// GetCSVFiles finds the appropriate files paths and downloads gcs files in any. +func GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []Table, err error) { + // If manifest file not provided, we assume the csvs exist in the same directory + // in table_name.csv format. + if sourceProfile.Csv.Manifest == "" { + fmt.Println("Manifest file not provided, checking for files named `[table_name].csv` in current working directory...") + for t := range conv.SpSchema { + tables = append(tables, Table{Table_name: t, File_patterns: []string{fmt.Sprintf("%s.csv", t)}}) + } + } else { + fmt.Println("Manifest file provided, reading csv file paths...") + // Read paths provided in manifest. + tables, err = loadManifest(conv, sourceProfile.Csv.Manifest) + if err != nil { + return nil, err + } + } + + // Download gcs files if any. + tables, err = preloadGCSFiles(tables) + if err != nil { + return nil, fmt.Errorf("gcs file download error: %v", err) + } + return tables, nil +} + +// loadManifest reads the manifest file and unmarshalls it into a list of Table struct. // It also performs certain checks on the manifest. -func LoadManifest(conv *internal.Conv, manifestFile string) ([]Table, error) { +func loadManifest(conv *internal.Conv, manifestFile string) ([]Table, error) { manifest, err := ioutil.ReadFile(manifestFile) if err != nil { return nil, fmt.Errorf("can't read manifest file due to: %v", err) @@ -68,37 +98,61 @@ func VerifyManifest(conv *internal.Conv, tables []Table) error { if len(tables) == 0 { return fmt.Errorf("no tables found") } + missing := []string{} + for name := range conv.SrcSchema { + found := false + for _, table := range tables { + if name == table.Table_name { + found = true + break + } + } + if !found { + missing = append(missing, name) + } + } + if len(missing) > 0 { + fmt.Printf("WARNING: did not find manifest entries for tables [ %s ], ignoring and proceeding...\n", strings.Join(missing, ", ")) + conv.Unexpected(fmt.Sprintf("did not find manifest entries for tables [ %s ]", strings.Join(missing, ", "))) + } for i, table := range tables { name := table.Table_name if name == "" { return fmt.Errorf("table number %d (0-indexed) does not have a name", i) } + if _, ok := conv.SrcSchema[name]; !ok { + return fmt.Errorf("table %s provided in manifest does not exist in spanner", name) + } if len(table.File_patterns) == 0 { return fmt.Errorf("no file path provided for table %s", name) } - cols := table.Columns - if len(cols) == 0 { - return fmt.Errorf("`columns` field for table %s is empty", name) - } - // Populating just the table names in the conv for SrcSchema and SpSchema - // so the report for row stats is generated. - conv.SrcSchema[table.Table_name] = schema.Table{Name: table.Table_name} + } + return nil +} - // The map colDefs stores the mapping from column names to their final types. - colDefs := make(map[string]ddl.ColumnDef) - for j, col := range cols { - if col.Column_name == "" || col.Type_name == "" { - return fmt.Errorf("please provide column_name and type_name in `columns` field at position %d (0-indexed)", j) - } - ty, err := ToSpannerType(col.Type_name) +func preloadGCSFiles(tables []Table) ([]Table, error) { + for i, table := range tables { + for j, filePath := range table.File_patterns { + u, err := url.Parse(filePath) if err != nil { - return fmt.Errorf("can't map to spanner type: %v. Please use the data types as in your spanner database", err) + return nil, fmt.Errorf("unable parse file path %s for table %s", filePath, table.Table_name) + } + if u.Scheme == "gs" { + bucketName := u.Host + filePath := u.Path[1:] // removes "/" from beginning of path + tmpFile := strings.ReplaceAll(filePath, "/", ".") + // Files get downloaded to tmp dir. + fileLoc := os.TempDir() + constants.HB_TMP_DIR + "/" + tmpFile + _, err = utils.DownloadFromGCS(bucketName, filePath, tmpFile) + if err != nil { + return nil, fmt.Errorf("cannot download gcs file: %s for table %s", filePath, table.Table_name) + } + tables[i].File_patterns[j] = fileLoc + fmt.Printf("Downloaded file: %s\n", fileLoc) } - colDefs[col.Column_name] = ddl.ColumnDef{Name: col.Column_name, T: ty} } - conv.SpSchema[table.Table_name] = ddl.CreateTable{Name: table.Table_name, ColDefs: colDefs} } - return nil + return tables, nil } // SetRowStats calculates the number of rows per table. @@ -111,7 +165,7 @@ func SetRowStats(conv *internal.Conv, tables []Table, delimiter rune) { } r := csvReader.NewReader(csvFile) r.Comma = delimiter - count, err := getCSVRowCount(r) + count, err := getCSVDataRowCount(r, conv.SpSchema[table.Table_name].ColNames) if err != nil { conv.Unexpected(fmt.Sprintf("Couldn't get number of rows for table %s", table.Table_name)) continue @@ -120,14 +174,25 @@ func SetRowStats(conv *internal.Conv, tables []Table, delimiter rune) { conv.Unexpected(fmt.Sprintf("error processing table %s: file %s is empty.", table.Table_name, filePath)) continue } - conv.Stats.Rows[table.Table_name] += count - 1 + conv.Stats.Rows[table.Table_name] += count } } } -// getCSVRowCount returns the number of rows in the CSV file. -func getCSVRowCount(r *csvReader.Reader) (int64, error) { +// getCSVDataRowCount returns the number of data rows in the CSV file. This excludes the headers if present. +func getCSVDataRowCount(r *csvReader.Reader, colNames []string) (int64, error) { count := int64(0) + srcCols, err := r.Read() + if err == io.EOF { + return count, nil + } + if err != nil { + return count, fmt.Errorf("can't read csv headers for col names due to: %v", err) + } + // If the row read was not a header, increase count. + if !utils.CheckEqualSets(srcCols, colNames) { + count += 1 + } for { _, err := r.Read() if err == io.EOF { @@ -144,7 +209,17 @@ func getCSVRowCount(r *csvReader.Reader) (int64, error) { // ProcessCSV writes data across the tables provided in the manifest file. Each table's data can be provided // across multiple CSV files hence, the manifest accepts a list of file paths in the input. func ProcessCSV(conv *internal.Conv, tables []Table, nullStr string, delimiter rune) error { + orderedTableNames := ddl.OrderTables(conv.SpSchema) + nameToFiles := map[string][]string{} for _, table := range tables { + nameToFiles[table.Table_name] = table.File_patterns + } + orderedTables := []Table{} + for _, name := range orderedTableNames { + orderedTables = append(orderedTables, Table{name, nameToFiles[name]}) + } + + for _, table := range orderedTables { for _, filePath := range table.File_patterns { csvFile, err := os.Open(filePath) if err != nil { @@ -153,26 +228,38 @@ func ProcessCSV(conv *internal.Conv, tables []Table, nullStr string, delimiter r r := csvReader.NewReader(csvFile) r.Comma = delimiter - // First row is expected to be the column headers. + // Default column order is same as in Spanner schema. + colNames := conv.SpSchema[table.Table_name].ColNames srcCols, err := r.Read() if err == io.EOF { conv.Unexpected(fmt.Sprintf("error processing table %s: file %s is empty.", table.Table_name, filePath)) continue } if err != nil { - return fmt.Errorf("can't read csv headers for col names due to: %v", err) + return fmt.Errorf("can't read row for %s due to: %v", filePath, err) } + // If first row is some permutation of Spanner schema columns, we assume the first row is headers. + if utils.CheckEqualSets(srcCols, colNames) { + colNames = srcCols + } else { + // Write the first row since it was not a column header. + processDataRow(conv, nullStr, table.Table_name, colNames, srcCols) + } + for { values, err := r.Read() if err == io.EOF { break } if err != nil { - return fmt.Errorf(fmt.Sprintf("can't read row names due to: %v", err)) + return fmt.Errorf("can't read row for %s due to: %v", filePath, err) } - processDataRow(conv, nullStr, table.Table_name, srcCols, values) + processDataRow(conv, nullStr, table.Table_name, colNames, values) } } + if conv.DataFlush != nil { + conv.DataFlush() + } } return nil } @@ -200,7 +287,14 @@ func convertData(conv *internal.Conv, nullStr, tableName string, srcCols []strin continue } colName := srcCols[i] - x, err := convScalar(colDefs[colName].T, val) + spColDef := colDefs[colName] + var x interface{} + var err error + if spColDef.T.IsArray { + x, err = convArray(spColDef.T, val) + } else { + x, err = convScalar(conv, spColDef.T, val) + } if err != nil { return nil, nil, err } @@ -210,7 +304,169 @@ func convertData(conv *internal.Conv, nullStr, tableName string, srcCols []strin return cvtCols, v, nil } -func convScalar(spannerType ddl.Type, val string) (interface{}, error) { +func convArray(spannerType ddl.Type, val string) (interface{}, error) { + val = strings.TrimSpace(val) + // Handle empty array. Note that we use an empty NullString array + // for all Spanner array types since this will be converted to the + // appropriate type by the Spanner client. + if val == "{}" || val == "[]" { + return []spanner.NullString{}, nil + } + braces := val[:1] + val[len(val)-1:] + if braces != "{}" && braces != "[]" { + return []interface{}{}, fmt.Errorf("unrecognized data format for array: expected {v1, v2, ...} or [v1, v2, ...]") + } + a := strings.Split(val[1:len(val)-1], ",") + + // The Spanner client for go does not accept []interface{} for arrays. + // Instead it only accepts slices of a specific type e.g. []int64, []string. + // Hence we have to do the following case analysis. + switch spannerType.Name { + case ddl.Bool: + var r []spanner.NullBool + for _, s := range a { + if s == "NULL" { + r = append(r, spanner.NullBool{Valid: false}) + continue + } + s, err := processQuote(s) + if err != nil { + return []spanner.NullBool{}, err + } + b, err := convBool(s) + if err != nil { + return []spanner.NullBool{}, err + } + r = append(r, spanner.NullBool{Bool: b, Valid: true}) + } + return r, nil + case ddl.Bytes: + var r [][]byte + for _, s := range a { + if s == "NULL" { + r = append(r, nil) + continue + } + s, err := processQuote(s) + if err != nil { + return [][]byte{}, err + } + b, err := convBytes(s) + if err != nil { + return [][]byte{}, err + } + r = append(r, b) + } + return r, nil + case ddl.Date: + var r []spanner.NullDate + for _, s := range a { + if s == "NULL" { + r = append(r, spanner.NullDate{Valid: false}) + continue + } + s, err := processQuote(s) + if err != nil { + return []spanner.NullDate{}, err + } + date, err := convDate(s) + if err != nil { + return []spanner.NullDate{}, err + } + r = append(r, spanner.NullDate{Date: date, Valid: true}) + } + return r, nil + case ddl.Float64: + var r []spanner.NullFloat64 + for _, s := range a { + if s == "NULL" { + r = append(r, spanner.NullFloat64{Valid: false}) + continue + } + s, err := processQuote(s) + if err != nil { + return []spanner.NullFloat64{}, err + } + f, err := convFloat64(s) + if err != nil { + return []spanner.NullFloat64{}, err + } + r = append(r, spanner.NullFloat64{Float64: f, Valid: true}) + } + return r, nil + case ddl.Numeric: + var r []spanner.NullNumeric + for _, s := range a { + if s == "NULL" { + r = append(r, spanner.NullNumeric{Valid: false}) + continue + } + s, err := processQuote(s) + if err != nil { + return []spanner.NullFloat64{}, err + } + n, err := convNumeric(s) + if err != nil { + return []spanner.NullFloat64{}, err + } + r = append(r, spanner.NullNumeric{Numeric: n, Valid: true}) + } + return r, nil + case ddl.Int64: + var r []spanner.NullInt64 + for _, s := range a { + if s == "NULL" { + r = append(r, spanner.NullInt64{Valid: false}) + continue + } + s, err := processQuote(s) + if err != nil { + return []spanner.NullInt64{}, err + } + i, err := convInt64(s) + if err != nil { + return r, err + } + r = append(r, spanner.NullInt64{Int64: i, Valid: true}) + } + return r, nil + case ddl.String: + var r []spanner.NullString + for _, s := range a { + if s == "NULL" { + r = append(r, spanner.NullString{Valid: false}) + continue + } + s, err := processQuote(s) + if err != nil { + return []spanner.NullString{}, err + } + r = append(r, spanner.NullString{StringVal: s, Valid: true}) + } + return r, nil + case ddl.Timestamp: + var r []spanner.NullTime + for _, s := range a { + if s == "NULL" { + r = append(r, spanner.NullTime{Valid: false}) + continue + } + s, err := processQuote(s) + if err != nil { + return []spanner.NullTime{}, err + } + t, err := convTimestamp(s) + if err != nil { + return []spanner.NullTime{}, err + } + r = append(r, spanner.NullTime{Time: t, Valid: true}) + } + return r, nil + } + return []interface{}{}, fmt.Errorf("array type conversion not implemented for type []%v", spannerType.Name) +} + +func convScalar(conv *internal.Conv, spannerType ddl.Type, val string) (interface{}, error) { switch spannerType.Name { case ddl.Bool: return convBool(val) @@ -223,6 +479,9 @@ func convScalar(spannerType ddl.Type, val string) (interface{}, error) { case ddl.Int64: return convInt64(val) case ddl.Numeric: + if conv.TargetDb == constants.TargetExperimentalPostgres { + return spanner.PGNumeric{Numeric: val, Valid: true}, nil + } return convNumeric(val) case ddl.String: return val, nil @@ -275,12 +534,12 @@ func convInt64(val string) (int64, error) { // convNumeric maps a source database string value (representing a numeric) // into a string representing a valid Spanner numeric. -func convNumeric(val string) (interface{}, error) { +func convNumeric(val string) (big.Rat, error) { r := new(big.Rat) if _, ok := r.SetString(val); !ok { - return "", fmt.Errorf("can't convert %q to big.Rat", val) + return big.Rat{}, fmt.Errorf("can't convert %q to big.Rat", val) } - return r, nil + return *r, nil } func convTimestamp(val string) (t time.Time, err error) { @@ -290,3 +549,10 @@ func convTimestamp(val string) (t time.Time, err error) { } return t, err } + +func processQuote(s string) (string, error) { + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + return strconv.Unquote(s) + } + return s, nil +} diff --git a/sources/csv/data_test.go b/sources/csv/data_test.go index 39ea5658cc..a6f40235f6 100644 --- a/sources/csv/data_test.go +++ b/sources/csv/data_test.go @@ -34,26 +34,10 @@ func getManifestTables() []Table { { Table_name: ALL_TYPES_TABLE, File_patterns: []string{ALL_TYPES_CSV}, - Columns: []Column{ - {Column_name: "bool_col", Type_name: "BOOL"}, - {Column_name: "byte_col", Type_name: "BYTES"}, - {Column_name: "date_col", Type_name: "DATE"}, - {Column_name: "float_col", Type_name: "FLOAT64"}, - {Column_name: "int_col", Type_name: "INT64"}, - {Column_name: "numeric_col", Type_name: "NUMERIC"}, - {Column_name: "string_col", Type_name: "STRING"}, - {Column_name: "timestamp_col", Type_name: "TIMESTAMP"}, - {Column_name: "json_col", Type_name: "JSON"}, - }, }, { Table_name: SINGERS_TABLE, File_patterns: []string{SINGERS_1_CSV, SINGERS_2_CSV}, - Columns: []Column{ - {Column_name: "SingerId", Type_name: "INT64"}, - {Column_name: "FirstName", Type_name: "STRING"}, - {Column_name: "LastName", Type_name: "STRING"}, - }, }, } } @@ -103,34 +87,32 @@ func cleanupCSVs() { } func TestSetRowStats(t *testing.T) { - conv := internal.MakeConv() + conv := buildConv(getCreateTable()) writeCSVs(t) defer cleanupCSVs() SetRowStats(conv, getManifestTables(), ',') assert.Equal(t, map[string]int64{ALL_TYPES_TABLE: 1, SINGERS_TABLE: 2}, conv.Stats.Rows) } -func TestProcessDataRow(t *testing.T) { - conv := internal.MakeConv() +func TestProcessCSV(t *testing.T) { + writeCSVs(t) + defer cleanupCSVs() + tables := getManifestTables() + + conv := buildConv(getCreateTable()) var rows []spannerData conv.SetDataMode() conv.SetDataSink( func(table string, cols []string, vals []interface{}) { rows = append(rows, spannerData{table: table, cols: cols, vals: vals}) }) - - writeCSVs(t) - defer cleanupCSVs() - tables := getManifestTables() - VerifyManifest(conv, tables) err := ProcessCSV(conv, tables, "", ',') - fmt.Println(err) assert.Nil(t, err) assert.Equal(t, []spannerData{ { table: ALL_TYPES_TABLE, cols: []string{"bool_col", "byte_col", "date_col", "float_col", "int_col", "numeric_col", "string_col", "timestamp_col", "json_col"}, - vals: []interface{}{true, []uint8{0x74, 0x65, 0x73, 0x74}, getDate("2019-10-29"), 15.13, int64(100), big.NewRat(3994, 100), "Helloworld", getTime(t, "2019-10-29T05:30:00Z"), "{\"key1\": \"value1\", \"key2\": \"value2\"}"}, + vals: []interface{}{true, []uint8{0x74, 0x65, 0x73, 0x74}, getDate("2019-10-29"), 15.13, int64(100), *big.NewRat(3994, 100), "Helloworld", getTime(t, "2019-10-29T05:30:00Z"), "{\"key1\": \"value1\", \"key2\": \"value2\"}"}, }, {table: SINGERS_TABLE, cols: []string{"SingerId", "FirstName", "LastName"}, vals: []interface{}{int64(1), "fn1", "ln1"}}, {table: SINGERS_TABLE, cols: []string{"SingerId", "FirstName", "LastName"}, vals: []interface{}{int64(2), "fn2", "ln2"}}, @@ -150,7 +132,7 @@ func TestConvertData(t *testing.T) { {"date", ddl.Type{Name: ddl.Date}, "2019-10-29", getDate("2019-10-29")}, {"float64", ddl.Type{Name: ddl.Float64}, "42.6", float64(42.6)}, {"int64", ddl.Type{Name: ddl.Int64}, "42", int64(42)}, - {"numeric", ddl.Type{Name: ddl.Numeric}, "42.6", big.NewRat(426, 10)}, + {"numeric", ddl.Type{Name: ddl.Numeric}, "42.6", *big.NewRat(426, 10)}, {"string", ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, "eh", "eh"}, {"timestamp", ddl.Type{Name: ddl.Timestamp}, "2019-10-29 05:30:00", getTime(t, "2019-10-29T05:30:00Z")}, {"json", ddl.Type{Name: ddl.JSON}, "{\"key1\": \"value1\"}", "{\"key1\": \"value1\"}"}, @@ -158,10 +140,9 @@ func TestConvertData(t *testing.T) { tableName := "testtable" for _, tc := range singleColTests { col := "a" - conv := buildConv( - ddl.CreateTable{ - Name: tableName, - ColDefs: map[string]ddl.ColumnDef{col: ddl.ColumnDef{Name: col, T: tc.ty}}}) + conv := buildConv([]ddl.CreateTable{{ + Name: tableName, + ColDefs: map[string]ddl.ColumnDef{col: ddl.ColumnDef{Name: col, T: tc.ty}}}}) _, av, err := convertData(conv, "", tableName, []string{col}, []string{tc.in}) // NULL scenario. if tc.ev == nil { @@ -200,15 +181,46 @@ func TestConvertData(t *testing.T) { }, } for _, tc := range errorTests { - conv := buildConv(spTable) + conv := buildConv([]ddl.CreateTable{spTable}) _, _, err := convertData(conv, "", tableName, cols, tc.vals) assert.NotNil(t, err, tc.name) } } -func buildConv(spTable ddl.CreateTable) *internal.Conv { +func getCreateTable() []ddl.CreateTable { + return []ddl.CreateTable{ + { + Name: ALL_TYPES_TABLE, + ColNames: []string{"bool_col", "byte_col", "date_col", "float_col", "int_col", "numeric_col", "string_col", "timestamp_col", "json_col"}, + ColDefs: map[string]ddl.ColumnDef{ + "bool_col": {Name: "bool_col", T: ddl.Type{Name: ddl.Bool}}, + "byte_col": {Name: "byte_col", T: ddl.Type{Name: ddl.Bytes}}, + "date_col": {Name: "date_col", T: ddl.Type{Name: ddl.Date}}, + "float_col": {Name: "float_col", T: ddl.Type{Name: ddl.Float64}}, + "int_col": {Name: "int_col", T: ddl.Type{Name: ddl.Int64}}, + "numeric_col": {Name: "numeric_col", T: ddl.Type{Name: ddl.Numeric}}, + "string_col": {Name: "string_col", T: ddl.Type{Name: ddl.String}}, + "timestamp_col": {Name: "timestamp_col", T: ddl.Type{Name: ddl.Timestamp}}, + "json_col": {Name: "json_col", T: ddl.Type{Name: ddl.JSON}}, + }, + }, + { + Name: SINGERS_TABLE, + ColNames: []string{"SingerId", "FirstName", "LastName"}, + ColDefs: map[string]ddl.ColumnDef{ + "SingerId": {Name: "SingerId", T: ddl.Type{Name: ddl.Int64}}, + "FirstName": {Name: "FirstName", T: ddl.Type{Name: ddl.String}}, + "LastName": {Name: "LastName", T: ddl.Type{Name: ddl.String}}, + }, + }, + } +} + +func buildConv(spTables []ddl.CreateTable) *internal.Conv { conv := internal.MakeConv() - conv.SpSchema[spTable.Name] = spTable + for _, spTable := range spTables { + conv.SpSchema[spTable.Name] = spTable + } return conv } diff --git a/sources/spanner/infoschema.go b/sources/spanner/infoschema.go new file mode 100644 index 0000000000..7bb6f7d8cc --- /dev/null +++ b/sources/spanner/infoschema.go @@ -0,0 +1,339 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "fmt" + "math" + "sort" + "strconv" + "strings" + + "cloud.google.com/go/spanner" + _ "github.com/lib/pq" // we will use database/sql package instead of using this package directly + "google.golang.org/api/iterator" + + "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/schema" + "github.com/cloudspannerecosystem/harbourbridge/sources/common" + "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" +) + +// InfoSchemaImpl postgres specific implementation for InfoSchema. +type InfoSchemaImpl struct { + Client *spanner.Client + Ctx context.Context + TargetDb string +} + +// GetToDdl function below implement the common.InfoSchema interface. +func (isi InfoSchemaImpl) GetToDdl() common.ToDdl { + return ToDdlImpl{} +} + +// We leave the 3 functions below empty to be able to pass this as an infoSchema interface. We don't need these for now. +func (isi InfoSchemaImpl) ProcessData(conv *internal.Conv, srcTable string, srcSchema schema.Table, spTable string, spCols []string, spSchema ddl.CreateTable) error { + return nil +} + +func (isi InfoSchemaImpl) GetRowCount(table common.SchemaAndName) (int64, error) { + return 0, nil +} + +func (isi InfoSchemaImpl) GetRowsFromTable(conv *internal.Conv, srcTable string) (interface{}, error) { + return nil, nil +} + +// GetTableName returns table name. +func (isi InfoSchemaImpl) GetTableName(schema string, tableName string) string { + if isi.TargetDb == constants.TargetExperimentalPostgres { + if schema == "public" { // Drop public prefix for pg spanner. + return tableName + } + } else { + if schema == "" { + return tableName + } + } + return fmt.Sprintf("%s.%s", schema, tableName) +} + +// GetTables return list of tables in the selected database. +// TODO: All of the queries to get tables and table data should be in +// a single transaction to ensure we obtain a consistent snapshot of +// schema information and table data (pg_dump does something +// similar). +func (isi InfoSchemaImpl) GetTables() ([]common.SchemaAndName, error) { + q := `SELECT table_schema, table_name FROM information_schema.tables + WHERE table_type = 'BASE TABLE' AND table_schema = @tableSchema` + stmt := spanner.Statement{SQL: q} + if isi.TargetDb == constants.TargetExperimentalPostgres { + stmt.Params = map[string]interface{}{ + "tableSchema": "public", + } + } else { + stmt.Params = map[string]interface{}{ + "tableSchema": "", + } + } + iter := isi.Client.Single().Query(isi.Ctx, stmt) + defer iter.Stop() + + var tableSchema, tableName string + var tables []common.SchemaAndName + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("couldn't get tables: %w", err) + } + err = row.Columns(&tableSchema, &tableName) + if err != nil { + return nil, err + } + tables = append(tables, common.SchemaAndName{Schema: tableSchema, Name: tableName}) + } + return tables, nil +} + +// GetColumns returns a list of Column objects and names +func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) { + q := `SELECT column_name, spanner_type, is_nullable + FROM information_schema.columns + WHERE table_schema = @tableSchema AND table_name = @tableName + ORDER BY ordinal_position;` + stmt := spanner.Statement{ + SQL: q, + Params: map[string]interface{}{ + "tableSchema": table.Schema, + "tableName": table.Name, + }, + } + iter := isi.Client.Single().Query(isi.Ctx, stmt) + defer iter.Stop() + + colDefs := make(map[string]schema.Column) + var colNames []string + var colName, spannerType, isNullable string + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, nil, fmt.Errorf("couldn't get column info for table %s: %s", table.Name, err) + } + err = row.Columns(&colName, &spannerType, &isNullable) + if err != nil { + return nil, nil, fmt.Errorf("cannot read row for table %s: %s", table.Name, err) + } + ignored := schema.Ignored{} + for _, c := range constraints[colName] { + // c can be UNIQUE, PRIMARY KEY, FOREIGN KEY, + // or CHECK (based on msql, sql server, postgres docs). + // We've already filtered out PRIMARY KEY. + switch c { + case "CHECK": + ignored.Check = true + case "FOREIGN KEY", "PRIMARY KEY", "UNIQUE": + // Nothing to do here -- these are handled elsewhere. + } + } + c := schema.Column{ + Name: colName, + Type: toType(spannerType), + NotNull: common.ToNotNull(conv, isNullable), + } + colDefs[colName] = c + colNames = append(colNames, colName) + } + return colDefs, colNames, nil +} + +// GetConstraints returns a list of primary keys and by-column map of +// other constraints. Note: we need to preserve ordinal order of +// columns in primary key constraints. +// Note that foreign key constraints are handled in getForeignKeys. +func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.SchemaAndName) ([]string, map[string][]string, error) { + q := `SELECT k.column_name, t.constraint_type + FROM information_schema.table_constraints AS t + INNER JOIN information_schema.KEY_COLUMN_USAGE AS k + ON t.constraint_name = k.constraint_name AND t.constraint_schema = k.constraint_schema + WHERE k.table_schema = @tableSchema AND k.table_name = @tableName ORDER BY k.ordinal_position;` + stmt := spanner.Statement{ + SQL: q, + Params: map[string]interface{}{ + "tableSchema": table.Schema, + "tableName": table.Name, + }, + } + iter := isi.Client.Single().Query(isi.Ctx, stmt) + defer iter.Stop() + + var primaryKeys []string + var col, constraint string + m := make(map[string][]string) + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, nil, fmt.Errorf("couldn't get row: %w", err) + } + err = row.Columns(&col, &constraint) + if err != nil { + return nil, nil, err + } + if col == "" || constraint == "" { + conv.Unexpected("Got empty col or constraint") + continue + } + switch constraint { + case "PRIMARY KEY": + primaryKeys = append(primaryKeys, col) + default: + m[col] = append(m[col], constraint) + } + } + return primaryKeys, m, nil +} + +// GetForeignKeys returns a list of all the foreign key constraints. +func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.SchemaAndName) (foreignKeys []schema.ForeignKey, err error) { + q := `SELECT k.constraint_name, k.column_name, t.table_name, c.column_name + FROM information_schema.key_column_usage AS k + JOIN information_schema.constraint_column_usage AS c ON k.constraint_name = c.constraint_name + JOIN information_schema.table_constraints AS t ON k.constraint_name = t.constraint_name + WHERE t.constraint_type='FOREIGN KEY' AND t.table_schema = @tableSchema AND t.table_name = @tableName + ORDER BY k.constraint_name, k.ordinal_position;` + stmt := spanner.Statement{ + SQL: q, + Params: map[string]interface{}{ + "tableSchema": table.Schema, + "tableName": table.Name, + }, + } + iter := isi.Client.Single().Query(isi.Ctx, stmt) + defer iter.Stop() + + var col, refCol, fKeyName, refTable string + fKeys := make(map[string]common.FkConstraint) + var keyNames []string + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("couldn't get row: %w", err) + } + err = row.Columns(&fKeyName, &col, &refTable, &refCol) + if err != nil { + return nil, err + } + if _, found := fKeys[fKeyName]; found { + fk := fKeys[fKeyName] + fk.Cols = append(fk.Cols, col) + fk.Refcols = append(fk.Refcols, refCol) + fKeys[fKeyName] = fk + continue + } + fKeys[fKeyName] = common.FkConstraint{Name: fKeyName, Table: isi.GetTableName(table.Schema, refTable), Refcols: []string{refCol}, Cols: []string{col}} + keyNames = append(keyNames, fKeyName) + } + sort.Strings(keyNames) + for _, k := range keyNames { + // The query returns a crypted result for multi-col FKs. Currently for a FK from (a,b,c) -> (x,y,z), + // the returned rows like (a,x), (a,y), (a,z), (b,x), (b,y), (b,z), (c,x), (c,y), (c,z). + // Need to reduce it to (a,x), (b,y), (c,z). The logic below does that. + n := int(math.Sqrt(float64(len(fKeys[k].Cols)))) + cols, refcols := []string{}, []string{} + for i := 0; i < n; i++ { + cols = append(cols, fKeys[k].Cols[i*n]) + refcols = append(refcols, fKeys[k].Refcols[i]) + } + foreignKeys = append(foreignKeys, + schema.ForeignKey{ + Name: fKeys[k].Name, + Columns: cols, + ReferTable: fKeys[k].Table, + ReferColumns: refcols}) + } + return foreignKeys, nil +} + +// Skipped since we dont have a use case for storing indexes yet. +func (isi InfoSchemaImpl) GetIndexes(conv *internal.Conv, table common.SchemaAndName) ([]schema.Index, error) { + return []schema.Index{}, nil +} + +func (isi InfoSchemaImpl) GetInterleaveTables() (map[string]string, error) { + q := `SELECT table_name, parent_table_name FROM information_schema.tables + WHERE interleave_type = 'IN PARENT' AND table_type = 'BASE TABLE' AND table_schema = @tableSchema` + stmt := spanner.Statement{SQL: q} + if isi.TargetDb == constants.TargetExperimentalPostgres { + stmt.Params = map[string]interface{}{ + "tableSchema": "public", + } + } else { + stmt.Params = map[string]interface{}{ + "tableSchema": "", + } + } + iter := isi.Client.Single().Query(isi.Ctx, stmt) + defer iter.Stop() + + var tableName, parentTable string + parentTables := map[string]string{} + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("couldn't get tables: %w", err) + } + err = row.Columns(&tableName, &parentTable) + if err != nil { + return nil, err + } + parentTables[tableName] = parentTable + } + return parentTables, nil +} + +func toType(dataType string) schema.Type { + switch { + case strings.Contains(dataType, "ARRAY"): + typeLenStr := dataType[(strings.Index(dataType, "<") + 1):(len(dataType) - 1)] + schemaType := toType(typeLenStr) + schemaType.ArrayBounds = []int64{-1} + return schemaType + case strings.Contains(dataType, "("): + typeLenStr := dataType[strings.Index(dataType, "("):(len(dataType) - 1)] + if typeLenStr == "MAX" { + return schema.Type{Name: dataType, Mods: []int64{ddl.MaxLength}} + } + typeLen, _ := strconv.ParseInt(typeLenStr, 10, 64) + return schema.Type{Name: dataType, Mods: []int64{typeLen}} + default: + return schema.Type{Name: dataType} + } +} diff --git a/sources/spanner/toddl.go b/sources/spanner/toddl.go new file mode 100644 index 0000000000..0f5ca2c5b1 --- /dev/null +++ b/sources/spanner/toddl.go @@ -0,0 +1,79 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/schema" + "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" +) + +type ToDdlImpl struct { +} + +// ToSpannerType maps a scalar source schema type (defined by id and +// mods) into a Spanner type. This is the core source-to-Spanner type +// mapping. toSpannerType returns the Spanner type and a list of type +// conversion issues encountered. +// Functions below implement the common.ToDdl interface +func (tdi ToDdlImpl) ToSpannerType(conv *internal.Conv, columnType schema.Type) (ddl.Type, []internal.SchemaIssue) { + ty, issues := toSpannerTypeInternal(conv, columnType.Name, columnType.Mods) + if conv.TargetDb == constants.TargetExperimentalPostgres { + ty, issues = overrideExperimentalType(columnType, ty, issues) + } else { + ty.IsArray = len(columnType.ArrayBounds) == 1 + } + return ty, issues +} + +// toSpannerType maps a scalar source schema type (defined by id and +// mods) into a Spanner type. This is the core source-to-Spanner type +// mapping. toSpannerType returns the Spanner type and a list of type +// conversion issues encountered. +func toSpannerTypeInternal(conv *internal.Conv, id string, mods []int64) (ddl.Type, []internal.SchemaIssue) { + switch id { + case "BOOL": + return ddl.Type{Name: ddl.Bool}, nil + case "BYTES": + return ddl.Type{Name: ddl.Bytes, Len: mods[0]}, nil + case "DATE": + return ddl.Type{Name: ddl.Date}, nil + case "FLOAT64": + return ddl.Type{Name: ddl.Float64}, nil + case "INT64": + return ddl.Type{Name: ddl.Int64}, nil + case "JSON": + return ddl.Type{Name: ddl.JSON}, nil + case "NUMERIC": + return ddl.Type{Name: ddl.Numeric}, nil + case "STRING": + return ddl.Type{Name: ddl.String, Len: mods[0]}, nil + case "TIMESTAMP": + return ddl.Type{Name: ddl.Timestamp}, nil + } + return ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, []internal.SchemaIssue{internal.NoGoodType} +} + +// Override the types to map to experimental postgres types. +func overrideExperimentalType(columnType schema.Type, originalType ddl.Type, issues []internal.SchemaIssue) (ddl.Type, []internal.SchemaIssue) { + switch columnType.Name { + case "PG.NUMERIC": + return ddl.Type{Name: ddl.Numeric}, nil + case "PG.JSONB": + return ddl.Type{Name: ddl.JSON}, nil + } + return originalType, issues +} diff --git a/spanner/batchwriter.go b/spanner/writer/batchwriter.go similarity index 99% rename from spanner/batchwriter.go rename to spanner/writer/batchwriter.go index b27030bb72..c82a78bd3a 100644 --- a/spanner/batchwriter.go +++ b/spanner/writer/batchwriter.go @@ -15,7 +15,7 @@ // Package spanner provides high-level abstractions for working with // Cloud Spanner that are not available from the core Cloud Spanner // libraries. -package spanner +package writer import ( "fmt" diff --git a/spanner/batchwriter_test.go b/spanner/writer/batchwriter_test.go similarity index 99% rename from spanner/batchwriter_test.go rename to spanner/writer/batchwriter_test.go index 282de76ca7..4f8a503cb6 100644 --- a/spanner/batchwriter_test.go +++ b/spanner/writer/batchwriter_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package spanner +package writer import ( "errors" From afd697706d39137d317a6b7835bcf68137292a22 Mon Sep 17 00:00:00 2001 From: Deep1998 Date: Thu, 20 Jan 2022 01:11:38 +0530 Subject: [PATCH 2/3] Added tests --- cmd/cmd.go | 3 +- cmd/data.go | 2 + cmd/schema.go | 2 + cmd/schema_and_data.go | 2 + common/utils/utils.go | 37 ++++++++++++- sources/csv/data.go | 61 ++++----------------- sources/csv/data_test.go | 10 +++- sources/spanner/infoschema.go | 88 ++++++++++++++++-------------- sources/spanner/infoschema_test.go | 39 +++++++++++++ sources/spanner/toddl_test.go | 45 +++++++++++++++ testing/csv/integration_test.go | 19 ++++--- 11 files changed, 206 insertions(+), 102 deletions(-) create mode 100644 sources/spanner/infoschema_test.go create mode 100644 sources/spanner/toddl_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 588bd560af..287cb76f19 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -43,7 +43,8 @@ var ( // 3. Run data conversion (if schemaOnly is set to false) // 4. Generate report func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, schemaOnly, skipForeignKeys bool, schemaSampleSize int64, sessionJSON string, ioHelper *utils.IOStreams, outputFilePrefix string, now time.Time) error { - + // Cleanup hb tmp data directory in case residuals remain from prev runs. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) // Legacy mode is only supported for MySQL, PostgreSQL and DynamoDB if driver != "" && utils.IsValidDriver(driver) && !utils.IsLegacyModeSupportedDriver(driver) { return fmt.Errorf("legacy mode is not supported for drivers other than %s", strings.Join(utils.GetLegacyModeSupportedDrivers(), ", ")) diff --git a/cmd/data.go b/cmd/data.go index cc82957a2a..e88c06c767 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -60,6 +60,8 @@ func (cmd *DataCmd) SetFlags(f *flag.FlagSet) { } func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + // Cleanup hb tmp data directory in case residuals remain from prev runs. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) var err error defer func() { if err != nil { diff --git a/cmd/schema.go b/cmd/schema.go index 69fa89729c..42d9a0c06e 100644 --- a/cmd/schema.go +++ b/cmd/schema.go @@ -56,6 +56,8 @@ func (cmd *SchemaCmd) SetFlags(f *flag.FlagSet) { } func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + // Cleanup hb tmp data directory in case residuals remain from prev runs. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) var err error defer func() { if err != nil { diff --git a/cmd/schema_and_data.go b/cmd/schema_and_data.go index 6f7021891f..72b77ebdfc 100644 --- a/cmd/schema_and_data.go +++ b/cmd/schema_and_data.go @@ -58,6 +58,8 @@ func (cmd *SchemaAndDataCmd) SetFlags(f *flag.FlagSet) { } func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus { + // Cleanup hb tmp data directory in case residuals remain from prev runs. + os.RemoveAll(os.TempDir() + constants.HB_TMP_DIR) var err error defer func() { if err != nil { diff --git a/common/utils/utils.go b/common/utils/utils.go index ca4ae0c03a..1a1d193ef2 100644 --- a/common/utils/utils.go +++ b/common/utils/utils.go @@ -39,6 +39,12 @@ type IOStreams struct { BytesRead int64 } +// Harbourbridge accepts a manifest file in the form of a json which unmarshalls into the ManifestTables struct. +type ManifestTable struct { + Table_name string `json:"table_name"` + File_patterns []string `json:"file_patterns"` +} + // NewIOStreams returns a new IOStreams struct such that input stream is set // to open file descriptor for dumpFile if driver is PGDUMP or MYSQLDUMP. // Input stream defaults to stdin. Output stream is always set to stdout. @@ -124,6 +130,32 @@ func DownloadFromGCS(bucketName, filePath, tmpFile string) (*os.File, error) { return tmpfile, nil } +// PreloadGCSFiles downloads gcs files to tmp and updates the file paths in manifest with the local path. +func PreloadGCSFiles(tables []ManifestTable) ([]ManifestTable, error) { + for i, table := range tables { + for j, filePath := range table.File_patterns { + u, err := url.Parse(filePath) + if err != nil { + return nil, fmt.Errorf("unable parse file path %s for table %s", filePath, table.Table_name) + } + if u.Scheme == "gs" { + bucketName := u.Host + filePath := u.Path[1:] // removes "/" from beginning of path + tmpFile := strings.ReplaceAll(filePath, "/", ".") + // Files get downloaded to tmp dir. + fileLoc := os.TempDir() + constants.HB_TMP_DIR + "/" + tmpFile + _, err = DownloadFromGCS(bucketName, filePath, tmpFile) + if err != nil { + return nil, fmt.Errorf("cannot download gcs file: %s for table %s", filePath, table.Table_name) + } + tables[i].File_patterns[j] = fileLoc + fmt.Printf("Downloaded file: %s\n", fileLoc) + } + } + } + return tables, nil +} + // GetProject returns the cloud project we should use for accessing Spanner. // Use environment variable GCLOUD_PROJECT if it is set. // Otherwise, use the default project returned from gcloud. @@ -440,7 +472,10 @@ func ReadSpannerSchema(ctx context.Context, conv *internal.Conv, client *sp.Clie } parentTables, err := infoSchema.GetInterleaveTables() if err != nil { - return fmt.Errorf("error trying to fetch interleave table info from schema: %v", err) + // We should ideally throw an error here as it could potentially cause a lot of failed writes. + // We raise an unexpected error for now to make it compatible with the integration tests. + // In the emulator, the interleave_type column in not supported hence the query fails. + conv.Unexpected(fmt.Sprintf("error trying to fetch interleave table info from schema: %v", err)) } // Assign parents if any. for table, parent := range parentTables { diff --git a/sources/csv/data.go b/sources/csv/data.go index 6c5bbc2704..4dac126764 100644 --- a/sources/csv/data.go +++ b/sources/csv/data.go @@ -21,7 +21,6 @@ import ( "io" "io/ioutil" "math/big" - "net/url" "os" "strconv" "strings" @@ -36,25 +35,14 @@ import ( "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" ) -type Column struct { - Column_name string `json:"column_name"` - Type_name string `json:"type_name"` -} - -// Harbourbridge accepts a manifest file in the form of a json which unmarshalls into the Table struct. -type Table struct { - Table_name string `json:"table_name"` - File_patterns []string `json:"file_patterns"` -} - // GetCSVFiles finds the appropriate files paths and downloads gcs files in any. -func GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []Table, err error) { +func GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tables []utils.ManifestTable, err error) { // If manifest file not provided, we assume the csvs exist in the same directory // in table_name.csv format. if sourceProfile.Csv.Manifest == "" { fmt.Println("Manifest file not provided, checking for files named `[table_name].csv` in current working directory...") for t := range conv.SpSchema { - tables = append(tables, Table{Table_name: t, File_patterns: []string{fmt.Sprintf("%s.csv", t)}}) + tables = append(tables, utils.ManifestTable{Table_name: t, File_patterns: []string{fmt.Sprintf("%s.csv", t)}}) } } else { fmt.Println("Manifest file provided, reading csv file paths...") @@ -66,7 +54,7 @@ func GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tab } // Download gcs files if any. - tables, err = preloadGCSFiles(tables) + tables, err = utils.PreloadGCSFiles(tables) if err != nil { return nil, fmt.Errorf("gcs file download error: %v", err) } @@ -75,12 +63,12 @@ func GetCSVFiles(conv *internal.Conv, sourceProfile profiles.SourceProfile) (tab // loadManifest reads the manifest file and unmarshalls it into a list of Table struct. // It also performs certain checks on the manifest. -func loadManifest(conv *internal.Conv, manifestFile string) ([]Table, error) { +func loadManifest(conv *internal.Conv, manifestFile string) ([]utils.ManifestTable, error) { manifest, err := ioutil.ReadFile(manifestFile) if err != nil { return nil, fmt.Errorf("can't read manifest file due to: %v", err) } - tables := []Table{} + tables := []utils.ManifestTable{} err = json.Unmarshal(manifest, &tables) if err != nil { return nil, fmt.Errorf("unable to unmarshall json due to: %v", err) @@ -94,7 +82,7 @@ func loadManifest(conv *internal.Conv, manifestFile string) ([]Table, error) { // VerifyManifest performs certain prechecks on the structure of the manifest while populating the conv with // the ddl types. Also checks on valid file paths and empty CSVs are handled as conv.Unexpected errors later during processing. -func VerifyManifest(conv *internal.Conv, tables []Table) error { +func VerifyManifest(conv *internal.Conv, tables []utils.ManifestTable) error { if len(tables) == 0 { return fmt.Errorf("no tables found") } @@ -130,33 +118,8 @@ func VerifyManifest(conv *internal.Conv, tables []Table) error { return nil } -func preloadGCSFiles(tables []Table) ([]Table, error) { - for i, table := range tables { - for j, filePath := range table.File_patterns { - u, err := url.Parse(filePath) - if err != nil { - return nil, fmt.Errorf("unable parse file path %s for table %s", filePath, table.Table_name) - } - if u.Scheme == "gs" { - bucketName := u.Host - filePath := u.Path[1:] // removes "/" from beginning of path - tmpFile := strings.ReplaceAll(filePath, "/", ".") - // Files get downloaded to tmp dir. - fileLoc := os.TempDir() + constants.HB_TMP_DIR + "/" + tmpFile - _, err = utils.DownloadFromGCS(bucketName, filePath, tmpFile) - if err != nil { - return nil, fmt.Errorf("cannot download gcs file: %s for table %s", filePath, table.Table_name) - } - tables[i].File_patterns[j] = fileLoc - fmt.Printf("Downloaded file: %s\n", fileLoc) - } - } - } - return tables, nil -} - // SetRowStats calculates the number of rows per table. -func SetRowStats(conv *internal.Conv, tables []Table, delimiter rune) { +func SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) { for _, table := range tables { for _, filePath := range table.File_patterns { csvFile, err := os.Open(filePath) @@ -208,15 +171,15 @@ func getCSVDataRowCount(r *csvReader.Reader, colNames []string) (int64, error) { // ProcessCSV writes data across the tables provided in the manifest file. Each table's data can be provided // across multiple CSV files hence, the manifest accepts a list of file paths in the input. -func ProcessCSV(conv *internal.Conv, tables []Table, nullStr string, delimiter rune) error { +func ProcessCSV(conv *internal.Conv, tables []utils.ManifestTable, nullStr string, delimiter rune) error { orderedTableNames := ddl.OrderTables(conv.SpSchema) nameToFiles := map[string][]string{} for _, table := range tables { nameToFiles[table.Table_name] = table.File_patterns } - orderedTables := []Table{} + orderedTables := []utils.ManifestTable{} for _, name := range orderedTableNames { - orderedTables = append(orderedTables, Table{name, nameToFiles[name]}) + orderedTables = append(orderedTables, utils.ManifestTable{name, nameToFiles[name]}) } for _, table := range orderedTables { @@ -403,11 +366,11 @@ func convArray(spannerType ddl.Type, val string) (interface{}, error) { } s, err := processQuote(s) if err != nil { - return []spanner.NullFloat64{}, err + return []spanner.NullNumeric{}, err } n, err := convNumeric(s) if err != nil { - return []spanner.NullFloat64{}, err + return []spanner.NullNumeric{}, err } r = append(r, spanner.NullNumeric{Numeric: n, Valid: true}) } diff --git a/sources/csv/data_test.go b/sources/csv/data_test.go index a6f40235f6..0b2fc9bc94 100644 --- a/sources/csv/data_test.go +++ b/sources/csv/data_test.go @@ -9,6 +9,8 @@ import ( "time" "cloud.google.com/go/civil" + "cloud.google.com/go/spanner" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/internal" "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" "github.com/stretchr/testify/assert" @@ -29,8 +31,8 @@ const ( SINGERS_2_CSV string = SINGERS_TABLE + "_2.csv" ) -func getManifestTables() []Table { - return []Table{ +func getManifestTables() []utils.ManifestTable { + return []utils.ManifestTable{ { Table_name: ALL_TYPES_TABLE, File_patterns: []string{ALL_TYPES_CSV}, @@ -136,6 +138,10 @@ func TestConvertData(t *testing.T) { {"string", ddl.Type{Name: ddl.String, Len: ddl.MaxLength}, "eh", "eh"}, {"timestamp", ddl.Type{Name: ddl.Timestamp}, "2019-10-29 05:30:00", getTime(t, "2019-10-29T05:30:00Z")}, {"json", ddl.Type{Name: ddl.JSON}, "{\"key1\": \"value1\"}", "{\"key1\": \"value1\"}"}, + {"int_array", ddl.Type{Name: ddl.Int64, IsArray: true}, "{1,2,NULL}", []spanner.NullInt64{{Int64: int64(1), Valid: true}, {Int64: int64(2), Valid: true}, {Valid: false}}}, + {"string_array", ddl.Type{Name: ddl.String, IsArray: true}, "[ab,cd]", []spanner.NullString{{StringVal: "ab", Valid: true}, {StringVal: "cd", Valid: true}}}, + {"float_array", ddl.Type{Name: ddl.Float64, IsArray: true}, "{1.3,2.5}", []spanner.NullFloat64{{Float64: float64(1.3), Valid: true}, {Float64: float64(2.5), Valid: true}}}, + {"numeric_array", ddl.Type{Name: ddl.Numeric, IsArray: true}, "[1.7]", []spanner.NullNumeric{{Numeric: *big.NewRat(17, 10), Valid: true}}}, } tableName := "testtable" for _, tc := range singleColTests { diff --git a/sources/spanner/infoschema.go b/sources/spanner/infoschema.go index 7bb6f7d8cc..5cc701c27b 100644 --- a/sources/spanner/infoschema.go +++ b/sources/spanner/infoschema.go @@ -73,23 +73,14 @@ func (isi InfoSchemaImpl) GetTableName(schema string, tableName string) string { } // GetTables return list of tables in the selected database. -// TODO: All of the queries to get tables and table data should be in -// a single transaction to ensure we obtain a consistent snapshot of -// schema information and table data (pg_dump does something -// similar). func (isi InfoSchemaImpl) GetTables() ([]common.SchemaAndName, error) { q := `SELECT table_schema, table_name FROM information_schema.tables - WHERE table_type = 'BASE TABLE' AND table_schema = @tableSchema` - stmt := spanner.Statement{SQL: q} + WHERE table_type = 'BASE TABLE' AND table_schema = ''` if isi.TargetDb == constants.TargetExperimentalPostgres { - stmt.Params = map[string]interface{}{ - "tableSchema": "public", - } - } else { - stmt.Params = map[string]interface{}{ - "tableSchema": "", - } + q = `SELECT table_schema, table_name FROM information_schema.tables + WHERE table_type = 'BASE TABLE' AND table_schema = 'public'` } + stmt := spanner.Statement{SQL: q} iter := isi.Client.Single().Query(isi.Ctx, stmt) defer iter.Stop() @@ -116,13 +107,18 @@ func (isi InfoSchemaImpl) GetTables() ([]common.SchemaAndName, error) { func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAndName, constraints map[string][]string, primaryKeys []string) (map[string]schema.Column, []string, error) { q := `SELECT column_name, spanner_type, is_nullable FROM information_schema.columns - WHERE table_schema = @tableSchema AND table_name = @tableName + WHERE table_schema = '' AND table_name = @p1 ORDER BY ordinal_position;` + if isi.TargetDb == constants.TargetExperimentalPostgres { + q = `SELECT column_name, spanner_type, is_nullable + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = $1 + ORDER BY ordinal_position;` + } stmt := spanner.Statement{ SQL: q, Params: map[string]interface{}{ - "tableSchema": table.Schema, - "tableName": table.Name, + "p1": table.Name, }, } iter := isi.Client.Single().Query(isi.Ctx, stmt) @@ -141,13 +137,10 @@ func (isi InfoSchemaImpl) GetColumns(conv *internal.Conv, table common.SchemaAnd } err = row.Columns(&colName, &spannerType, &isNullable) if err != nil { - return nil, nil, fmt.Errorf("cannot read row for table %s: %s", table.Name, err) + return nil, nil, fmt.Errorf("cannot read row for table %s while reading columns: %s", table.Name, err) } ignored := schema.Ignored{} for _, c := range constraints[colName] { - // c can be UNIQUE, PRIMARY KEY, FOREIGN KEY, - // or CHECK (based on msql, sql server, postgres docs). - // We've already filtered out PRIMARY KEY. switch c { case "CHECK": ignored.Check = true @@ -175,12 +168,18 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem FROM information_schema.table_constraints AS t INNER JOIN information_schema.KEY_COLUMN_USAGE AS k ON t.constraint_name = k.constraint_name AND t.constraint_schema = k.constraint_schema - WHERE k.table_schema = @tableSchema AND k.table_name = @tableName ORDER BY k.ordinal_position;` + WHERE k.table_schema = '' AND k.table_name = @p1 ORDER BY k.ordinal_position;` + if isi.TargetDb == constants.TargetExperimentalPostgres { + q = `SELECT k.column_name, t.constraint_type + FROM information_schema.table_constraints AS t + INNER JOIN information_schema.KEY_COLUMN_USAGE AS k + ON t.constraint_name = k.constraint_name AND t.constraint_schema = k.constraint_schema + WHERE k.table_schema = 'public' AND k.table_name = $1 ORDER BY k.ordinal_position;` + } stmt := spanner.Statement{ SQL: q, Params: map[string]interface{}{ - "tableSchema": table.Schema, - "tableName": table.Name, + "p1": table.Name, }, } iter := isi.Client.Single().Query(isi.Ctx, stmt) @@ -195,7 +194,7 @@ func (isi InfoSchemaImpl) GetConstraints(conv *internal.Conv, table common.Schem break } if err != nil { - return nil, nil, fmt.Errorf("couldn't get row: %w", err) + return nil, nil, fmt.Errorf("couldn't get row while reading constraints: %w", err) } err = row.Columns(&col, &constraint) if err != nil { @@ -221,13 +220,20 @@ func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.Schem FROM information_schema.key_column_usage AS k JOIN information_schema.constraint_column_usage AS c ON k.constraint_name = c.constraint_name JOIN information_schema.table_constraints AS t ON k.constraint_name = t.constraint_name - WHERE t.constraint_type='FOREIGN KEY' AND t.table_schema = @tableSchema AND t.table_name = @tableName + WHERE t.constraint_type='FOREIGN KEY' AND t.table_schema = '' AND t.table_name = @p1 ORDER BY k.constraint_name, k.ordinal_position;` + if isi.TargetDb == constants.TargetExperimentalPostgres { + q = `SELECT k.constraint_name, k.column_name, t.table_name, c.column_name + FROM information_schema.key_column_usage AS k + JOIN information_schema.constraint_column_usage AS c ON k.constraint_name = c.constraint_name + JOIN information_schema.table_constraints AS t ON k.constraint_name = t.constraint_name + WHERE t.constraint_type='FOREIGN KEY' AND t.table_schema = 'public' AND t.table_name = $1 + ORDER BY k.constraint_name, k.ordinal_position;` + } stmt := spanner.Statement{ SQL: q, Params: map[string]interface{}{ - "tableSchema": table.Schema, - "tableName": table.Name, + "p1": table.Name, }, } iter := isi.Client.Single().Query(isi.Ctx, stmt) @@ -242,7 +248,7 @@ func (isi InfoSchemaImpl) GetForeignKeys(conv *internal.Conv, table common.Schem break } if err != nil { - return nil, fmt.Errorf("couldn't get row: %w", err) + return nil, fmt.Errorf("couldn't get row while fetching foreign keys: %w", err) } err = row.Columns(&fKeyName, &col, &refTable, &refCol) if err != nil { @@ -286,17 +292,12 @@ func (isi InfoSchemaImpl) GetIndexes(conv *internal.Conv, table common.SchemaAnd func (isi InfoSchemaImpl) GetInterleaveTables() (map[string]string, error) { q := `SELECT table_name, parent_table_name FROM information_schema.tables - WHERE interleave_type = 'IN PARENT' AND table_type = 'BASE TABLE' AND table_schema = @tableSchema` - stmt := spanner.Statement{SQL: q} + WHERE interleave_type = 'IN PARENT' AND table_type = 'BASE TABLE' AND table_schema = ''` if isi.TargetDb == constants.TargetExperimentalPostgres { - stmt.Params = map[string]interface{}{ - "tableSchema": "public", - } - } else { - stmt.Params = map[string]interface{}{ - "tableSchema": "", - } + q = `SELECT table_name, parent_table_name FROM information_schema.tables + WHERE interleave_type = 'IN PARENT' AND table_type = 'BASE TABLE' AND table_schema = 'public'` } + stmt := spanner.Statement{SQL: q} iter := isi.Client.Single().Query(isi.Ctx, stmt) defer iter.Stop() @@ -308,7 +309,7 @@ func (isi InfoSchemaImpl) GetInterleaveTables() (map[string]string, error) { break } if err != nil { - return nil, fmt.Errorf("couldn't get tables: %w", err) + return nil, fmt.Errorf("couldn't read row while fetching interleaved tables: %w", err) } err = row.Columns(&tableName, &parentTable) if err != nil { @@ -327,12 +328,15 @@ func toType(dataType string) schema.Type { schemaType.ArrayBounds = []int64{-1} return schemaType case strings.Contains(dataType, "("): - typeLenStr := dataType[strings.Index(dataType, "("):(len(dataType) - 1)] + idx := strings.Index(dataType, "(") + typeLenStr := dataType[(idx + 1):(len(dataType) - 1)] + var typeLen int64 if typeLenStr == "MAX" { - return schema.Type{Name: dataType, Mods: []int64{ddl.MaxLength}} + typeLen = ddl.MaxLength + } else { + typeLen, _ = strconv.ParseInt(typeLenStr, 10, 64) } - typeLen, _ := strconv.ParseInt(typeLenStr, 10, 64) - return schema.Type{Name: dataType, Mods: []int64{typeLen}} + return schema.Type{Name: dataType[:idx], Mods: []int64{typeLen}} default: return schema.Type{Name: dataType} } diff --git a/sources/spanner/infoschema_test.go b/sources/spanner/infoschema_test.go new file mode 100644 index 0000000000..682e2c71a3 --- /dev/null +++ b/sources/spanner/infoschema_test.go @@ -0,0 +1,39 @@ +package spanner + +import ( + "testing" + + "github.com/cloudspannerecosystem/harbourbridge/schema" + "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" + "github.com/stretchr/testify/assert" +) + +func TestToType(t *testing.T) { + testCases := []struct { + name string + dataType string + expColumnType schema.Type + }{ + // Scalar inputs. + {"bool", "BOOL", schema.Type{Name: "BOOL"}}, + {"int", "INT64", schema.Type{Name: "INT64"}}, + {"float", "FLOAT64", schema.Type{Name: "FLOAT64"}}, + {"date", "DATE", schema.Type{Name: "DATE"}}, + {"numeric", "NUMERIC", schema.Type{Name: "NUMERIC"}}, + {"json", "JSON", schema.Type{Name: "JSON"}}, + {"timestamp", "TIMESTAMP", schema.Type{Name: "TIMESTAMP"}}, + {"bytes", "BYTES(100)", schema.Type{Name: "BYTES", Mods: []int64{100}}}, + {"bytes", "BYTES(MAX)", schema.Type{Name: "BYTES", Mods: []int64{ddl.MaxLength}}}, + {"string", "STRING(100)", schema.Type{Name: "STRING", Mods: []int64{100}}}, + {"string", "STRING(MAX)", schema.Type{Name: "STRING", Mods: []int64{ddl.MaxLength}}}, + // Array types. + {"string_max_arr", "ARRAY", schema.Type{Name: "STRING", Mods: []int64{ddl.MaxLength}, ArrayBounds: []int64{-1}}}, + {"string_arr", "ARRAY", schema.Type{Name: "STRING", Mods: []int64{100}, ArrayBounds: []int64{-1}}}, + {"float_arr", "ARRAY", schema.Type{Name: "FLOAT64", ArrayBounds: []int64{-1}}}, + {"numeric_arr", "ARRAY", schema.Type{Name: "NUMERIC", ArrayBounds: []int64{-1}}}, + } + for _, tc := range testCases { + ty := toType(tc.dataType) + assert.Equal(t, tc.expColumnType, ty, tc.name) + } +} diff --git a/sources/spanner/toddl_test.go b/sources/spanner/toddl_test.go new file mode 100644 index 0000000000..6c0319bd00 --- /dev/null +++ b/sources/spanner/toddl_test.go @@ -0,0 +1,45 @@ +package spanner + +import ( + "testing" + + "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/schema" + "github.com/cloudspannerecosystem/harbourbridge/spanner/ddl" + "github.com/stretchr/testify/assert" +) + +func TestToSpannerType(t *testing.T) { + conv := internal.MakeConv() + toDDLImpl := ToDdlImpl{} + toDDLTests := []struct { + name string + pgTarget bool + columnType schema.Type + expDDLType ddl.Type + }{ + // Exact inputs. + {"bool", false, schema.Type{Name: "BOOL"}, ddl.Type{Name: ddl.Bool}}, + {"bytes", false, schema.Type{Name: "BYTES", Mods: []int64{100}}, ddl.Type{Name: ddl.Bytes, Len: 100}}, + {"date", false, schema.Type{Name: "DATE"}, ddl.Type{Name: ddl.Date}}, + {"float", false, schema.Type{Name: "FLOAT64"}, ddl.Type{Name: ddl.Float64}}, + {"int", false, schema.Type{Name: "INT64"}, ddl.Type{Name: ddl.Int64}}, + {"json", false, schema.Type{Name: "JSON"}, ddl.Type{Name: ddl.JSON}}, + {"numeric", false, schema.Type{Name: "NUMERIC"}, ddl.Type{Name: ddl.Numeric}}, + {"string", false, schema.Type{Name: "STRING", Mods: []int64{100}}, ddl.Type{Name: ddl.String, Len: 100}}, + {"timestamp", false, schema.Type{Name: "TIMESTAMP"}, ddl.Type{Name: ddl.Timestamp}}, + // PG target. + {"pg_numeric", true, schema.Type{Name: "PG.NUMERIC"}, ddl.Type{Name: ddl.Numeric}}, + {"pg_json", true, schema.Type{Name: "PG.JSONB"}, ddl.Type{Name: ddl.JSON}}, + } + for _, tc := range toDDLTests { + conv.TargetDb = constants.TargetSpanner + if tc.pgTarget { + conv.TargetDb = constants.TargetExperimentalPostgres + } + ty, err := toDDLImpl.ToSpannerType(conv, tc.columnType) + assert.Nil(t, err, tc.name) + assert.Equal(t, tc.expDDLType, ty, tc.name) + } +} diff --git a/testing/csv/integration_test.go b/testing/csv/integration_test.go index a5811337e6..7f43667af5 100644 --- a/testing/csv/integration_test.go +++ b/testing/csv/integration_test.go @@ -48,7 +48,7 @@ var ( ) const ( - ALL_TYPES_CSV string = "../../test_data/all_data_types.csv" + ALL_TYPES_CSV string = "all_data_types.csv" ) type SpannerRecord struct { @@ -61,6 +61,8 @@ type SpannerRecord struct { AttrString string AttrTimestamp time.Time AttrJson spanner.NullJSON + AttrStringArr []spanner.NullString + AttrInt64Arr []spanner.NullInt64 } func TestMain(m *testing.M) { @@ -137,6 +139,8 @@ func createSpannerSchema(t *testing.T, project, instance, dbName string) { "g STRING(50)," + "h TIMESTAMP," + "i JSON," + + "j ARRAY," + + "k ARRAY," + ") PRIMARY KEY(e)", } op, err := databaseAdmin.CreateDatabase(ctx, req) @@ -157,12 +161,11 @@ func TestIntegration_CSV_Command(t *testing.T) { dbName := "csv-test" dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) - manifest := "../../test_data/csv_manifest.json" writeCSVs(t) defer cleanupCSVs() createSpannerSchema(t, projectID, instanceID, dbName) - args := fmt.Sprintf("data -source=csv -source-profile='manifest=%s' -target-profile='instance=%s,dbname=%s'", manifest, instanceID, dbName) + args := fmt.Sprintf("data -source=csv -target-profile='instance=%s,dbname=%s'", instanceID, dbName) err := common.RunCommand(args, projectID) if err != nil { t.Fatal(err) @@ -181,8 +184,8 @@ func writeCSVs(t *testing.T) { { ALL_TYPES_CSV, []string{ - "a,b,c,d,e,f,g,h,i\n", - "true,test,2019-10-29,15.13,100,39.94,Helloworld,2019-10-29 05:30:00,\"{\"\"key1\"\": \"\"value1\"\", \"\"key2\"\": \"\"value2\"\"}\"", + "a,b,c,d,e,f,g,h,i,j,k\n", + "true,test,2019-10-29,15.13,100,39.94,Helloworld,2019-10-29 05:30:00,\"{\"\"key1\"\": \"\"value1\"\", \"\"key2\"\": \"\"value2\"\"}\",\"{ab,cd}\",\"[1,2]\"", }, }, } @@ -225,11 +228,13 @@ func checkRow(ctx context.Context, t *testing.T, client *spanner.Client) { AttrString: "Helloworld", AttrTimestamp: getTime(t, "2019-10-29T05:30:00Z"), AttrJson: spanner.NullJSON{Valid: true}, + AttrStringArr: []spanner.NullString{{StringVal: "ab", Valid: true}, {StringVal: "cd", Valid: true}}, + AttrInt64Arr: []spanner.NullInt64{{Int64: int64(1), Valid: true}, {Int64: int64(2), Valid: true}}, } json.Unmarshal([]byte("{\"key1\": \"value1\", \"key2\": \"value2\"}"), &wantRecord.AttrJson.Value) gotRecord := SpannerRecord{} - stmt := spanner.Statement{SQL: `SELECT a, b, c, d, e, f, g, h, i FROM all_data_types`} + stmt := spanner.Statement{SQL: `SELECT a, b, c, d, e, f, g, h, i, j, k FROM all_data_types`} iter := client.Single().Query(ctx, stmt) defer iter.Stop() for { @@ -245,7 +250,7 @@ func checkRow(ctx context.Context, t *testing.T, client *spanner.Client) { // We don't create big.Rat fields in the SpannerRecord structs // because cmp.Equal cannot compare big.Rat fields automatically. var AttrNumeric big.Rat - if err := row.Columns(&gotRecord.AttrBool, &gotRecord.AttrBytes, &gotRecord.AttrDate, &gotRecord.AttrFloat, &gotRecord.AttrInt, &AttrNumeric, &gotRecord.AttrString, &gotRecord.AttrTimestamp, &gotRecord.AttrJson); err != nil { + if err := row.Columns(&gotRecord.AttrBool, &gotRecord.AttrBytes, &gotRecord.AttrDate, &gotRecord.AttrFloat, &gotRecord.AttrInt, &AttrNumeric, &gotRecord.AttrString, &gotRecord.AttrTimestamp, &gotRecord.AttrJson, &gotRecord.AttrStringArr, &gotRecord.AttrInt64Arr); err != nil { log.Println("Error reading into variables: ", err) t.Fatal(err) break From 807b16c8439dca83527025f28f5cfdb7f5c430eb Mon Sep 17 00:00:00 2001 From: Deep1998 Date: Tue, 25 Jan 2022 21:07:14 +0530 Subject: [PATCH 3/3] Added check to match num of cols --- conversion/conversion.go | 5 ++++- sources/csv/data.go | 11 +++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/conversion/conversion.go b/conversion/conversion.go index d8cd85958f..edf3511d6a 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -306,7 +306,10 @@ func dataFromCSV(ctx context.Context, sourceProfile profiles.SourceProfile, targ } // Find the number of rows in each csv file for generating stats. - csv.SetRowStats(conv, tables, delimiter) + err = csv.SetRowStats(conv, tables, delimiter) + if err != nil { + return nil, err + } totalRows := conv.Rows() p := internal.NewProgress(totalRows, "Writing data to Spanner", internal.Verbose(), false) rows := int64(0) diff --git a/sources/csv/data.go b/sources/csv/data.go index 4dac126764..131396d33d 100644 --- a/sources/csv/data.go +++ b/sources/csv/data.go @@ -119,19 +119,18 @@ func VerifyManifest(conv *internal.Conv, tables []utils.ManifestTable) error { } // SetRowStats calculates the number of rows per table. -func SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) { +func SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter rune) error { for _, table := range tables { for _, filePath := range table.File_patterns { csvFile, err := os.Open(filePath) if err != nil { - fmt.Printf("can't read csv file: %s due to: %v\n", filePath, err) + return fmt.Errorf("can't read csv file: %s due to: %v", filePath, err) } r := csvReader.NewReader(csvFile) r.Comma = delimiter count, err := getCSVDataRowCount(r, conv.SpSchema[table.Table_name].ColNames) if err != nil { - conv.Unexpected(fmt.Sprintf("Couldn't get number of rows for table %s", table.Table_name)) - continue + return fmt.Errorf("error reading file %s for table %s: %v", filePath, table.Table_name, err) } if count == 0 { conv.Unexpected(fmt.Sprintf("error processing table %s: file %s is empty.", table.Table_name, filePath)) @@ -140,6 +139,7 @@ func SetRowStats(conv *internal.Conv, tables []utils.ManifestTable, delimiter ru conv.Stats.Rows[table.Table_name] += count } } + return nil } // getCSVDataRowCount returns the number of data rows in the CSV file. This excludes the headers if present. @@ -152,6 +152,9 @@ func getCSVDataRowCount(r *csvReader.Reader, colNames []string) (int64, error) { if err != nil { return count, fmt.Errorf("can't read csv headers for col names due to: %v", err) } + if len(srcCols) != len(colNames) { + return 0, fmt.Errorf("found %d columns in csv, expected %d as per Spanner schema", len(srcCols), len(colNames)) + } // If the row read was not a header, increase count. if !utils.CheckEqualSets(srcCols, colNames) { count += 1