Skip to content

Commit 5248cd2

Browse files
richardc-dbhimadripal
authored andcommitted
[SPARK-49074][SQL] Fix variant with df.cache()
### What changes were proposed in this pull request? Currently, the `actualSize` method of the `VARIANT` `columnType` isn't overridden, so we use the default size of 2kb for the `actualSize`. We should define `actualSize` so the cached variant column can correctly be written to the byte buffer. Currently, if the avg per-variant size is greater than 2KB and the total column size is greater than 128KB (the default initial buffer size), an exception will be (incorrectly) thrown. ### Why are the changes needed? to fix caching larger variants (in df.cache()), such as the ones included in the UTs. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added UT ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#47559 from richardc-db/fix_variant_cache. Authored-by: Richard Chen <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 464f7e4 commit 5248cd2

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,12 @@ private[columnar] object VARIANT
829829
/** Chosen to match the default size set in `VariantType`. */
830830
override def defaultSize: Int = 2048
831831

832+
override def actualSize(row: InternalRow, ordinal: Int): Int = {
833+
val v = getField(row, ordinal)
834+
// 4 bytes each for the integers representing the 'value' and 'metadata' lengths.
835+
8 + v.getValue().length + v.getMetadata().length
836+
}
837+
832838
override def getField(row: InternalRow, ordinal: Int): VariantVal = row.getVariant(ordinal)
833839

834840
override def setField(row: InternalRow, ordinal: Int, value: VariantVal): Unit =

sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,21 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
652652
checkAnswer(df, expected.collect())
653653
}
654654

655+
test("variant with many keys in a cached row-based df") {
656+
// The initial size of the buffer backing a cached dataframe column is 128KB.
657+
// See `ColumnBuilder`.
658+
val numKeys = 128 * 1024
659+
var keyIterator = (0 until numKeys).iterator
660+
val entries = Array.fill(numKeys)(s"""\"${keyIterator.next()}\": \"test\"""")
661+
val jsonStr = s"{${entries.mkString(", ")}}"
662+
val query = s"""select parse_json('${jsonStr}') v from range(0, 10)"""
663+
val df = spark.sql(query)
664+
df.cache()
665+
666+
val expected = spark.sql(query)
667+
checkAnswer(df, expected.collect())
668+
}
669+
655670
test("struct of variant in a cached row-based df") {
656671
val query = """select named_struct(
657672
'v', parse_json(format_string('{\"a\": %s}', id)),
@@ -680,6 +695,21 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
680695
checkAnswer(df, expected.collect())
681696
}
682697

698+
test("array variant with many keys in a cached row-based df") {
699+
// The initial size of the buffer backing a cached dataframe column is 128KB.
700+
// See `ColumnBuilder`.
701+
val numKeys = 128 * 1024
702+
var keyIterator = (0 until numKeys).iterator
703+
val entries = Array.fill(numKeys)(s"""\"${keyIterator.next()}\": \"test\"""")
704+
val jsonStr = s"{${entries.mkString(", ")}}"
705+
val query = s"""select array(parse_json('${jsonStr}')) v from range(0, 10)"""
706+
val df = spark.sql(query)
707+
df.cache()
708+
709+
val expected = spark.sql(query)
710+
checkAnswer(df, expected.collect())
711+
}
712+
683713
test("map of variant in a cached row-based df") {
684714
val query = """select map(
685715
'v', parse_json(format_string('{\"a\": %s}', id)),
@@ -711,6 +741,29 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
711741
}
712742
}
713743

744+
test("variant with many keys in a cached column-based df") {
745+
withTable("t") {
746+
// The initial size of the buffer backing a cached dataframe column is 128KB.
747+
// See `ColumnBuilder`.
748+
val numKeys = 128 * 1024
749+
var keyIterator = (0 until numKeys).iterator
750+
val entries = Array.fill(numKeys)(s"""\"${keyIterator.next()}\": \"test\"""")
751+
val jsonStr = s"{${entries.mkString(", ")}}"
752+
val query = s"""select named_struct(
753+
'v', parse_json('$jsonStr'),
754+
'null_v', cast(null as variant),
755+
'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) else null end
756+
) v
757+
from range(0, 10)"""
758+
spark.sql(query).write.format("parquet").mode("overwrite").saveAsTable("t")
759+
val df = spark.sql("select * from t")
760+
df.cache()
761+
762+
val expected = spark.sql(query)
763+
checkAnswer(df, expected.collect())
764+
}
765+
}
766+
714767
test("variant_get size") {
715768
val largeKey = "x" * 1000
716769
val df = Seq(s"""{ "$largeKey": {"a" : 1 },

0 commit comments

Comments
 (0)