Skip to content

Commit efb03b1

Browse files
szehon-hocloud-fan
authored andcommitted
[SPARK-52403][SQL] Add metric to MergeRowExec for rows that do not match condition
### What changes were proposed in this pull request? MergeRowsExec can record some useful information, like how many rows are not target of any action (doesn't match any condition), which can be emitted as metrics. ### Why are the changes needed? Improve debuggability of MERGE INTO ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add test to MergeIntoTableSuiteBase ### Was this patch authored or co-authored using generative AI tooling? No Closes #51091 from szehon-ho/numTargetRowsCopied. Authored-by: Szehon Ho <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 5432402 commit efb03b1

File tree

5 files changed

+222
-6
lines changed

5 files changed

+222
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter}
2525
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta}
26-
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
26+
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Discard, Instruction, Keep, ROW_ID, Split}
2727
import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{OPERATION_COLUMN, WRITE_OPERATION, WRITE_WITH_METADATA_OPERATION}
2828
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
2929
import org.apache.spark.sql.connector.write.{RowLevelOperationTable, SupportsDelta}
@@ -199,7 +199,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
199199
// as the last MATCHED and NOT MATCHED BY SOURCE instruction
200200
// this logic is specific to data sources that replace groups of data
201201
val carryoverRowsOutput = Literal(WRITE_WITH_METADATA_OPERATION) +: targetTable.output
202-
val keepCarryoverRowsInstruction = Keep(TrueLiteral, carryoverRowsOutput)
202+
val keepCarryoverRowsInstruction = Copy(carryoverRowsOutput)
203203

204204
val matchedInstructions = matchedActions.map { action =>
205205
toInstruction(action, metadataAttrs)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, Unevaluable}
21+
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2122
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, ROW_ID}
2223
import org.apache.spark.sql.catalyst.trees.UnaryLike
2324
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -87,7 +88,22 @@ object MergeRows {
8788
override def dataType: DataType = NullType
8889
}
8990

90-
case class Keep(condition: Expression, output: Seq[Expression]) extends Instruction {
91+
// A special case of Keep where the row is kept as is.
92+
case class Copy(output: Seq[Expression]) extends Instruction {
93+
override def condition: Expression = TrueLiteral
94+
override def outputs: Seq[Seq[Expression]] = Seq(output)
95+
override def children: Seq[Expression] = output
96+
97+
override protected def withNewChildrenInternal(
98+
newChildren: IndexedSeq[Expression]): Expression = {
99+
copy(output = newChildren)
100+
}
101+
}
102+
103+
case class Keep(
104+
condition: Expression,
105+
output: Seq[Expression])
106+
extends Instruction {
91107
def children: Seq[Expression] = condition +: output
92108
override def outputs: Seq[Seq[Expression]] = Seq(output)
93109

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,16 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
2626
import org.apache.spark.sql.catalyst.expressions.AttributeSet
2727
import org.apache.spark.sql.catalyst.expressions.BasePredicate
2828
import org.apache.spark.sql.catalyst.expressions.Expression
29+
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
2930
import org.apache.spark.sql.catalyst.expressions.Projection
3031
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
3132
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
32-
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard, Instruction, Keep, ROW_ID, Split}
33+
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Discard, Instruction, Keep, ROW_ID, Split}
3334
import org.apache.spark.sql.catalyst.util.truncatedString
3435
import org.apache.spark.sql.errors.QueryExecutionErrors
3536
import org.apache.spark.sql.execution.SparkPlan
3637
import org.apache.spark.sql.execution.UnaryExecNode
38+
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
3739

