@@ -24,14 +24,19 @@ import org.apache.spark.sql.catalyst.optimizer.BuildLeft
24
24
import org .apache .spark .sql .connector .catalog .{Column , ColumnDefaultValue , TableInfo }
25
25
import org .apache .spark .sql .connector .expressions .{GeneralScalarExpression , LiteralValue }
26
26
import org .apache .spark .sql .execution .SparkPlan
27
+ import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
28
+ import org .apache .spark .sql .execution .datasources .v2 .MergeRowsExec
27
29
import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , BroadcastNestedLoopJoinExec , CartesianProductExec }
28
30
import org .apache .spark .sql .internal .SQLConf
29
31
import org .apache .spark .sql .types .{IntegerType , StringType }
30
32
31
- abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase {
33
+ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
34
+ with AdaptiveSparkPlanHelper {
32
35
33
36
import testImplicits ._
34
37
38
+ protected def deltaMerge : Boolean = false
39
+
35
40
test(" merge into table with expression-based default values" ) {
36
41
val columns = Array (
37
42
Column .create(" pk" , IntegerType ),
@@ -1771,6 +1776,166 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase {
1771
1776
}
1772
1777
}
1773
1778
1779
+ test(" Merge metrics with matched clause" ) {
1780
+ withTempView(" source" ) {
1781
+ createAndInitTable(" pk INT NOT NULL, salary INT, dep STRING" ,
1782
+ """ { "pk": 1, "salary": 100, "dep": "hr" }
1783
+ |{ "pk": 2, "salary": 200, "dep": "software" }
1784
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
1785
+ |""" .stripMargin)
1786
+
1787
+ val sourceDF = Seq (1 , 2 , 10 ).toDF(" pk" )
1788
+ sourceDF.createOrReplaceTempView(" source" )
1789
+
1790
+ val mergeExec = findMergeExec {
1791
+ s """ MERGE INTO $tableNameAsString t
1792
+ |USING source s
1793
+ |ON t.pk = s.pk
1794
+ |WHEN MATCHED AND salary < 200 THEN
1795
+ | UPDATE SET salary = 1000
1796
+ | """ .stripMargin
1797
+ }
1798
+
1799
+ assertMetric(mergeExec, " numTargetRowsCopied" , if (deltaMerge) 0 else 2 )
1800
+
1801
+ checkAnswer(
1802
+ sql(s " SELECT * FROM $tableNameAsString" ),
1803
+ Seq (
1804
+ Row (1 , 1000 , " hr" ), // updated
1805
+ Row (2 , 200 , " software" ),
1806
+ Row (3 , 300 , " hr" )))
1807
+ }
1808
+ }
1809
+
1810
+ test(" Merge metrics with matched and not matched clause" ) {
1811
+ withTempView(" source" ) {
1812
+ createAndInitTable(" pk INT NOT NULL, salary INT, dep STRING" ,
1813
+ """ { "pk": 1, "salary": 100, "dep": "hr" }
1814
+ |{ "pk": 2, "salary": 200, "dep": "software" }
1815
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
1816
+ |""" .stripMargin)
1817
+
1818
+ val sourceDF = Seq (
1819
+ (4 , 100 , " marketing" ),
1820
+ (5 , 400 , " executive" ),
1821
+ (6 , 100 , " hr" )
1822
+ ).toDF(" pk" , " salary" , " dep" )
1823
+ sourceDF.createOrReplaceTempView(" source" )
1824
+
1825
+ val mergeExec = findMergeExec {
1826
+ s """ MERGE INTO $tableNameAsString t
1827
+ |USING source s
1828
+ |ON t.pk = s.pk
1829
+ |WHEN MATCHED THEN
1830
+ | UPDATE SET salary = 9999
1831
+ |WHEN NOT MATCHED AND salary > 200 THEN
1832
+ | INSERT *
1833
+ | """ .stripMargin
1834
+ }
1835
+
1836
+ assertMetric(mergeExec, " numTargetRowsCopied" , 0 )
1837
+
1838
+ checkAnswer(
1839
+ sql(s " SELECT * FROM $tableNameAsString" ),
1840
+ Seq (
1841
+ Row (1 , 100 , " hr" ),
1842
+ Row (2 , 200 , " software" ),
1843
+ Row (3 , 300 , " hr" ),
1844
+ Row (5 , 400 , " executive" ))) // inserted
1845
+ }
1846
+ }
1847
+
1848
+ test(" Merge metrics with matched and not matched by source clauses" ) {
1849
+ withTempView(" source" ) {
1850
+ createAndInitTable(" pk INT NOT NULL, salary INT, dep STRING" ,
1851
+ """ { "pk": 1, "salary": 100, "dep": "hr" }
1852
+ |{ "pk": 2, "salary": 200, "dep": "software" }
1853
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
1854
+ |{ "pk": 4, "salary": 400, "dep": "marketing" }
1855
+ |{ "pk": 5, "salary": 500, "dep": "executive" }
1856
+ |""" .stripMargin)
1857
+
1858
+ val sourceDF = Seq (1 , 2 , 10 ).toDF(" pk" )
1859
+ sourceDF.createOrReplaceTempView(" source" )
1860
+
1861
+ val mergeExec = findMergeExec {
1862
+ s """ MERGE INTO $tableNameAsString t
1863
+ |USING source s
1864
+ |ON t.pk = s.pk
1865
+ |WHEN MATCHED AND salary < 200 THEN
1866
+ | UPDATE SET salary = 1000
1867
+ |WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN
1868
+ | UPDATE SET salary = -1
1869
+ | """ .stripMargin
1870
+ }
1871
+
1872
+
1873
+ assertMetric(mergeExec, " numTargetRowsCopied" , if (deltaMerge) 0 else 3 )
1874
+
1875
+ checkAnswer(
1876
+ sql(s " SELECT * FROM $tableNameAsString" ),
1877
+ Seq (
1878
+ Row (1 , 1000 , " hr" ), // updated
1879
+ Row (2 , 200 , " software" ),
1880
+ Row (3 , 300 , " hr" ),
1881
+ Row (4 , 400 , " marketing" ),
1882
+ Row (5 , - 1 , " executive" ))) // updated
1883
+ }
1884
+ }
1885
+
1886
+ test(" Merge metrics with matched, not matched, and not matched by source clauses" ) {
1887
+ withTempView(" source" ) {
1888
+ createAndInitTable(" pk INT NOT NULL, salary INT, dep STRING" ,
1889
+ """ { "pk": 1, "salary": 100, "dep": "hr" }
1890
+ |{ "pk": 2, "salary": 200, "dep": "software" }
1891
+ |{ "pk": 3, "salary": 300, "dep": "hr" }
1892
+ |{ "pk": 4, "salary": 400, "dep": "marketing" }
1893
+ |{ "pk": 5, "salary": 500, "dep": "executive" }
1894
+ |""" .stripMargin)
1895
+
1896
+ val sourceDF = Seq (1 , 2 , 6 , 10 ).toDF(" pk" )
1897
+ sourceDF.createOrReplaceTempView(" source" )
1898
+
1899
+ val mergeExec = findMergeExec {
1900
+ s """ MERGE INTO $tableNameAsString t
1901
+ |USING source s
1902
+ |ON t.pk = s.pk
1903
+ |WHEN MATCHED AND salary < 200 THEN
1904
+ | UPDATE SET salary = 1000
1905
+ |WHEN NOT MATCHED AND s.pk < 10 THEN
1906
+ | INSERT (pk, salary, dep) VALUES (s.pk, -1, "dummy")
1907
+ |WHEN NOT MATCHED BY SOURCE AND salary > 400 THEN
1908
+ | UPDATE SET salary = -1
1909
+ | """ .stripMargin
1910
+ }
1911
+
1912
+ assertMetric(mergeExec, " numTargetRowsCopied" , if (deltaMerge) 0 else 3 )
1913
+
1914
+ checkAnswer(
1915
+ sql(s " SELECT * FROM $tableNameAsString" ),
1916
+ Seq (
1917
+ Row (1 , 1000 , " hr" ), // updated
1918
+ Row (2 , 200 , " software" ),
1919
+ Row (3 , 300 , " hr" ),
1920
+ Row (4 , 400 , " marketing" ),
1921
+ Row (5 , - 1 , " executive" ), // updated
1922
+ Row (6 , - 1 , " dummy" ))) // inserted
1923
+ }
1924
+ }
1925
+
1926
+ private def findMergeExec (query : String ): MergeRowsExec = {
1927
+ val plan = executeAndKeepPlan {
1928
+ sql(query)
1929
+ }
1930
+ collectFirst(plan) {
1931
+ case m : MergeRowsExec => m
1932
+ } match {
1933
+ case Some (m) => m
1934
+ case None =>
1935
+ fail(" MergeRowsExec not found in the plan" )
1936
+ }
1937
+ }
1938
+
1774
1939
private def assertNoLeftBroadcastOrReplication (query : String ): Unit = {
1775
1940
val plan = executeAndKeepPlan {
1776
1941
sql(query)
@@ -1793,4 +1958,16 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase {
1793
1958
}
1794
1959
assert(e.getMessage.contains(" ON search condition of the MERGE statement" ))
1795
1960
}
1961
+
1962
+ private def assertMetric (
1963
+ mergeExec : MergeRowsExec ,
1964
+ metricName : String ,
1965
+ expected : Long ): Unit = {
1966
+ mergeExec.metrics.get(metricName) match {
1967
+ case Some (metric) =>
1968
+ assert(metric.value == expected,
1969
+ s " Expected $metricName to be $expected, but got ${metric.value}" )
1970
+ case None => fail(s " $metricName metric not found " )
1971
+ }
1972
+ }
1796
1973
}
0 commit comments