diff --git a/README.md b/README.md
index 794e5c4f..1d863692 100644
--- a/README.md
+++ b/README.md
@@ -571,6 +571,16 @@ must also set a distribution key with the distkey option.
Since setting usestagingtable=false operation risks data loss / unavailability, we have chosen to deprecate it in favor of requiring users to manually drop the destination table themselves.
+
+ | include_column_list |
+ No |
+ false |
+
+ If true then this library will automatically extract the columns from the schema
+ and add them to the COPY command according to the Column List docs.
+ (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`).
+ |
+
| description |
No |
diff --git a/src/main/scala/com/databricks/spark/redshift/Parameters.scala b/src/main/scala/com/databricks/spark/redshift/Parameters.scala
index 875f5b75..e897ba7c 100644
--- a/src/main/scala/com/databricks/spark/redshift/Parameters.scala
+++ b/src/main/scala/com/databricks/spark/redshift/Parameters.scala
@@ -38,7 +38,8 @@ private[redshift] object Parameters {
"diststyle" -> "EVEN",
"usestagingtable" -> "true",
"preactions" -> ";",
- "postactions" -> ";"
+ "postactions" -> ";",
+ "include_column_list" -> "false"
)
val VALID_TEMP_FORMATS = Set("AVRO", "CSV", "CSV GZIP")
@@ -285,5 +286,11 @@ private[redshift] object Parameters {
new BasicSessionCredentials(accessKey, secretAccessKey, sessionToken))
}
}
+
+ /**
+ * If true then this library will extract the column list from the schema to
+ * include in the COPY command (e.g. `COPY "PUBLIC"."tablename" ("column1" [,"column2", ...])`)
+ */
+ def includeColumnList: Boolean = parameters("include_column_list").toBoolean
}
}
diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
index 8383231d..784285d8 100644
--- a/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
+++ b/src/main/scala/com/databricks/spark/redshift/RedshiftWriter.scala
@@ -88,6 +88,7 @@ private[redshift] class RedshiftWriter(
*/
private def copySql(
sqlContext: SQLContext,
+ schema: StructType,
params: MergedParameters,
creds: AWSCredentialsProvider,
manifestUrl: String): String = {
@@ -98,7 +99,13 @@ private[redshift] class RedshiftWriter(
case "AVRO" => "AVRO 'auto'"
case csv if csv == "CSV" || csv == "CSV GZIP" => csv + s" NULL AS '${params.nullString}'"
}
- s"COPY ${params.table.get} FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
+ val columns = if (params.includeColumnList) {
+ "(" + schema.fieldNames.map(name => s""""$name"""").mkString(",") + ") "
+ } else {
+ ""
+ }
+
+ s"COPY ${params.table.get} ${columns}FROM '$fixedUrl' CREDENTIALS '$credsString' FORMAT AS " +
s"${format} manifest ${params.extraCopyOptions}"
}
@@ -140,7 +147,7 @@ private[redshift] class RedshiftWriter(
manifestUrl.foreach { manifestUrl =>
// Load the temporary data into the new file
- val copyStatement = copySql(data.sqlContext, params, creds, manifestUrl)
+ val copyStatement = copySql(data.sqlContext, data.schema, params, creds, manifestUrl)
log.info(copyStatement)
try {
jdbcWrapper.executeInterruptibly(conn.prepareStatement(copyStatement))
diff --git a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala
index e4ed9d14..590b5505 100644
--- a/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala
+++ b/src/test/scala/com/databricks/spark/redshift/ParametersSuite.scala
@@ -28,7 +28,8 @@ class ParametersSuite extends FunSuite with Matchers {
"tempdir" -> "s3://foo/bar",
"dbtable" -> "test_schema.test_table",
"url" -> "jdbc:redshift://foo/bar?user=user&password=password",
- "forward_spark_s3_credentials" -> "true")
+ "forward_spark_s3_credentials" -> "true",
+ "include_column_list" -> "true")
val mergedParams = Parameters.mergeParameters(params)
@@ -37,9 +38,10 @@ class ParametersSuite extends FunSuite with Matchers {
mergedParams.jdbcUrl shouldBe params("url")
mergedParams.table shouldBe Some(TableName("test_schema", "test_table"))
assert(mergedParams.forwardSparkS3Credentials)
+ assert(mergedParams.includeColumnList)
// Check that the defaults have been added
- (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials").foreach {
+ (Parameters.DEFAULT_PARAMETERS - "forward_spark_s3_credentials" - "include_column_list").foreach {
case (key, value) => mergedParams.parameters(key) shouldBe value
}
}
diff --git a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
index ac2a644a..ed2da22d 100644
--- a/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
+++ b/src/test/scala/com/databricks/spark/redshift/RedshiftSourceSuite.scala
@@ -435,6 +435,27 @@ class RedshiftSourceSuite
mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
}
+ test("Include Column List adds the schema columns to the COPY query") {
+ val copyCommand =
+ "COPY \"PUBLIC\".\"test_table\" \\(\"testbyte\",\"testbool\",\"testdate\",\"testdouble\"" +
+ ",\"testfloat\",\"testint\",\"testlong\",\"testshort\",\"teststring\",\"testtimestamp\"\\) FROM .*"
+ val expectedCommands =
+ Seq("CREATE TABLE IF NOT EXISTS \"PUBLIC\".\"test_table\" .*".r,
+ copyCommand.r)
+
+ val params = defaultParams ++ Map("include_column_list" -> "true")
+
+ val mockRedshift = new MockRedshift(
+ defaultParams("url"),
+ Map(TableName.parseFromEscaped(defaultParams("dbtable")).toString -> null))
+
+ val source = new DefaultSource(mockRedshift.jdbcWrapper, _ => mockS3Client)
+ source.createRelation(testSqlContext, SaveMode.Append, params, expectedDataDF)
+
+ mockRedshift.verifyThatConnectionsWereClosed()
+ mockRedshift.verifyThatExpectedQueriesWereIssued(expectedCommands)
+ }
+
test("configuring maxlength on string columns") {
val longStrMetadata = new MetadataBuilder().putLong("maxlength", 512).build()
val shortStrMetadata = new MetadataBuilder().putLong("maxlength", 10).build()