diff --git a/performance/bdg_perf/bdg_perf_sequila.scala b/performance/bdg_perf/bdg_perf_sequila.scala index 1b519589..be01906d 100644 --- a/performance/bdg_perf/bdg_perf_sequila.scala +++ b/performance/bdg_perf/bdg_perf_sequila.scala @@ -32,6 +32,7 @@ ss.sql(s""" |OPTIONS (path "${bedPath}", delimiter "\t")""".stripMargin) ss.sqlContext.setConf("spark.biodatageeks.bam.predicatePushdown","true") +ss.sqlContext.setConf("spark.biodatageeks.window.optimization", "true") val queries = Array( BDGQuery("bdg_seq_count_NA12878","SELECT COUNT(*) FROM reads WHERE sampleId='NA12878'"), @@ -48,7 +49,8 @@ val queries = Array( | ) | GROUP BY targets.contigName,targets.start,targets.end """.stripMargin), - BDGQuery("bdg_cov_window_fix_length_100_count_NA12878","SELECT COUNT(*) FROM bdg_coverage ('reads','NA12878', 'bases','500')") + BDGQuery("bdg_cov_window_fix_length_100_count_NA12878","SELECT COUNT(*) FROM bdg_coverage ('reads','NA12878', 'blocks','500')"), + BDGQuery("bdg_cov_window_bed_file_targets_count_NA12878", "SELECT COUNT(*) FROM bdg_coverage ('reads', 'NA12878', 'blocks', 'targets')") ) BDGPerfRunner.run(ss,queries) diff --git a/src/main/scala/org/biodatageeks/catalyst/utvf/ResolveTableValuedFunctionsSeq.scala b/src/main/scala/org/biodatageeks/catalyst/utvf/ResolveTableValuedFunctionsSeq.scala index 8c20443e..515780f4 100644 --- a/src/main/scala/org/biodatageeks/catalyst/utvf/ResolveTableValuedFunctionsSeq.scala +++ b/src/main/scala/org/biodatageeks/catalyst/utvf/ResolveTableValuedFunctionsSeq.scala @@ -93,7 +93,8 @@ object ResolveTableValuedFunctionsSeq extends Rule[LogicalPlan] { tvf("table" -> StringType, "sampleId" -> StringType, "result" -> StringType, "target" -> StringType) { case Seq(table: Any,sampleId:Any, result:Any, target: Any) => BDGCoverage(table.toString,sampleId.toString,result.toString, Some(target.toString)) - }), + } + ), "range" -> Map( /* range(end) */ tvf("end" -> LongType) { case Seq(end: Long) => diff --git a/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageMethods.scala b/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageMethods.scala index 61d6bcc1..12b88145 100644 --- a/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageMethods.scala +++ b/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageMethods.scala @@ -2,12 +2,14 @@ package org.biodatageeks.preprocessing.coverage -import htsjdk.samtools.{BAMFileReader, CigarOperator, SamReaderFactory, ValidationStringency} +import htsjdk.samtools.{CigarOperator, _} +import org.apache.log4j.Logger +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.biodatageeks.utils.BDGInternalParams + import scala.collection.mutable -import htsjdk.samtools._ -import org.apache.spark.broadcast.Broadcast -import org.apache.log4j.Logger abstract class AbstractCovRecord { @@ -171,7 +173,6 @@ object CoverageMethodsMos { val winLen = windowLength.get val windowStart = ( (i+posShift) / winLen) * windowLength.get val windowEnd = windowStart + winLen - 1 - // val lastWindowLength = (i + posShift) % winLen // (i + posShift -1) % winLen // -1 current val lastWindowLength = (i + posShift) % winLen - 1 // HACK to fix last window (omit last element) sum -= cov // HACK to fix last window (substract last element) @@ -181,7 +182,26 @@ object CoverageMethodsMos { indexShift } - def eventsToCoverage(sampleId:String, events: RDD[(String,(Array[Short],Int,Int,Int))], + @inline def addLastOptimizedWindow(contig: String, contigMax: Int, windowLength: Option[Int], posShift: Int, i:Int, covSum: Int, cov: Int, ind: Int, result: Array[AbstractCovRecord]) = { + var indexShift = ind + var sum = covSum + + if ((i + posShift - 1) == contigMax) { // add last window + val winLen = windowLength.get + val windowStart = ((i + posShift) / winLen) * winLen + val windowEnd = windowStart + winLen + val lastWindowLength = (i + posShift) % winLen - 1// HACK to fix last window (omit last element) + sum -= cov // HACK to fix last window (substract last element) + + result(ind) = CovRecordWindow(contig, windowStart, windowEnd, sum / lastWindowLength.toFloat, Some(lastWindowLength)) + + indexShift += 1 + } + + indexShift + } + + def eventsToCoverage(sampleId:String, events: RDD[(String, (Array[Short], Int, Int, Int, mutable.HashMap[Int, Int]))], contigMinMap: mutable.HashMap [String,(Int,Int)], blocksResult:Boolean, allPos: Boolean, windowLength: Option[Int], targetsTable:Option[String]) : RDD[AbstractCovRecord] = { events @@ -209,8 +229,12 @@ object CoverageMethodsMos { var prevCov = 0 var blockLength = 0 + val session = SparkSession.builder().getOrCreate() + val optimizeWindow = session.sqlContext.getConf(BDGInternalParams.OptimizationWindow,"false") + + val targetsTab = targetsTable.getOrElse(None) - if (windowLength.isEmpty) { // BLOCKS & BASES (NON-WINDOW) COVERAGE CALCULATIONS + if (windowLength.isEmpty && targetsTable.isEmpty) { // BLOCKS & BASES (NON-WINDOW) COVERAGE CALCULATIONS ind = addFirstBlock(contig, contigMinMap(contig)._1, posShift, blocksResult, allPos, ind, result) // add first block if necessary (if current positionshift is equal to the earliest read in the contig) while (i < covArrayLength) { @@ -237,7 +261,7 @@ object CoverageMethodsMos { result.take(ind).iterator - } else { // FIXED - WINDOW COVERAGE CALCULATIONS + } else if (windowLength.isDefined && targetsTable.isEmpty && optimizeWindow == "false") { // FIXED - WINDOW COVERAGE CALCULATIONS while (i < covArrayLength) { cov += r._2._1(i) @@ -249,6 +273,7 @@ object CoverageMethodsMos { val winLen = windowLength.get val windowStart = (((i + posShift) / winLen) - 1) * winLen val windowEnd = windowStart + winLen - 1 + result(ind) = CovRecordWindow(contig, windowStart, windowEnd, covSum / length.toFloat, Some(length)) covSum = 0 ind += 1 @@ -261,6 +286,197 @@ object CoverageMethodsMos { ind = addLastWindow(contig, windowLength, posShift, i, covSum, cov, ind, result) result.take(ind).iterator + } else if (windowLength.isDefined && targetsTable.isEmpty && optimizeWindow == "true"){ // optimized fixed windows + while (i < covArrayLength) { + cov += r._2._1(i) + + if ((i + posShift) % windowLength.get == 0 && (i + posShift) > 0) { + var length = + if (i < windowLength.get) i + else windowLength.get + + val winLen = windowLength.get + val windowStart = (((i + posShift) / winLen) - 1) * winLen + val windowEnd = windowStart + winLen - 1 + if (r._2._5.contains(windowStart)) { + covSum += r._2._5(windowStart) + length = winLen + } + + result(ind) = CovRecordWindow(contig, windowStart, windowEnd, covSum / length.toFloat, Some(length)) + covSum = 0 + ind += 1 + } + + covSum += cov + i += 1 + + } + + ind = addLastOptimizedWindow(contig, contigMinMap(contig)._2, windowLength, posShift, i, covSum, cov, ind, result) + + result.take(ind).iterator + + } else if (windowLength.isEmpty && targetsTable.isDefined && optimizeWindow == "true"){ + + val session = SparkSession.builder().getOrCreate() + + val targets = session.sql(s"SELECT * FROM ${targetsTab}") + + val targetsCount = targets.count + val targetsRowCollection = targets.rdd.map(x => x).collect() + for (j <- 0 until targetsCount.toInt) { + val row = targetsRowCollection(j) + + val targetContig = row(0).asInstanceOf[String] + val targetStart = row(1).asInstanceOf[Int] + val targetEnd = row(2).asInstanceOf[Int] + + val targetId = row(3).asInstanceOf[String] + + val previousTargetContig = + if (j > 0) targetsRowCollection(j - 1)(0).asInstanceOf[String] + else targetContig + + val previousTargetEnd = + if (j > 0 && previousTargetContig == targetContig) targetsRowCollection(j - 1)(2).asInstanceOf[Int] + else 0 + + if (previousTargetContig != targetContig) { + covSum = 0 + cov = 0 + + } + + val windowLen = targetEnd - targetStart + 1 + if (targetContig == contig) { + + if (targetStart > previousTargetEnd) { //shift over gap + for (k <- previousTargetEnd + 1 until targetStart) { + if (k - posShift + 1 < covArrayLength && k - posShift + 1 >= 0) { + cov += r._2._1(k - posShift + 1) + } + covSum = cov + } + } else { + for (k <- previousTargetEnd until targetStart - 1 by -1) { // shift back -> targets are crossing + if (k - posShift + 1 < covArrayLength && k - posShift + 1 >= 0) { + cov -= r._2._1(k - posShift + 1) + } + } + covSum = cov + } + + for (i <- targetStart to targetEnd) { + val iter = i - posShift + 1 + if (iter < covArrayLength && iter >= 0) { + cov += r._2._1(iter) + + if (i == targetEnd && i > 0) { + var length = + if (iter < windowLen) iter + else windowLen + + if (length < windowLen) { + if (r._2._5.contains(targetStart)) { + covSum += r._2._5(targetStart) + length = windowLen + } + } + + result(ind) = CovRecordWindow(targetContig, targetStart, targetEnd, covSum / length.toFloat, Some(length)) + + ind += 1 + covSum = 0 + } + covSum += cov + } + else if (iter == covArrayLength) { + ind = addLastOptimizedWindow(contig, contigMinMap(contig)._2, Some(windowLen), posShift, iter, covSum, cov, ind, result) + covSum = 0 + } + } + } + } + + result.take(ind).iterator + } else { + + val session = SparkSession.builder().getOrCreate() + + val targets = session.sql(s"SELECT * FROM ${targetsTab}") + + val targetsCount = targets.count + val targetsRowCollection = targets.rdd.map(x => x).collect() + for (j <- 0 until targetsCount.toInt) { + val row = targetsRowCollection(j) + + val targetContig = row(0).asInstanceOf[String] + val targetStart = row(1).asInstanceOf[Int] + val targetEnd = row(2).asInstanceOf[Int] + + val targetId = row(3).asInstanceOf[String] + + val previousTargetContig = + if (j > 0) targetsRowCollection(j - 1)(0).asInstanceOf[String] + else targetContig + + val previousTargetEnd = + if (j > 0 && previousTargetContig == targetContig) targetsRowCollection(j - 1)(2).asInstanceOf[Int] + else 0 + + if (previousTargetContig != targetContig) { + covSum = 0 + cov = 0 + + } + + val windowLen = targetEnd - targetStart + 1 + if (targetContig == contig) { + + if (targetStart > previousTargetEnd) { //shift over gap + for (k <- previousTargetEnd + 1 until targetStart) { + if (k - posShift + 1 < covArrayLength && k - posShift + 1 >= 0) { + cov += r._2._1(k - posShift + 1) + } + covSum = cov + } + } else { + for (k <- previousTargetEnd until targetStart - 1 by -1) { // shift back -> targets are crossing + if (k - posShift + 1 < covArrayLength && k - posShift + 1 >= 0) { + cov -= r._2._1(k - posShift + 1) + } + } + covSum = cov + } + + for (i <- targetStart to targetEnd) { + val iter = i - posShift + 1 + if (iter < covArrayLength && iter >= 0) { + cov += r._2._1(iter) + + if (i == targetEnd && i > 0) { + val length = + if (iter < windowLen) iter + else windowLen + + result(ind) = CovRecordWindow(targetContig, targetStart, targetEnd, covSum / length.toFloat, Some(length)) + + ind += 1 + covSum = 0 + } + covSum += cov + } + else if (iter == covArrayLength) { + ind = addLastWindow(contig, Some(windowLen), posShift, iter, covSum, cov, ind, result) + covSum = 0 + } + } + } + } + + result.take(ind).iterator + } }) @@ -274,12 +490,14 @@ object CoverageMethodsMos { val upd = b.value.upd val shrink = b.value.shrink val(contig,(eventsArray,minPos,maxPos,contigLength,maxCigarLength)) = c // to REFACTOR + var beginChange = 0 val updArray = upd.get( (contig,minPos) ) match { // check if there is a value for contigName and minPos in upd, returning array of coverage and cumSum to update current contigRange case Some((arr,covSum)) => { // array of covs and cumSum arr match { case Some(overlapArray) => { var i = 0 + beginChange = eventsArray(0) eventsArray(i) = (eventsArray(i) + covSum).toShort // add cumSum to zeroth element while (i < overlapArray.length) { @@ -304,8 +522,23 @@ object CoverageMethodsMos { } case None => updArray } - (contig, (shrinkArray, minPos, maxPos, contigLength) ) + (contig, (shrinkArray, minPos, maxPos, contigLength, beginChange) ) } } } + + def upateContigSumRange(b:Broadcast[UpdateSumStruct],reducedEvents: RDD[(String,(Array[Short],Int,Int,Int,Int))]) = { + reducedEvents.map{ + c => { + val upd = b.value.upd + + val (contig, (eventsArray, minPos, maxPos, contigLength, firstEvent)) = c // to REFACTOR + + if(upd.get( (contig,minPos) ).nonEmpty) + (contig, (eventsArray, minPos, maxPos, contigLength, upd((contig, minPos)))) + else + (contig, (eventsArray, minPos, maxPos, contigLength, new mutable.HashMap[Int, Int]())) + } + } + } } diff --git a/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageStrategy.scala b/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageStrategy.scala index 721bec0b..f864389e 100644 --- a/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageStrategy.scala +++ b/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageStrategy.scala @@ -1,14 +1,15 @@ package org.biodatageeks.preprocessing.coverage import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql._ +import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.storage.StorageLevel import org.apache.spark.unsafe.types.UTF8String -import org.biodatageeks.datasources.BAM.{BDGAlignFileReaderWriter} +import org.biodatageeks.datasources.BAM.BDGAlignFileReaderWriter import org.biodatageeks.datasources.BDGInputDataType import org.biodatageeks.inputformats.BDGAlignInputFormat import org.biodatageeks.utils.{BDGInternalParams, BDGTableFuncs} @@ -17,6 +18,7 @@ import org.seqdoop.hadoop_bam.{BAMBDGInputFormat, CRAMBDGInputFormat} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import scala.util.Try class CoverageStrategy(spark: SparkSession) extends Strategy with Serializable { @@ -52,6 +54,9 @@ case class UpdateStruct( shrink:mutable.HashMap[(String,Int),(Int)], minmax:mutable.HashMap[String,(Int,Int)] ) +case class UpdateSumStruct( + upd:mutable.HashMap[(String,Int), mutable.HashMap[Int, Int]] + ) @@ -67,7 +72,7 @@ case class BDGCoveragePlan [T<:BDGAlignInputFormat](plan: LogicalPlan, spark: Sp spark .sparkContext .getPersistentRDDs - .filter((t)=> t._2.name==BDGInternalParams.RDDEventsName) + .filter(t=> t._2.name==BDGInternalParams.RDDEventsName) .foreach(_._2.unpersist()) val schema = plan.schema @@ -163,9 +168,48 @@ case class BDGCoveragePlan [T<:BDGAlignInputFormat](plan: LogicalPlan, spark: Sp UpdateStruct(updateMap, shrinkMap, minmax) } - val covBroad = spark.sparkContext.broadcast(prepareBroadcast(acc.value())) + def prepareSumBroadcast(a: CovSumUpdate, reducedEvents: RDD[(String, (Array[Short], Int, Int, Int, Int))]) = { - lazy val reducedEvents = CoverageMethodsMos.upateContigRange(covBroad, events) + val contigRanges = a.sumArray + val updateArray = reducedEvents + val updateMap = new mutable.HashMap[(String, Int), mutable.HashMap[Int, Int]]() + + contigRanges.foreach { + c => + val contig = c.contigName + + val filteredResult = updateArray + .filter(f => f._1 == c.contigName && f._2._2 == c.maxPos) + + if (!filteredResult.isEmpty()) { + val currentRDD=filteredResult + .first() + val firstElem = currentRDD + ._2 + ._1(0) + + var tempElem = firstElem.toInt - currentRDD._2._5 // remove change of first elem + var cumSum = 0 + + c.cov.reverse.foreach(x => { + tempElem -= x + cumSum += tempElem + }) + + val startPoint = c.maxPos - c.cov.length + + var currentHashMap = if (!updateMap.contains((contig, c.maxPos))) + new mutable.HashMap[Int, Int]() + else + updateMap((contig, c.maxPos)) + + currentHashMap += startPoint -> cumSum + updateMap += (contig, c.maxPos) -> currentHashMap + } + } + + UpdateSumStruct(updateMap) + } val blocksResult = { result.toLowerCase() match { @@ -178,45 +222,132 @@ case class BDGCoveragePlan [T<:BDGAlignInputFormat](plan: LogicalPlan, spark: Sp } } val allPos = spark.sqlContext.getConf(BDGInternalParams.ShowAllPositions, "false").toBoolean + val optimizeWindow = spark.sqlContext.getConf(BDGInternalParams.OptimizationWindow, "false").toBoolean + + val targetType = target match { + case Some(t) => if (Try(t.toInt).isSuccess) { + IntegerType + } else { + StringType + } + case _ => None + } // determine if target is fixed length windows or windows from table - //check if it's a window length or a table name - val maybeWindowLength = + + val windowLength = try { target match { case Some(t) => Some(t.toInt) case _ => None } } catch { - case e: Exception => None + case _: Exception => None } + val targetVal = + if (targetType.equals(IntegerType)) + windowLength + else if (targetType.equals(StringType)) + target + else + None + val covBroad = spark.sparkContext.broadcast(prepareBroadcast(acc.value())) + + lazy val reducedEvents = CoverageMethodsMos.upateContigRange(covBroad, events) + + val covSumUpdate = new CovSumUpdate(new ArrayBuffer[RightCovSumEdge]()) + val covSumAcc = new CovSumAccumulatorV2(covSumUpdate) + + spark + .sparkContext + .register(covSumAcc, "CoverageSumAcc") + + reducedEvents + .persist(StorageLevel.MEMORY_AND_DISK) + .foreach(x => { + val maxC = x._2._2 + x._2._1.length - 1 + if (!targetType.equals(StringType)) { + val maxCounter = if (targetType.equals(IntegerType)) + maxC % windowLength.get + else // Blocks and bases for faster computing + 1 + val rightCovSumEdge = RightCovSumEdge(x._1, maxC, x._2._1.takeRight(maxCounter)) + val rightCovSumUpdate = new CovSumUpdate(ArrayBuffer(rightCovSumEdge)) + covSumAcc.add(rightCovSumUpdate) + } else { //Targets from table + val targetsTab = target.get + val session = SparkSession.builder().getOrCreate() + val crossingTargets = session.sql(s"SELECT * FROM ${targetsTab} WHERE start < ${maxC} AND end > ${maxC} AND contigName = '${x._1}'") + crossingTargets + .collect() + .foreach(targetRow => { + val maxCounter = maxC - targetRow.get(1).asInstanceOf[Int] + val rightCovSumEdge = RightCovSumEdge(x._1, maxC, x._2._1.takeRight(maxCounter)) + val rightCovSumUpdate = new CovSumUpdate(ArrayBuffer(rightCovSumEdge)) + covSumAcc.add(rightCovSumUpdate) + }) + } + }) + + lazy val covSumBroad = spark.sparkContext.broadcast(prepareSumBroadcast(covSumAcc.value(), reducedEvents)) + + lazy val reducedSumEvents = CoverageMethodsMos.upateContigSumRange(covSumBroad, reducedEvents) - lazy val cov = - if(maybeWindowLength != None) //fixed-length window - CoverageMethodsMos.eventsToCoverage(sampleId, reducedEvents, covBroad.value.minmax, blocksResult, allPos,maybeWindowLength,None) + lazy val cov = + if (targetType.equals(IntegerType) && !optimizeWindow) // fixed-length window + CoverageMethodsMos.eventsToCoverage(sampleId, reducedSumEvents, covBroad.value.minmax, blocksResult, allPos, windowLength, None) .keyBy(_.key) - .reduceByKey((a,b) => - CovRecordWindow(a.contigName, - a.start, - a.end, - (a.asInstanceOf[CovRecordWindow].overLap.get * a.asInstanceOf[CovRecordWindow].cov + b.asInstanceOf[CovRecordWindow].overLap.get * b.asInstanceOf[CovRecordWindow].cov )/ - (a.asInstanceOf[CovRecordWindow].overLap.get + b.asInstanceOf[CovRecordWindow].overLap.get), - Some(a.asInstanceOf[CovRecordWindow].overLap.get + b.asInstanceOf[CovRecordWindow].overLap.get) ) ) + .reduceByKey( + (a, b) => + CovRecordWindow( + a.contigName, + a.start, + a.end, + (a.asInstanceOf[CovRecordWindow].overLap.get * a.asInstanceOf[CovRecordWindow].cov + b.asInstanceOf[CovRecordWindow].overLap.get * b.asInstanceOf[CovRecordWindow].cov) / (a.asInstanceOf[CovRecordWindow].overLap.get + b.asInstanceOf[CovRecordWindow].overLap.get), + Some(a.asInstanceOf[CovRecordWindow].overLap.get + b.asInstanceOf[CovRecordWindow].overLap.get) + ) + ) .map(_._2) + else if (targetType.equals(StringType) && !optimizeWindow) { + CoverageMethodsMos.eventsToCoverage(sampleId, reducedSumEvents, covBroad.value.minmax, blocksResult, allPos, None, target) + .keyBy(_.key) + .reduceByKey( + (a, b) => + CovRecordWindow( + a.contigName, + a.start, + a.end, + (a.asInstanceOf[CovRecordWindow].overLap.get * a.asInstanceOf[CovRecordWindow].cov + b.asInstanceOf[CovRecordWindow].overLap.get * b.asInstanceOf[CovRecordWindow].cov) / (a.asInstanceOf[CovRecordWindow].overLap.get + b.asInstanceOf[CovRecordWindow].overLap.get), + Some(a.asInstanceOf[CovRecordWindow].overLap.get + b.asInstanceOf[CovRecordWindow].overLap.get) + ) + ) + .map(_._2) + } + else if (targetType.equals(IntegerType) && optimizeWindow) + CoverageMethodsMos.eventsToCoverage(sampleId, reducedSumEvents, covBroad.value.minmax, blocksResult, allPos, windowLength, None) + else if (targetType.equals(StringType) && optimizeWindow) + CoverageMethodsMos.eventsToCoverage(sampleId, reducedSumEvents, covBroad.value.minmax, blocksResult, allPos, None, target) + else + CoverageMethodsMos.eventsToCoverage(sampleId, reducedSumEvents, covBroad.value.minmax, blocksResult, allPos, None, None) - else - CoverageMethodsMos.eventsToCoverage(sampleId, reducedEvents, covBroad.value.minmax, blocksResult, allPos,None, None) - - if(maybeWindowLength != None) { // windows + if (targetType.equals(IntegerType)) { // windows with fixed length cov.mapPartitions(p => { val proj = UnsafeProjection.create(schema) p.map(r => proj.apply(InternalRow.fromSeq(Seq( UTF8String.fromString(r.contigName), r.start, r.end, r.asInstanceOf[CovRecordWindow].cov)))) }) - } else { // regular blocks + } + else if (targetType.equals(StringType)) { // windows with targets from bed + cov.mapPartitions(p => { + val proj = UnsafeProjection.create(schema) + p.map(r => proj.apply(InternalRow.fromSeq(Seq(/*UTF8String.fromString(sampleId),*/ + UTF8String.fromString(r.contigName), r.start, r.end, r.asInstanceOf[CovRecordWindow].cov)))) + }) + } + else { // regular blocks cov.mapPartitions(p => { val proj = UnsafeProjection.create(schema) p.map(r => proj.apply(InternalRow.fromSeq(Seq(/*UTF8String.fromString(sampleId),*/ diff --git a/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageUpdate.scala b/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageUpdate.scala index e80e2015..3f376e1b 100644 --- a/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageUpdate.scala +++ b/src/main/scala/org/biodatageeks/preprocessing/coverage/CoverageUpdate.scala @@ -48,3 +48,46 @@ class CoverageAccumulatorV2(var covAcc: CovUpdate) extends AccumulatorV2[CovUpda } } +case class RightCovSumEdge(contigName: String, maxPos: Int, cov: Array[Short]) + + +class CovSumUpdate(var sumArray: ArrayBuffer[RightCovSumEdge]) extends Serializable { + + def reset(): Unit = { + sumArray = new ArrayBuffer[RightCovSumEdge]() + } + + def add(p: CovSumUpdate): CovSumUpdate = { + sumArray = sumArray ++ p.sumArray + return this + } + +} + +class CovSumAccumulatorV2(var covAcc: CovSumUpdate) extends AccumulatorV2[CovSumUpdate, CovSumUpdate] { + + def reset(): Unit = { + covAcc = new CovSumUpdate(new ArrayBuffer[RightCovSumEdge]()) + } + + def add(v: CovSumUpdate): Unit = { + covAcc.add(v) + } + + def value(): CovSumUpdate = { + return covAcc + } + + def isZero(): Boolean = { + return covAcc.sumArray.isEmpty + } + + def copy(): CovSumAccumulatorV2 = { + return new CovSumAccumulatorV2(covAcc) + } + + def merge(other: AccumulatorV2[CovSumUpdate, CovSumUpdate]) = { + covAcc.add(other.value) + } +} + diff --git a/src/main/scala/org/biodatageeks/utils/BDGInternalParams.scala b/src/main/scala/org/biodatageeks/utils/BDGInternalParams.scala index 3d7d5221..eb0f9407 100644 --- a/src/main/scala/org/biodatageeks/utils/BDGInternalParams.scala +++ b/src/main/scala/org/biodatageeks/utils/BDGInternalParams.scala @@ -21,4 +21,6 @@ object BDGInternalParams { final val RDDEventsName = "spark.biodatageeks.events" final val InputSplitSize = "spark.biodatageeks.bam.splitSize" + + final val OptimizationWindow = "spark.biodatageeks.window.optimization" } diff --git a/src/test/resources/test-target.bed b/src/test/resources/test-target.bed new file mode 100644 index 00000000..edd8f2ab --- /dev/null +++ b/src/test/resources/test-target.bed @@ -0,0 +1,34 @@ +chrM 6500 6599 66 +chrM 6600 6699 67 +chrM 6700 6799 68 +chrM 6800 6999 69 +chrM 6900 7099 70 +chrM 7100 7199 72 +chrM 7200 7299 73 +chrM 7300 7399 74 +chrM 7400 7499 75 +chrM 7500 7599 76 +chrM 7600 7999 77 +chrM 7700 7799 78 +chrM 7800 7899 79 +chrM 7900 7999 80 +chrM 8000 8099 81 +chrM 8100 8199 82 +chrM 8300 8399 84 +chrM 8400 8499 85 +chrM 8500 8599 86 +chrM 8600 8699 87 +chrM 8700 8799 88 +chrM 8800 8899 89 +chrM 14400 14499 145 +chrM 14500 14599 146 +chrM 14600 14799 147 +chrM 14900 14999 150 +chrM 15200 15399 153 +chrM 15000 15099 155 +chr1 2700 2799 158 +chrM 15400 15499 155 +chr1 3200 3299 159 +chrM 15500 15599 156 +chr1 10000 10099 160 +chrM 15600 15699 161 \ No newline at end of file diff --git a/src/test/scala/pl/edu/pw/ii/biodatageeks/tests/CoverageTestSuite.scala b/src/test/scala/pl/edu/pw/ii/biodatageeks/tests/CoverageTestSuite.scala index 16f44e5b..9eb016e6 100644 --- a/src/test/scala/pl/edu/pw/ii/biodatageeks/tests/CoverageTestSuite.scala +++ b/src/test/scala/pl/edu/pw/ii/biodatageeks/tests/CoverageTestSuite.scala @@ -9,63 +9,74 @@ import org.biodatageeks.preprocessing.coverage.CoverageStrategy import org.biodatageeks.utils.{BDGInternalParams, SequilaRegister} import org.scalatest.{BeforeAndAfter, FunSuite} -class CoverageTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndAfter with SharedSparkContext{ - - val bamPath = getClass.getResource("/NA12878.slice.bam").getPath - val bamMultiPath = getClass.getResource("/multichrom/NA12878.multichrom.bam").getPath - val adamPath = getClass.getResource("/NA12878.slice.adam").getPath - val metricsListener = new MetricsListener(new RecordedMetrics()) - val writer = new PrintWriter(new OutputStreamWriter(System.out)) - val cramPath = getClass.getResource("/test.cram").getPath - val refPath = getClass.getResource("/phix-illumina.fa").getPath - val tableNameBAM = "reads" - val tableNameMultiBAM = "readsMulti" - val tableNameADAM = "readsADAM" - val tableNameCRAM = "readsCRAM" - val splitSize = "1000000" - - before{ - - Metrics.initialize(sc) - sc.addSparkListener(metricsListener) - System.setSecurityManager(null) - spark.sql(s"DROP TABLE IF EXISTS ${tableNameBAM}") - spark.sql( - s""" - |CREATE TABLE ${tableNameBAM} - |USING org.biodatageeks.datasources.BAM.BAMDataSource - |OPTIONS(path "${bamPath}") - | +class CoverageTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndAfter with SharedSparkContext { + + val bamPath = getClass.getResource("/NA12878.slice.bam").getPath + val bamMultiPath = getClass.getResource("/multichrom/NA12878.multichrom.bam").getPath + val adamPath = getClass.getResource("/NA12878.slice.adam").getPath + val metricsListener = new MetricsListener(new RecordedMetrics()) + val writer = new PrintWriter(new OutputStreamWriter(System.out)) + val cramPath = getClass.getResource("/test.cram").getPath + val refPath = getClass.getResource("/phix-illumina.fa").getPath + val bedPath = getClass.getResource("/test-target.bed").getPath + val tableNameBAM = "reads" + val tableNameMultiBAM = "readsMulti" + val tableNameADAM = "readsADAM" + val tableNameCRAM = "readsCRAM" + val tableNameTargets = "targets" + val splitSize = "1000000" + + before { + + Metrics.initialize(sc) + sc.addSparkListener(metricsListener) + System.setSecurityManager(null) + spark.sql(s"DROP TABLE IF EXISTS ${tableNameBAM}") + spark.sql( + s""" + |CREATE TABLE ${tableNameBAM} + |USING org.biodatageeks.datasources.BAM.BAMDataSource + |OPTIONS(path "${bamPath}") + | """.stripMargin) - spark.sql(s"DROP TABLE IF EXISTS ${tableNameMultiBAM}") - spark.sql( - s""" - |CREATE TABLE ${tableNameMultiBAM} - |USING org.biodatageeks.datasources.BAM.BAMDataSource - |OPTIONS(path "${bamMultiPath}") - | + spark.sql(s"DROP TABLE IF EXISTS ${tableNameMultiBAM}") + spark.sql( + s""" + |CREATE TABLE ${tableNameMultiBAM} + |USING org.biodatageeks.datasources.BAM.BAMDataSource + |OPTIONS(path "${bamMultiPath}") + | """.stripMargin) - spark.sql(s"DROP TABLE IF EXISTS ${tableNameCRAM}") - spark.sql( - s""" - |CREATE TABLE ${tableNameCRAM} - |USING org.biodatageeks.datasources.BAM.CRAMDataSource - |OPTIONS(path "${cramPath}", refPath "${refPath}") - | + spark.sql(s"DROP TABLE IF EXISTS ${tableNameCRAM}") + spark.sql( + s""" + |CREATE TABLE ${tableNameCRAM} + |USING org.biodatageeks.datasources.BAM.CRAMDataSource + |OPTIONS(path "${cramPath}", refPath "${refPath}") + | """.stripMargin) - spark.sql(s"DROP TABLE IF EXISTS ${tableNameADAM}") - spark.sql( - s""" - |CREATE TABLE ${tableNameADAM} - |USING org.biodatageeks.datasources.ADAM.ADAMDataSource - |OPTIONS(path "${adamPath}") - | + spark.sql(s"DROP TABLE IF EXISTS ${tableNameADAM}") + spark.sql( + s""" + |CREATE TABLE ${tableNameADAM} + |USING org.biodatageeks.datasources.ADAM.ADAMDataSource + |OPTIONS(path "${adamPath}") + | """.stripMargin) - } + spark.sql(s"DROP TABLE IF EXISTS ${tableNameTargets}") + spark.sql( + s""" + |CREATE TABLE ${tableNameTargets}(contigName String, start Integer, end Integer, targetId String) + |USING csv + |OPTIONS (path "file:///${bedPath}", delimiter "\t") + | + """.stripMargin) + + } /* @@ -78,35 +89,142 @@ class CoverageTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndA */ - test("BAM - bdg_coverage - windows"){ + test("BAM - bdg_coverage - windows") { + + spark.sqlContext.setConf(BDGInternalParams.InputSplitSize, splitSize) + + val session: SparkSession = SequilaSession(spark) + SequilaRegister.register(session) + + val windowLength = 100 + val bdg = session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blocks', '${windowLength.toString}')") + + assert(bdg.count == 267) + assert(bdg.first().getInt(1) % windowLength == 0) // check for fixed window start position + assert(bdg.first().getInt(2) % windowLength == windowLength - 1) // // check for fixed window end position + assert(bdg.where("contigName == 'chr1' and start == 2700").first().getFloat(3) == 4.65.toFloat) + assert(bdg.where("contigName == 'chr1' and start == 3200").first().getFloat(3) == 166.79.toFloat) + assert(bdg.where("contigName == 'chr1' and start == 10000").first().getFloat(3) == 1.5522388.toFloat) //value check [partition boundary] + assert(bdg.where("contigName == 'chrM' and start == 7800").first().getFloat(3) == 253.03001.toFloat) //value check [partition boundary] + assert(bdg.where("contigName == 'chrM' and start == 14400").first().getFloat(3) == 134.7.toFloat) //value check [partition boundary] + assert(bdg.groupBy("contigName", "start").count().where("count != 1").count == 0) // no duplicates check + } + + test("BAM - bdg_coverage - windows with targets from table") { + + spark.sqlContext.setConf(BDGInternalParams.InputSplitSize, splitSize) + + val session: SparkSession = SequilaSession(spark) + SequilaRegister.register(session) + + val windowLength = 100 + val bdg = session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blocks', '${tableNameTargets.toString}')") + + assert(bdg.count == 34) // test-target.bed contains 34 lines (targets) + + /* + * group of tests which check the same values as for const windows with 100 size + * in bed file are only three chr1 targets -> it tests gaps between targets also + */ + assert(bdg.where("contigName == 'chr1' and start == 2700").first().getFloat(3) == 4.65.toFloat) + assert(bdg.where("contigName == 'chr1' and start == 3200").first().getFloat(3) == 166.79.toFloat) + assert(bdg.where("contigName == 'chr1' and start == 10000").first().getFloat(3) == 1.5522388.toFloat) //value check [partition boundary] + assert(bdg.where("contigName == 'chrM' and start == 7800").first().getFloat(3) == 253.03001.toFloat) //value check [partition boundary] + assert(bdg.where("contigName == 'chrM' and start == 14400").first().getFloat(3) == 134.7.toFloat) //value check [partition boundary] + + //checking values for crossing targets + assert(bdg.where("contigName == 'chrM' and start == 6800").first().getFloat(3) == 96.545.toFloat) //6800-6999 + assert(bdg.where("contigName == 'chrM' and start == 6900").first().getFloat(3) == 78.945.toFloat) //6900-7099 + + //checking values for target inside target + assert(bdg.where("contigName == 'chrM' and start == 7600").first().getFloat(3) == 155.69.toFloat) //7600 -> 7999 -> this contains next one inside -> [partition boundary] + assert(bdg.where("contigName == 'chrM' and start = 7700").first().getFloat(3) == 134.41.toFloat) //7700 - > 7799 -> after returning - from 7999 to 7799 + + // other gap checks + assert(bdg.where("contigName == 'chrM' and start = 14600").first().getFloat(3) == 63.68.toFloat) //14600 -> 14799 + assert(bdg.where("contigName == 'chrM' and start = 14900").first().getFloat(3) == 196.81.toFloat) //14900 -> 14999 + + assert(bdg.where("contigName == 'chrM' and start = 15200").first().getFloat(3) == 169.02.toFloat) //15200 -> 15399 -> no target between 15100 and 15200 - after gap check + assert(bdg.where("contigName == 'chrM' and start = 15000").first().getFloat(3) == 144.75.toFloat) //15000 -> 15099 -> returning in BED FILE + + + assert(bdg.groupBy("contigName", "start").count().where("count != 1").count == 0) // no duplicates check + } + + test("BAM - bdg_coverage - windows optimized") { spark.sqlContext.setConf(BDGInternalParams.InputSplitSize, splitSize) + spark.sqlContext.setConf(BDGInternalParams.OptimizationWindow, "true") val session: SparkSession = SequilaSession(spark) SequilaRegister.register(session) val windowLength = 100 val bdg = session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blocks', '${windowLength}')") - + assert (bdg.count == 267) assert (bdg.first().getInt(1) % windowLength == 0) // check for fixed window start position assert (bdg.first().getInt(2) % windowLength == windowLength - 1) // // check for fixed window end position assert(bdg.where("contigName == 'chr1' and start == 2700").first().getFloat(3)==4.65.toFloat) assert(bdg.where("contigName == 'chr1' and start == 3200").first().getFloat(3)== 166.79.toFloat) assert(bdg.where("contigName == 'chr1' and start == 10000").first().getFloat(3)== 1.5522388.toFloat) //value check [partition boundary] - assert(bdg.where("contigName == 'chrM' and start == 7800").first().getFloat(3)== 253.03001.toFloat) //value check [partition boundary] + assert(bdg.where("contigName == 'chrM' and start == 7800").first().getFloat(3)== 253.03.toFloat) //value check [partition boundary] assert(bdg.where("contigName == 'chrM' and start == 14400").first().getFloat(3)== 134.7.toFloat) //value check [partition boundary] assert(bdg.groupBy("contigName", "start").count().where("count != 1").count == 0) // no duplicates check + } - test("BAM - bdg_coverage - blocks - allPositions"){ + test("BAM - bdg_coverage - windows with targets from table optmized") { + + spark.sqlContext.setConf(BDGInternalParams.InputSplitSize, splitSize) + spark.sqlContext.setConf(BDGInternalParams.OptimizationWindow, "true") + + val session: SparkSession = SequilaSession(spark) + SequilaRegister.register(session) + + val windowLength = 100 + val bdg = session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blocks', '${tableNameTargets.toString}')") + + assert(bdg.count == 34) // test-target.bed contains 34 lines (targets) + + /* + * group of tests which check the same values as for const windows with 100 size + * in bed file are only three chr1 targets -> it tests gaps between targets also + */ + assert(bdg.where("contigName == 'chr1' and start == 2700").first().getFloat(3) == 4.65.toFloat) + assert(bdg.where("contigName == 'chr1' and start == 3200").first().getFloat(3) == 166.79.toFloat) + assert(bdg.where("contigName == 'chr1' and start == 10000").first().getFloat(3) == 1.5522388.toFloat) //value check [partition boundary] + assert(bdg.where("contigName == 'chrM' and start == 7800").first().getFloat(3) == 253.03.toFloat) //value check [partition boundary] + assert(bdg.where("contigName == 'chrM' and start == 14400").first().getFloat(3) == 134.7.toFloat) //value check [partition boundary] + + //checking values for crossing targets + assert(bdg.where("contigName == 'chrM' and start == 6800").first().getFloat(3) == 96.545.toFloat) //6800-6999 + assert(bdg.where("contigName == 'chrM' and start == 6900").first().getFloat(3) == 78.945.toFloat) //6900-7099 + + //checking values for target inside target + assert(bdg.where("contigName == 'chrM' and start == 7600").first().getFloat(3) == 155.69.toFloat) //7600 -> 7999 -> this contains next one inside -> [partition boundary] + assert(bdg.where("contigName == 'chrM' and start = 7700").first().getFloat(3) == 134.41.toFloat) //7700 - > 7799 -> after returning - from 7999 to 7799 + + // other gap checks + assert(bdg.where("contigName == 'chrM' and start = 14600").first().getFloat(3) == 63.68.toFloat) //14600 -> 14799 + assert(bdg.where("contigName == 'chrM' and start = 14900").first().getFloat(3) == 196.81.toFloat) //14900 -> 14999 + + assert(bdg.where("contigName == 'chrM' and start = 15200").first().getFloat(3) == 169.02.toFloat) //15200 -> 15399 -> no target between 15100 and 15200 - after gap check + assert(bdg.where("contigName == 'chrM' and start = 15000").first().getFloat(3) == 144.75.toFloat) //15000 -> 15099 -> returning in BED FILE + + + assert(bdg.groupBy("contigName", "start").count().where("count != 1").count == 0) // no duplicates check + + } + + test("BAM - bdg_coverage - blocks - allPositions") { spark.sqlContext.setConf(BDGInternalParams.InputSplitSize, splitSize) val session: SparkSession = SequilaSession(spark) SequilaRegister.register(session) session.experimental.extraStrategies = new CoverageStrategy(session) :: Nil - session.sqlContext.setConf(BDGInternalParams.ShowAllPositions,"true") + session.sqlContext.setConf(BDGInternalParams.ShowAllPositions, "true") val bdg = session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blocks')") @@ -123,12 +241,12 @@ class CoverageTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndA assert(bdg.groupBy("contigName", "start").count().where("count != 1").count == 0) // no duplicates check } - test("BAM - bdg_coverage - blocks notAllPositions"){ + test("BAM - bdg_coverage - blocks notAllPositions") { spark.sqlContext.setConf(BDGInternalParams.InputSplitSize, splitSize) val session: SparkSession = SequilaSession(spark) SequilaRegister.register(session) - session.sqlContext.setConf(BDGInternalParams.ShowAllPositions,"false") + session.sqlContext.setConf(BDGInternalParams.ShowAllPositions, "false") val bdg = session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blocks')") @@ -145,12 +263,12 @@ class CoverageTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndA assert(bdg.groupBy("contigName", "start").count().where("count != 1").count == 0) // no duplicates check } - test("BAM - bdg_coverage - bases - notAllPositions"){ + test("BAM - bdg_coverage - bases - notAllPositions") { spark.sqlContext.setConf(BDGInternalParams.InputSplitSize, splitSize) val session: SparkSession = SequilaSession(spark) SequilaRegister.register(session) - session.sqlContext.setConf(BDGInternalParams.ShowAllPositions,"false") + session.sqlContext.setConf(BDGInternalParams.ShowAllPositions, "false") val bdg = session.sql(s"SELECT contigName, start, end, coverage FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'bases')") assert(bdg.count() == 26598) // total count check // was 26598 @@ -168,7 +286,7 @@ class CoverageTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndA } - test("CRAM - bdg_coverage - show"){ + test("CRAM - bdg_coverage - show") { val session: SparkSession = SequilaSession(spark) SequilaRegister.register(session) @@ -179,15 +297,15 @@ class CoverageTestSuite extends FunSuite with DataFrameSuiteBase with BeforeAndA } test("BAM - bdg_coverage - wrong param, Exception should be thrown") { - val session: SparkSession = SequilaSession(spark) - SequilaRegister.register(session) + val session: SparkSession = SequilaSession(spark) + SequilaRegister.register(session) - assertThrows[Exception]( - session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blaaaaaah')").show()) + assertThrows[Exception]( + session.sql(s"SELECT * FROM bdg_coverage('${tableNameMultiBAM}','NA12878', 'blaaaaaah')").show()) - } + } - after{ + after { Metrics.print(writer, Some(metricsListener.metrics.sparkMetrics.stageTimes)) writer.flush()