Skip to content

Commit a566aad

Browse files
varant-zlaiezvz
andauthored
Simple LabelJoin flow (#546)
## Summary Implements the simple label join logic to create a materialized table by joining forward looking partitions from the snapshot table back of labelJoinParts back to join output. ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced an updated mechanism for formatting and standardizing label outputs with the addition of `outputLabelTableV2`. - Added a new distributed label join operation with the `LabelJoinV2` class, featuring robust validations, comprehensive error handling, and detailed logging for improved data integration. - Implemented a comprehensive test suite for the `LabelJoinV2` functionality to ensure accuracy and reliability of label joins. - **Updates** - Replaced the existing `LabelJoin` class with the new `LabelJoinV2` class, enhancing the label join process. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: ezvz <[email protected]>
1 parent 6892eff commit a566aad

File tree

4 files changed

+642
-2
lines changed

4 files changed

+642
-2
lines changed

api/src/main/scala/ai/chronon/api/Extensions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ object Extensions {
122122
def outputLabelTable: String = s"${metaData.outputNamespace}.${metaData.cleanName}_labels"
123123
def outputFinalView: String = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled"
124124
def outputLatestLabelView: String = s"${metaData.outputNamespace}.${metaData.cleanName}_labeled_latest"
125+
def outputLabelTableV2: String =
126+
s"${metaData.outputNamespace}.${metaData.cleanName}_with_labels" // Used for the LabelJoinV2 flow
125127
def loggedTable: String = s"${outputTable}_logged"
126128
def summaryTable: String = s"${outputTable}_summary"
127129
def packedSummaryTable: String = s"${outputTable}_summary_packed"

spark/src/main/scala/ai/chronon/spark/Driver.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,12 +417,12 @@ object Driver {
417417

418418
def run(args: Args): Unit = {
419419
val tableUtils = args.buildTableUtils()
420-
val labelJoin = new LabelJoin(
420+
val labelJoin = new LabelJoinV2(
421421
args.joinConf,
422422
tableUtils,
423423
args.endDate()
424424
)
425-
labelJoin.computeLabelJoin(args.stepDays.toOption)
425+
labelJoin.compute()
426426

427427
if (args.shouldExport()) {
428428
args.exportTableToLocal(args.joinConf.metaData.outputLabelTable, tableUtils)
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
package ai.chronon.spark
2+
3+
import ai.chronon.api
4+
import ai.chronon.api.DataModel.Events
5+
import ai.chronon.api.Extensions._
6+
import ai.chronon.api._
7+
import ai.chronon.online.Metrics
8+
import ai.chronon.spark.Extensions._
9+
import org.apache.spark.sql.DataFrame
10+
import org.apache.spark.sql.functions.lit
11+
import org.apache.spark.sql.types.DataType
12+
import org.slf4j.{Logger, LoggerFactory}
13+
14+
import scala.collection.JavaConverters._
15+
import scala.collection.Seq
16+
17+
// let's say we are running the label join on `ds`, we want to modify partitions of the join output table
18+
// that are `ds - windowLength` days old. window sizes could repeat across different label join parts
19+
//
20+
// so we create a struct to map which partitions of join output table to modify for each window size (AllLabelOutputInfo)
21+
// and for each label join part which columns have that particular window size. (LabelPartOutputInfo)
22+
case class LabelPartOutputInfo(labelPart: JoinPart, outputColumnNames: Seq[String])
23+
case class AllLabelOutputInfo(joinDsAsRange: PartitionRange, labelPartOutputInfos: Seq[LabelPartOutputInfo])
24+
25+
class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
26+
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
27+
implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
28+
assert(Option(joinConf.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null")
29+
30+
val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.LabelJoin, joinConf)
31+
private val outputLabelTable = joinConf.metaData.outputLabelTableV2
32+
private val labelJoinConf = joinConf.labelParts
33+
private val confTableProps = Option(joinConf.metaData.tableProperties)
34+
.map(_.asScala.toMap)
35+
.getOrElse(Map.empty[String, String])
36+
private val labelColumnPrefix = "label_"
37+
private val labelDsAsRange = PartitionRange(labelDs, labelDs)
38+
39+
private def getLabelColSchema(labelOutputs: Seq[AllLabelOutputInfo]): Seq[(String, DataType)] = {
40+
val labelPartToOutputCols = labelOutputs
41+
.flatMap(_.labelPartOutputInfos)
42+
.groupBy(_.labelPart)
43+
.mapValues(_.flatMap(_.outputColumnNames))
44+
45+
labelPartToOutputCols.flatMap { case (labelPart, outputCols) =>
46+
val labelPartSchema = tableUtils.scanDf(null, labelPart.groupBy.metaData.outputTable).schema
47+
outputCols.map(col => (col, labelPartSchema(col).dataType))
48+
}.toSeq
49+
}
50+
51+
private def runAssertions(): Unit = {
52+
assert(joinConf.left.dataModel == Events,
53+
s"join.left.dataMode needs to be Events for label join ${joinConf.metaData.name}")
54+
55+
assert(Option(joinConf.metaData.team).nonEmpty,
56+
s"join.metaData.team needs to be set for join ${joinConf.metaData.name}")
57+
58+
labelJoinConf.labels.asScala.foreach { jp =>
59+
assert(jp.groupBy.dataModel == Events,
60+
s"groupBy.dataModel must be Events for label join with aggregations ${jp.groupBy.metaData.name}")
61+
62+
assert(Option(jp.groupBy.aggregations).isDefined,
63+
s"aggregations must be defined for label join ${jp.groupBy.metaData.name}")
64+
65+
val windows = jp.groupBy.aggregations.asScala.flatMap(_.windows.asScala).filter(_.timeUnit == TimeUnit.DAYS)
66+
67+
assert(windows.nonEmpty,
68+
s"at least one aggregation with a daily window must be defined for label join ${jp.groupBy.metaData.name}")
69+
}
70+
}
71+
72+
private def getWindowToLabelOutputInfos: Map[Int, AllLabelOutputInfo] = {
73+
// Create a map of window to LabelOutputInfo
74+
// Each window could be shared across multiple labelJoinParts
75+
val labelJoinParts = labelJoinConf.labels.asScala
76+
77+
labelJoinParts
78+
.flatMap { labelJoinPart =>
79+
labelJoinPart.groupBy.aggregations.asScala
80+
.flatMap { agg =>
81+
agg.windows.asScala.filter(_.timeUnit == TimeUnit.DAYS).map { w =>
82+
// TODO -- support buckets
83+
assert(Option(agg.buckets).isEmpty, "Buckets as labels are not yet supported in LabelJoinV2")
84+
val aggPart = Builders.AggregationPart(agg.operation, agg.inputColumn, w)
85+
(w.length, aggPart.outputColumnName)
86+
}
87+
}
88+
.groupBy(_._1)
89+
.map { case (window, windowAndOutputCols) =>
90+
(window, LabelPartOutputInfo(labelJoinPart, windowAndOutputCols.map(_._2)))
91+
}
92+
}
93+
.groupBy(_._1) // Flatten map and combine into one map with window as key
94+
.mapValues(_.map(_._2)) // Drop the duplicate window
95+
.map { case (window, labelPartOutputInfos) =>
96+
// The labelDs is a lookback from the labelSnapshot partition back to the join output table
97+
val joinPartitionDsAsRange = labelDsAsRange.shift(window * -1)
98+
window -> AllLabelOutputInfo(joinPartitionDsAsRange, labelPartOutputInfos)
99+
}
100+
.toMap
101+
}
102+
103+
def compute(): DataFrame = {
104+
logger.info(s"Running LabelJoinV2 for $labelDs")
105+
106+
runAssertions()
107+
108+
// First get a map of window to LabelOutputInfo
109+
val windowToLabelOutputInfos = getWindowToLabelOutputInfos
110+
111+
// Find existing partition in the join table
112+
val joinTable = joinConf.metaData.outputTable
113+
val existingJoinPartitions = tableUtils.partitions(joinTable)
114+
115+
// Split the windows into two groups, one that has a corresponding partition in the join table and one that doesn't
116+
// If a partition is missing, we can't compute the labels for that window, but the job will proceed with the rest
117+
val (computableWindowToOutputs, missingWindowToOutputs) = windowToLabelOutputInfos.partition {
118+
case (_, labelOutputInfo) =>
119+
existingJoinPartitions.contains(labelOutputInfo.joinDsAsRange.start)
120+
}
121+
122+
if (missingWindowToOutputs.nonEmpty) {
123+
logger.info(
124+
s"""Missing following partitions from $joinTable: ${missingWindowToOutputs.values
125+
.map(_.joinDsAsRange.start)
126+
.mkString(", ")}
127+
|
128+
|Found existing partitions: ${existingJoinPartitions.mkString(", ")}
129+
|
130+
|Therefore unable to compute the labels for ${missingWindowToOutputs.keys.mkString(", ")}
131+
|
132+
|For requested ds: $labelDs
133+
|
134+
|Proceeding with valid windows: ${computableWindowToOutputs.keys.mkString(", ")}
135+
|
136+
|""".stripMargin
137+
)
138+
139+
require(
140+
computableWindowToOutputs.isEmpty,
141+
"No valid windows to compute labels for given the existing join output range." +
142+
s"Consider backfilling the join output table for the following days: ${missingWindowToOutputs.values.map(_.joinDsAsRange.start)}."
143+
)
144+
}
145+
146+
// Find existing partition in the outputLabelTable (different from the join output table used above)
147+
// This is used below in computing baseJoinDf
148+
val existingLabelTableOutputPartitions = tableUtils.partitions(outputLabelTable)
149+
logger.info(s"Found existing partitions in Label Table: ${existingLabelTableOutputPartitions.mkString(", ")}")
150+
151+
// Each unique window is an output partition in the joined table
152+
// Each window may contain a subset of the joinParts and their columns
153+
computableWindowToOutputs.foreach { case (windowLength, joinOutputInfo) =>
154+
computeOutputForWindow(windowLength, joinOutputInfo, existingLabelTableOutputPartitions, windowToLabelOutputInfos)
155+
}
156+
157+
val allOutputDfs = computableWindowToOutputs.values
158+
.map(_.joinDsAsRange)
159+
.map { range =>
160+
tableUtils.scanDf(null, outputLabelTable, range = Some(range))
161+
}
162+
.toSeq
163+
164+
if (allOutputDfs.length == 1) {
165+
allOutputDfs.head
166+
} else {
167+
allOutputDfs.reduce(_ union _)
168+
}
169+
}
170+
171+
// Writes out a single partition of the label table with all labels for the corresponding window
172+
private def computeOutputForWindow(windowLength: Int,
173+
joinOutputInfo: AllLabelOutputInfo,
174+
existingLabelTableOutputPartitions: Seq[String],
175+
windowToLabelOutputInfos: Map[Int, AllLabelOutputInfo]): Unit = {
176+
logger.info(
177+
s"Computing labels for window: $windowLength days on labelDs: $labelDs \n" +
178+
s"Includes the following joinParts and output cols: ${joinOutputInfo.labelPartOutputInfos
179+
.map(x => s"${x.labelPart.groupBy.metaData.name} -> ${x.outputColumnNames.mkString(", ")}")
180+
.mkString("\n")}")
181+
182+
val startMillis = System.currentTimeMillis()
183+
// This is the join output ds that we're working with
184+
val joinDsAsRange = labelDsAsRange.shift(windowLength * -1)
185+
186+
val joinBaseDf = if (existingLabelTableOutputPartitions.contains(joinDsAsRange.start)) {
187+
// If the existing join table has the partition, then we should use it, because another label column
188+
// may have landed for this date, otherwise we can use the base join output and
189+
logger.info(s"Found existing partition in Label Table: ${joinDsAsRange.start}")
190+
tableUtils.scanDf(null, outputLabelTable, range = Some(joinDsAsRange))
191+
} else {
192+
// Otherwise we need to use the join output, but pad the schema to include other label columns that might
193+
// be on the schema
194+
logger.info(s"Did not find existing partition in Label Table, querying from Join Output: ${joinDsAsRange.start}")
195+
val joinOutputDf = tableUtils.scanDf(null, joinConf.metaData.outputTable, range = Some(joinDsAsRange))
196+
val allLabelCols = getLabelColSchema(windowToLabelOutputInfos.values.toSeq)
197+
allLabelCols.foldLeft(joinOutputDf) { case (currentDf, (colName, dataType)) =>
198+
val prefixedColName = s"${labelColumnPrefix}_$colName"
199+
currentDf.withColumn(prefixedColName, lit(null).cast(dataType))
200+
}
201+
}
202+
203+
val joinPartsAndDfs = joinOutputInfo.labelPartOutputInfos.map { labelOutputInfo =>
204+
val labelJoinPart = labelOutputInfo.labelPart
205+
206+
val outputColumnNames = labelOutputInfo.outputColumnNames
207+
val selectCols: Map[String, String] =
208+
(labelJoinPart.rightToLeft.keys ++ outputColumnNames).map(x => x -> x).toMap
209+
210+
val snapshotQuery = Builders.Query(selects = selectCols)
211+
val snapshotTable = labelJoinPart.groupBy.metaData.outputTable
212+
val snapshotDf = tableUtils.scanDf(snapshotQuery, snapshotTable, range = Some(labelDsAsRange))
213+
214+
(labelJoinPart, snapshotDf)
215+
}
216+
217+
val joined = joinPartsAndDfs.foldLeft(joinBaseDf) { case (left, (joinPart, rightDf)) =>
218+
joinWithLeft(left, rightDf, joinPart)
219+
}
220+
221+
val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
222+
223+
metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins)
224+
225+
joined.save(outputLabelTable, confTableProps, Seq(tableUtils.partitionColumn), autoExpand = true)
226+
227+
logger.info(s"Wrote to table $outputLabelTable, into partitions: ${joinDsAsRange.start} in $elapsedMins mins")
228+
}
229+
230+
def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart): DataFrame = {
231+
val partLeftKeys = joinPart.rightToLeft.values.toArray
232+
233+
// apply key-renaming to key columns
234+
val keyRenamedRight = joinPart.rightToLeft.foldLeft(rightDf) { case (updatedRight, (rightKey, leftKey)) =>
235+
updatedRight.withColumnRenamed(rightKey, leftKey)
236+
}
237+
238+
val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn,
239+
tableUtils.partitionColumn,
240+
Constants.TimePartitionColumn,
241+
Constants.LabelPartitionColumn)
242+
val valueColumns = rightDf.schema.names.filterNot(nonValueColumns.contains)
243+
244+
// In this case, since we're joining with the full-schema dataframe,
245+
// we need to drop the columns that we're attempting to overwrite
246+
val cleanLeftDf = valueColumns.foldLeft(leftDf)((df, colName) => df.drop(s"${labelColumnPrefix}_$colName"))
247+
248+
val prefixedRight = keyRenamedRight.prefixColumnNames(labelColumnPrefix, valueColumns)
249+
250+
val partName = joinPart.groupBy.metaData.name
251+
252+
logger.info(s"""Join keys for $partName: ${partLeftKeys.mkString(", ")}
253+
|Left Schema:
254+
|${leftDf.schema.pretty}
255+
|
256+
|Right Schema:
257+
|${prefixedRight.schema.pretty}
258+
|
259+
|""".stripMargin)
260+
261+
cleanLeftDf.validateJoinKeys(prefixedRight, partLeftKeys)
262+
cleanLeftDf.join(prefixedRight, partLeftKeys, "left_outer")
263+
}
264+
}

0 commit comments

Comments
 (0)