Skip to content

Commit abd7556

Browse files
authored
perf: resolve schema only once and cache (#696)
1 parent bdeda94 commit abd7556

File tree

9 files changed

+72
-52
lines changed

9 files changed

+72
-52
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryNative.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ case object BigQueryNative extends Format {
1414

1515
private val internalBQCol = "__chronon_internal_bq_col__"
1616

17-
override def table(tableName: String, partitionFilters: String)(implicit sparkSession: SparkSession): DataFrame = {
17+
// TODO(tchow): use the cache flag
18+
override def table(tableName: String, partitionFilters: String, cacheDf: Boolean = false)(implicit
19+
sparkSession: SparkSession): DataFrame = {
1820
import sparkSession.implicits._
1921

2022
// First, need to clean the spark-based table name for the bigquery queries below.

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.functions._
2828
import org.apache.spark.sql.types.{LongType, StructType}
2929
import org.apache.spark.util.sketch.BloomFilter
3030
import org.slf4j.{Logger, LoggerFactory}
31-
import ai.chronon.spark.catalog.TableUtils
31+
import ai.chronon.spark.catalog.{TableCache, TableUtils}
3232

3333
import java.util
3434
import scala.collection.Seq
@@ -76,7 +76,6 @@ object Extensions {
7676
if (intersectedCounts.isEmpty) return None
7777
Some(DfWithStats(df.prunePartition(range), intersectedCounts))
7878
}
79-
def stats: DfStats = DfStats(count, partitionRange)
8079
}
8180

8281
object DfWithStats {
@@ -143,6 +142,7 @@ object Extensions {
143142
tableProperties: Map[String, String] = null,
144143
partitionColumns: Seq[String] = List(tableUtils.partitionColumn),
145144
autoExpand: Boolean = false): Unit = {
145+
146146
TableUtils(df.sparkSession).insertPartitions(df,
147147
tableName,
148148
tableProperties,

spark/src/main/scala/ai/chronon/spark/GroupBy.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,8 @@ object GroupBy {
687687
if (mutations) source.getEntities.mutationTable.cleanSpec else source.table,
688688
Option(source.query.wheres).map(_.toScala).getOrElse(Seq.empty[String]),
689689
partitionConditions,
690-
Some(metaColumns ++ keys.map(_ -> null))
690+
Some(metaColumns ++ keys.map(_ -> null)),
691+
cacheDf = true
691692
)
692693
}
693694

spark/src/main/scala/ai/chronon/spark/Join.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,13 +287,6 @@ class Join(joinConf: api.Join,
287287
}
288288
}
289289

290-
val leftTimeRangeOpt = if (leftTaggedDf.schema.fieldNames.contains(Constants.TimePartitionColumn)) {
291-
val leftTimePartitionMinMax = leftTaggedDf.range[String](Constants.TimePartitionColumn)
292-
Some(PartitionRange(leftTimePartitionMinMax._1, leftTimePartitionMinMax._2))
293-
} else {
294-
None
295-
}
296-
297290
implicit val executionContext: ExecutionContextExecutorService =
298291
ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(tableUtils.joinPartParallelism))
299292

@@ -342,7 +335,7 @@ class Join(joinConf: api.Join,
342335
}
343336

344337
val runContext =
345-
JoinPartJobContext(unfilledLeftDf, bloomFilterOpt, leftTimeRangeOpt, tableProps, runSmallMode)
338+
JoinPartJobContext(unfilledLeftDf, bloomFilterOpt, tableProps, runSmallMode)
346339

347340
val skewKeys: Option[Map[String, Seq[String]]] = Option(joinConfCloned.skewKeys).map { jmap =>
348341
val scalaMap = jmap.toScala

spark/src/main/scala/ai/chronon/spark/JoinUtils.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,27 +71,33 @@ object JoinUtils {
7171
tableUtils: TableUtils,
7272
allowEmpty: Boolean = false,
7373
limit: Option[Int] = None): Option[DataFrame] = {
74+
7475
val timeProjection = if (joinConf.left.dataModel == EVENTS) {
7576
Seq(Constants.TimeColumn -> Option(joinConf.left.query).map(_.timeColumn).orNull)
7677
} else {
7778
Seq()
7879
}
80+
7981
var df = tableUtils.scanDf(joinConf.left.query,
8082
joinConf.left.table,
8183
Some((Map(tableUtils.partitionColumn -> null) ++ timeProjection).toMap),
8284
range = Some(range))
85+
8386
limit.foreach(l => df = df.limit(l))
87+
8488
val skewFilter = joinConf.skewFilter()
8589
val result = skewFilter
8690
.map(sf => {
8791
logger.info(s"left skew filter: $sf")
8892
df.filter(sf)
8993
})
9094
.getOrElse(df)
95+
9196
if (!allowEmpty && result.isEmpty) {
9297
logger.info(s"Left side query below produced 0 rows in range $range, and allowEmpty=false.")
9398
return None
9499
}
100+
95101
Some(result)
96102
}
97103

@@ -561,11 +567,7 @@ object JoinUtils {
561567
}.toMap)
562568
}
563569

564-
def shiftDays(leftDataModel: DataModel,
565-
joinPart: JoinPart,
566-
leftTimeRangeOpt: Option[PartitionRange],
567-
leftDf: Option[DfWithStats],
568-
leftRange: PartitionRange) = {
570+
def shiftDays(leftDataModel: DataModel, joinPart: JoinPart, leftRange: PartitionRange): PartitionRange = {
569571
val shiftDays =
570572
if (leftDataModel == EVENTS && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) {
571573
-1
@@ -580,7 +582,7 @@ object JoinUtils {
580582
// events | entities | temporal => right part tables are aligned - so scan by leftRange
581583
// entities | entities | snapshot => right part tables are aligned - so scan by leftRange
582584
val rightRange = if (leftDataModel == EVENTS && joinPart.groupBy.inferredAccuracy == Accuracy.SNAPSHOT) {
583-
// Diabling for now
585+
// Disabling for now
584586
// val leftTimeRange = leftTimeRangeOpt.getOrElse(leftDf.get.timeRange.toPartitionRange)
585587
leftRange.shift(shiftDays)
586588
} else {

spark/src/main/scala/ai/chronon/spark/batch/JoinPartJob.scala

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package ai.chronon.spark.batch
33
import ai.chronon.api.DataModel.{ENTITIES, EVENTS}
44
import ai.chronon.api.Extensions.{DateRangeOps, DerivationOps, GroupByOps, JoinPartOps, MetadataOps}
55
import ai.chronon.api.PartitionRange.toTimeRange
6-
import ai.chronon.api.{Accuracy, Builders, Constants, DateRange, JoinPart, PartitionRange}
6+
import ai.chronon.api.{Accuracy, Builders, Constants, DateRange, JoinPart, PartitionRange, PartitionSpec}
77
import ai.chronon.online.metrics.Metrics
88
import ai.chronon.orchestration.JoinPartNode
99
import ai.chronon.spark.Extensions._
@@ -20,13 +20,12 @@ import scala.jdk.CollectionConverters._
2020

2121
case class JoinPartJobContext(leftDf: Option[DfWithStats],
2222
joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]],
23-
leftTimeRangeOpt: Option[PartitionRange],
2423
tableProps: Map[String, String],
2524
runSmallMode: Boolean)
2625

2726
class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)(implicit tableUtils: TableUtils) {
2827
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
29-
implicit val partitionSpec = tableUtils.partitionSpec
28+
implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
3029

3130
private val leftTable = node.leftSourceTable
3231
private val joinPart = node.joinPart
@@ -50,14 +49,6 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
5049
val query = Builders.Query(selects = relevantLeftCols.map(t => t -> t).toMap)
5150
val cachedLeftDf = tableUtils.scanDf(query = query, leftTable, range = Some(dateRange))
5251

53-
val leftTimeRangeOpt: Option[PartitionRange] =
54-
if (cachedLeftDf.schema.fieldNames.contains(Constants.TimePartitionColumn)) {
55-
val leftTimePartitionMinMax = cachedLeftDf.range[String](Constants.TimePartitionColumn)
56-
Some(PartitionRange(leftTimePartitionMinMax._1, leftTimePartitionMinMax._2))
57-
} else {
58-
None
59-
}
60-
6152
val runSmallMode = JoinUtils.runSmallMode(tableUtils, cachedLeftDf)
6253

6354
val leftWithStats = cachedLeftDf.withStats
@@ -67,7 +58,6 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
6758

6859
JoinPartJobContext(Option(leftWithStats),
6960
joinLevelBloomMapOpt,
70-
leftTimeRangeOpt,
7161
Option(node.metaData.tableProps).getOrElse(Map.empty[String, String]),
7262
runSmallMode)
7363
}
@@ -77,27 +67,25 @@ class JoinPartJob(node: JoinPartNode, range: DateRange, showDf: Boolean = false)
7767
jobContext.leftDf,
7868
joinPart,
7969
dateRange,
80-
jobContext.leftTimeRangeOpt,
8170
node.metaData.outputTable,
8271
jobContext.tableProps,
8372
jobContext.joinLevelBloomMapOpt,
8473
jobContext.runSmallMode
8574
)
8675
}
8776

88-
def computeRightTable(leftDfOpt: Option[DfWithStats],
89-
joinPart: JoinPart,
90-
leftRange: PartitionRange, // missing left partitions
91-
leftTimeRangeOpt: Option[PartitionRange], // range of timestamps within missing left partitions
92-
partTable: String,
93-
tableProps: Map[String, String] = Map(),
94-
joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]],
95-
smallMode: Boolean = false): Option[DataFrame] = {
77+
private def computeRightTable(leftDfOpt: Option[DfWithStats],
78+
joinPart: JoinPart,
79+
leftRange: PartitionRange, // missing left partitions
80+
partTable: String,
81+
tableProps: Map[String, String] = Map(),
82+
joinLevelBloomMapOpt: Option[util.Map[String, BloomFilter]],
83+
smallMode: Boolean = false): Option[DataFrame] = {
9684

9785
// val partMetrics = Metrics.Context(metrics, joinPart) -- TODO is this metrics context sufficient, or should we pass thru for monolith join?
9886
val partMetrics = Metrics.Context(Metrics.Environment.JoinOffline, joinPart.groupBy)
9987

100-
val rightRange = JoinUtils.shiftDays(node.leftDataModel, joinPart, leftTimeRangeOpt, leftDfOpt, leftRange)
88+
val rightRange = JoinUtils.shiftDays(node.leftDataModel, joinPart, leftRange)
10189

10290
// Can kill the option after we deprecate monolith join job
10391
leftDfOpt.map { leftDf =>

spark/src/main/scala/ai/chronon/spark/catalog/Format.scala

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,47 @@
11
package ai.chronon.spark.catalog
22

3-
import org.apache.spark.sql.SparkSession
3+
import org.apache.spark.sql.{DataFrame, SparkSession}
44
import org.slf4j.{Logger, LoggerFactory}
5-
import org.apache.spark.sql.DataFrame
5+
6+
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
7+
import java.util.function
8+
9+
object TableCache {
10+
private val dfMap: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap[String, DataFrame]()
11+
12+
def get(tableName: String)(implicit sparkSession: SparkSession): DataFrame = {
13+
dfMap.computeIfAbsent(tableName,
14+
new function.Function[String, DataFrame] {
15+
override def apply(t: String): DataFrame = {
16+
sparkSession.read.table(t)
17+
}
18+
})
19+
}
20+
21+
def remove(tableName: String): Unit = {
22+
dfMap.remove(tableName)
23+
}
24+
}
625

726
trait Format {
827

928
@transient protected lazy val logger: Logger = LoggerFactory.getLogger(getClass)
1029

11-
def table(tableName: String, partitionFilters: String)(implicit sparkSession: SparkSession): DataFrame = {
12-
val df = sparkSession.read.table(tableName)
30+
def table(tableName: String, partitionFilters: String, cacheDf: Boolean = false)(implicit
31+
sparkSession: SparkSession): DataFrame = {
32+
33+
val df = if (cacheDf) {
34+
TableCache.get(tableName)
35+
} else {
36+
sparkSession.read.table(tableName)
37+
}
38+
1339
if (partitionFilters.isEmpty) {
1440
df
1541
} else {
1642
df.where(partitionFilters)
1743
}
44+
1845
}
1946

2047
// Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided

spark/src/main/scala/ai/chronon/spark/catalog/TableUtils.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,12 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
112112
}
113113
}
114114

115-
def loadTable(tableName: String, rangeWheres: Seq[String] = List.empty[String]): DataFrame = {
115+
def loadTable(tableName: String,
116+
rangeWheres: Seq[String] = List.empty[String],
117+
cacheDf: Boolean = false): DataFrame = {
116118
tableFormatProvider
117119
.readFormat(tableName)
118-
.map(_.table(tableName, andPredicates(rangeWheres))(sparkSession))
120+
.map(_.table(tableName, andPredicates(rangeWheres), cacheDf)(sparkSession))
119121
.getOrElse(
120122
throw new RuntimeException(s"Could not load table: ${tableName} with partition filter: ${rangeWheres}"))
121123
}
@@ -292,14 +294,18 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
292294
dfRearranged
293295
}
294296

297+
TableCache.remove(tableName)
298+
295299
logger.info(s"Writing to $tableName ...")
300+
296301
finalizedDf.write
297302
.mode(saveMode)
298303
// Requires table to exist before inserting.
299304
// Fails if schema does not match.
300305
// Does NOT overwrite the schema.
301306
// Handles dynamic partition overwrite.
302307
.insertInto(tableName)
308+
303309
logger.info(s"Finished writing to $tableName")
304310
}
305311

@@ -575,7 +581,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
575581
table: String,
576582
wheres: Seq[String],
577583
rangeWheres: Seq[String],
578-
fallbackSelects: Option[Map[String, String]] = None): DataFrame = {
584+
fallbackSelects: Option[Map[String, String]] = None,
585+
cacheDf: Boolean = false): DataFrame = {
579586

580587
val selects = QueryUtils.buildSelects(selectMap, fallbackSelects)
581588

@@ -589,7 +596,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
589596
| ${rangeWheres.mkString(",\n ").green}
590597
|""".stripMargin)
591598

592-
var df = loadTable(table, rangeWheres)
599+
var df = loadTable(table, rangeWheres, cacheDf)
593600

594601
if (selects.nonEmpty) df = df.selectExpr(selects: _*)
595602

spark/src/test/scala/ai/chronon/spark/test/ResultValidationAbilityTest.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.scalatest.flatspec.AnyFlatSpec
3636
class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter {
3737
val confPath = "joins/team/example_join.v1"
3838
val spark: SparkSession = SparkSessionBuilder.build("test", local = true)
39-
val mockTableUtils: TableUtils = mock(classOf[TableUtils])
39+
private val mockTableUtils: TableUtils = mock(classOf[TableUtils])
4040

4141
before {
4242
when(mockTableUtils.partitionColumn).thenReturn("ds")
@@ -46,7 +46,7 @@ class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter {
4646
class TestArgs(args: Array[String]) extends ScallopConf(args) with OfflineSubcommand with ResultValidationAbility {
4747
verify()
4848

49-
override def subcommandName: String = "test"
49+
override def subcommandName(): String = "test"
5050
override def buildSparkSession(): SparkSession = spark
5151
}
5252

@@ -69,7 +69,7 @@ class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter {
6969
val rdd = args.sparkSession.sparkContext.parallelize(leftData)
7070
val df = args.sparkSession.createDataFrame(rdd).toDF(columns: _*)
7171

72-
when(mockTableUtils.loadTable(any(), any())).thenReturn(df)
72+
when(mockTableUtils.loadTable(any(), any(), any())).thenReturn(df)
7373

7474
assertTrue(args.validateResult(df, Seq("keyId", "ds"), mockTableUtils))
7575
}
@@ -85,7 +85,7 @@ class ResultValidationAbilityTest extends AnyFlatSpec with BeforeAndAfter {
8585
val rightRdd = args.sparkSession.sparkContext.parallelize(rightData)
8686
val rightDf = args.sparkSession.createDataFrame(rightRdd).toDF(columns: _*)
8787

88-
when(mockTableUtils.loadTable(any(), any())).thenReturn(rightDf)
88+
when(mockTableUtils.loadTable(any(), any(), any())).thenReturn(rightDf)
8989

9090
assertFalse(args.validateResult(leftDf, Seq("keyId", "ds"), mockTableUtils))
9191
}

0 commit comments

Comments
 (0)