diff --git a/README.md b/README.md index 794e5c4f..1d863692 100644 --- a/README.md +++ b/README.md @@ -571,6 +571,16 @@ must also set a distribution key with the distkey option.

Since setting usestagingtable=false operation risks data loss / unavailability, we have chosen to deprecate it in favor of requiring users to manually drop the destination table themselves.

+ + include_column_list + No + false + + If true then this library will automatically extract the columns from the schema + and add them to the COPY command according to the Column List docs. + (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`). + + description No diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala index 875f5b75..e897ba7c 100644 --- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala +++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala @@ -38,7 +38,8 @@ private[redshift] object Parameters { "diststyle" -> "EVEN", "usestagingtable" -> "true", "preactions" -> ";", - "postactions" -> ";" + "postactions" -> ";", + "include_column_list" -> "false" ) val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP") @@ -285,5 +286,11 @@ private[redshift] object Parameters { new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken)) } } + + /** + * If true then this library will extract the column list from the schema to + * include in the COPY command (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`) + */ + def includeColumnList: Boolean = parameters("include_column_list").toBoolean } } diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala index 8383231d..784285d8 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala @@ -88,6 +88,7 @@ private[redshift] class RedshiftWriter( */ private def copySql( sqlContext: SQLContext, + schema: StructType, params: MergedParameters, creds: AWSCredentialsProvider, manifestUrl: String): String = { @@ -98,7 +99,13 @@ private[redshift] class RedshiftWriter( case "AVRO" => "AVRO 'auto'" case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'" } - s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " + + val columns = if (params.includeColumnList) { + "(" + schema.fieldNames.map(name => s""""$name"""").mkString(",") + ") " + } else { + "" + } + + s"COPY ${params.table.get} ${columns}FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " + s"${format} manifest ${params.extraCopyOptions}" } @@ -140,7 +147,7 @@ private[redshift] class RedshiftWriter( manifestUrl.foreach { manifestUrl => // Load the temporary data into the new file - val copyStatement = copySql(data.sqlContext, params, creds, manifestUrl) + val copyStatement = copySql(data.sqlContext, data.schema, params, creds, manifestUrl) log.info(copyStatement) try { jdbcWrapper.executeInterruptibly(conn.prepareStatement(copyStatement)) diff --git a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala index e4ed9d14..590b5505 100644 --- a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala @@ -28,7 +28,8 @@ class ParametersSuite extends FunSuite with Matchers { "tempdir" -> "s3://foo/bar", "dbtable" -> "test_schema.test_table", "url" -> "jdbc:redshift://foo/bar?user=user&password=password", - "forward_spark_s3_credentials" -> "true") + "forward_spark_s3_credentials" -> "true", + "include_column_list" -> "true") val mergedParams = Parameters.mergeParameters(params) @@ -37,9 +38,10 @@ class ParametersSuite extends FunSuite with Matchers { mergedParams.jdbcUrl shouldBe params("url") mergedParams.table shouldBe Some(TableName("test_schema", "test_table")) assert(mergedParams.forwardSparkS3Credentials) + assert(mergedParams.includeColumnList) // Check that the defaults have been added - (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials").foreach { + (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials" - "include_column_list").foreach { case (key, value) => mergedParams.parameters(key) shouldBe value } } diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala index ac2a644a..ed2da22d 100644 --- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala +++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala @@ -435,6 +435,27 @@ class RedshiftSourceSuite mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) } + test("Include Column List adds the schema columns to the COPY query") { + val copyCommand = + "COPY \"PUBLIC\".\"test_table\" \\(\"testbyte\",\"testbool\",\"testdate\",\"testdouble\"" + + ",\"testfloat\",\"testint\",\"testlong\",\"testshort\",\"teststring\",\"testtimestamp\"\\) FROM .*" + val expectedCommands = + Seq("CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r, + copyCommand.r) + + val params = defaultParams ++ Map("include_column_list" -> "true") + + val mockRedshift = new MockRedshift( + defaultParams("url"), + Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null)) + + val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client) + source.createRelation(testSqlContext, SaveMode.Append, params, expectedDataDF) + + mockRedshift.verifyThatConnectionsWereClosed() + mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands) + } + test("configuring maxlength on string columns") { val longStrMetadata = new MetadataBuilder().putLong("maxlength", 512).build() val shortStrMetadata = new MetadataBuilder().putLong("maxlength", 10).build()