3840
case class MergeRowsExec(
3941
isSourceRowPresent: Expression,
@@ -45,6 +47,10 @@ case class MergeRowsExec(
4547
output: Seq[Attribute],
4648
child: SparkPlan) extends UnaryExecNode {
4749

50+
override lazy val metrics: Map[String, SQLMetric] = Map(
51+
"numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext,
52+
"Number of target rows copied unmodified because they did not match any action."))
53+
4854
@transient override lazy val producedAttributes: AttributeSet = {
4955
AttributeSet(output.filterNot(attr => inputSet.contains(attr)))
5056
}
@@ -107,6 +113,9 @@ case class MergeRowsExec(
107113

108114
private def planInstructions(instructions: Seq[Instruction]): Seq[InstructionExec] = {
109115
instructions.map {
116+
case Copy(output) =>
117+
CopyExec(createProjection(output))
118+
110119
case Keep(cond, output) =>
111120
KeepExec(createPredicate(cond), createProjection(output))
112121

@@ -127,7 +136,14 @@ case class MergeRowsExec(
127136
def condition: BasePredicate
128137
}
129138

130-
case class KeepExec(condition: BasePredicate, projection: Projection) extends InstructionExec {
139+
case class CopyExec(projection: Projection) extends InstructionExec {
140+
override lazy val condition: BasePredicate = createPredicate(TrueLiteral)
141+
def apply(row: InternalRow): InternalRow = projection.apply(row)
142+
}
143+
144+
case class KeepExec(
145+
condition: BasePredicate,
146+
projection: Projection) extends InstructionExec {
131147
def apply(row: InternalRow): InternalRow = projection.apply(row)
132148
}
133149

@@ -220,6 +236,11 @@ case class MergeRowsExec(
220236
for (instruction <- instructions) {
221237
if (instruction.condition.eval(row)) {
222238
instruction match {
239+
case copy: CopyExec =>
240+
// group-based operations copy over target rows that didn't match any actions
241+
longMetric("numTargetRowsCopied") += 1
242+
return copy.apply(row)
243+
223244
case keep: KeepExec =>
224245
return keep.apply(row)
225246

sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableSuiteBase.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ abstract class DeltaBasedMergeIntoTableSuiteBase extends MergeIntoTableSuiteBase
2323

2424
import testImplicits._
2525

26+
override protected def deltaMerge = true
27+
2628
test("merge into schema pruning with WHEN MATCHED clause (update)") {
2729
withTempView("source") {
2830
createAndInitTable("pk INT NOT NULL, salary INT, country STRING, dep STRING",

sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,19 @@ import org.apache.spark.sql.catalyst.optimizer.BuildLeft
2424
import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, TableInfo}
2525
import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue}
2626
import org.apache.spark.sql.execution.SparkPlan
27+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
28+
import org.apache.spark.sql.execution.datasources.v2.MergeRowsExec
2729
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec}
2830
import org.apache.spark.sql.internal.SQLConf
2931
import org.apache.spark.sql.types.{IntegerType, StringType}
3032

31-
abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase {
33+
abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
34+
with AdaptiveSparkPlanHelper {
3235

3336
import testImplicits._
3437

38+
protected def deltaMerge: Boolean = false
39+
3540
test("merge into table with expression-based default values") {
3641
val columns = Array(
3742
Column.create("pk", IntegerType),
@@ -1771,6 +1776,166 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase {
17711776
}
17721777
}
17731778

1779+
test("Merge metrics with matched clause") {
1780+
withTempView("source") {
1781+
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
1782+
"""{ "pk": 1, "salary": 100, "dep": "hr" }
1783+
|{ "pk": 2, "salary": 200, "dep": "software" }
1784+
|{ "pk": 3, "salary": 300, "dep": "hr" }
1785+
|""".stripMargin)
1786+
1787+
val sourceDF = Seq(1, 2, 10).toDF("pk")
1788+
sourceDF.createOrReplaceTempView("source")
1789+
1790+
val mergeExec = findMergeExec {
1791+
s"""MERGE INTO $tableNameAsString t
1792+
|USING source s
1793+
|ON t.pk = s.pk
1794+
|WHEN MATCHED AND salary < 200 THEN
1795+
| UPDATE SET salary = 1000
1796+
|""".stripMargin
1797+
}
1798+
1799+
assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 2)
1800+
1801+
checkAnswer(
1802+
sql(s"SELECT * FROM $tableNameAsString"),
1803+
Seq(
1804+
Row(1, 1000, "hr"), // updated
1805+
Row(2, 200, "software"),
1806+
Row(3, 300, "hr")))
1807+
}
1808+
}
1809+
1810+
test("Merge metrics with matched and not matched clause") {
1811+
withTempView("source") {
1812+
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
1813+
"""{ "pk": 1, "salary": 100, "dep": "hr" }
1814+
|{ "pk": 2, "salary": 200, "dep": "software" }
1815+
|{ "pk": 3, "salary": 300, "dep": "hr" }
1816+
|""".stripMargin)
1817+
1818+
val sourceDF = Seq(
1819+
(4, 100, "marketing"),
1820+
(5, 400, "executive"),
1821+
(6, 100, "hr")
1822+
).toDF("pk", "salary", "dep")
1823+
sourceDF.createOrReplaceTempView("source")
1824+
1825+
val mergeExec = findMergeExec {
1826+
s"""MERGE INTO $tableNameAsString t
1827+
|USING source s
1828+
|ON t.pk = s.pk
1829+
|WHEN MATCHED THEN
1830+
| UPDATE SET salary = 9999
1831+
|WHEN NOT MATCHED AND salary > 200 THEN
1832+
| INSERT *
1833+
|""".stripMargin
1834+
}
1835+
1836+
assertMetric(mergeExec, "numTargetRowsCopied", 0)
1837+
1838+
checkAnswer(
1839+
sql(s"SELECT * FROM $tableNameAsString"),
1840+
Seq(
1841+
Row(1, 100, "hr"),
1842+
Row(2, 200, "software"),
1843+
Row(3, 300, "hr"),
1844+
Row(5, 400, "executive"))) // inserted
1845+
}
1846+
}
1847+
1848+
test("Merge metrics with matched and not matched by source clauses") {
1849+
withTempView("source") {
1850+
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
1851+
"""{ "pk": 1, "salary": 100, "dep": "hr" }
1852+
|{ "pk": 2, "salary": 200, "dep": "software" }
1853+
|{ "pk": 3, "salary": 300, "dep": "hr" }
1854+
|{ "pk": 4, "salary": 400, "dep": "marketing" }
1855+
|{ "pk": 5, "salary": 500, "dep": "executive" }
1856+
|""".stripMargin)
1857+
1858+
val sourceDF = Seq(1, 2, 10).toDF("pk")
1859+
sourceDF.createOrReplaceTempView("source")
1860+
1861+
val mergeExec = findMergeExec {
1862+
s"""MERGE INTO $tableNameAsString t
1863+
|USING source s
1864+
|ON t.pk = s.pk
1865+
|WHEN MATCHED AND salary < 200 THEN
1866+
| UPDATE SET salary = 1000
1867+
|WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN
1868+
| UPDATE SET salary = -1
1869+
|""".stripMargin
1870+
}
1871+
1872+
1873+
assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 3)
1874+
1875+
checkAnswer(
1876+
sql(s"SELECT * FROM $tableNameAsString"),
1877+
Seq(
1878+
Row(1, 1000, "hr"), // updated
1879+
Row(2, 200, "software"),
1880+
Row(3, 300, "hr"),
1881+
Row(4, 400, "marketing"),
1882+
Row(5, -1, "executive"))) // updated
1883+
}
1884+
}
1885+
1886+
test("Merge metrics with matched, not matched, and not matched by source clauses") {
1887+
withTempView("source") {
1888+
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
1889+
"""{ "pk": 1, "salary": 100, "dep": "hr" }
1890+
|{ "pk": 2, "salary": 200, "dep": "software" }
1891+
|{ "pk": 3, "salary": 300, "dep": "hr" }
1892+
|{ "pk": 4, "salary": 400, "dep": "marketing" }
1893+
|{ "pk": 5, "salary": 500, "dep": "executive" }
1894+
|""".stripMargin)
1895+
1896+
val sourceDF = Seq(1, 2, 6, 10).toDF("pk")
1897+
sourceDF.createOrReplaceTempView("source")
1898+
1899+
val mergeExec = findMergeExec {
1900+
s"""MERGE INTO $tableNameAsString t
1901+
|USING source s
1902+
|ON t.pk = s.pk
1903+
|WHEN MATCHED AND salary < 200 THEN
1904+
| UPDATE SET salary = 1000
1905+
|WHEN NOT MATCHED AND s.pk < 10 THEN
1906+
| INSERT (pk, salary, dep) VALUES (s.pk, -1, "dummy")
1907+
|WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN
1908+
| UPDATE SET salary = -1
1909+
|""".stripMargin
1910+
}
1911+
1912+
assertMetric(mergeExec, "numTargetRowsCopied", if (deltaMerge) 0 else 3)
1913+
1914+
checkAnswer(
1915+
sql(s"SELECT * FROM $tableNameAsString"),
1916+
Seq(
1917+
Row(1, 1000, "hr"), // updated
1918+
Row(2, 200, "software"),
1919+
Row(3, 300, "hr"),
1920+
Row(4, 400, "marketing"),
1921+
Row(5, -1, "executive"), // updated
1922+
Row(6, -1, "dummy"))) // inserted
1923+
}
1924+
}
1925+
1926+
private def findMergeExec(query: String): MergeRowsExec = {
1927+
val plan = executeAndKeepPlan {
1928+
sql(query)
1929+
}
1930+
collectFirst(plan) {
1931+
case m: MergeRowsExec => m
1932+
} match {
1933+
case Some(m) => m
1934+
case None =>
1935+
fail("MergeRowsExec not found in the plan")
1936+
}
1937+
}
1938+
17741939
private def assertNoLeftBroadcastOrReplication(query: String): Unit = {
17751940
val plan = executeAndKeepPlan {
17761941
sql(query)
@@ -1793,4 +1958,16 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase {
17931958
}
17941959
assert(e.getMessage.contains("ON search condition of the MERGE statement"))
17951960
}
1961+
1962+
private def assertMetric(
1963+
mergeExec: MergeRowsExec,
1964+
metricName: String,
1965+
expected: Long): Unit = {
1966+
mergeExec.metrics.get(metricName) match {
1967+
case Some(metric) =>
1968+
assert(metric.value == expected,
1969+
s"Expected $metricName to be $expected, but got ${metric.value}")
1970+
case None => fail(s"$metricName metric not found")
1971+
}
1972+
}
17961973
}

0 commit comments

Comments
 (0)