diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index c5014818c4..c547b43d48 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -139,6 +139,13 @@ private[spark] class CometExecRDD( ctx.addTaskCompletionListener[Unit] { _ => it.close() subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub)) + + nativeMetrics.metrics + .get("bytes_scanned") + .foreach(m => ctx.taskMetrics().inputMetrics.setBytesRead(m.value)) + nativeMetrics.metrics + .get("output_rows") + .foreach(m => ctx.taskMetrics().inputMetrics.setRecordsRead(m.value)) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..7883775c80 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -79,6 +79,7 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM } } + // Called via JNI from `comet_metric_node.rs` def set_all_from_bytes(bytes: Array[Byte]): Unit = { val metricNode = Metric.NativeMetricNode.parseFrom(bytes) set_all(metricNode) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 3946aab184..59d02512a0 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet import scala.collection.mutable +import org.apache.spark.executor.InputMetrics import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.SparkListener @@ -30,6 +31,8 @@ import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.comet.CometConf + class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -91,4 +94,66 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("native_datafusion scan reports task-level input metrics matching Spark") { + withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") { + // Collect baseline input metrics from vanilla Spark (Comet disabled) + val (sparkBytes, sparkRecords) = collectInputMetrics(CometConf.COMET_ENABLED.key -> "false") + + // Collect input metrics from Comet native_datafusion scan + val (cometBytes, cometRecords) = collectInputMetrics( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) + + // Records must match exactly + assert( + cometRecords == sparkRecords, + s"recordsRead mismatch: comet=$cometRecords, spark=$sparkRecords") + + // Bytes should be in the same ballpark -- both read the same Parquet file(s), + // but the exact byte count can differ due to reader implementation details + // (e.g. footer reads, page headers, buffering granularity). + assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes") + assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes") + val ratio = cometBytes.toDouble / sparkBytes.toDouble + assert( + ratio >= 0.8 && ratio <= 1.2, + s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio") + } + } + + /** + * Runs `SELECT * FROM tbl` with the given SQL config overrides and returns the aggregated + * (bytesRead, recordsRead) across all tasks. + */ + private def collectInputMetrics(confs: (String, String)*): (Long, Long) = { + val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics] + + val listener = new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val im = taskEnd.taskMetrics.inputMetrics + inputMetricsList.synchronized { + inputMetricsList += im + } + } + } + + spark.sparkContext.addSparkListener(listener) + try { + // Drain any earlier events + spark.sparkContext.listenerBus.waitUntilEmpty() + + withSQLConf(confs: _*) { + sql("SELECT * FROM tbl").collect() + } + + spark.sparkContext.listenerBus.waitUntilEmpty() + + assert(inputMetricsList.nonEmpty, s"No input metrics found for confs=$confs") + val totalBytes = inputMetricsList.map(_.bytesRead).sum + val totalRecords = inputMetricsList.map(_.recordsRead).sum + (totalBytes, totalRecords) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } }