Skip to content

Reuse Dedup Shuffle #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,14 @@ class Join(joinConf: api.Join,

bootstrapDf = bootstrapDf
.select(includedColumns.map(col): _*)
// TODO: allow customization of deduplication logic
.dropDuplicates(part.keys(joinConf, tableUtils.partitionColumn).toArray)

coalescedJoin(partialDf, bootstrapDf, part.keys(joinConf, tableUtils.partitionColumn).toSeq)
val dedupedBootstrap = dropDuplicatesUsingJoinShuffle(
bootstrapDf,
partialDf,
part.keys(joinConf, tableUtils.partitionColumn).toSeq
)

coalescedJoin(partialDf, dedupedBootstrap, part.keys(joinConf, tableUtils.partitionColumn).toSeq)
// as part of the left outer join process, we update and maintain matched_hashes for each record
// that summarizes whether there is a join-match for each bootstrap source.
// later on we use this information to decide whether we still need to re-run the backfill logic
Expand Down
43 changes: 40 additions & 3 deletions spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import ai.chronon.api.Constants
import ai.chronon.api.DataModel.Events
import ai.chronon.api.Extensions.{JoinOps, _}
import ai.chronon.spark.Extensions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{coalesce, col, udf}
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.expressions.{UserDefinedFunction, Window}
import org.apache.spark.sql.functions.{coalesce, col, lit, row_number, udf}
import org.slf4j.LoggerFactory

import java.util
Expand Down Expand Up @@ -115,6 +116,42 @@ object JoinUtils {
PartitionRange(leftStart, leftEnd)(tableUtils)
}

/**
* Deduplicate the given [[toDedup]] dataframe using the same shuffle as
* would be used when joining to the [[joinPartner]] dataframe.
*
* This is an optimization for the scenario where a caller is deduplicating
* the dataframe before joining:
* joinPartner.join(toDedup.dropDuplicates(keys), keys)
*
* By default, spark does a poor job of reusing the deduplication shuffle in
* this scenario -- the join query plan will normalize floating point numbers
* but dropDuplicates will not, resulting in subtly different shuffles.
*
* This function plans a fake join and extracts the exact expressions which
* will be used in the join. Then it uses the join expressions to deduplicate
* so that the same shuffle can be reused.
*/
def dropDuplicatesUsingJoinShuffle(toDedup: DataFrame, joinPartner: DataFrame, keys: Seq[String]): DataFrame = {
val condition = keys.map(key => joinPartner(key) === toDedup(key)).reduce(_ && _)
val plannedJoin = joinPartner.join(toDedup, condition)

plannedJoin.queryExecution.logical match {
case ExtractEquiJoinKeys(_, _, rightKeys, _, _, _, _) =>
Comment on lines +139 to +140

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is planning the join necessary? couldn't tell if you just needed rightKeys , and if you just needed that, is pulling it from ExtractEquiJoinKeys via logical plan the only way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rightKeys is only knowable after planning the join; spark will insert additional normalization functions to ensure correctness:
Screenshot 2025-03-12 at 10 38 51 AM

The only way I've found to correctly detect and re-use the knownfloatingpointnormalized(normalizenanandzero(...)) functions is to plan a join.


val cols = rightKeys.map(new Column(_))
val w = Window.partitionBy(cols: _*).orderBy(cols: _*)
toDedup
.withColumn("dedup_row_number", row_number().over(w))
.filter(col("dedup_row_number") === lit(1))
.drop("dedup_row_number")
case _ =>
// This should never happen
logger.warn(s"Couldn't plan an equijoin for $condition. Falling back to dropDuplicates.")
toDedup.dropDuplicates(keys)
}
}

