diff --git a/cmd/cmd.go b/cmd/cmd.go index ebd4a7c62b..1f4470d68f 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -20,8 +20,10 @@ import ( "fmt" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/profiles" ) var ( @@ -37,13 +39,20 @@ var ( // 2. Create database (if schemaOnly is set to false) // 3. Run data conversion (if schemaOnly is set to false) // 4. Generate report -func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, schemaOnly, skipForeignKeys bool, schemaSampleSize int64, sessionJSON string, ioHelper *conversion.IOStreams, outputFilePrefix string, now time.Time) error { +func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, schemaOnly, skipForeignKeys bool, schemaSampleSize int64, sessionJSON string, ioHelper *utils.IOStreams, outputFilePrefix string, now time.Time) error { var conv *internal.Conv var err error + // Creating profiles from legacy flags. We only pass schema-sample-size here because thats the + // only flag passed through the arguments. Dumpfile params are contained within ioHelper + // and direct connect params will be fetched from the env variables. + sourceProfile, _ := profiles.NewSourceProfile(fmt.Sprintf("schema-sample-size=%d", schemaSampleSize), driver) + sourceProfile.Driver = driver + targetProfile, _ := profiles.NewTargetProfile("") + targetProfile.TargetDb = targetDb if !dataOnly { // We pass an empty string to the sqlConnectionStr parameter as this is the legacy codepath, // which reads the environment variables and constructs the string later on. - conv, err = conversion.SchemaConv(driver, "", targetDb, ioHelper, schemaSampleSize) + conv, err = conversion.SchemaConv(sourceProfile, targetProfile, ioHelper) if err != nil { return err } @@ -64,9 +73,9 @@ func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, return err } } - adminClient, err := conversion.NewDatabaseAdminClient(ctx) + adminClient, err := utils.NewDatabaseAdminClient(ctx) if err != nil { - return fmt.Errorf("can't create admin client: %w", conversion.AnalyzeError(err, dbURI)) + return fmt.Errorf("can't create admin client: %w", utils.AnalyzeError(err, dbURI)) } defer adminClient.Close() err = conversion.CreateOrUpdateDatabase(ctx, adminClient, dbURI, conv, ioHelper.Out) @@ -74,14 +83,14 @@ func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, return fmt.Errorf("can't create/update database: %v", err) } - client, err := conversion.GetClient(ctx, dbURI) + client, err := utils.GetClient(ctx, dbURI) if err != nil { return fmt.Errorf("can't create client for db %s: %v", dbURI, err) } // We pass an empty string to the sqlConnectionStr parameter as this is the legacy codepath, // which reads the environment variables and constructs the string later on. - bw, err := conversion.DataConv(driver, "", ioHelper, client, conv, dataOnly, schemaSampleSize) + bw, err := conversion.DataConv(sourceProfile, targetProfile, ioHelper, client, conv, dataOnly) if err != nil { return fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) } @@ -90,7 +99,7 @@ func CommandLine(ctx context.Context, driver, targetDb, dbURI string, dataOnly, return fmt.Errorf("can't perform update schema on db %s with foreign keys: %v", dbURI, err) } } - banner := conversion.GetBanner(now, dbURI) + banner := utils.GetBanner(now, dbURI) conversion.Report(driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, outputFilePrefix+reportFile, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, outputFilePrefix+badDataFile, ioHelper.Out) return nil diff --git a/cmd/common.go b/cmd/common.go deleted file mode 100644 index 8e0d9f3679..0000000000 --- a/cmd/common.go +++ /dev/null @@ -1,111 +0,0 @@ -package cmd - -import ( - "context" - "encoding/csv" - "fmt" - "os" - "strings" - "time" - - "github.com/cloudspannerecosystem/harbourbridge/conversion" -) - -// Parses input string `s` as a map of key-value pairs. It's expected that the -// input string is of the form "key1=value1,key2=value2,..." etc. Return error -// otherwise. -func parseProfile(s string) (map[string]string, error) { - params := make(map[string]string) - if len(s) == 0 { - return params, nil - } - - // We use CSV reader to parse key=value pairs separated by a comma to - // handle the case where a value may contain a comma within a quote. We - // expect exactly one record to be returned. - r := csv.NewReader(strings.NewReader(s)) - r.Comma = ',' - r.TrimLeadingSpace = true - records, err := r.ReadAll() - if err != nil { - return params, err - } - if len(records) > 1 { - return params, fmt.Errorf("contains invalid newline characters") - } - - for _, kv := range records[0] { - s := strings.Split(strings.TrimSpace(kv), "=") - if len(s) != 2 { - return params, fmt.Errorf("invalid key=value pair (expected format: key1=value1): %v", kv) - } - if _, ok := params[s[0]]; ok { - return params, fmt.Errorf("duplicate key found: %v", s[0]) - } - params[s[0]] = s[1] - } - return params, nil -} - -func getResourceIds(ctx context.Context, targetProfile TargetProfile, now time.Time, driverName string, out *os.File) (string, string, string, error) { - var err error - project := targetProfile.conn.sp.project - if project == "" { - project, err = conversion.GetProject() - if err != nil { - return "", "", "", fmt.Errorf("can't get project: %v", err) - } - } - fmt.Println("Using Google Cloud project:", project) - - instance := targetProfile.conn.sp.instance - if instance == "" { - instance, err = conversion.GetInstance(ctx, project, out) - if err != nil { - return "", "", "", fmt.Errorf("can't get instance: %v", err) - } - } - fmt.Println("Using Cloud Spanner instance:", instance) - conversion.PrintPermissionsWarning(driverName, out) - - dbName := targetProfile.conn.sp.dbname - if dbName == "" { - dbName, err = conversion.GetDatabaseName(driverName, now) - if err != nil { - return "", "", "", fmt.Errorf("can't get database name: %v", err) - } - } - return project, instance, dbName, err -} - -func getSQLConnectionStr(sourceProfile SourceProfile) string { - sqlConnectionStr := "" - if sourceProfile.ty == SourceProfileTypeConnection { - switch sourceProfile.conn.ty { - case SourceProfileConnectionTypeMySQL: - connParams := sourceProfile.conn.mysql - return conversion.GetMYSQLConnectionStr(connParams.host, connParams.port, connParams.user, connParams.pwd, connParams.db) - case SourceProfileConnectionTypePostgreSQL: - connParams := sourceProfile.conn.pg - return conversion.GetPGSQLConnectionStr(connParams.host, connParams.port, connParams.user, connParams.pwd, connParams.db) - case SourceProfileConnectionTypeDynamoDB: - // For DynamoDB, client provided by aws-sdk reads connection credentials from env variables only. - // Thus, there is no need to create sqlConnectionStr for the same. We instead set the env variables - // programmatically if not set. - return "" - } - } - return sqlConnectionStr -} - -func getSchemaSampleSize(sourceProfile SourceProfile) int64 { - schemaSampleSize := int64(100000) - if sourceProfile.ty == SourceProfileTypeConnection { - if sourceProfile.conn.ty == SourceProfileConnectionTypeDynamoDB { - if sourceProfile.conn.dydb.schemaSampleSize != 0 { - schemaSampleSize = sourceProfile.conn.dydb.schemaSampleSize - } - } - } - return schemaSampleSize -} diff --git a/cmd/data.go b/cmd/data.go index 0b40afd72c..350a86a968 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -8,8 +8,10 @@ import ( "path" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/profiles" "github.com/google/subcommands" ) @@ -64,32 +66,32 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface } }() - sourceProfile, err := NewSourceProfile(cmd.sourceProfile, cmd.source) + sourceProfile, err := profiles.NewSourceProfile(cmd.sourceProfile, cmd.source) if err != nil { return subcommands.ExitUsageError } - driverName, err := sourceProfile.ToLegacyDriver(cmd.source) + sourceProfile.Driver, err = sourceProfile.ToLegacyDriver(cmd.source) if err != nil { return subcommands.ExitUsageError } - targetProfile, err := NewTargetProfile(cmd.targetProfile) + targetProfile, err := profiles.NewTargetProfile(cmd.targetProfile) if err != nil { return subcommands.ExitUsageError } - targetDb := targetProfile.ToLegacyTargetDb() + targetProfile.TargetDb = targetProfile.ToLegacyTargetDb() dumpFilePath := "" - if sourceProfile.ty == SourceProfileTypeFile && (sourceProfile.file.format == "" || sourceProfile.file.format == "dump") { - dumpFilePath = sourceProfile.file.path + if sourceProfile.Ty == profiles.SourceProfileTypeFile && (sourceProfile.File.Format == "" || sourceProfile.File.Format == "dump") { + dumpFilePath = sourceProfile.File.Path } - ioHelper := conversion.NewIOStreams(driverName, dumpFilePath) + ioHelper := utils.NewIOStreams(sourceProfile.Driver, dumpFilePath) if ioHelper.SeekableIn != nil { defer ioHelper.In.Close() } now := time.Now() - project, instance, dbName, err := getResourceIds(ctx, targetProfile, now, driverName, ioHelper.Out) + project, instance, dbName, err := profiles.GetResourceIds(ctx, targetProfile, now, sourceProfile.Driver, ioHelper.Out) if err != nil { return subcommands.ExitUsageError } @@ -105,18 +107,18 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface if err != nil { return subcommands.ExitUsageError } - if targetDb != "" && conv.TargetDb != targetDb { - err = fmt.Errorf("running data migration for Spanner dialect: %v, whereas schema mapping was done for dialect: %v", targetDb, conv.TargetDb) + if targetProfile.TargetDb != "" && conv.TargetDb != targetProfile.TargetDb { + err = fmt.Errorf("running data migration for Spanner dialect: %v, whereas schema mapping was done for dialect: %v", targetProfile.TargetDb, conv.TargetDb) return subcommands.ExitUsageError } - adminClient, err := conversion.NewDatabaseAdminClient(ctx) + adminClient, err := utils.NewDatabaseAdminClient(ctx) if err != nil { - err = fmt.Errorf("can't create admin client: %w", conversion.AnalyzeError(err, dbURI)) + err = fmt.Errorf("can't create admin client: %w", utils.AnalyzeError(err, dbURI)) return subcommands.ExitFailure } defer adminClient.Close() - client, err := conversion.GetClient(ctx, dbURI) + client, err := utils.GetClient(ctx, dbURI) if err != nil { err = fmt.Errorf("can't create client for db %s: %v", dbURI, err) return subcommands.ExitFailure @@ -129,7 +131,7 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface return subcommands.ExitFailure } - bw, err := conversion.DataConv(driverName, getSQLConnectionStr(sourceProfile), &ioHelper, client, conv, true, getSchemaSampleSize(sourceProfile)) + bw, err := conversion.DataConv(sourceProfile, targetProfile, &ioHelper, client, conv, true) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return subcommands.ExitFailure @@ -140,8 +142,8 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface return subcommands.ExitFailure } } - banner := conversion.GetBanner(now, dbURI) - conversion.Report(driverName, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix+reportFile, ioHelper.Out) + banner := utils.GetBanner(now, dbURI) + conversion.Report(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix+reportFile, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, cmd.filePrefix+badDataFile, ioHelper.Out) return subcommands.ExitSuccess } diff --git a/cmd/schema.go b/cmd/schema.go index d1e9d44ca6..1d5f9cc08c 100644 --- a/cmd/schema.go +++ b/cmd/schema.go @@ -8,8 +8,10 @@ import ( "path" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/profiles" "github.com/google/subcommands" ) @@ -60,33 +62,33 @@ func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfa } }() - sourceProfile, err := NewSourceProfile(cmd.sourceProfile, cmd.source) + sourceProfile, err := profiles.NewSourceProfile(cmd.sourceProfile, cmd.source) if err != nil { return subcommands.ExitUsageError } - driverName, err := sourceProfile.ToLegacyDriver(cmd.source) + sourceProfile.Driver, err = sourceProfile.ToLegacyDriver(cmd.source) if err != nil { return subcommands.ExitUsageError } - targetProfile, err := NewTargetProfile(cmd.targetProfile) + targetProfile, err := profiles.NewTargetProfile(cmd.targetProfile) if err != nil { return subcommands.ExitUsageError } - targetDb := targetProfile.ToLegacyTargetDb() + targetProfile.TargetDb = targetProfile.ToLegacyTargetDb() dumpFilePath := "" - if sourceProfile.ty == SourceProfileTypeFile && (sourceProfile.file.format == "" || sourceProfile.file.format == "dump") { - dumpFilePath = sourceProfile.file.path + if sourceProfile.Ty == profiles.SourceProfileTypeFile && (sourceProfile.File.Format == "" || sourceProfile.File.Format == "dump") { + dumpFilePath = sourceProfile.File.Path } - ioHelper := conversion.NewIOStreams(driverName, dumpFilePath) + ioHelper := utils.NewIOStreams(sourceProfile.Driver, dumpFilePath) if ioHelper.SeekableIn != nil { defer ioHelper.In.Close() } // If filePrefix not explicitly set, use generated dbName. if cmd.filePrefix == "" { - dbName, err := conversion.GetDatabaseName(driverName, time.Now()) + dbName, err := utils.GetDatabaseName(sourceProfile.Driver, time.Now()) if err != nil { err = fmt.Errorf("can't generate database name for prefix: %v", err) return subcommands.ExitFailure @@ -95,7 +97,7 @@ func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfa } var conv *internal.Conv - conv, err = conversion.SchemaConv(driverName, getSQLConnectionStr(sourceProfile), targetDb, &ioHelper, getSchemaSampleSize(sourceProfile)) + conv, err = conversion.SchemaConv(sourceProfile, targetProfile, &ioHelper) if err != nil { return subcommands.ExitFailure } @@ -103,6 +105,6 @@ func (cmd *SchemaCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interfa now := time.Now() conversion.WriteSchemaFile(conv, now, cmd.filePrefix+schemaFile, ioHelper.Out) conversion.WriteSessionFile(conv, cmd.filePrefix+sessionFile, ioHelper.Out) - conversion.Report(driverName, nil, ioHelper.BytesRead, "", conv, cmd.filePrefix+reportFile, ioHelper.Out) + conversion.Report(sourceProfile.Driver, nil, ioHelper.BytesRead, "", conv, cmd.filePrefix+reportFile, ioHelper.Out) return subcommands.ExitSuccess } diff --git a/cmd/schema_and_data.go b/cmd/schema_and_data.go index 47657e3ace..b6bdd17cd0 100644 --- a/cmd/schema_and_data.go +++ b/cmd/schema_and_data.go @@ -8,8 +8,10 @@ import ( "path" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/profiles" "github.com/google/subcommands" ) @@ -62,26 +64,26 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... } }() - sourceProfile, err := NewSourceProfile(cmd.sourceProfile, cmd.source) + sourceProfile, err := profiles.NewSourceProfile(cmd.sourceProfile, cmd.source) if err != nil { return subcommands.ExitUsageError } - driverName, err := sourceProfile.ToLegacyDriver(cmd.source) + sourceProfile.Driver, err = sourceProfile.ToLegacyDriver(cmd.source) if err != nil { return subcommands.ExitUsageError } - targetProfile, err := NewTargetProfile(cmd.targetProfile) + targetProfile, err := profiles.NewTargetProfile(cmd.targetProfile) if err != nil { return subcommands.ExitUsageError } - targetDb := targetProfile.ToLegacyTargetDb() + targetProfile.TargetDb = targetProfile.ToLegacyTargetDb() dumpFilePath := "" - if sourceProfile.ty == SourceProfileTypeFile && (sourceProfile.file.format == "" || sourceProfile.file.format == "dump") { - dumpFilePath = sourceProfile.file.path + if sourceProfile.Ty == profiles.SourceProfileTypeFile && (sourceProfile.File.Format == "" || sourceProfile.File.Format == "dump") { + dumpFilePath = sourceProfile.File.Path } - ioHelper := conversion.NewIOStreams(driverName, dumpFilePath) + ioHelper := utils.NewIOStreams(sourceProfile.Driver, dumpFilePath) if ioHelper.SeekableIn != nil { defer ioHelper.In.Close() } @@ -90,38 +92,36 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... // If filePrefix not explicitly set, use dbName as prefix. if cmd.filePrefix == "" { - dbName, err := conversion.GetDatabaseName(driverName, now) + dbName, err := utils.GetDatabaseName(sourceProfile.Driver, now) if err != nil { panic(fmt.Errorf("can't generate database name for prefix: %v", err)) } cmd.filePrefix = dbName + "." } - sqlConnectionStr := getSQLConnectionStr(sourceProfile) - schemaSampleSize := getSchemaSampleSize(sourceProfile) var conv *internal.Conv - conv, err = conversion.SchemaConv(driverName, sqlConnectionStr, targetDb, &ioHelper, schemaSampleSize) + conv, err = conversion.SchemaConv(sourceProfile, targetProfile, &ioHelper) if err != nil { panic(err) } conversion.WriteSchemaFile(conv, now, cmd.filePrefix+schemaFile, ioHelper.Out) conversion.WriteSessionFile(conv, cmd.filePrefix+sessionFile, ioHelper.Out) - conversion.Report(driverName, nil, ioHelper.BytesRead, "", conv, cmd.filePrefix+reportFile, ioHelper.Out) + conversion.Report(sourceProfile.Driver, nil, ioHelper.BytesRead, "", conv, cmd.filePrefix+reportFile, ioHelper.Out) - project, instance, dbName, err := getResourceIds(ctx, targetProfile, now, driverName, ioHelper.Out) + project, instance, dbName, err := profiles.GetResourceIds(ctx, targetProfile, now, sourceProfile.Driver, ioHelper.Out) if err != nil { return subcommands.ExitUsageError } dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, dbName) - adminClient, err := conversion.NewDatabaseAdminClient(ctx) + adminClient, err := utils.NewDatabaseAdminClient(ctx) if err != nil { - err = fmt.Errorf("can't create admin client: %w", conversion.AnalyzeError(err, dbURI)) + err = fmt.Errorf("can't create admin client: %w", utils.AnalyzeError(err, dbURI)) return subcommands.ExitFailure } defer adminClient.Close() - client, err := conversion.GetClient(ctx, dbURI) + client, err := utils.GetClient(ctx, dbURI) if err != nil { err = fmt.Errorf("can't create client for db %s: %v", dbURI, err) return subcommands.ExitFailure @@ -134,7 +134,7 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... return subcommands.ExitFailure } - bw, err := conversion.DataConv(driverName, sqlConnectionStr, &ioHelper, client, conv, true, schemaSampleSize) + bw, err := conversion.DataConv(sourceProfile, targetProfile, &ioHelper, client, conv, true) if err != nil { err = fmt.Errorf("can't finish data conversion for db %s: %v", dbURI, err) return subcommands.ExitFailure @@ -145,8 +145,8 @@ func (cmd *SchemaAndDataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ... return subcommands.ExitFailure } } - banner := conversion.GetBanner(now, dbURI) - conversion.Report(driverName, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix+reportFile, ioHelper.Out) + banner := utils.GetBanner(now, dbURI) + conversion.Report(sourceProfile.Driver, bw.DroppedRowsByTable(), ioHelper.BytesRead, banner, conv, cmd.filePrefix+reportFile, ioHelper.Out) conversion.WriteBadData(bw, conv, banner, cmd.filePrefix+badDataFile, ioHelper.Out) return subcommands.ExitSuccess } diff --git a/common/utils/utils.go b/common/utils/utils.go new file mode 100644 index 0000000000..7568bec272 --- /dev/null +++ b/common/utils/utils.go @@ -0,0 +1,379 @@ +// Package utils contains common helper functions used across multiple other packages. +// Utils should not import any harbourbridge packages. +package utils + +import ( + "bufio" + "context" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "log" + "net/url" + "os" + "os/exec" + "strings" + "syscall" + "time" + + sp "cloud.google.com/go/spanner" + database "cloud.google.com/go/spanner/admin/database/apiv1" + instance "cloud.google.com/go/spanner/admin/instance/apiv1" + "cloud.google.com/go/storage" + "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "golang.org/x/crypto/ssh/terminal" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" +) + +// IOStreams is a struct that contains the file descriptor for dumpFile. +type IOStreams struct { + In, SeekableIn, Out *os.File + BytesRead int64 +} + +// NewIOStreams returns a new IOStreams struct such that input stream is set +// to open file descriptor for dumpFile if driver is PGDUMP or MYSQLDUMP. +// Input stream defaults to stdin. Output stream is always set to stdout. +func NewIOStreams(driver string, dumpFile string) IOStreams { + io := IOStreams{In: os.Stdin, Out: os.Stdout} + u, err := url.Parse(dumpFile) + if err != nil { + fmt.Printf("parseFilePath: unable parse file path for dumpfile %s", dumpFile) + log.Fatal(err) + } + if (driver == constants.PGDUMP || driver == constants.MYSQLDUMP) && dumpFile != "" { + fmt.Printf("\nLoading dump file from path: %s\n", dumpFile) + var f *os.File + var err error + if u.Scheme == "gs" { + bucketName := u.Host + filePath := u.Path[1:] // removes "/" from beginning of path + f, err = downloadFromGCS(bucketName, filePath) + } else { + f, err = os.Open(dumpFile) + } + if err != nil { + fmt.Printf("\nError reading dump file: %v err:%v\n", dumpFile, err) + log.Fatal(err) + } + io.In = f + } + return io +} + +// downloadFromGCS returns the dump file that is downloaded from GCS +func downloadFromGCS(bucketName string, filePath string) (*os.File, error) { + ctx := context.Background() + + client, err := storage.NewClient(ctx) + if err != nil { + fmt.Printf("Failed to create GCS client for bucket %q", bucketName) + log.Fatal(err) + } + defer client.Close() + + bucket := client.Bucket(bucketName) + rc, err := bucket.Object(filePath).NewReader(ctx) + if err != nil { + fmt.Printf("readFile: unable to open file from bucket %q, file %q: %v", bucketName, filePath, err) + log.Fatal(err) + return nil, err + } + defer rc.Close() + r := bufio.NewReader(rc) + + tmpfile, err := ioutil.TempFile("", "harbourbridge.gcs.data") + if err != nil { + fmt.Printf("saveFile: unable to open temporary file to save dump file from GCS bucket %v", err) + log.Fatal(err) + return nil, err + } + syscall.Unlink(tmpfile.Name()) // File will be deleted when this process exits. + + fmt.Printf("\nDownloading dump file from GCS bucket %s, path %s\n", bucketName, filePath) + buffer := make([]byte, 1024) + for { + // read a chunk + n, err := r.Read(buffer[:cap(buffer)]) + + if err != nil && err != io.EOF { + fmt.Printf("readFile: unable to read entire dump file from bucket %s, file %s: %v", bucketName, filePath, err) + log.Fatal(err) + return nil, err + } + if n == 0 && err == io.EOF { + break + } + + // write a chunk + if _, err = tmpfile.Write(buffer[:n]); err != nil { + fmt.Printf("saveFile: unable to save read data from bucket %s, file %s: %v", bucketName, filePath, err) + log.Fatal(err) + } + } + + return tmpfile, nil +} + +// GetProject returns the cloud project we should use for accessing Spanner. +// Use environment variable GCLOUD_PROJECT if it is set. +// Otherwise, use the default project returned from gcloud. +func GetProject() (string, error) { + project := os.Getenv("GCLOUD_PROJECT") + if project != "" { + return project, nil + } + cmd := exec.Command("gcloud", "config", "list", "--format", "value(core.project)") + out, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("call to gcloud to get project failed: %w", err) + } + project = strings.TrimSpace(string(out)) + return project, nil +} + +// GetInstance returns the Spanner instance we should use for creating DBs. +// If the user specified instance (via flag 'instance') then use that. +// Otherwise try to deduce the instance using gcloud. +func GetInstance(ctx context.Context, project string, out *os.File) (string, error) { + l, err := getInstances(ctx, project) + if err != nil { + return "", err + } + if len(l) == 0 { + fmt.Fprintf(out, "Could not find any Spanner instances for project %s\n", project) + return "", fmt.Errorf("no Spanner instances for %s", project) + } + + // Note: we could ask for user input to select/confirm which Spanner + // instance to use, but that interacts poorly with piping pg_dump/mysqldump data + // to the tool via stdin. + if len(l) == 1 { + fmt.Fprintf(out, "Using only available Spanner instance: %s\n", l[0]) + return l[0], nil + } + fmt.Fprintf(out, "Available Spanner instances:\n") + for i, x := range l { + fmt.Fprintf(out, " %d) %s\n", i+1, x) + } + fmt.Fprintf(out, "Please pick one of the available instances and set the flag '--instance'\n\n") + return "", fmt.Errorf("auto-selection of instance failed: project %s has more than one Spanner instance. "+ + "Please use the flag '--instance' to select an instance", project) +} + +func getInstances(ctx context.Context, project string) ([]string, error) { + instanceClient, err := instance.NewInstanceAdminClient(ctx) + if err != nil { + return nil, AnalyzeError(err, fmt.Sprintf("projects/%s", project)) + } + it := instanceClient.ListInstances(ctx, &instancepb.ListInstancesRequest{Parent: fmt.Sprintf("projects/%s", project)}) + var l []string + for { + resp, err := it.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, AnalyzeError(err, fmt.Sprintf("projects/%s", project)) + } + l = append(l, strings.TrimPrefix(resp.Name, fmt.Sprintf("projects/%s/instances/", project))) + } + return l, nil +} + +func GetPassword() string { + fmt.Print("Enter Password: ") + bytePassword, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + fmt.Println("\nCoudln't read password") + return "" + } + fmt.Printf("\n") + return strings.TrimSpace(string(bytePassword)) +} + +// GetDatabaseName generates database name with driver_date prefix. +func GetDatabaseName(driver string, now time.Time) (string, error) { + return generateName(fmt.Sprintf("%s_%s", driver, now.Format("2006-01-02"))) +} + +func generateName(prefix string) (string, error) { + b := make([]byte, 4) + _, err := rand.Read(b) + if err != nil { + return "", fmt.Errorf("error generating name: %w", err) + + } + return fmt.Sprintf("%s_%x-%x", prefix, b[0:2], b[2:4]), nil +} + +// parseURI parses an unknown URI string that could be a database, instance or project URI. +func parseURI(URI string) (project, instance, dbName string) { + project, instance, dbName = "", "", "" + if strings.Contains(URI, "databases") { + project, instance, dbName = ParseDbURI(URI) + } else if strings.Contains(URI, "instances") { + project, instance = parseInstanceURI(URI) + } else if strings.Contains(URI, "projects") { + project = parseProjectURI(URI) + } + return +} + +func ParseDbURI(dbURI string) (project, instance, dbName string) { + split := strings.Split(dbURI, "/databases/") + project, instance = parseInstanceURI(split[0]) + dbName = split[1] + return +} + +func parseInstanceURI(instanceURI string) (project, instance string) { + split := strings.Split(instanceURI, "/instances/") + project = parseProjectURI(split[0]) + instance = split[1] + return +} + +func parseProjectURI(projectURI string) (project string) { + split := strings.Split(projectURI, "/") + project = split[1] + return +} + +// AnalyzeError inspects an error returned from Cloud Spanner and adds information +// about potential root causes e.g. authentication issues. +func AnalyzeError(err error, URI string) error { + project, instance, _ := parseURI(URI) + e := strings.ToLower(err.Error()) + if ContainsAny(e, []string{"unauthenticated", "cannot fetch token", "default credentials"}) { + return fmt.Errorf("%w."+` +Possible cause: credentials are mis-configured. Do you need to run + + gcloud auth application-default login + +or configure environment variable GOOGLE_APPLICATION_CREDENTIALS. +See https://cloud.google.com/docs/authentication/getting-started`, err) + } + if ContainsAny(e, []string{"instance not found"}) && instance != "" { + return fmt.Errorf("%w.\n"+` +Possible cause: Spanner instance specified via instance option does not exist. +Please check that '%s' is correct and that it is a valid Spanner +instance for project %s`, err, instance, project) + } + return err +} + +// PrintPermissionsWarning prints permission warning. +func PrintPermissionsWarning(driver string, out *os.File) { + fmt.Fprintf(out, + ` +WARNING: Please check that permissions for this Spanner instance are +appropriate. Spanner manages access control at the database level, and the +database created by HarbourBridge will inherit default permissions from this +instance. All data written to Spanner will be visible to anyone who can +access the created database. Note that `+driver+` table-level and row-level +ACLs are dropped during conversion since they are not supported by Spanner. + +`) +} + +func ContainsAny(s string, l []string) bool { + for _, a := range l { + if strings.Contains(s, a) { + return true + } + } + return false +} + +func GetFileSize(f *os.File) (int64, error) { + info, err := f.Stat() + if err != nil { + return 0, fmt.Errorf("can't stat file: %w", err) + } + return info.Size(), nil +} + +// SetupLogFile configures the file used for logs. +// By default we just drop logs on the floor. To enable them (e.g. to debug +// Cloud Spanner client library issues), set logfile to a non-empty filename. +// Note: this tool itself doesn't generate logs, but some of the libraries it +// uses do. If we don't set the log file, we see a number of unhelpful and +// unactionable logs spamming stdout, which is annoying and confusing. +func SetupLogFile() (*os.File, error) { + // To enable debug logs, set logfile to a non-empty filename. + logfile := "" + if logfile == "" { + log.SetOutput(ioutil.Discard) + return nil, nil + } + f, err := os.Create(logfile) + if err != nil { + return nil, err + } + log.SetOutput(f) + return f, nil +} + +// Close closes file. +func Close(f *os.File) { + if f != nil { + f.Close() + } +} + +func PrintSeekError(driver string, err error, out *os.File) { + fmt.Fprintf(out, "\nCan't get seekable input file: %v\n", err) + fmt.Fprintf(out, "Likely cause: not enough space in %s.\n", os.TempDir()) + fmt.Fprintf(out, "Try writing "+driver+" output to a file first i.e.\n") + fmt.Fprintf(out, " "+driver+" > tmpfile\n") + fmt.Fprintf(out, " harbourbridge < tmpfile\n") +} + +// NewSpannerClient returns a new Spanner client. +// It respects SPANNER_API_ENDPOINT. +func NewSpannerClient(ctx context.Context, db string) (*sp.Client, error) { + if endpoint := os.Getenv("SPANNER_API_ENDPOINT"); endpoint != "" { + return sp.NewClient(ctx, db, option.WithEndpoint(endpoint)) + } + return sp.NewClient(ctx, db) +} + +// GetClient returns a new Spanner client. It uses the background context. +func GetClient(ctx context.Context, db string) (*sp.Client, error) { + return NewSpannerClient(ctx, db) +} + +// NewDatabaseAdminClient returns a new db-admin client. +// It respects SPANNER_API_ENDPOINT. +func NewDatabaseAdminClient(ctx context.Context) (*database.DatabaseAdminClient, error) { + if endpoint := os.Getenv("SPANNER_API_ENDPOINT"); endpoint != "" { + return database.NewDatabaseAdminClient(ctx, option.WithEndpoint(endpoint)) + } + return database.NewDatabaseAdminClient(ctx) +} + +// NewInstanceAdminClient returns a new instance-admin client. +// It respects SPANNER_API_ENDPOINT. +func NewInstanceAdminClient(ctx context.Context) (*instance.InstanceAdminClient, error) { + if endpoint := os.Getenv("SPANNER_API_ENDPOINT"); endpoint != "" { + return instance.NewInstanceAdminClient(ctx, option.WithEndpoint(endpoint)) + } + return instance.NewInstanceAdminClient(ctx) +} + +func SumMapValues(m map[string]int64) int64 { + n := int64(0) + for _, c := range m { + n += c + } + return n +} + +// GetBanner prints banner message after command line process is finished. +func GetBanner(now time.Time, db string) string { + return fmt.Sprintf("Generated at %s for db %s\n\n", now.Format("2006-01-02 15:04:05"), db) +} diff --git a/conversion/conversion.go b/conversion/conversion.go index 3ce00a1c0b..ab2f1da8b5 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -25,16 +25,12 @@ package conversion import ( "bufio" "context" - "crypto/rand" "database/sql" "encoding/json" "fmt" "io" "io/ioutil" - "log" - "net/url" "os" - "os/exec" "strings" "sync" "sync/atomic" @@ -43,20 +39,15 @@ import ( sp "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" - instance "cloud.google.com/go/spanner/admin/instance/apiv1" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" dydb "github.com/aws/aws-sdk-go/service/dynamodb" - "golang.org/x/crypto/ssh/terminal" - "google.golang.org/api/iterator" - "google.golang.org/api/option" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" - instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1" - - "cloud.google.com/go/storage" "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/profiles" "github.com/cloudspannerecosystem/harbourbridge/sources/common" "github.com/cloudspannerecosystem/harbourbridge/sources/dynamodb" "github.com/cloudspannerecosystem/harbourbridge/sources/mysql" @@ -82,14 +73,14 @@ var ( // - Driver is DynamoDB or a dump file mode. // - This function is called as part of the legacy global CLI flag mode. (This string is constructed from env variables later on) // When using source-profile, the sqlConnectionStr is constructed from the input params. -func SchemaConv(driver, sqlConnectionStr, targetDb string, ioHelper *IOStreams, schemaSampleSize int64) (*internal.Conv, error) { - switch driver { +func SchemaConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams) (*internal.Conv, error) { + switch sourceProfile.Driver { case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB: - return schemaFromDatabase(driver, sqlConnectionStr, targetDb, schemaSampleSize) + return schemaFromDatabase(sourceProfile, targetProfile) case constants.PGDUMP, constants.MYSQLDUMP: - return schemaFromDump(driver, targetDb, ioHelper) + return schemaFromDump(sourceProfile.Driver, targetProfile.TargetDb, ioHelper) default: - return nil, fmt.Errorf("schema conversion for driver %s not supported", driver) + return nil, fmt.Errorf("schema conversion for driver %s not supported", sourceProfile.Driver) } } @@ -99,89 +90,56 @@ func SchemaConv(driver, sqlConnectionStr, targetDb string, ioHelper *IOStreams, // - Driver is DynamoDB or a dump file mode. // - This function is called as part of the legacy global CLI flag mode. (This string is constructed from env variables later on) // When using source-profile, the sqlConnectionStr and schemaSampleSize are constructed from the input params. -func DataConv(driver, sqlConnectionStr string, ioHelper *IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool, schemaSampleSize int64) (*spanner.BatchWriter, error) { +func DataConv(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*spanner.BatchWriter, error) { config := spanner.BatchWriterConfig{ BytesLimit: 100 * 1000 * 1000, WriteLimit: 40, RetryLimit: 1000, Verbose: internal.Verbose(), } - switch driver { + switch sourceProfile.Driver { case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB: - return dataFromDatabase(driver, sqlConnectionStr, config, client, conv, schemaSampleSize) + return dataFromDatabase(sourceProfile, config, client, conv) case constants.PGDUMP, constants.MYSQLDUMP: if conv.SpSchema.CheckInterleaved() { return nil, fmt.Errorf("harbourBridge does not currently support data conversion from dump files\nif the schema contains interleaved tables. Suggest using direct access to source database\ni.e. using drivers postgres and mysql") } - return dataFromDump(driver, config, ioHelper, client, conv, dataOnly) + return dataFromDump(sourceProfile.Driver, config, ioHelper, client, conv, dataOnly) default: - return nil, fmt.Errorf("data conversion for driver %s not supported", driver) + return nil, fmt.Errorf("data conversion for driver %s not supported", sourceProfile.Driver) } } -func connectionConfig(driver string, sqlConnectionStr string) (interface{}, error) { - switch driver { +func connectionConfig(sourceProfile profiles.SourceProfile) (interface{}, error) { + switch sourceProfile.Driver { + // For PG and MYSQL, When called as part of the subcommand flow, host/user/db etc will + // never be empty as we error out right during source profile creation. If any of them + // are empty, that means this was called through the legacy cmd flow and we create the + // string using env vars. case constants.POSTGRES: - // If empty, this is called as part of the legacy mode witih global CLI flags. - // When using source-profile mode is used, the sqlConnectionStr is already populated. - if sqlConnectionStr == "" { - return generatePGSQLConnectionStr() + pgConn := sourceProfile.Conn.Pg + if !(pgConn.Host != "" && pgConn.User != "" && pgConn.Db != "") { + return profiles.GeneratePGSQLConnectionStr() + } else { + return profiles.GetSQLConnectionStr(sourceProfile), nil } - return sqlConnectionStr, nil case constants.MYSQL: // If empty, this is called as part of the legacy mode witih global CLI flags. // When using source-profile mode is used, the sqlConnectionStr is already populated. - if sqlConnectionStr == "" { - return generateMYSQLConnectionStr() + mysqlConn := sourceProfile.Conn.Mysql + if !(mysqlConn.Host != "" && mysqlConn.User != "" && mysqlConn.Db != "") { + return profiles.GenerateMYSQLConnectionStr() + } else { + return profiles.GetSQLConnectionStr(sourceProfile), nil } - return sqlConnectionStr, nil + // For Dynamodb, both legacy and new flows use env vars. case constants.DYNAMODB: return getDynamoDBClientConfig() default: - return "", fmt.Errorf("driver %s not supported", driver) + return "", fmt.Errorf("driver %s not supported", sourceProfile.Driver) } } -func generatePGSQLConnectionStr() (string, error) { - server := os.Getenv("PGHOST") - port := os.Getenv("PGPORT") - user := os.Getenv("PGUSER") - dbname := os.Getenv("PGDATABASE") - if server == "" || port == "" || user == "" || dbname == "" { - fmt.Printf("Please specify host, port, user and database using PGHOST, PGPORT, PGUSER and PGDATABASE environment variables\n") - return "", fmt.Errorf("could not connect to source database") - } - password := os.Getenv("PGPASSWORD") - if password == "" { - password = GetPassword() - } - return GetPGSQLConnectionStr(server, port, user, password, dbname), nil -} - -func GetPGSQLConnectionStr(server, port, user, password, dbname string) string { - return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", server, port, user, password, dbname) -} - -func generateMYSQLConnectionStr() (string, error) { - server := os.Getenv("MYSQLHOST") - port := os.Getenv("MYSQLPORT") - user := os.Getenv("MYSQLUSER") - dbname := os.Getenv("MYSQLDATABASE") - if server == "" || port == "" || user == "" || dbname == "" { - fmt.Printf("Please specify host, port, user and database using MYSQLHOST, MYSQLPORT, MYSQLUSER and MYSQLDATABASE environment variables\n") - return "", fmt.Errorf("could not connect to source database") - } - password := os.Getenv("MYSQLPWD") - if password == "" { - password = GetPassword() - } - return GetMYSQLConnectionStr(server, port, user, password, dbname), nil -} - -func GetMYSQLConnectionStr(server, port, user, password, dbname string) string { - return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", user, password, server, port, dbname) -} - func getDbNameFromSQLConnectionStr(driver, sqlConnectionStr string) string { switch driver { case constants.POSTGRES: @@ -193,18 +151,18 @@ func getDbNameFromSQLConnectionStr(driver, sqlConnectionStr string) string { return "" } -func schemaFromDatabase(driver, sqlConnectionStr, targetDb string, schemaSampleSize int64) (*internal.Conv, error) { +func schemaFromDatabase(sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile) (*internal.Conv, error) { conv := internal.MakeConv() - conv.TargetDb = targetDb - infoSchema, err := GetInfoSchema(driver, sqlConnectionStr, schemaSampleSize) + conv.TargetDb = targetProfile.TargetDb + infoSchema, err := GetInfoSchema(sourceProfile) if err != nil { return conv, err } return conv, common.ProcessSchema(conv, infoSchema) } -func dataFromDatabase(driver, sqlConnectionStr string, config spanner.BatchWriterConfig, client *sp.Client, conv *internal.Conv, schemaSampleSize int64) (*spanner.BatchWriter, error) { - infoSchema, err := GetInfoSchema(driver, sqlConnectionStr, schemaSampleSize) +func dataFromDatabase(sourceProfile profiles.SourceProfile, config spanner.BatchWriterConfig, client *sp.Client, conv *internal.Conv) (*spanner.BatchWriter, error) { + infoSchema, err := GetInfoSchema(sourceProfile) if err != nil { return nil, err } @@ -244,100 +202,10 @@ func getDynamoDBClientConfig() (*aws.Config, error) { return &cfg, nil } -// IOStreams is a struct that contains the file descriptor for dumpFile. -type IOStreams struct { - In, SeekableIn, Out *os.File - BytesRead int64 -} - -// downloadFromGCS returns the dump file that is downloaded from GCS -func downloadFromGCS(bucketName string, filePath string) (*os.File, error) { - ctx := context.Background() - - client, err := storage.NewClient(ctx) - if err != nil { - fmt.Printf("Failed to create GCS client for bucket %q", bucketName) - log.Fatal(err) - } - defer client.Close() - - bucket := client.Bucket(bucketName) - rc, err := bucket.Object(filePath).NewReader(ctx) - if err != nil { - fmt.Printf("readFile: unable to open file from bucket %q, file %q: %v", bucketName, filePath, err) - log.Fatal(err) - return nil, err - } - defer rc.Close() - r := bufio.NewReader(rc) - - tmpfile, err := ioutil.TempFile("", "harbourbridge.gcs.data") - if err != nil { - fmt.Printf("saveFile: unable to open temporary file to save dump file from GCS bucket %v", err) - log.Fatal(err) - return nil, err - } - syscall.Unlink(tmpfile.Name()) // File will be deleted when this process exits. - - fmt.Printf("\nDownloading dump file from GCS bucket %s, path %s\n", bucketName, filePath) - buffer := make([]byte, 1024) - for { - // read a chunk - n, err := r.Read(buffer[:cap(buffer)]) - - if err != nil && err != io.EOF { - fmt.Printf("readFile: unable to read entire dump file from bucket %s, file %s: %v", bucketName, filePath, err) - log.Fatal(err) - return nil, err - } - if n == 0 && err == io.EOF { - break - } - - // write a chunk - if _, err = tmpfile.Write(buffer[:n]); err != nil { - fmt.Printf("saveFile: unable to save read data from bucket %s, file %s: %v", bucketName, filePath, err) - log.Fatal(err) - } - } - - return tmpfile, nil -} - -// NewIOStreams returns a new IOStreams struct such that input stream is set -// to open file descriptor for dumpFile if driver is PGDUMP or MYSQLDUMP. -// Input stream defaults to stdin. Output stream is always set to stdout. -func NewIOStreams(driver string, dumpFile string) IOStreams { - io := IOStreams{In: os.Stdin, Out: os.Stdout} - u, err := url.Parse(dumpFile) - if err != nil { - fmt.Printf("parseFilePath: unable parse file path for dumpfile %s", dumpFile) - log.Fatal(err) - } - if (driver == constants.PGDUMP || driver == constants.MYSQLDUMP) && dumpFile != "" { - fmt.Printf("\nLoading dump file from path: %s\n", dumpFile) - var f *os.File - var err error - if u.Scheme == "gs" { - bucketName := u.Host - filePath := u.Path[1:] // removes "/" from beginning of path - f, err = downloadFromGCS(bucketName, filePath) - } else { - f, err = os.Open(dumpFile) - } - if err != nil { - fmt.Printf("\nError reading dump file: %v err:%v\n", dumpFile, err) - log.Fatal(err) - } - io.In = f - } - return io -} - -func schemaFromDump(driver string, targetDb string, ioHelper *IOStreams) (*internal.Conv, error) { +func schemaFromDump(driver string, targetDb string, ioHelper *utils.IOStreams) (*internal.Conv, error) { f, n, err := getSeekable(ioHelper.In) if err != nil { - printSeekError(driver, err, ioHelper.Out) + utils.PrintSeekError(driver, err, ioHelper.Out) return nil, fmt.Errorf("can't get seekable input file") } ioHelper.SeekableIn = f @@ -357,7 +225,7 @@ func schemaFromDump(driver string, targetDb string, ioHelper *IOStreams) (*inter return conv, nil } -func dataFromDump(driver string, config spanner.BatchWriterConfig, ioHelper *IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*spanner.BatchWriter, error) { +func dataFromDump(driver string, config spanner.BatchWriterConfig, ioHelper *utils.IOStreams, client *sp.Client, conv *internal.Conv, dataOnly bool) (*spanner.BatchWriter, error) { // TODO: refactor of the way we handle getSeekable // to avoid the code duplication here if !dataOnly { @@ -371,7 +239,7 @@ func dataFromDump(driver string, config spanner.BatchWriterConfig, ioHelper *IOS // changes in showing progress for data migration. f, n, err := getSeekable(ioHelper.In) if err != nil { - printSeekError(driver, err, ioHelper.Out) + utils.PrintSeekError(driver, err, ioHelper.Out) return nil, fmt.Errorf("can't get seekable input file") } ioHelper.SeekableIn = f @@ -445,7 +313,7 @@ func Report(driver string, badWrites map[string]int64, BytesRead int64, banner s func getSeekable(f *os.File) (*os.File, int64, error) { _, err := f.Seek(0, 0) if err == nil { // Stdin is seekable, let's just use that. This happens when you run 'cmd < file'. - n, err := getSize(f) + n, err := utils.GetFileSize(f) return f, n, err } internal.VerbosePrintln("Creating a tmp file with a copy of stdin because stdin is not seekable.") @@ -468,7 +336,7 @@ func getSeekable(f *os.File) (*os.File, int64, error) { if err != nil { return nil, 0, fmt.Errorf("can't reset file offset: %w", err) } - n, _ := getSize(fcopy) + n, _ := utils.GetFileSize(fcopy) return fcopy, n, nil } @@ -488,7 +356,7 @@ func VerifyDb(ctx context.Context, adminClient *database.DatabaseAdminClient, db func CheckExistingDb(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string) (bool, error) { _, err := adminClient.GetDatabase(ctx, &adminpb.GetDatabaseRequest{Name: dbURI}) if err != nil { - if containsAny(strings.ToLower(err.Error()), []string{"database not found"}) { + if utils.ContainsAny(strings.ToLower(err.Error()), []string{"database not found"}) { return false, nil } return false, fmt.Errorf("can't get database info: %s", err) @@ -534,7 +402,7 @@ func CreateOrUpdateDatabase(ctx context.Context, adminClient *database.DatabaseA // Spanner instance to use, generates a new Spanner DB name, // and call into the Spanner admin interface to create the new DB. func CreateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File) error { - project, instance, dbName := parseDbURI(dbURI) + project, instance, dbName := utils.ParseDbURI(dbURI) fmt.Fprintf(out, "Creating new database %s in instance %s with default permissions ... \n", dbName, instance) // The schema we send to Spanner excludes comments (since Cloud // Spanner DDL doesn't accept them), and protects table and col names @@ -557,10 +425,10 @@ func CreateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClie op, err := adminClient.CreateDatabase(ctx, req) if err != nil { - return fmt.Errorf("can't build CreateDatabaseRequest: %w", AnalyzeError(err, dbURI)) + return fmt.Errorf("can't build CreateDatabaseRequest: %w", utils.AnalyzeError(err, dbURI)) } if _, err := op.Wait(ctx); err != nil { - return fmt.Errorf("createDatabase call failed: %w", AnalyzeError(err, dbURI)) + return fmt.Errorf("createDatabase call failed: %w", utils.AnalyzeError(err, dbURI)) } fmt.Fprintf(out, "Created database successfully.\n") @@ -578,62 +446,29 @@ func UpdateDatabase(ctx context.Context, adminClient *database.DatabaseAdminClie // Spanner DDL doesn't accept them), and protects table and col names // using backticks (to avoid any issues with Spanner reserved words). // Foreign Keys are set to false since we create them post data migration. - schema := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: false, Tables: true, ForeignKeys: false, TargetDb: conv.TargetDb}) + schema := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: false, TargetDb: conv.TargetDb}) req := &adminpb.UpdateDatabaseDdlRequest{ Database: dbURI, Statements: schema, } op, err := adminClient.UpdateDatabaseDdl(ctx, req) if err != nil { - return fmt.Errorf("can't build UpdateDatabaseDdlRequest: %w", AnalyzeError(err, dbURI)) + return fmt.Errorf("can't build UpdateDatabaseDdlRequest: %w", utils.AnalyzeError(err, dbURI)) } if err := op.Wait(ctx); err != nil { - return fmt.Errorf("UpdateDatabaseDdl call failed: %w", AnalyzeError(err, dbURI)) + return fmt.Errorf("UpdateDatabaseDdl call failed: %w", utils.AnalyzeError(err, dbURI)) } fmt.Fprintf(out, "Updated schema successfully.\n") return nil } -// parseURI parses an unknown URI string that could be a database, instance or project URI. -func parseURI(URI string) (project, instance, dbName string) { - project, instance, dbName = "", "", "" - if strings.Contains(URI, "databases") { - project, instance, dbName = parseDbURI(URI) - } else if strings.Contains(URI, "instances") { - project, instance = parseInstanceURI(URI) - } else if strings.Contains(URI, "projects") { - project = parseProjectURI(URI) - } - return -} - -func parseDbURI(dbURI string) (project, instance, dbName string) { - split := strings.Split(dbURI, "/databases/") - project, instance = parseInstanceURI(split[0]) - dbName = split[1] - return -} - -func parseInstanceURI(instanceURI string) (project, instance string) { - split := strings.Split(instanceURI, "/instances/") - project = parseProjectURI(split[0]) - instance = split[1] - return -} - -func parseProjectURI(projectURI string) (project string) { - split := strings.Split(projectURI, "/") - project = split[1] - return -} - // UpdateDDLForeignKeys updates the Spanner database with foreign key // constraints using ALTER TABLE statements. func UpdateDDLForeignKeys(ctx context.Context, adminClient *database.DatabaseAdminClient, dbURI string, conv *internal.Conv, out *os.File) error { // The schema we send to Spanner excludes comments (since Cloud // Spanner DDL doesn't accept them), and protects table and col names // using backticks (to avoid any issues with Spanner reserved words). - fkStmts := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: false, ForeignKeys: true}) + fkStmts := conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: false, ForeignKeys: true, TargetDb: conv.TargetDb}) if len(fkStmts) == 0 { return nil } @@ -698,72 +533,6 @@ Recommended value is between 20-30.`) return nil } -// GetProject returns the cloud project we should use for accessing Spanner. -// Use environment variable GCLOUD_PROJECT if it is set. -// Otherwise, use the default project returned from gcloud. -func GetProject() (string, error) { - project := os.Getenv("GCLOUD_PROJECT") - if project != "" { - return project, nil - } - cmd := exec.Command("gcloud", "config", "list", "--format", "value(core.project)") - out, err := cmd.CombinedOutput() - if err != nil { - return "", fmt.Errorf("call to gcloud to get project failed: %w", err) - } - project = strings.TrimSpace(string(out)) - return project, nil -} - -// GetInstance returns the Spanner instance we should use for creating DBs. -// If the user specified instance (via flag 'instance') then use that. -// Otherwise try to deduce the instance using gcloud. -func GetInstance(ctx context.Context, project string, out *os.File) (string, error) { - l, err := getInstances(ctx, project) - if err != nil { - return "", err - } - if len(l) == 0 { - fmt.Fprintf(out, "Could not find any Spanner instances for project %s\n", project) - return "", fmt.Errorf("no Spanner instances for %s", project) - } - - // Note: we could ask for user input to select/confirm which Spanner - // instance to use, but that interacts poorly with piping pg_dump/mysqldump data - // to the tool via stdin. - if len(l) == 1 { - fmt.Fprintf(out, "Using only available Spanner instance: %s\n", l[0]) - return l[0], nil - } - fmt.Fprintf(out, "Available Spanner instances:\n") - for i, x := range l { - fmt.Fprintf(out, " %d) %s\n", i+1, x) - } - fmt.Fprintf(out, "Please pick one of the available instances and set the flag '--instance'\n\n") - return "", fmt.Errorf("auto-selection of instance failed: project %s has more than one Spanner instance. "+ - "Please use the flag '--instance' to select an instance", project) -} - -func getInstances(ctx context.Context, project string) ([]string, error) { - instanceClient, err := instance.NewInstanceAdminClient(ctx) - if err != nil { - return nil, AnalyzeError(err, fmt.Sprintf("projects/%s", project)) - } - it := instanceClient.ListInstances(ctx, &instancepb.ListInstancesRequest{Parent: fmt.Sprintf("projects/%s", project)}) - var l []string - for { - resp, err := it.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, AnalyzeError(err, fmt.Sprintf("projects/%s", project)) - } - l = append(l, strings.TrimPrefix(resp.Name, fmt.Sprintf("projects/%s/instances/", project))) - } - return l, nil -} - // WriteSchemaFile writes DDL statements in a file. It includes CREATE TABLE // statements and ALTER TABLE statements to add foreign keys. // The parameter name should end with a .txt. @@ -778,7 +547,7 @@ func WriteSchemaFile(conv *internal.Conv, now time.Time, name string, out *os.Fi // and doesn't add backticks around table and column names. This file is // intended for explanatory and documentation purposes, and is not strictly // legal Cloud Spanner DDL (Cloud Spanner doesn't currently support comments). - spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true}) + spDDL := conv.SpSchema.GetDDL(ddl.Config{Comments: true, ProtectIds: false, Tables: true, ForeignKeys: true, TargetDb: conv.TargetDb}) if len(spDDL) == 0 { spDDL = []string{"\n-- Schema is empty -- no tables found\n"} } @@ -805,7 +574,7 @@ func WriteSchemaFile(conv *internal.Conv, now time.Time, name string, out *os.Fi // We change 'Comments' to false and 'ProtectIds' to true below to write out a // schema file that is a legal Cloud Spanner DDL. - spDDL = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true}) + spDDL = conv.SpSchema.GetDDL(ddl.Config{Comments: false, ProtectIds: true, Tables: true, ForeignKeys: true, TargetDb: conv.TargetDb}) if len(spDDL) == 0 { spDDL = []string{"\n-- Schema is empty -- no tables found\n"} } @@ -878,7 +647,7 @@ func ReadSessionFile(conv *internal.Conv, sessionJSON string) error { // to file 'name'. func WriteBadData(bw *spanner.BatchWriter, conv *internal.Conv, banner, name string, out *os.File) { badConversions := conv.BadRows() - badWrites := sum(bw.DroppedRowsByTable()) + badWrites := utils.SumMapValues(bw.DroppedRowsByTable()) if badConversions == 0 && badWrites == 0 { os.Remove(name) // Cleanup bad-data file from previous run. return @@ -923,167 +692,6 @@ func WriteBadData(bw *spanner.BatchWriter, conv *internal.Conv, banner, name str fmt.Fprintf(out, "See file '%s' for details of bad rows\n", name) } -// GetDatabaseName generates database name with driver_date prefix. -func GetDatabaseName(driver string, now time.Time) (string, error) { - return generateName(fmt.Sprintf("%s_%s", driver, now.Format("2006-01-02"))) -} - -func GetPassword() string { - fmt.Print("Enter Password: ") - bytePassword, err := terminal.ReadPassword(int(syscall.Stdin)) - if err != nil { - fmt.Println("\nCoudln't read password") - return "" - } - fmt.Printf("\n") - return strings.TrimSpace(string(bytePassword)) -} - -// AnalyzeError inspects an error returned from Cloud Spanner and adds information -// about potential root causes e.g. authentication issues. -func AnalyzeError(err error, URI string) error { - project, instance, _ := parseURI(URI) - e := strings.ToLower(err.Error()) - if containsAny(e, []string{"unauthenticated", "cannot fetch token", "default credentials"}) { - return fmt.Errorf("%w."+` -Possible cause: credentials are mis-configured. Do you need to run - - gcloud auth application-default login - -or configure environment variable GOOGLE_APPLICATION_CREDENTIALS. -See https://cloud.google.com/docs/authentication/getting-started`, err) - } - if containsAny(e, []string{"instance not found"}) && instance != "" { - return fmt.Errorf("%w.\n"+` -Possible cause: Spanner instance specified via instance option does not exist. -Please check that '%s' is correct and that it is a valid Spanner -instance for project %s`, err, instance, project) - } - return err -} - -// PrintPermissionsWarning prints permission warning. -func PrintPermissionsWarning(driver string, out *os.File) { - fmt.Fprintf(out, - ` -WARNING: Please check that permissions for this Spanner instance are -appropriate. Spanner manages access control at the database level, and the -database created by HarbourBridge will inherit default permissions from this -instance. All data written to Spanner will be visible to anyone who can -access the created database. Note that `+driver+` table-level and row-level -ACLs are dropped during conversion since they are not supported by Spanner. - -`) -} - -func printSeekError(driver string, err error, out *os.File) { - fmt.Fprintf(out, "\nCan't get seekable input file: %v\n", err) - fmt.Fprintf(out, "Likely cause: not enough space in %s.\n", os.TempDir()) - fmt.Fprintf(out, "Try writing "+driver+" output to a file first i.e.\n") - fmt.Fprintf(out, " "+driver+" > tmpfile\n") - fmt.Fprintf(out, " harbourbridge < tmpfile\n") -} - -func containsAny(s string, l []string) bool { - for _, a := range l { - if strings.Contains(s, a) { - return true - } - } - return false -} - -func generateName(prefix string) (string, error) { - b := make([]byte, 4) - _, err := rand.Read(b) - if err != nil { - return "", fmt.Errorf("error generating name: %w", err) - - } - return fmt.Sprintf("%s_%x-%x", prefix, b[0:2], b[2:4]), nil -} - -// NewSpannerClient returns a new Spanner client. -// It respects SPANNER_API_ENDPOINT. -func NewSpannerClient(ctx context.Context, db string) (*sp.Client, error) { - if endpoint := os.Getenv("SPANNER_API_ENDPOINT"); endpoint != "" { - return sp.NewClient(ctx, db, option.WithEndpoint(endpoint)) - } - return sp.NewClient(ctx, db) -} - -// GetClient returns a new Spanner client. It uses the background context. -func GetClient(ctx context.Context, db string) (*sp.Client, error) { - return NewSpannerClient(ctx, db) -} - -// NewDatabaseAdminClient returns a new db-admin client. -// It respects SPANNER_API_ENDPOINT. -func NewDatabaseAdminClient(ctx context.Context) (*database.DatabaseAdminClient, error) { - if endpoint := os.Getenv("SPANNER_API_ENDPOINT"); endpoint != "" { - return database.NewDatabaseAdminClient(ctx, option.WithEndpoint(endpoint)) - } - return database.NewDatabaseAdminClient(ctx) -} - -// NewInstanceAdminClient returns a new instance-admin client. -// It respects SPANNER_API_ENDPOINT. -func NewInstanceAdminClient(ctx context.Context) (*instance.InstanceAdminClient, error) { - if endpoint := os.Getenv("SPANNER_API_ENDPOINT"); endpoint != "" { - return instance.NewInstanceAdminClient(ctx, option.WithEndpoint(endpoint)) - } - return instance.NewInstanceAdminClient(ctx) -} - -func getSize(f *os.File) (int64, error) { - info, err := f.Stat() - if err != nil { - return 0, fmt.Errorf("can't stat file: %w", err) - } - return info.Size(), nil -} - -// SetupLogFile configures the file used for logs. -// By default we just drop logs on the floor. To enable them (e.g. to debug -// Cloud Spanner client library issues), set logfile to a non-empty filename. -// Note: this tool itself doesn't generate logs, but some of the libraries it -// uses do. If we don't set the log file, we see a number of unhelpful and -// unactionable logs spamming stdout, which is annoying and confusing. -func SetupLogFile() (*os.File, error) { - // To enable debug logs, set logfile to a non-empty filename. - logfile := "" - if logfile == "" { - log.SetOutput(ioutil.Discard) - return nil, nil - } - f, err := os.Create(logfile) - if err != nil { - return nil, err - } - log.SetOutput(f) - return f, nil -} - -// Close closes file. -func Close(f *os.File) { - if f != nil { - f.Close() - } -} - -func sum(m map[string]int64) int64 { - n := int64(0) - for _, c := range m { - n += c - } - return n -} - -// GetBanner prints banner message after command line process is finished. -func GetBanner(now time.Time, db string) string { - return fmt.Sprintf("Generated at %s for db %s\n\n", now.Format("2006-01-02 15:04:05"), db) -} - // ProcessDump invokes process dump function from a sql package based on driver selected. func ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error { switch driver { @@ -1096,11 +704,12 @@ func ProcessDump(driver string, conv *internal.Conv, r *internal.Reader) error { } } -func GetInfoSchema(driver, sqlConnectionStr string, schemaSampleSize int64) (common.InfoSchema, error) { - connectionConfig, err := connectionConfig(driver, sqlConnectionStr) +func GetInfoSchema(sourceProfile profiles.SourceProfile) (common.InfoSchema, error) { + connectionConfig, err := connectionConfig(sourceProfile) if err != nil { return nil, err } + driver := sourceProfile.Driver switch driver { case constants.MYSQL: db, err := sql.Open(driver, connectionConfig.(string)) @@ -1118,7 +727,7 @@ func GetInfoSchema(driver, sqlConnectionStr string, schemaSampleSize int64) (com case constants.DYNAMODB: mySession := session.Must(session.NewSession()) dydbClient := dydb.New(mySession, connectionConfig.(*aws.Config)) - return dynamodb.InfoSchemaImpl{DynamoClient: dydbClient, SampleSize: schemaSampleSize}, nil + return dynamodb.InfoSchemaImpl{DynamoClient: dydbClient, SampleSize: profiles.GetSchemaSampleSize(sourceProfile)}, nil default: return nil, fmt.Errorf("driver %s not supported", driver) } diff --git a/main.go b/main.go index ffcfb8f9df..1f90c49233 100644 --- a/main.go +++ b/main.go @@ -30,7 +30,7 @@ import ( "github.com/cloudspannerecosystem/harbourbridge/cmd" "github.com/cloudspannerecosystem/harbourbridge/common/constants" - "github.com/cloudspannerecosystem/harbourbridge/conversion" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/internal" "github.com/cloudspannerecosystem/harbourbridge/web" "github.com/google/subcommands" @@ -91,12 +91,12 @@ Sample usage: func main() { ctx := context.Background() - lf, err := conversion.SetupLogFile() + lf, err := utils.SetupLogFile() if err != nil { fmt.Printf("\nCan't set up log file: %v\n", err) panic(fmt.Errorf("can't set up log file")) } - defer conversion.Close(lf) + defer utils.Close(lf) // TODO: Remove this check and always run HB in subcommands mode once // global command line mode is deprecated. We can also enable support for @@ -149,11 +149,11 @@ func main() { } fmt.Printf("Using driver (source DB): %s target-db: %s\n", driverName, targetDb) - ioHelper := conversion.NewIOStreams(driverName, dumpFilePath) + ioHelper := utils.NewIOStreams(driverName, dumpFilePath) var project, instance string if !schemaOnly { - project, err = conversion.GetProject() + project, err = utils.GetProject() if err != nil { fmt.Printf("\nCan't get project: %v\n", err) panic(fmt.Errorf("can't get project")) @@ -162,20 +162,20 @@ func main() { instance = instanceOverride if instance == "" { - instance, err = conversion.GetInstance(ctx, project, ioHelper.Out) + instance, err = utils.GetInstance(ctx, project, ioHelper.Out) if err != nil { fmt.Printf("\nCan't get instance: %v\n", err) panic(fmt.Errorf("can't get instance")) } } fmt.Println("Using Cloud Spanner instance:", instance) - conversion.PrintPermissionsWarning(driverName, ioHelper.Out) + utils.PrintPermissionsWarning(driverName, ioHelper.Out) } now := time.Now() dbName := dbNameOverride if dbName == "" { - dbName, err = conversion.GetDatabaseName(driverName, now) + dbName, err = utils.GetDatabaseName(driverName, now) if err != nil { fmt.Printf("\nCan't get database name: %v\n", err) panic(fmt.Errorf("can't get database name")) diff --git a/profiles/common.go b/profiles/common.go new file mode 100644 index 0000000000..c6cf1bddf8 --- /dev/null +++ b/profiles/common.go @@ -0,0 +1,151 @@ +package profiles + +import ( + "context" + "encoding/csv" + "fmt" + "os" + "strings" + "time" + + "github.com/cloudspannerecosystem/harbourbridge/common/utils" +) + +// Parses input string `s` as a map of key-value pairs. It's expected that the +// input string is of the form "key1=value1,key2=value2,..." etc. Return error +// otherwise. +func parseProfile(s string) (map[string]string, error) { + params := make(map[string]string) + if len(s) == 0 { + return params, nil + } + + // We use CSV reader to parse key=value pairs separated by a comma to + // handle the case where a value may contain a comma within a quote. We + // expect exactly one record to be returned. + r := csv.NewReader(strings.NewReader(s)) + r.Comma = ',' + r.TrimLeadingSpace = true + records, err := r.ReadAll() + if err != nil { + return params, err + } + if len(records) > 1 { + return params, fmt.Errorf("contains invalid newline characters") + } + + for _, kv := range records[0] { + s := strings.Split(strings.TrimSpace(kv), "=") + if len(s) != 2 { + return params, fmt.Errorf("invalid key=value pair (expected format: key1=value1): %v", kv) + } + if _, ok := params[s[0]]; ok { + return params, fmt.Errorf("duplicate key found: %v", s[0]) + } + params[s[0]] = s[1] + } + return params, nil +} + +func GetResourceIds(ctx context.Context, targetProfile TargetProfile, now time.Time, driverName string, out *os.File) (string, string, string, error) { + var err error + project := targetProfile.conn.sp.project + if project == "" { + project, err = utils.GetProject() + if err != nil { + return "", "", "", fmt.Errorf("can't get project: %v", err) + } + } + fmt.Println("Using Google Cloud project:", project) + + instance := targetProfile.conn.sp.instance + if instance == "" { + instance, err = utils.GetInstance(ctx, project, out) + if err != nil { + return "", "", "", fmt.Errorf("can't get instance: %v", err) + } + } + fmt.Println("Using Cloud Spanner instance:", instance) + utils.PrintPermissionsWarning(driverName, out) + + dbName := targetProfile.conn.sp.dbname + if dbName == "" { + dbName, err = utils.GetDatabaseName(driverName, now) + if err != nil { + return "", "", "", fmt.Errorf("can't get database name: %v", err) + } + } + return project, instance, dbName, err +} + +func GetSQLConnectionStr(sourceProfile SourceProfile) string { + sqlConnectionStr := "" + if sourceProfile.Ty == SourceProfileTypeConnection { + switch sourceProfile.Conn.Ty { + case SourceProfileConnectionTypeMySQL: + connParams := sourceProfile.Conn.Mysql + return getMYSQLConnectionStr(connParams.Host, connParams.Port, connParams.User, connParams.Pwd, connParams.Db) + case SourceProfileConnectionTypePostgreSQL: + connParams := sourceProfile.Conn.Pg + return getPGSQLConnectionStr(connParams.Host, connParams.Port, connParams.User, connParams.Pwd, connParams.Db) + case SourceProfileConnectionTypeDynamoDB: + // For DynamoDB, client provided by aws-sdk reads connection credentials from env variables only. + // Thus, there is no need to create sqlConnectionStr for the same. We instead set the env variables + // programmatically if not set. + return "" + } + } + return sqlConnectionStr +} + +func GeneratePGSQLConnectionStr() (string, error) { + server := os.Getenv("PGHOST") + port := os.Getenv("PGPORT") + user := os.Getenv("PGUSER") + dbname := os.Getenv("PGDATABASE") + if server == "" || port == "" || user == "" || dbname == "" { + fmt.Printf("Please specify host, port, user and database using PGHOST, PGPORT, PGUSER and PGDATABASE environment variables\n") + return "", fmt.Errorf("could not connect to source database") + } + password := os.Getenv("PGPASSWORD") + if password == "" { + password = utils.GetPassword() + } + return getPGSQLConnectionStr(server, port, user, password, dbname), nil +} + +func getPGSQLConnectionStr(server, port, user, password, dbname string) string { + return fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", server, port, user, password, dbname) +} + +func GenerateMYSQLConnectionStr() (string, error) { + server := os.Getenv("MYSQLHOST") + port := os.Getenv("MYSQLPORT") + user := os.Getenv("MYSQLUSER") + dbname := os.Getenv("MYSQLDATABASE") + if server == "" || port == "" || user == "" || dbname == "" { + fmt.Printf("Please specify host, port, user and database using MYSQLHOST, MYSQLPORT, MYSQLUSER and MYSQLDATABASE environment variables\n") + return "", fmt.Errorf("could not connect to source database") + } + password := os.Getenv("MYSQLPWD") + if password == "" { + password = utils.GetPassword() + } + return getMYSQLConnectionStr(server, port, user, password, dbname), nil +} + +func getMYSQLConnectionStr(server, port, user, password, dbname string) string { + return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", user, password, server, port, dbname) +} + +func GetSchemaSampleSize(sourceProfile SourceProfile) int64 { + schemaSampleSize := int64(100000) + if sourceProfile.Ty == SourceProfileTypeConnection { + if sourceProfile.Conn.Ty == SourceProfileConnectionTypeDynamoDB { + if sourceProfile.Conn.Dydb.SchemaSampleSize != 0 { + schemaSampleSize = sourceProfile.Conn.Dydb.SchemaSampleSize + } + } + } + return schemaSampleSize +} diff --git a/cmd/source_profile.go b/profiles/source_profile.go similarity index 74% rename from cmd/source_profile.go rename to profiles/source_profile.go index 84d09cd4f9..9e7ac300f5 100644 --- a/cmd/source_profile.go +++ b/profiles/source_profile.go @@ -1,4 +1,4 @@ -package cmd +package profiles import ( "fmt" @@ -7,7 +7,7 @@ import ( "strings" "github.com/cloudspannerecosystem/harbourbridge/common/constants" - "github.com/cloudspannerecosystem/harbourbridge/conversion" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" ) type SourceProfileType int @@ -20,21 +20,21 @@ const ( ) type SourceProfileFile struct { - path string - format string + Path string + Format string } func NewSourceProfileFile(params map[string]string) SourceProfileFile { profile := SourceProfileFile{} if !filePipedToStdin() { - profile.path = params["file"] + profile.Path = params["file"] } if format, ok := params["format"]; ok { - profile.format = format + profile.Format = format // TODO: Add check that format takes values from ["dump", "csv", "avro", ... etc] } else { fmt.Printf("source-profile format defaulting to `dump`\n") - profile.format = "dump" + profile.Format = "dump" } return profile } @@ -49,11 +49,11 @@ const ( ) type SourceProfileConnectionMySQL struct { - host string // Same as MYSQLHOST environment variable - port string // Same as MYSQLPORT environment variable - user string // Same as MYSQLUSER environment variable - db string // Same as MYSQLDATABASE environment variable - pwd string // Same as MYSQLPWD environment variable + Host string // Same as MYSQLHOST environment variable + Port string // Same as MYSQLPORT environment variable + User string // Same as MYSQLUSER environment variable + Db string // Same as MYSQLDATABASE environment variable + Pwd string // Same as MYSQLPWD environment variable } func NewSourceProfileConnectionMySQL(params map[string]string) (SourceProfileConnectionMySQL, error) { @@ -69,21 +69,21 @@ func NewSourceProfileConnectionMySQL(params map[string]string) (SourceProfileCon // No connection params provided through source-profile. Fetching from env variables. fmt.Printf("Connection parameters not specified in source-profile. Reading from " + "environment variables MYSQLHOST, MYSQLUSER, MYSQLDATABASE, MYSQLPORT, MYSQLPWD...\n") - mysql.host = os.Getenv("MYSQLHOST") - mysql.user = os.Getenv("MYSQLUSER") - mysql.db = os.Getenv("MYSQLDATABASE") - mysql.port = os.Getenv("MYSQLPORT") - mysql.pwd = os.Getenv("MYSQLPWD") + mysql.Host = os.Getenv("MYSQLHOST") + mysql.User = os.Getenv("MYSQLUSER") + mysql.Db = os.Getenv("MYSQLDATABASE") + mysql.Port = os.Getenv("MYSQLPORT") + mysql.Pwd = os.Getenv("MYSQLPWD") // Throw error if the input entered is empty. - if mysql.host == "" || mysql.user == "" || mysql.db == "" { + if mysql.Host == "" || mysql.User == "" || mysql.Db == "" { return mysql, fmt.Errorf("found empty string for MYSQLHOST/MYSQLUSER/MYSQLDATABASE. Please specify these environment variables with correct values") } } else if hostOk && userOk && dbOk { // If atleast host, username and dbname are provided through source-profile, // go ahead and use source-profile. Port and password handled later even if they are empty. - mysql.host, mysql.user, mysql.db, mysql.port, mysql.pwd = host, user, db, port, pwd + mysql.Host, mysql.User, mysql.Db, mysql.Port, mysql.Pwd = host, user, db, port, pwd // Throw error if the input entered is empty. - if mysql.host == "" || mysql.user == "" || mysql.db == "" { + if mysql.Host == "" || mysql.User == "" || mysql.Db == "" { return mysql, fmt.Errorf("found empty string for host/user/db_name. Please specify host, port, user and db_name in the source-profile") } } else { @@ -92,27 +92,27 @@ func NewSourceProfileConnectionMySQL(params map[string]string) (SourceProfileCon } // Throw same error if the input entered is empty. - if mysql.host == "" || mysql.user == "" || mysql.db == "" { + if mysql.Host == "" || mysql.User == "" || mysql.Db == "" { return mysql, fmt.Errorf("found empty string for host/user/db. please specify host, port, user and db_name in the source-profile") } - if mysql.port == "" { + if mysql.Port == "" { // Set default port for mysql, which rarely changes. - mysql.port = "3306" + mysql.Port = "3306" } - if mysql.pwd == "" { - mysql.pwd = conversion.GetPassword() + if mysql.Pwd == "" { + mysql.Pwd = utils.GetPassword() } return mysql, nil } type SourceProfileConnectionPostgreSQL struct { - host string // Same as PGHOST environment variable - port string // Same as PGPORT environment variable - user string // Same as PGUSER environment variable - db string // Same as PGDATABASE environment variable - pwd string // Same as PGPASSWORD environment variable + Host string // Same as PGHOST environment variable + Port string // Same as PGPORT environment variable + User string // Same as PGUSER environment variable + Db string // Same as PGDATABASE environment variable + Pwd string // Same as PGPASSWORD environment variable } func NewSourceProfileConnectionPostgreSQL(params map[string]string) (SourceProfileConnectionPostgreSQL, error) { @@ -128,20 +128,20 @@ func NewSourceProfileConnectionPostgreSQL(params map[string]string) (SourceProfi // No connection params provided through source-profile. Fetching from env variables. fmt.Printf("Connection parameters not specified in source-profile. Reading from " + "environment variables PGHOST, PGUSER, PGDATABASE, PGPORT, PGPASSWORD...\n") - pg.host = os.Getenv("PGHOST") - pg.user = os.Getenv("PGUSER") - pg.db = os.Getenv("PGDATABASE") - pg.port = os.Getenv("PGPORT") - pg.pwd = os.Getenv("PGPASSWORD") + pg.Host = os.Getenv("PGHOST") + pg.User = os.Getenv("PGUSER") + pg.Db = os.Getenv("PGDATABASE") + pg.Port = os.Getenv("PGPORT") + pg.Pwd = os.Getenv("PGPASSWORD") // Throw error if the input entered is empty. - if pg.host == "" || pg.user == "" || pg.db == "" { + if pg.Host == "" || pg.User == "" || pg.Db == "" { return pg, fmt.Errorf("found empty string for PGHOST/PGUSER/PGDATABASE. Please specify these environment variables with correct values") } } else if hostOk && userOk && dbOk { // All connection params provided through source-profile. Port and password handled later. - pg.host, pg.user, pg.db, pg.port, pg.pwd = host, user, db, port, pwd + pg.Host, pg.User, pg.Db, pg.Port, pg.Pwd = host, user, db, port, pwd // Throw error if the input entered is empty. - if pg.host == "" || pg.user == "" || pg.db == "" { + if pg.Host == "" || pg.User == "" || pg.Db == "" { return pg, fmt.Errorf("found empty string for host/user/db_name. Please specify host, port, user and db_name in the source-profile") } } else { @@ -149,12 +149,12 @@ func NewSourceProfileConnectionPostgreSQL(params map[string]string) (SourceProfi return pg, fmt.Errorf("please specify host, port, user and db_name in the source-profile") } - if pg.port == "" { + if pg.Port == "" { // Set default port for postgresql, which rarely changes. - pg.port = "5432" + pg.Port = "5432" } - if pg.pwd == "" { - pg.pwd = conversion.GetPassword() + if pg.Pwd == "" { + pg.Pwd = utils.GetPassword() } return pg, nil @@ -164,11 +164,11 @@ type SourceProfileConnectionDynamoDB struct { // These connection params are not used currently because the SDK reads directly from the env variables. // These are still kept around as reference when we refactor passing // SourceProfile instead of sqlConnectionStr around. - awsAccessKeyID string // Same as AWS_ACCESS_KEY_ID environment variable - awsSecretAccessKey string // Same as AWS_SECRET_ACCESS_KEY environment variable - awsRegion string // Same as AWS_REGION environment variable - dydbEndpoint string // Same as DYNAMODB_ENDPOINT_OVERRIDE environment variable - schemaSampleSize int64 // Number of rows to use for inferring schema (default 100,000) + AwsAccessKeyID string // Same as AWS_ACCESS_KEY_ID environment variable + AwsSecretAccessKey string // Same as AWS_SECRET_ACCESS_KEY environment variable + AwsRegion string // Same as AWS_REGION environment variable + DydbEndpoint string // Same as DYNAMODB_ENDPOINT_OVERRIDE environment variable + SchemaSampleSize int64 // Number of rows to use for inferring schema (default 100,000) } func NewSourceProfileConnectionDynamoDB(params map[string]string) (SourceProfileConnectionDynamoDB, error) { @@ -178,32 +178,32 @@ func NewSourceProfileConnectionDynamoDB(params map[string]string) (SourceProfile if err != nil { return dydb, fmt.Errorf("could not parse schema-sample-size = %v as a valid int64", schemaSampleSize) } - dydb.schemaSampleSize = int64(schemaSampleSizeInt) + dydb.SchemaSampleSize = int64(schemaSampleSizeInt) } // For DynamoDB, the preferred way to provide connection params is through env variables. // Unlike postgres and mysql, there may not be deprecation of env variables, hence it // is better to override env variables optionally via source profile params. var ok bool - if dydb.awsAccessKeyID, ok = params["aws-access-key-id"]; ok { - os.Setenv("AWS_ACCESS_KEY_ID", dydb.awsAccessKeyID) + if dydb.AwsAccessKeyID, ok = params["aws-access-key-id"]; ok { + os.Setenv("AWS_ACCESS_KEY_ID", dydb.AwsAccessKeyID) } - if dydb.awsSecretAccessKey, ok = params["aws-secret-access-key"]; ok { - os.Setenv("AWS_SECRET_ACCESS_KEY", dydb.awsAccessKeyID) + if dydb.AwsSecretAccessKey, ok = params["aws-secret-access-key"]; ok { + os.Setenv("AWS_SECRET_ACCESS_KEY", dydb.AwsSecretAccessKey) } - if dydb.awsRegion, ok = params["aws-region"]; ok { - os.Setenv("AWS_REGION", dydb.awsAccessKeyID) + if dydb.AwsRegion, ok = params["aws-region"]; ok { + os.Setenv("AWS_REGION", dydb.AwsRegion) } - if dydb.dydbEndpoint, ok = params["dydb-endpoint"]; ok { - os.Setenv("DYNAMODB_ENDPOINT_OVERRIDE", dydb.awsAccessKeyID) + if dydb.DydbEndpoint, ok = params["dydb-endpoint"]; ok { + os.Setenv("DYNAMODB_ENDPOINT_OVERRIDE", dydb.DydbEndpoint) } return dydb, nil } type SourceProfileConnection struct { - ty SourceProfileConnectionType - mysql SourceProfileConnectionMySQL - pg SourceProfileConnectionPostgreSQL - dydb SourceProfileConnectionDynamoDB + Ty SourceProfileConnectionType + Mysql SourceProfileConnectionMySQL + Pg SourceProfileConnectionPostgreSQL + Dydb SourceProfileConnectionDynamoDB } func NewSourceProfileConnection(source string, params map[string]string) (SourceProfileConnection, error) { @@ -212,24 +212,24 @@ func NewSourceProfileConnection(source string, params map[string]string) (Source switch strings.ToLower(source) { case "mysql": { - conn.ty = SourceProfileConnectionTypeMySQL - conn.mysql, err = NewSourceProfileConnectionMySQL(params) + conn.Ty = SourceProfileConnectionTypeMySQL + conn.Mysql, err = NewSourceProfileConnectionMySQL(params) if err != nil { return conn, err } } case "postgresql", "postgres", "pg": { - conn.ty = SourceProfileConnectionTypePostgreSQL - conn.pg, err = NewSourceProfileConnectionPostgreSQL(params) + conn.Ty = SourceProfileConnectionTypePostgreSQL + conn.Pg, err = NewSourceProfileConnectionPostgreSQL(params) if err != nil { return conn, err } } case "dynamodb": { - conn.ty = SourceProfileConnectionTypeDynamoDB - conn.dydb, err = NewSourceProfileConnectionDynamoDB(params) + conn.Ty = SourceProfileConnectionTypeDynamoDB + conn.Dydb, err = NewSourceProfileConnectionDynamoDB(params) if err != nil { return conn, err } @@ -249,10 +249,11 @@ func NewSourceProfileConfig(path string) SourceProfileConfig { } type SourceProfile struct { - ty SourceProfileType - file SourceProfileFile - conn SourceProfileConnection - config SourceProfileConfig + Driver string + Ty SourceProfileType + File SourceProfileFile + Conn SourceProfileConnection + Config SourceProfileConfig } // ToLegacyDriver converts source-profile to equivalent legacy global flags @@ -260,7 +261,7 @@ type SourceProfile struct { // same. TODO: Deprecate this function and pass around SourceProfile across the // codebase wherever information about source connection is required. func (src SourceProfile) ToLegacyDriver(source string) (string, error) { - switch src.ty { + switch src.Ty { case SourceProfileTypeFile: { switch strings.ToLower(source) { @@ -322,19 +323,19 @@ func NewSourceProfile(s string, source string) (SourceProfile, error) { if _, ok := params["file"]; ok || filePipedToStdin() { profile := NewSourceProfileFile(params) - return SourceProfile{ty: SourceProfileTypeFile, file: profile}, nil + return SourceProfile{Ty: SourceProfileTypeFile, File: profile}, nil } else if format, ok := params["format"]; ok { // File is not passed in from stdin or specified using "file" flag. - return SourceProfile{ty: SourceProfileTypeFile}, fmt.Errorf("file not specified, but format set to %v", format) + return SourceProfile{Ty: SourceProfileTypeFile}, fmt.Errorf("file not specified, but format set to %v", format) } else if file, ok := params["config"]; ok { config := NewSourceProfileConfig(file) - return SourceProfile{ty: SourceProfileTypeConfig, config: config}, fmt.Errorf("source-profile type config not yet implemented") + return SourceProfile{Ty: SourceProfileTypeConfig, Config: config}, fmt.Errorf("source-profile type config not yet implemented") } else { // Assume connection profile type connection by default, since // connection parameters could be specified as part of environment // variables. conn, err := NewSourceProfileConnection(source, params) - return SourceProfile{ty: SourceProfileTypeConnection, conn: conn}, err + return SourceProfile{Ty: SourceProfileTypeConnection, Conn: conn}, err } } diff --git a/cmd/source_profile_test.go b/profiles/source_profile_test.go similarity index 91% rename from cmd/source_profile_test.go rename to profiles/source_profile_test.go index 6562ee9822..57fad6e484 100644 --- a/cmd/source_profile_test.go +++ b/profiles/source_profile_test.go @@ -1,4 +1,4 @@ -package cmd +package profiles import ( "testing" @@ -17,31 +17,31 @@ func TestNewSourceProfileFile(t *testing.T) { name: "no params, file piped", params: map[string]string{}, pipedToStdin: true, - want: SourceProfileFile{format: "dump"}, + want: SourceProfileFile{Format: "dump"}, }, { name: "format param, file piped", - params: map[string]string{"format": "dump"}, + params: map[string]string{"Format": "dump"}, pipedToStdin: true, - want: SourceProfileFile{format: "dump"}, + want: SourceProfileFile{Format: "dump"}, }, { name: "format and path param, file piped -- piped file takes precedence", params: map[string]string{"format": "dump", "file": "file1.mysqldump"}, pipedToStdin: true, - want: SourceProfileFile{format: "dump"}, + want: SourceProfileFile{Format: "dump"}, }, { name: "format and path param, no file piped", params: map[string]string{"format": "dump", "file": "file1.mysqldump"}, pipedToStdin: false, - want: SourceProfileFile{format: "dump", path: "file1.mysqldump"}, + want: SourceProfileFile{Format: "dump", Path: "file1.mysqldump"}, }, { name: "only path param, no file piped -- default dump format", params: map[string]string{"file": "file1.mysqldump"}, pipedToStdin: false, - want: SourceProfileFile{format: "dump", path: "file1.mysqldump"}, + want: SourceProfileFile{Format: "dump", Path: "file1.mysqldump"}, }, } diff --git a/cmd/target_profile.go b/profiles/target_profile.go similarity index 97% rename from cmd/target_profile.go rename to profiles/target_profile.go index e5e5b0954b..65a920c97b 100644 --- a/cmd/target_profile.go +++ b/profiles/target_profile.go @@ -1,4 +1,4 @@ -package cmd +package profiles import ( "fmt" @@ -35,8 +35,9 @@ type TargetProfileConnection struct { } type TargetProfile struct { - ty TargetProfileType - conn TargetProfileConnection + TargetDb string + ty TargetProfileType + conn TargetProfileConnection } // ToLegacyTargetDb converts source-profile to equivalent legacy global flag diff --git a/sources/postgres/data.go b/sources/postgres/data.go index da3d237903..a8392816f7 100644 --- a/sources/postgres/data.go +++ b/sources/postgres/data.go @@ -76,7 +76,10 @@ func ConvertData(conv *internal.Conv, srcTable string, srcCols []string, vals [] } for i, spCol := range spCols { srcCol := srcCols[i] - if vals[i] == "\\N" { // PostgreSQL representation of empty column in COPY-FROM blocks. + // "\\N" is for PostgreSQL representation of empty column in COPY-FROM blocks. + // TODO: Consider using NullString to differentiate between an actual column having "NULL" as a string + // and NULL values. + if vals[i] == "\\N" || vals[i] == "NULL" { continue } spColDef, ok1 := spSchema.ColDefs[spCol] diff --git a/sources/postgres/pgdump.go b/sources/postgres/pgdump.go index 4318ae886e..f64cc1f24d 100644 --- a/sources/postgres/pgdump.go +++ b/sources/postgres/pgdump.go @@ -724,6 +724,9 @@ func getRows(conv *internal.Conv, vll []*pg_query.Node, n *pg_query.InsertStmt) switch val := v.GetNode().(type) { case *pg_query.Node_AConst: switch c := val.AConst.Val.GetNode().(type) { + // Most data is dumped enclosed in quotes ('') lke 'abc', '12:30:45' etc which is classified + // as type Node_String_ by the parser. Some data might not be quoted like (NULL, 14.67) and + // the type assigned to them is Node_Null and Node_Float respectively. case *pg_query.Node_String_: values = append(values, trimString(c.String_)) case *pg_query.Node_Integer: @@ -732,6 +735,12 @@ func getRows(conv *internal.Conv, vll []*pg_query.Node, n *pg_query.InsertStmt) // high priority (it isn't right now), then consider preserving int64 // here to avoid the int64 -> string -> int64 conversions. values = append(values, strconv.FormatInt(int64(c.Integer.Ival), 10)) + case *pg_query.Node_Float: + values = append(values, c.Float.Str) + case *pg_query.Node_Null: + values = append(values, "NULL") + // TODO: There might be other Node types like Node_IntList, Node_List, Node_BitString etc that + // need to be checked if they are handled or not. default: conv.Unexpected(fmt.Sprintf("Processing %v statement: found %s node for A_Const Val", printNodeType(n), printNodeType(c))) } diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index ea6a72fa6e..3f5cc052d7 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -158,7 +158,11 @@ type Config struct { func (c Config) quote(s string) string { if c.ProtectIds { - return "`" + s + "`" + if c.TargetDb == constants.TargetExperimentalPostgres { + return "\"" + s + "\"" + } else { + return "`" + s + "`" + } } return s } diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index cfaafceffe..2476702d20 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -88,7 +88,7 @@ func TestPrintColumnDefPG(t *testing.T) { {in: ColumnDef{Name: "col1", T: Type{Name: Int64, IsArray: true}}, expected: "col1 VARCHAR(2621440)"}, {in: ColumnDef{Name: "col1", T: Type{Name: Int64}, NotNull: true}, expected: "col1 INT8 NOT NULL"}, {in: ColumnDef{Name: "col1", T: Type{Name: Int64, IsArray: true}, NotNull: true}, expected: "col1 VARCHAR(2621440) NOT NULL"}, - {in: ColumnDef{Name: "col1", T: Type{Name: Int64}}, protectIds: true, expected: "`col1` INT8"}, + {in: ColumnDef{Name: "col1", T: Type{Name: Int64}}, protectIds: true, expected: "\"col1\" INT8"}, } for _, tc := range tests { s, _ := tc.in.PrintColumnDef(Config{ProtectIds: tc.protectIds, TargetDb: constants.TargetExperimentalPostgres}) @@ -100,15 +100,16 @@ func TestPrintIndexKey(t *testing.T) { tests := []struct { in IndexKey protectIds bool + targetDb string expected string }{ {in: IndexKey{Col: "col1"}, expected: "col1"}, {in: IndexKey{Col: "col1", Desc: true}, expected: "col1 DESC"}, {in: IndexKey{Col: "col1"}, protectIds: true, expected: "`col1`"}, + {in: IndexKey{Col: "col1"}, protectIds: true, targetDb: constants.TargetExperimentalPostgres, expected: "\"col1\""}, } for _, tc := range tests { - assert.Equal(t, tc.expected, tc.in.PrintIndexKey(Config{ProtectIds: tc.protectIds})) - assert.Equal(t, tc.expected, tc.in.PrintIndexKey(Config{ProtectIds: tc.protectIds, TargetDb: constants.TargetExperimentalPostgres})) + assert.Equal(t, tc.expected, tc.in.PrintIndexKey(Config{ProtectIds: tc.protectIds, TargetDb: tc.targetDb})) } } @@ -226,11 +227,11 @@ func TestPrintCreateTablePG(t *testing.T) { "quote", true, t1, - "CREATE TABLE `mytable` (\n" + - " `col1` INT8 NOT NULL,\n" + - " `col2` VARCHAR(2621440),\n" + - " `col3` BYTEA,\n" + - " PRIMARY KEY (`col1` DESC)\n" + + "CREATE TABLE \"mytable\" (\n" + + " \"col1\" INT8 NOT NULL,\n" + + " \"col2\" VARCHAR(2621440),\n" + + " \"col3\" BYTEA,\n" + + " PRIMARY KEY (\"col1\" DESC)\n" + ")", }, { @@ -267,16 +268,18 @@ func TestPrintCreateIndex(t *testing.T) { tests := []struct { name string protectIds bool + targetDb string index CreateIndex expected string }{ - {"no quote non unique", false, ci[0], "CREATE INDEX myindex ON mytable (col1 DESC, col2)"}, - {"quote non unique", true, ci[0], "CREATE INDEX `myindex` ON `mytable` (`col1` DESC, `col2`)"}, - {"unique key", true, ci[1], "CREATE UNIQUE INDEX `myindex2` ON `mytable` (`col1` DESC, `col2`)"}, + {"no quote non unique", false, "", ci[0], "CREATE INDEX myindex ON mytable (col1 DESC, col2)"}, + {"quote non unique", true, "", ci[0], "CREATE INDEX `myindex` ON `mytable` (`col1` DESC, `col2`)"}, + {"unique key", true, "", ci[1], "CREATE UNIQUE INDEX `myindex2` ON `mytable` (`col1` DESC, `col2`)"}, + {"quote non unique PG", true, constants.TargetExperimentalPostgres, ci[0], "CREATE INDEX \"myindex\" ON \"mytable\" (\"col1\" DESC, \"col2\")"}, + {"unique key PG", true, constants.TargetExperimentalPostgres, ci[1], "CREATE UNIQUE INDEX \"myindex2\" ON \"mytable\" (\"col1\" DESC, \"col2\")"}, } for _, tc := range tests { - assert.Equal(t, tc.expected, tc.index.PrintCreateIndex(Config{ProtectIds: tc.protectIds})) - assert.Equal(t, tc.expected, tc.index.PrintCreateIndex(Config{ProtectIds: tc.protectIds, TargetDb: constants.TargetExperimentalPostgres})) + assert.Equal(t, tc.expected, tc.index.PrintCreateIndex(Config{ProtectIds: tc.protectIds, TargetDb: tc.targetDb})) } } @@ -298,16 +301,17 @@ func TestPrintForeignKey(t *testing.T) { tests := []struct { name string protectIds bool + targetDb string expected string fk Foreignkey }{ - {"no quote", false, "CONSTRAINT fk_test FOREIGN KEY (c1, c2) REFERENCES ref_table (ref_c1, ref_c2)", fk[0]}, - {"quote", true, "CONSTRAINT `fk_test` FOREIGN KEY (`c1`, `c2`) REFERENCES `ref_table` (`ref_c1`, `ref_c2`)", fk[0]}, - {"no constraint name", false, "FOREIGN KEY (c1) REFERENCES ref_table (ref_c1)", fk[1]}, + {"no quote", false, "", "CONSTRAINT fk_test FOREIGN KEY (c1, c2) REFERENCES ref_table (ref_c1, ref_c2)", fk[0]}, + {"quote", true, "", "CONSTRAINT `fk_test` FOREIGN KEY (`c1`, `c2`) REFERENCES `ref_table` (`ref_c1`, `ref_c2`)", fk[0]}, + {"no constraint name", false, "", "FOREIGN KEY (c1) REFERENCES ref_table (ref_c1)", fk[1]}, + {"quote PG", true, constants.TargetExperimentalPostgres, "CONSTRAINT \"fk_test\" FOREIGN KEY (\"c1\", \"c2\") REFERENCES \"ref_table\" (\"ref_c1\", \"ref_c2\")", fk[0]}, } for _, tc := range tests { - assert.Equal(t, tc.expected, tc.fk.PrintForeignKey(Config{ProtectIds: tc.protectIds})) - assert.Equal(t, tc.expected, tc.fk.PrintForeignKey(Config{ProtectIds: tc.protectIds, TargetDb: constants.TargetExperimentalPostgres})) + assert.Equal(t, tc.expected, tc.fk.PrintForeignKey(Config{ProtectIds: tc.protectIds, TargetDb: tc.targetDb})) } } @@ -330,16 +334,17 @@ func TestPrintForeignKeyAlterTable(t *testing.T) { name string table string protectIds bool + targetDb string expected string fk Foreignkey }{ - {"no quote", "table1", false, "ALTER TABLE table1 ADD CONSTRAINT fk_test FOREIGN KEY (c1, c2) REFERENCES ref_table (ref_c1, ref_c2)", fk[0]}, - {"quote", "table1", true, "ALTER TABLE `table1` ADD CONSTRAINT `fk_test` FOREIGN KEY (`c1`, `c2`) REFERENCES `ref_table` (`ref_c1`, `ref_c2`)", fk[0]}, - {"no constraint name", "table1", false, "ALTER TABLE table1 ADD FOREIGN KEY (c1) REFERENCES ref_table (ref_c1)", fk[1]}, + {"no quote", "table1", false, "", "ALTER TABLE table1 ADD CONSTRAINT fk_test FOREIGN KEY (c1, c2) REFERENCES ref_table (ref_c1, ref_c2)", fk[0]}, + {"quote", "table1", true, "", "ALTER TABLE `table1` ADD CONSTRAINT `fk_test` FOREIGN KEY (`c1`, `c2`) REFERENCES `ref_table` (`ref_c1`, `ref_c2`)", fk[0]}, + {"no constraint name", "table1", false, "", "ALTER TABLE table1 ADD FOREIGN KEY (c1) REFERENCES ref_table (ref_c1)", fk[1]}, + {"quote PG", "table1", true, constants.TargetExperimentalPostgres, "ALTER TABLE \"table1\" ADD CONSTRAINT \"fk_test\" FOREIGN KEY (\"c1\", \"c2\") REFERENCES \"ref_table\" (\"ref_c1\", \"ref_c2\")", fk[0]}, } for _, tc := range tests { - assert.Equal(t, tc.expected, tc.fk.PrintForeignKeyAlterTable(Config{ProtectIds: tc.protectIds}, tc.table)) - assert.Equal(t, tc.expected, tc.fk.PrintForeignKeyAlterTable(Config{ProtectIds: tc.protectIds, TargetDb: constants.TargetExperimentalPostgres}, tc.table)) + assert.Equal(t, tc.expected, tc.fk.PrintForeignKeyAlterTable(Config{ProtectIds: tc.protectIds, TargetDb: tc.targetDb}, tc.table)) } } diff --git a/testing/dynamodb/integration_test.go b/testing/dynamodb/integration_test.go index 213c3bdc0b..4711ba4bd1 100644 --- a/testing/dynamodb/integration_test.go +++ b/testing/dynamodb/integration_test.go @@ -27,7 +27,7 @@ import ( "time" "github.com/cloudspannerecosystem/harbourbridge/common/constants" - "github.com/cloudspannerecosystem/harbourbridge/conversion" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/testing/common" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" @@ -214,7 +214,7 @@ func TestIntegration_DYNAMODB_Command(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := conversion.GetDatabaseName(constants.DYNAMODB, now) + dbName, _ := utils.GetDatabaseName(constants.DYNAMODB, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName+".") diff --git a/testing/mysql/integration_test.go b/testing/mysql/integration_test.go index 9ef60b2a13..c1367183f5 100644 --- a/testing/mysql/integration_test.go +++ b/testing/mysql/integration_test.go @@ -26,7 +26,7 @@ import ( "time" "github.com/cloudspannerecosystem/harbourbridge/common/constants" - "github.com/cloudspannerecosystem/harbourbridge/conversion" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/testing/common" "cloud.google.com/go/spanner" @@ -110,7 +110,7 @@ func TestIntegration_MYSQLDUMP_Command(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := conversion.GetDatabaseName(constants.MYSQLDUMP, now) + dbName, _ := utils.GetDatabaseName(constants.MYSQLDUMP, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) dataFilepath := "../../test_data/mysqldump.test.out" filePrefix := filepath.Join(tmpdir, dbName+".") @@ -159,7 +159,7 @@ func TestIntegration_MYSQL_Command(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := conversion.GetDatabaseName(constants.MYSQL, now) + dbName, _ := utils.GetDatabaseName(constants.MYSQL, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName+".") diff --git a/testing/postgres/integration_test.go b/testing/postgres/integration_test.go index 767324674d..6696a1511b 100644 --- a/testing/postgres/integration_test.go +++ b/testing/postgres/integration_test.go @@ -30,7 +30,7 @@ import ( "cloud.google.com/go/spanner" database "cloud.google.com/go/spanner/admin/database/apiv1" "github.com/cloudspannerecosystem/harbourbridge/common/constants" - "github.com/cloudspannerecosystem/harbourbridge/conversion" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/testing/common" "google.golang.org/api/iterator" databasepb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" @@ -112,7 +112,7 @@ func TestIntegration_PGDUMP_Command(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := conversion.GetDatabaseName(constants.PGDUMP, now) + dbName, _ := utils.GetDatabaseName(constants.PGDUMP, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) dataFilepath := "../../test_data/pg_dump.test.out" @@ -136,7 +136,7 @@ func TestIntegration_PGDUMP_SchemaAndDataSubcommand(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := conversion.GetDatabaseName(constants.PGDUMP, now) + dbName, _ := utils.GetDatabaseName(constants.PGDUMP, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) dataFilepath := "../../test_data/pg_dump.test.out" @@ -176,7 +176,7 @@ func TestIntegration_POSTGRES_Command(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := conversion.GetDatabaseName(constants.POSTGRES, now) + dbName, _ := utils.GetDatabaseName(constants.POSTGRES, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName+".") @@ -199,7 +199,7 @@ func TestIntegration_POSTGRES_SchemaAndDataSubcommand(t *testing.T) { defer os.RemoveAll(tmpdir) now := time.Now() - dbName, _ := conversion.GetDatabaseName(constants.POSTGRES, now) + dbName, _ := utils.GetDatabaseName(constants.POSTGRES, now) dbURI := fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectID, instanceID, dbName) filePrefix := filepath.Join(tmpdir, dbName+".") diff --git a/web/session.go b/web/session.go index 6f6689457a..c30262de06 100644 --- a/web/session.go +++ b/web/session.go @@ -22,6 +22,7 @@ import ( "os" "time" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" ) @@ -37,12 +38,12 @@ type session struct { } func createSession(w http.ResponseWriter, r *http.Request) { - ioHelper := &conversion.IOStreams{In: os.Stdin, Out: os.Stdout} + ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} now := time.Now() dbName := sessionState.dbName var err error if dbName == "" { - dbName, err = conversion.GetDatabaseName(sessionState.driver, now) + dbName, err = utils.GetDatabaseName(sessionState.driver, now) if err != nil { http.Error(w, fmt.Sprintf("Can not create database name : %v", err), http.StatusInternalServerError) } diff --git a/web/web.go b/web/web.go index 8a2a882692..5fc6cef394 100644 --- a/web/web.go +++ b/web/web.go @@ -35,8 +35,10 @@ import ( "time" "github.com/cloudspannerecosystem/harbourbridge/common/constants" + "github.com/cloudspannerecosystem/harbourbridge/common/utils" "github.com/cloudspannerecosystem/harbourbridge/conversion" "github.com/cloudspannerecosystem/harbourbridge/internal" + "github.com/cloudspannerecosystem/harbourbridge/profiles" "github.com/cloudspannerecosystem/harbourbridge/sources/common" "github.com/cloudspannerecosystem/harbourbridge/sources/mysql" "github.com/cloudspannerecosystem/harbourbridge/sources/postgres" @@ -167,7 +169,12 @@ func convertSchemaDump(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("failed to open dump file %v : %v", dc.FilePath, err), http.StatusNotFound) return } - conv, err := conversion.SchemaConv(dc.Driver, "", constants.TargetSpanner, &conversion.IOStreams{In: f, Out: os.Stdout}, 0) + // We don't support Dynamodb in web hence no need to pass schema sample size here. + sourceProfile, _ := profiles.NewSourceProfile("", dc.Driver) + sourceProfile.Driver = dc.Driver + targetProfile, _ := profiles.NewTargetProfile("") + targetProfile.TargetDb = constants.TargetSpanner + conv, err := conversion.SchemaConv(sourceProfile, targetProfile, &utils.IOStreams{In: f, Out: os.Stdout}) if err != nil { http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) return @@ -402,7 +409,7 @@ func getConversionRate(w http.ResponseWriter, r *http.Request) { // getSchemaFile generates schema file and returns file path. func getSchemaFile(w http.ResponseWriter, r *http.Request) { - ioHelper := &conversion.IOStreams{In: os.Stdin, Out: os.Stdout} + ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} var err error now := time.Now() filePrefix, err := getFilePrefix(now) @@ -421,7 +428,7 @@ func getSchemaFile(w http.ResponseWriter, r *http.Request) { // getReportFile generates report file and returns file path. func getReportFile(w http.ResponseWriter, r *http.Request) { - ioHelper := &conversion.IOStreams{In: os.Stdin, Out: os.Stdout} + ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} var err error now := time.Now() filePrefix, err := getFilePrefix(now) @@ -774,7 +781,7 @@ func dropSecondaryIndex(w http.ResponseWriter, r *http.Request) { // updateSessionFile updates the content of session file with // latest sessionState.conv while also dumping schemas and report. func updateSessionFile() error { - ioHelper := &conversion.IOStreams{In: os.Stdin, Out: os.Stdout} + ioHelper := &utils.IOStreams{In: os.Stdin, Out: os.Stdout} _, err := conversion.WriteConvGeneratedFiles(sessionState.conv, sessionState.dbName, sessionState.driver, ioHelper.BytesRead, ioHelper.Out) if err != nil { return fmt.Errorf("encountered error %w. Cannot write files", err) @@ -1093,7 +1100,7 @@ func getFilePrefix(now time.Time) (string, error) { dbName := sessionState.dbName var err error if dbName == "" { - dbName, err = conversion.GetDatabaseName(sessionState.driver, now) + dbName, err = utils.GetDatabaseName(sessionState.driver, now) if err != nil { return "", fmt.Errorf("Can not create database name : %v", err) }