Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ private[spark] class CometExecRDD(
ctx.addTaskCompletionListener[Unit] { _ =>
it.close()
subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub))

nativeMetrics.metrics
.get("bytes_scanned")
.foreach(m => ctx.taskMetrics().inputMetrics.setBytesRead(m.value))
nativeMetrics.metrics
.get("output_rows")
.foreach(m => ctx.taskMetrics().inputMetrics.setRecordsRead(m.value))
Comment on lines +143 to +148
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation only retrieves metrics from the root node of the nativeMetrics tree. In most Comet execution plans (e.g., Scan -> Filter -> Project), the root node will be an operator like Project or Filter, which does not contain the bytes_scanned metric. As a result, bytesRead will be reported as 0 in the Spark UI for these queries. Furthermore, output_rows at the root node reflects the final result count, which may not match the number of records read from the source if a filter was applied.

To correctly report task-level input metrics, you should traverse the nativeMetrics tree and aggregate bytes_scanned and output_rows from all nodes that represent a scan (typically identified by the presence of the bytes_scanned metric). Consider adding a recursive helper method to CometMetricNode to perform this aggregation.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The Gemini AI reviewer is correct! Being executed in CometExecRDD this metrics collection is executed for all kinds of nodes, not just for the Scan ones. It would be better to move this logic to CometNativeScanExec#doExecuteColumnar(). This way it will collect only the Scan related metrics.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_rows is part of the baseline native metrics for many non-scan operators, so setting taskMetrics.inputMetrics.recordsRead from it here can misreport or overwrite Spark’s input metrics for non-scan stages. Consider gating this update so it only applies when the native metrics actually represent scan input (e.g., tied to scan-specific metrics like bytes_scanned).

Severity: medium

Fix This in Augment

🤖 Was this useful? React with 👍 or 👎, or 🚀 if it prevented an incident/outage.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The Augment AI reviewer is correct! Being executed in CometExecRDD this metrics collection is executed for all kinds of nodes, not just for the Scan ones. It would be better to move this logic to CometNativeScanExec#doExecuteColumnar(). This way it will collect only the Scan related metrics.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM
}
}

// Called via JNI from `comet_metric_node.rs`
def set_all_from_bytes(bytes: Array[Byte]): Unit = {
val metricNode = Metric.NativeMetricNode.parseFrom(bytes)
set_all(metricNode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.spark.sql.comet

import scala.collection.mutable

import org.apache.spark.executor.InputMetrics
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.SparkListener
Expand All @@ -30,6 +31,8 @@ import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper

import org.apache.comet.CometConf

class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {

import testImplicits._
Expand Down Expand Up @@ -91,4 +94,66 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("native_datafusion scan reports task-level input metrics matching Spark") {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test case only covers a simple scan where the scan node is the root of the native plan. To ensure that input metrics are correctly reported in more realistic scenarios, consider adding a test case that includes a filter (e.g., sql("SELECT * FROM tbl WHERE _1 > 5000")). This will verify that metrics are correctly aggregated even when the scan node is not the root of the execution tree.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback: The Gemini AI reviewer is correct! Using a more complex SQL query (e.g. with a Filter) will expose the problem that the collection of the metrics is done on all CometExecRDDs. It should be done only for the Scan nodes, i.e. the ones created by CometNativeScanExec#doExecuteColumnar().

withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") {
// Collect baseline input metrics from vanilla Spark (Comet disabled)
val (sparkBytes, sparkRecords) = collectInputMetrics(CometConf.COMET_ENABLED.key -> "false")

// Collect input metrics from Comet native_datafusion scan
val (cometBytes, cometRecords) = collectInputMetrics(
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION)

// Records must match exactly
assert(
cometRecords == sparkRecords,
s"recordsRead mismatch: comet=$cometRecords, spark=$sparkRecords")

// Bytes should be in the same ballpark -- both read the same Parquet file(s),
// but the exact byte count can differ due to reader implementation details
// (e.g. footer reads, page headers, buffering granularity).
assert(sparkBytes > 0, s"Spark bytesRead should be > 0, got $sparkBytes")
assert(cometBytes > 0, s"Comet bytesRead should be > 0, got $cometBytes")
val ratio = cometBytes.toDouble / sparkBytes.toDouble
assert(
ratio >= 0.8 && ratio <= 1.2,
s"bytesRead ratio out of range: comet=$cometBytes, spark=$sparkBytes, ratio=$ratio")
}
}

/**
* Runs `SELECT * FROM tbl` with the given SQL config overrides and returns the aggregated
* (bytesRead, recordsRead) across all tasks.
*/
private def collectInputMetrics(confs: (String, String)*): (Long, Long) = {
val inputMetricsList = mutable.ArrayBuffer.empty[InputMetrics]

val listener = new SparkListener {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val im = taskEnd.taskMetrics.inputMetrics
inputMetricsList.synchronized {
inputMetricsList += im
}
}
}

spark.sparkContext.addSparkListener(listener)
try {
// Drain any earlier events
spark.sparkContext.listenerBus.waitUntilEmpty()

withSQLConf(confs: _*) {
sql("SELECT * FROM tbl").collect()
}

spark.sparkContext.listenerBus.waitUntilEmpty()

assert(inputMetricsList.nonEmpty, s"No input metrics found for confs=$confs")
val totalBytes = inputMetricsList.map(_.bytesRead).sum
val totalRecords = inputMetricsList.map(_.recordsRead).sum
(totalBytes, totalRecords)
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
Loading