/** *
* join left and right dataframes, merging any shared columns if exists by the coalesce rule.
* fails if there is any data type mismatch between shared columns.
Expand Down
144 changes: 143 additions & 1 deletion spark/src/test/scala/ai/chronon/spark/test/JoinUtilsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package ai.chronon.spark.test
import ai.chronon.aggregator.test.Column
import ai.chronon.api
import ai.chronon.api.{Builders, Constants}
import ai.chronon.spark.JoinUtils.{contains_any, set_add}
import ai.chronon.spark.JoinUtils.{contains_any, dropDuplicatesUsingJoinShuffle, set_add}
import ai.chronon.spark.{GroupBy, JoinUtils, PartitionRange, SparkSessionBuilder, TableUtils}
import ai.chronon.spark.Extensions._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
Expand Down Expand Up @@ -102,6 +103,147 @@ class JoinUtilsTest {
}
}

@Test
def reuseDedupShuffle(): Unit = {
val spark: SparkSession =
SparkSessionBuilder.build("JoinUtilsTest" + "_" + Random.alphanumeric.take(6).mkString, local = true)
val leftDf = spark.createDataFrame(spark.sparkContext.parallelize(Seq(
(1.0, "a2")
)))
val rightDf = spark.createDataFrame(spark.sparkContext.parallelize(Seq(
(1.0, "a2")
)))

val keys = Seq("_1")
val deduped = dropDuplicatesUsingJoinShuffle(rightDf, leftDf, keys)
.as[(String, String)](spark.implicits.newProductEncoder)

val exchanges = leftDf.join(deduped, keys)
.queryExecution.executedPlan
.collect { case _: Exchange => true }
.length

assertEquals(2, exchanges)
}

private def runDropDuplicates(left: Seq[(String, String)], right: Seq[(String, String)], keys: Seq[String]): Seq[(String, String)] = {
val spark: SparkSession =
SparkSessionBuilder.build("JoinUtilsTest" + "_" + Random.alphanumeric.take(6).mkString, local = true)
val leftDf = spark.createDataFrame(spark.sparkContext.parallelize(left))
val rightDf = spark.createDataFrame(spark.sparkContext.parallelize(right))

dropDuplicatesUsingJoinShuffle(rightDf, leftDf, keys)
.as[(String, String)](spark.implicits.newProductEncoder)
.collect()
.toSeq
}

@Test
def testNoDuplicates(): Unit = {
val left = Seq(
("a1", "a2"),
("b1", "b2"),
("c1", "c2")
)
val right = Seq(
("a1", "a2"),
("b1", "b2"),
("c1", "c2")
)
val keys = Seq("_1")

val expected = right
val result = runDropDuplicates(left, right, keys)

assertEquals(expected.length, result.length)
result.foreach { r => assertTrue(expected.contains(r)) }
}

@Test
def testNoDuplicatesColumn2(): Unit = {
val left = Seq(
("a1", "a2"),
("b1", "b2"),
("c1", "c2")
)
val right = Seq(
("a1", "a2"),
("a1", "b2"),
("a1", "c2")
)
val keys = Seq("_1", "_2")

val expected = right
val result = runDropDuplicates(left, right, keys)

assertEquals(expected.length, result.length)
result.foreach { r => assertTrue(expected.contains(r)) }
}

@Test
def testDuplicates(): Unit = {
val left = Seq(
("a1", "a2"),
("b1", "b2"),
("c1", "c2")
)
val right = Seq(
("a1", "a2"),
("a1", "b2"),
("c1", "c2")
)
val keys = Seq("_1")

// to handle nondeterministic sort
val expected1 = Seq(
("a1", "a2"),
("c1", "c2")
)
val expected2 = Seq(
("a1", "b2"),
("c1", "c2")
)
val result = runDropDuplicates(left, right, keys)

assertEquals(expected1.length, result.length)
result.foreach { r => assertTrue(expected1.contains(r) || expected2.contains(r)) }
}

@Test
def test2ColumnDedup(): Unit = {
val left = Seq(
("a1", "a2"),
("b1", "b2"),
("c1", "c2")
)
val right = Seq(
("a1", "a2"),
("a1", "a2"),
("a1", "c2")
)
val keys = Seq("_1", "_2")

val expected = Seq(
("a1", "a2"),
("a1", "c2")
)
val result = runDropDuplicates(left, right, keys)

assertEquals(expected.length, result.length)
result.foreach { r => assertTrue(expected.contains(r)) }
}

@Test
def testEmpty(): Unit = {
val left = Seq()
val right = Seq()
val keys = Seq("_1")

val result = runDropDuplicates(left, right, keys)

assertTrue(result.isEmpty)
}

private def testJoinScenario(leftSchema: StructType,
rightSchema: StructType,
keys: Seq[String],
Expand Down