Skip to content

Commit c188e7d

Browse files
david-zlaitchow-zlaithomaschow
authored
Add flag to skip repartition before writing. (#239)
## Summary ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added configurable repartitioning option for DataFrame writes. - Introduced a new configuration setting to control repartitioning behavior. - Enhanced test suite with functionality to handle empty DataFrames. - **Chores** - Improved code formatting and logging for DataFrame writing process. <!-- end of auto-generated comment: release notes by coderabbit.ai --> <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` --> --------- Co-authored-by: tchow-zlai <[email protected]> Co-authored-by: Thomas Chow <[email protected]>
1 parent d60a669 commit c188e7d

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -457,15 +457,23 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
457457
sortByCols: Seq[String] = Seq.empty): Unit = {
458458
wrapWithCache(s"repartition & write to $tableName", df) {
459459
logger.info("Repartitioning before writing...")
460-
repartitionAndWriteInternal(df, tableName, saveMode, stats, sortByCols)
460+
val dataPointer = DataPointer.from(tableName, sparkSession)
461+
val repartitioned =
462+
if (sparkSession.conf.get("spark.chronon.write.repartition", true.toString).toBoolean)
463+
repartitionInternal(df, tableName, stats, sortByCols)
464+
else df
465+
repartitioned.write
466+
.mode(saveMode)
467+
.save(dataPointer)
468+
469+
logger.info(s"Finished writing to $tableName")
461470
}.get
462471
}
463472

464-
private def repartitionAndWriteInternal(df: DataFrame,
465-
tableName: String,
466-
saveMode: SaveMode,
467-
stats: Option[DfStats],
468-
sortByCols: Seq[String]): Unit = {
473+
private def repartitionInternal(df: DataFrame,
474+
tableName: String,
475+
stats: Option[DfStats],
476+
sortByCols: Seq[String]): DataFrame = {
469477

470478
// get row count and table partition count statistics
471479

@@ -483,7 +491,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
483491

484492
// set to one if tablePartitionCount=0 to avoid division by zero
485493
val nonZeroTablePartitionCount = if (tablePartitionCount == 0) 1 else tablePartitionCount
486-
487494
logger.info(s"$rowCount rows requested to be written into table $tableName")
488495
if (rowCount > 0) {
489496
val columnSizeEstimate = columnSizeEstimator(df.schema)
@@ -527,18 +534,13 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
527534
(Seq(partitionColumn, saltCol), Seq(partitionColumn) ++ sortByCols)
528535
} else { (Seq(saltCol), sortByCols) }
529536
logger.info(s"Sorting within partitions with cols: $partitionSortCols")
530-
val dataPointer = DataPointer.from(tableName, sparkSession)
531537

532538
saltedDf
533539
.repartition(shuffleParallelism, repartitionCols.map(saltedDf.col): _*)
534540
.drop(saltCol)
535541
.sortWithinPartitions(partitionSortCols.map(col): _*)
536-
.write
537-
.mode(saveMode)
538-
.save(dataPointer)
539-
540-
logger.info(s"Finished writing to $tableName")
541542
}
543+
df
542544
}
543545

544546
def chunk(partitions: Set[String]): Seq[PartitionRange] = {

spark/src/test/scala/ai/chronon/spark/test/TableUtilsTest.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import org.scalatest.flatspec.AnyFlatSpec
3131

3232
import scala.util.Try
3333

34+
case class TestRecord(ds: String, id: String)
35+
3436
class SimpleAddUDF extends UDF {
3537
def evaluate(value: Int): Int = {
3638
value + 20
@@ -482,6 +484,21 @@ class TableUtilsTest extends AnyFlatSpec {
482484
}
483485
}
484486

487+
it should "repartitioning an empty dataframe should work" in {
488+
import spark.implicits._
489+
val tableName = "db.test_empty_table"
490+
tableUtils.createDatabase("db")
491+
492+
tableUtils.insertPartitions(spark.emptyDataset[TestRecord].toDF(), tableName)
493+
val res = tableUtils.loadTable(tableName)
494+
assertEquals(0, res.count)
495+
496+
tableUtils.insertPartitions(spark.createDataFrame(List(TestRecord("2025-01-01", "a"))), tableName)
497+
val newRes = tableUtils.loadTable(tableName)
498+
499+
assertEquals(1, newRes.count)
500+
}
501+
485502
it should "create table" in {
486503
val tableName = "db.test_create_table"
487504
spark.sql("CREATE DATABASE IF NOT EXISTS db")

0 commit comments

Comments
 (0)