diff --git a/pom.xml b/pom.xml
index fee8118..9f7088f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -333,6 +333,12 @@
+
+ com.amazonaws
+ aws-java-sdk-s3
+ 1.10.32
+ provided
+
org.bdgenomics.utils
utils-cli_${scala.version.prefix}
@@ -360,10 +366,24 @@
2.2.6
test
+
+ com.github.samtools
+ htsjdk
+ 2.9.1
+
+
+ com.github.samtools
+ htsjdk
+
+
+ com.amazonaws
+ aws-java-sdk-s3
+ provided
+
org.apache.spark
spark-core_${scala.version.prefix}
diff --git a/src/main/scala/net/fnothaft/copier/CircularBuffer.scala b/src/main/scala/net/fnothaft/copier/CircularBuffer.scala
new file mode 100644
index 0000000..e42013e
--- /dev/null
+++ b/src/main/scala/net/fnothaft/copier/CircularBuffer.scala
@@ -0,0 +1,157 @@
+/**
+ * Copyright 2017 Frank Austin Nothaft
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package net.fnothaft.copier
+
+import java.io.{ InputStream, OutputStream }
+
+class CircularBuffer(bufferSize: Int) {
+
+ val buffer = new Array[Byte](bufferSize)
+ var start = 0
+ var end = 0
+ var isEmpty = true
+
+ def entries: Int = {
+ if (isEmpty) {
+ 0
+ } else if (end == start) {
+ size
+ } else if (end > start) {
+ end - start
+ } else {
+ (end + bufferSize) - start
+ }
+ }
+
+ def size: Int = buffer.size
+
+ val inputStream: InputStream = new CircularBufferInputStream(this)
+ val outputStream: OutputStream = new CircularBufferOutputStream(this)
+}
+
+case class CircularBufferInputStream private[copier] (
+ buffer: CircularBuffer) extends InputStream {
+
+ var pos: Int = buffer.start
+ var optMarkPos: Option[Int] = None
+ var limit: Int = -1
+
+ override def available(): Int = {
+ buffer.entries
+ }
+
+ override def close() {
+ // no-op
+ }
+
+ override def mark(readlimit: Int) {
+ optMarkPos = Some(pos)
+ limit = pos + readlimit
+ if (limit > buffer.size) {
+ limit -= buffer.size
+ }
+ }
+
+ override def markSupported(): Boolean = {
+ true
+ }
+
+ def read(): Int = {
+ if (buffer.isEmpty) {
+ -1
+ } else {
+
+ val byteRead = buffer.buffer(pos)
+ pos += 1
+ if (pos >= buffer.size) {
+ pos = 0
+ }
+ if (pos == buffer.end) {
+ buffer.isEmpty = true
+ }
+
+ optMarkPos match {
+ case Some(markPos) => {
+ if ((limit > markPos && pos >= limit) ||
+ (limit < markPos && pos >= limit && pos < markPos)) {
+ optMarkPos = None
+ buffer.start = pos
+ }
+ }
+ case None => {
+ buffer.start += 1
+ if (buffer.start == buffer.size) {
+ buffer.start = 0
+ }
+ }
+ }
+
+ byteRead
+ }
+ }
+
+ override def reset() {
+ optMarkPos match {
+ case Some(markPos) => {
+ buffer.isEmpty = false
+ pos = markPos
+ }
+ case None => {
+ throw new IllegalStateException("Stream was not marked.")
+ }
+ }
+ }
+
+ override def skip(n: Long): Long = {
+ val bytes = if (pos > buffer.end) {
+ buffer.end + buffer.size - pos
+ } else {
+ buffer.end - pos
+ }
+ val toSkip = if (n > bytes) {
+ n.toInt
+ } else {
+ bytes
+ }
+ pos += toSkip
+ toSkip
+ }
+}
+
+case class CircularBufferOutputStream private[copier] (
+ buffer: CircularBuffer) extends OutputStream {
+
+ override def close() {
+ // no-op
+ }
+
+ override def flush() {
+ // no-op
+ }
+
+ def write(b: Int) {
+ if ((buffer.start == 0 && buffer.end == buffer.size) ||
+ (buffer.start == buffer.end && !buffer.isEmpty)) {
+ throw new IllegalStateException("Buffer is full.")
+ }
+ buffer.isEmpty = false
+ buffer.buffer(buffer.end) = b.toByte
+ buffer.end += 1
+ if (buffer.end == buffer.size) {
+ buffer.end = 0
+ }
+ }
+}
diff --git a/src/main/scala/net/fnothaft/copier/Copier.scala b/src/main/scala/net/fnothaft/copier/Copier.scala
index d8a5953..01b8115 100644
--- a/src/main/scala/net/fnothaft/copier/Copier.scala
+++ b/src/main/scala/net/fnothaft/copier/Copier.scala
@@ -143,7 +143,7 @@ object Copier extends BDGCommandCompanion with Logging {
/**
* Downloads the entirety of a file.
- *
+ *
* @param is The stream to the file.
* @param outFs The file system to write to.
* @param outPath The path to write the file to.
diff --git a/src/main/scala/net/fnothaft/copier/GzipToBgzfS3Copier.scala b/src/main/scala/net/fnothaft/copier/GzipToBgzfS3Copier.scala
new file mode 100644
index 0000000..84769d9
--- /dev/null
+++ b/src/main/scala/net/fnothaft/copier/GzipToBgzfS3Copier.scala
@@ -0,0 +1,201 @@
+/**
+ * Copyright 2017 Frank Austin Nothaft
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package net.fnothaft.copier
+
+import com.amazonaws.auth._
+import com.amazonaws.auth.profile.ProfileCredentialsProvider
+import com.amazonaws.services.s3.AmazonS3Client
+import com.amazonaws.services.s3.model._
+import htsjdk.samtools.util.BlockCompressedOutputStream
+import java.io.{ InputStream, OutputStream }
+import java.net.{ ConnectException, URL, URLConnection }
+import org.apache.commons.io.IOUtils
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.compress.GzipCodec
+import org.apache.hadoop.fs.{ FileSystem, Path }
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.bdgenomics.utils.cli._
+import org.bdgenomics.utils.misc.Logging
+import org.kohsuke.args4j.{ Argument, Option => Args4jOption }
+import scala.collection.JavaConversions._
+
+object GzipToBgzfS3Copier extends BDGCommandCompanion with Logging {
+
+ val commandName = "gzipToBgzfS3Copier"
+ val commandDescription = "Copy Gzip'ed filess to S3, BGZFing them along the way."
+
+ def apply(cmdLine: Array[String]) = {
+ log.info("Running Copier with arguments: Array(%s)".format(
+ cmdLine.mkString(",")))
+
+ new GzipToBgzfS3Copier(Args4j[GzipToBgzfS3CopierArgs](cmdLine))
+ }
+
+ /**
+ * Copies a single path to S3
+ *
+ * Gunzips the file, BGZFs it, and uploads it to S3.
+ *
+ * @param file The path to download.
+ * @param outputBucket The output directory to write the file to.
+ * @param config The Hadoop Configuration values to use.
+ * @return Returns the number of bytes written to disk.
+ */
+ private def copyFile(file: String,
+ outputBucket: String,
+ blockSize: Int,
+ bufferSize: Int,
+ config: Map[String, String]): Long = {
+
+ // create our path objects
+ val filename = file.split("/").last
+ val inPath = new Path(file)
+
+ // reconstruct the config and get the FS for the output path
+ val conf = ConfigUtil.makeConfig(config)
+ val inFs = inPath.getFileSystem(conf)
+
+ // create our streams
+ val is = inFs.open(inPath)
+ val gzipIs = new GzipCodec().createInputStream(is)
+ val buffer = new CircularBuffer(bufferSize)
+ val bgzfOs = new BlockCompressedOutputStream(buffer.outputStream, null)
+
+ // create the s3 upload
+ val client = new AmazonS3Client(new ProfileCredentialsProvider())
+ val request = new InitiateMultipartUploadRequest(outputBucket, filename)
+ val uploadId = client.initiateMultipartUpload(request).getUploadId
+
+ // loop and copy blocks
+ val copyBuffer = new Array[Byte](bufferSize)
+ var doneCopying = false
+ var blockIdx = 0
+ var tags = List[PartETag]()
+ var bytesCopied = 0L
+ while (!doneCopying) {
+ while (buffer.entries < blockSize) {
+ val bytesRead = gzipIs.read(copyBuffer)
+ if (bytesRead != bufferSize) {
+ doneCopying = true
+ }
+ bgzfOs.write(copyBuffer)
+ }
+
+ val part = new UploadPartRequest()
+ .withBucketName(outputBucket)
+ .withKey(filename)
+ .withUploadId(uploadId)
+ .withPartNumber(blockIdx)
+ .withInputStream(buffer.inputStream)
+
+ // TODO: submit part upload
+ if (doneCopying) {
+
+ val remainingBytes = buffer.entries
+ tags = client.uploadPart(part.withPartSize(remainingBytes)).getPartETag :: tags
+
+ log.info("Uploaded final part %d of %s/%s with size %d.".format(blockIdx,
+ outputBucket,
+ filename,
+ remainingBytes))
+ bytesCopied += remainingBytes
+ } else {
+
+ tags = client.uploadPart(part.withPartSize(blockSize)).getPartETag :: tags
+ log.info("Uploaded part %d of %s/%s with size %d.".format(blockIdx,
+ outputBucket,
+ filename,
+ blockSize))
+ blockIdx += 1
+ bytesCopied += blockSize
+ }
+ }
+
+ client.completeMultipartUpload(new CompleteMultipartUploadRequest()
+ .withBucketName(outputBucket)
+ .withKey(filename)
+ .withUploadId(uploadId)
+ .withPartETags(tags))
+
+ bytesCopied
+ }
+
+ /**
+ * Copies all the URLs in an RDD of URLs to disk.
+ *
+ * @param rdd The RDD of URLs to download.
+ * @param outputBucket The output directory to write the file to.
+ * @param blockSize The S3 block size to use.
+ * @param bufferSize The buffer size to use on the output stream.
+ * @return Returns the total number of bytes written to disk across all files.
+ */
+ def copy(rdd: RDD[String],
+ outputBucket: String,
+ blockSize: Int,
+ bufferSize: Int): Long = {
+
+ // how many files do we have to download?
+ // we will repartition into one per partition
+ val files = rdd.count().toInt
+
+ // excise the hadoop conf values
+ val configMap = ConfigUtil.extractConfig(rdd.context.hadoopConfiguration)
+
+ // repartition and run the downloads
+ rdd.repartition(files)
+ .map(copyFile(_, outputBucket,
+ blockSize, bufferSize,
+ configMap))
+ .reduce(_ + _)
+ }
+}
+
+class GzipToBgzfS3CopierArgs extends Args4jBase {
+ @Argument(required = true,
+ metaVar = "PATHS",
+ usage = "The paths to download",
+ index = 0)
+ var inputPath: String = null
+ @Argument(required = true,
+ metaVar = "OUTPUT",
+ usage = "Location to write the downloaded data",
+ index = 1)
+ var outputPath: String = null
+ @Args4jOption(required = false,
+ name = "-buffer_size",
+ usage = "The size of the buffer for writing to the output directory. Defaults to 16Ki.")
+ var bufferSize = 16 * 1024
+ @Args4jOption(required = false,
+ name = "-block_size",
+ usage = "The size of the S3 block size for writing. Defaults to 64Mi.")
+ var blockSize = 64 * 1024 * 1024
+}
+
+class GzipToBgzfS3Copier(
+ protected val args: GzipToBgzfS3CopierArgs) extends BDGSparkCommand[GzipToBgzfS3CopierArgs] {
+ val companion = GzipToBgzfS3Copier
+
+ def run(sc: SparkContext) {
+
+ val files = sc.textFile(args.inputPath)
+
+ companion.copy(files,
+ args.outputPath,
+ args.blockSize,
+ args.bufferSize)
+ }
+}
diff --git a/src/test/scala/net/fnothaft/copier/CircularBufferSuite.scala b/src/test/scala/net/fnothaft/copier/CircularBufferSuite.scala
new file mode 100644
index 0000000..3c3087f
--- /dev/null
+++ b/src/test/scala/net/fnothaft/copier/CircularBufferSuite.scala
@@ -0,0 +1,123 @@
+/**
+ * Copyright 2017 Frank Austin Nothaft
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package net.fnothaft.copier
+
+import org.scalatest.FunSuite
+
+class CircularBufferSuite extends FunSuite {
+
+ test("write the full size of the buffer and read out") {
+
+ val buffer = new CircularBuffer(5)
+ assert(buffer.size === 5)
+ assert(buffer.entries === 0)
+
+ // write once
+ buffer.outputStream.write("HELLO".getBytes())
+ assert(buffer.entries === 5)
+
+ // buffer is now full
+ intercept[IllegalStateException] {
+ buffer.outputStream.write("HELLO".getBytes())
+ }
+
+ // read from buffer
+ val bytes = new Array[Byte](5)
+ assert(buffer.inputStream.read(bytes) === 5)
+ assert(buffer.entries === 0)
+ assert(new String(bytes) === "HELLO")
+
+ // buffer is now empty
+ assert(buffer.inputStream.read(bytes) === -1)
+
+ // write and read again
+ buffer.outputStream.write("HELLO".getBytes())
+ assert(buffer.entries === 5)
+ assert(buffer.inputStream.read(bytes) === 5)
+ assert(buffer.entries === 0)
+ assert(new String(bytes) === "HELLO")
+
+ // write and read again, but mark the stream this time
+ buffer.outputStream.write("HELLO".getBytes())
+ assert(buffer.entries === 5)
+ assert(buffer.inputStream.markSupported())
+ buffer.inputStream.mark(10)
+ bytes.indices.foreach(i => bytes(i) = '0') // "zero" out bytes
+ assert(buffer.inputStream.read(bytes, 0, 3) === 3)
+ assert(buffer.entries === 5)
+ assert(new String(bytes) === "HEL00")
+ buffer.inputStream.reset()
+ assert(buffer.entries === 5)
+ bytes.indices.foreach(i => bytes(i) = '0') // "zero" out bytes
+ assert(buffer.inputStream.read(bytes) === 5)
+ assert(buffer.entries === 0)
+ assert(new String(bytes) === "HELLO")
+ }
+
+ test("partially fill the buffer, read, and then write around") {
+
+ val buffer = new CircularBuffer(5)
+ assert(buffer.size === 5)
+ assert(buffer.entries === 0)
+
+ // write once
+ buffer.outputStream.write("ALLO".getBytes())
+ assert(buffer.entries === 4)
+
+ // read from buffer
+ var bytes = new Array[Byte](4)
+ assert(buffer.inputStream.read(bytes) === 4)
+ assert(buffer.entries === 0)
+ assert(new String(bytes) === "ALLO")
+ assert(buffer.inputStream.read(bytes) === -1)
+
+ // write again, wraps around buffer
+ buffer.outputStream.write("ALLO".getBytes())
+ assert(buffer.entries === 4)
+
+ // write one character
+ buffer.outputStream.write("Z".getBytes())
+ assert(buffer.entries === 5)
+
+ // buffer is now full
+ intercept[IllegalStateException] {
+ buffer.outputStream.write("HELLO".getBytes())
+ }
+
+ // read from buffer
+ assert(buffer.inputStream.markSupported())
+ buffer.inputStream.mark(4)
+ bytes = new Array[Byte](2)
+ assert(buffer.inputStream.read(bytes) === 2)
+ assert(buffer.entries === 5)
+ assert(new String(bytes) === "AL")
+ buffer.inputStream.reset()
+ bytes = new Array[Byte](4)
+ assert(buffer.inputStream.read(bytes) === 4)
+ assert(buffer.entries === 1)
+ assert(new String(bytes) === "ALLO")
+
+ bytes = new Array[Byte](1)
+ assert(buffer.inputStream.read(bytes) === 1)
+ assert(buffer.entries === 0)
+ assert(new String(bytes) === "Z")
+
+ // mark should have expired
+ intercept[IllegalStateException] {
+ buffer.inputStream.reset()
+ }
+ }
+}