diff --git a/pom.xml b/pom.xml index fee8118..9f7088f 100644 --- a/pom.xml +++ b/pom.xml @@ -333,6 +333,12 @@ + + com.amazonaws + aws-java-sdk-s3 + 1.10.32 + provided + org.bdgenomics.utils utils-cli_${scala.version.prefix} @@ -360,10 +366,24 @@ 2.2.6 test + + com.github.samtools + htsjdk + 2.9.1 + + + com.github.samtools + htsjdk + + + com.amazonaws + aws-java-sdk-s3 + provided + org.apache.spark spark-core_${scala.version.prefix} diff --git a/src/main/scala/net/fnothaft/copier/CircularBuffer.scala b/src/main/scala/net/fnothaft/copier/CircularBuffer.scala new file mode 100644 index 0000000..e42013e --- /dev/null +++ b/src/main/scala/net/fnothaft/copier/CircularBuffer.scala @@ -0,0 +1,157 @@ +/** + * Copyright 2017 Frank Austin Nothaft + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.fnothaft.copier + +import java.io.{ InputStream, OutputStream } + +class CircularBuffer(bufferSize: Int) { + + val buffer = new Array[Byte](bufferSize) + var start = 0 + var end = 0 + var isEmpty = true + + def entries: Int = { + if (isEmpty) { + 0 + } else if (end == start) { + size + } else if (end > start) { + end - start + } else { + (end + bufferSize) - start + } + } + + def size: Int = buffer.size + + val inputStream: InputStream = new CircularBufferInputStream(this) + val outputStream: OutputStream = new CircularBufferOutputStream(this) +} + +case class CircularBufferInputStream private[copier] ( + buffer: CircularBuffer) extends InputStream { + + var pos: Int = buffer.start + var optMarkPos: Option[Int] = None + var limit: Int = -1 + + override def available(): Int = { + buffer.entries + } + + override def close() { + // no-op + } + + override def mark(readlimit: Int) { + optMarkPos = Some(pos) + limit = pos + readlimit + if (limit > buffer.size) { + limit -= buffer.size + } + } + + override def markSupported(): Boolean = { + true + } + + def read(): Int = { + if (buffer.isEmpty) { + -1 + } else { + + val byteRead = buffer.buffer(pos) + pos += 1 + if (pos >= buffer.size) { + pos = 0 + } + if (pos == buffer.end) { + buffer.isEmpty = true + } + + optMarkPos match { + case Some(markPos) => { + if ((limit > markPos && pos >= limit) || + (limit < markPos && pos >= limit && pos < markPos)) { + optMarkPos = None + buffer.start = pos + } + } + case None => { + buffer.start += 1 + if (buffer.start == buffer.size) { + buffer.start = 0 + } + } + } + + byteRead + } + } + + override def reset() { + optMarkPos match { + case Some(markPos) => { + buffer.isEmpty = false + pos = markPos + } + case None => { + throw new IllegalStateException("Stream was not marked.") + } + } + } + + override def skip(n: Long): Long = { + val bytes = if (pos > buffer.end) { + buffer.end + buffer.size - pos + } else { + buffer.end - pos + } + val toSkip = if (n > bytes) { + n.toInt + } else { + bytes + } + pos += toSkip + toSkip + } +} + +case class CircularBufferOutputStream private[copier] ( + buffer: CircularBuffer) extends OutputStream { + + override def close() { + // no-op + } + + override def flush() { + // no-op + } + + def write(b: Int) { + if ((buffer.start == 0 && buffer.end == buffer.size) || + (buffer.start == buffer.end && !buffer.isEmpty)) { + throw new IllegalStateException("Buffer is full.") + } + buffer.isEmpty = false + buffer.buffer(buffer.end) = b.toByte + buffer.end += 1 + if (buffer.end == buffer.size) { + buffer.end = 0 + } + } +} diff --git a/src/main/scala/net/fnothaft/copier/Copier.scala b/src/main/scala/net/fnothaft/copier/Copier.scala index d8a5953..01b8115 100644 --- a/src/main/scala/net/fnothaft/copier/Copier.scala +++ b/src/main/scala/net/fnothaft/copier/Copier.scala @@ -143,7 +143,7 @@ object Copier extends BDGCommandCompanion with Logging { /** * Downloads the entirety of a file. - * + * * @param is The stream to the file. * @param outFs The file system to write to. * @param outPath The path to write the file to. diff --git a/src/main/scala/net/fnothaft/copier/GzipToBgzfS3Copier.scala b/src/main/scala/net/fnothaft/copier/GzipToBgzfS3Copier.scala new file mode 100644 index 0000000..84769d9 --- /dev/null +++ b/src/main/scala/net/fnothaft/copier/GzipToBgzfS3Copier.scala @@ -0,0 +1,201 @@ +/** + * Copyright 2017 Frank Austin Nothaft + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.fnothaft.copier + +import com.amazonaws.auth._ +import com.amazonaws.auth.profile.ProfileCredentialsProvider +import com.amazonaws.services.s3.AmazonS3Client +import com.amazonaws.services.s3.model._ +import htsjdk.samtools.util.BlockCompressedOutputStream +import java.io.{ InputStream, OutputStream } +import java.net.{ ConnectException, URL, URLConnection } +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.compress.GzipCodec +import org.apache.hadoop.fs.{ FileSystem, Path } +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.bdgenomics.utils.cli._ +import org.bdgenomics.utils.misc.Logging +import org.kohsuke.args4j.{ Argument, Option => Args4jOption } +import scala.collection.JavaConversions._ + +object GzipToBgzfS3Copier extends BDGCommandCompanion with Logging { + + val commandName = "gzipToBgzfS3Copier" + val commandDescription = "Copy Gzip'ed filess to S3, BGZFing them along the way." + + def apply(cmdLine: Array[String]) = { + log.info("Running Copier with arguments: Array(%s)".format( + cmdLine.mkString(","))) + + new GzipToBgzfS3Copier(Args4j[GzipToBgzfS3CopierArgs](cmdLine)) + } + + /** + * Copies a single path to S3 + * + * Gunzips the file, BGZFs it, and uploads it to S3. + * + * @param file The path to download. + * @param outputBucket The output directory to write the file to. + * @param config The Hadoop Configuration values to use. + * @return Returns the number of bytes written to disk. + */ + private def copyFile(file: String, + outputBucket: String, + blockSize: Int, + bufferSize: Int, + config: Map[String, String]): Long = { + + // create our path objects + val filename = file.split("/").last + val inPath = new Path(file) + + // reconstruct the config and get the FS for the output path + val conf = ConfigUtil.makeConfig(config) + val inFs = inPath.getFileSystem(conf) + + // create our streams + val is = inFs.open(inPath) + val gzipIs = new GzipCodec().createInputStream(is) + val buffer = new CircularBuffer(bufferSize) + val bgzfOs = new BlockCompressedOutputStream(buffer.outputStream, null) + + // create the s3 upload + val client = new AmazonS3Client(new ProfileCredentialsProvider()) + val request = new InitiateMultipartUploadRequest(outputBucket, filename) + val uploadId = client.initiateMultipartUpload(request).getUploadId + + // loop and copy blocks + val copyBuffer = new Array[Byte](bufferSize) + var doneCopying = false + var blockIdx = 0 + var tags = List[PartETag]() + var bytesCopied = 0L + while (!doneCopying) { + while (buffer.entries < blockSize) { + val bytesRead = gzipIs.read(copyBuffer) + if (bytesRead != bufferSize) { + doneCopying = true + } + bgzfOs.write(copyBuffer) + } + + val part = new UploadPartRequest() + .withBucketName(outputBucket) + .withKey(filename) + .withUploadId(uploadId) + .withPartNumber(blockIdx) + .withInputStream(buffer.inputStream) + + // TODO: submit part upload + if (doneCopying) { + + val remainingBytes = buffer.entries + tags = client.uploadPart(part.withPartSize(remainingBytes)).getPartETag :: tags + + log.info("Uploaded final part %d of %s/%s with size %d.".format(blockIdx, + outputBucket, + filename, + remainingBytes)) + bytesCopied += remainingBytes + } else { + + tags = client.uploadPart(part.withPartSize(blockSize)).getPartETag :: tags + log.info("Uploaded part %d of %s/%s with size %d.".format(blockIdx, + outputBucket, + filename, + blockSize)) + blockIdx += 1 + bytesCopied += blockSize + } + } + + client.completeMultipartUpload(new CompleteMultipartUploadRequest() + .withBucketName(outputBucket) + .withKey(filename) + .withUploadId(uploadId) + .withPartETags(tags)) + + bytesCopied + } + + /** + * Copies all the URLs in an RDD of URLs to disk. + * + * @param rdd The RDD of URLs to download. + * @param outputBucket The output directory to write the file to. + * @param blockSize The S3 block size to use. + * @param bufferSize The buffer size to use on the output stream. + * @return Returns the total number of bytes written to disk across all files. + */ + def copy(rdd: RDD[String], + outputBucket: String, + blockSize: Int, + bufferSize: Int): Long = { + + // how many files do we have to download? + // we will repartition into one per partition + val files = rdd.count().toInt + + // excise the hadoop conf values + val configMap = ConfigUtil.extractConfig(rdd.context.hadoopConfiguration) + + // repartition and run the downloads + rdd.repartition(files) + .map(copyFile(_, outputBucket, + blockSize, bufferSize, + configMap)) + .reduce(_ + _) + } +} + +class GzipToBgzfS3CopierArgs extends Args4jBase { + @Argument(required = true, + metaVar = "PATHS", + usage = "The paths to download", + index = 0) + var inputPath: String = null + @Argument(required = true, + metaVar = "OUTPUT", + usage = "Location to write the downloaded data", + index = 1) + var outputPath: String = null + @Args4jOption(required = false, + name = "-buffer_size", + usage = "The size of the buffer for writing to the output directory. Defaults to 16Ki.") + var bufferSize = 16 * 1024 + @Args4jOption(required = false, + name = "-block_size", + usage = "The size of the S3 block size for writing. Defaults to 64Mi.") + var blockSize = 64 * 1024 * 1024 +} + +class GzipToBgzfS3Copier( + protected val args: GzipToBgzfS3CopierArgs) extends BDGSparkCommand[GzipToBgzfS3CopierArgs] { + val companion = GzipToBgzfS3Copier + + def run(sc: SparkContext) { + + val files = sc.textFile(args.inputPath) + + companion.copy(files, + args.outputPath, + args.blockSize, + args.bufferSize) + } +} diff --git a/src/test/scala/net/fnothaft/copier/CircularBufferSuite.scala b/src/test/scala/net/fnothaft/copier/CircularBufferSuite.scala new file mode 100644 index 0000000..3c3087f --- /dev/null +++ b/src/test/scala/net/fnothaft/copier/CircularBufferSuite.scala @@ -0,0 +1,123 @@ +/** + * Copyright 2017 Frank Austin Nothaft + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.fnothaft.copier + +import org.scalatest.FunSuite + +class CircularBufferSuite extends FunSuite { + + test("write the full size of the buffer and read out") { + + val buffer = new CircularBuffer(5) + assert(buffer.size === 5) + assert(buffer.entries === 0) + + // write once + buffer.outputStream.write("HELLO".getBytes()) + assert(buffer.entries === 5) + + // buffer is now full + intercept[IllegalStateException] { + buffer.outputStream.write("HELLO".getBytes()) + } + + // read from buffer + val bytes = new Array[Byte](5) + assert(buffer.inputStream.read(bytes) === 5) + assert(buffer.entries === 0) + assert(new String(bytes) === "HELLO") + + // buffer is now empty + assert(buffer.inputStream.read(bytes) === -1) + + // write and read again + buffer.outputStream.write("HELLO".getBytes()) + assert(buffer.entries === 5) + assert(buffer.inputStream.read(bytes) === 5) + assert(buffer.entries === 0) + assert(new String(bytes) === "HELLO") + + // write and read again, but mark the stream this time + buffer.outputStream.write("HELLO".getBytes()) + assert(buffer.entries === 5) + assert(buffer.inputStream.markSupported()) + buffer.inputStream.mark(10) + bytes.indices.foreach(i => bytes(i) = '0') // "zero" out bytes + assert(buffer.inputStream.read(bytes, 0, 3) === 3) + assert(buffer.entries === 5) + assert(new String(bytes) === "HEL00") + buffer.inputStream.reset() + assert(buffer.entries === 5) + bytes.indices.foreach(i => bytes(i) = '0') // "zero" out bytes + assert(buffer.inputStream.read(bytes) === 5) + assert(buffer.entries === 0) + assert(new String(bytes) === "HELLO") + } + + test("partially fill the buffer, read, and then write around") { + + val buffer = new CircularBuffer(5) + assert(buffer.size === 5) + assert(buffer.entries === 0) + + // write once + buffer.outputStream.write("ALLO".getBytes()) + assert(buffer.entries === 4) + + // read from buffer + var bytes = new Array[Byte](4) + assert(buffer.inputStream.read(bytes) === 4) + assert(buffer.entries === 0) + assert(new String(bytes) === "ALLO") + assert(buffer.inputStream.read(bytes) === -1) + + // write again, wraps around buffer + buffer.outputStream.write("ALLO".getBytes()) + assert(buffer.entries === 4) + + // write one character + buffer.outputStream.write("Z".getBytes()) + assert(buffer.entries === 5) + + // buffer is now full + intercept[IllegalStateException] { + buffer.outputStream.write("HELLO".getBytes()) + } + + // read from buffer + assert(buffer.inputStream.markSupported()) + buffer.inputStream.mark(4) + bytes = new Array[Byte](2) + assert(buffer.inputStream.read(bytes) === 2) + assert(buffer.entries === 5) + assert(new String(bytes) === "AL") + buffer.inputStream.reset() + bytes = new Array[Byte](4) + assert(buffer.inputStream.read(bytes) === 4) + assert(buffer.entries === 1) + assert(new String(bytes) === "ALLO") + + bytes = new Array[Byte](1) + assert(buffer.inputStream.read(bytes) === 1) + assert(buffer.entries === 0) + assert(new String(bytes) === "Z") + + // mark should have expired + intercept[IllegalStateException] { + buffer.inputStream.reset() + } + } +}