Skip to content

Commit ec32096

Browse files
varant-zlaidavid-zlaiezvz
authored
Logging fix -- make root cause more clear if label job has misaligned dates (#611)
## Summary Improve logging when labelDs - Window does not overlap with join data Add step day functionality to labelJoin ## Checklist - [ ] Added Unit Tests - [x] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced flexibility in date handling for label join operations, allowing for an optional start date override. - Introduction of new JSON configuration files for the `quickstart.purchases.v1` dataset, facilitating structured data processing and analysis. - **Bug Fixes** - Improved diagnostic messaging to clearly indicate when required data elements are missing during processing. - Enhanced error detection to ensure necessary computation windows are present before proceeding. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: david-zlai <[email protected]> Co-authored-by: ezvz <[email protected]>
1 parent 8a4c4a5 commit ec32096

File tree

3 files changed

+61
-24
lines changed

3 files changed

+61
-24
lines changed

spark/src/main/scala/ai/chronon/spark/Driver.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,18 @@ object Driver {
407407

408408
def run(args: Args): Unit = {
409409
val tableUtils = args.buildTableUtils()
410+
411+
// Use startPartitionOverride if provided, otherwise use endDate for both (single day)
412+
val startDate = args.startPartitionOverride.toOption.getOrElse(args.endDate())
413+
val endDate = args.endDate()
414+
415+
// Create a DateRange with start and end dates
416+
val dateRange = new api.DateRange(startDate, endDate)
417+
410418
val labelJoin = new LabelJoinV2(
411419
args.joinConf,
412420
tableUtils,
413-
args.endDate()
421+
dateRange
414422
)
415423
labelJoin.compute()
416424

spark/src/main/scala/ai/chronon/spark/batch/LabelJoinV2.scala

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.Seq
2121
case class LabelPartOutputInfo(labelPart: JoinPart, outputColumnNames: Seq[String])
2222
case class AllLabelOutputInfo(joinDsAsRange: PartitionRange, labelPartOutputInfos: Seq[LabelPartOutputInfo])
2323

24-
class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
24+
class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDateRange: api.DateRange) {
2525
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
2626
implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
2727
assert(Option(joinConf.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null")
@@ -33,7 +33,7 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
3333
.map(_.asScala.toMap)
3434
.getOrElse(Map.empty[String, String])
3535
private val labelColumnPrefix = "label_"
36-
private val labelDsAsRange = PartitionRange(labelDs, labelDs)
36+
private val labelDsAsPartitionRange = labelDateRange.toPartitionRange
3737

3838
private def getLabelColSchema(labelOutputs: Seq[AllLabelOutputInfo]): Seq[(String, DataType)] = {
3939
val labelPartToOutputCols = labelOutputs
@@ -93,13 +93,21 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
9393
.mapValues(_.map(_._2)) // Drop the duplicate window
9494
.map { case (window, labelPartOutputInfos) =>
9595
// The labelDs is a lookback from the labelSnapshot partition back to the join output table
96-
val joinPartitionDsAsRange = labelDsAsRange.shift(window * -1)
96+
val joinPartitionDsAsRange = labelDsAsPartitionRange.shift(window * -1)
9797
window -> AllLabelOutputInfo(joinPartitionDsAsRange, labelPartOutputInfos)
9898
}
9999
.toMap
100100
}
101101

102102
def compute(): DataFrame = {
103+
val resultDfsPerDay = labelDsAsPartitionRange.steps(days = 1).map { dayStep =>
104+
computeDay(dayStep.start)
105+
}
106+
107+
resultDfsPerDay.tail.foldLeft(resultDfsPerDay.head)((acc, df) => acc.union(df))
108+
}
109+
110+
def computeDay(labelDs: String): DataFrame = {
103111
logger.info(s"Running LabelJoinV2 for $labelDs")
104112

105113
runAssertions()
@@ -119,26 +127,39 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
119127
}
120128

121129
if (missingWindowToOutputs.nonEmpty) {
122-
logger.info(
123-
s"""Missing following partitions from $joinTable: ${missingWindowToOutputs.values
124-
.map(_.joinDsAsRange.start)
125-
.mkString(", ")}
130+
131+
// Always log this no matter what.
132+
val baseLogString = s"""Missing following partitions from $joinTable: ${missingWindowToOutputs.values
133+
.map(_.joinDsAsRange.start)
134+
.mkString(", ")}
126135
|
127-
|Found existing partitions: ${existingJoinPartitions.mkString(", ")}
136+
|Found existing partitions of join output: ${existingJoinPartitions.mkString(", ")}
128137
|
129-
|Therefore unable to compute the labels for ${missingWindowToOutputs.keys.mkString(", ")}
138+
|Required dates are computed based on label date (the run date) - window for distinct windows that are used in label parts.
130139
|
131-
|For requested ds: $labelDs
140+
|In this case, the run date is: $labelDs, and given the existing partitions we are unable to compute the labels for the following windows: ${missingWindowToOutputs.keys
141+
.mkString(", ")} (days).
132142
|
133-
|Proceeding with valid windows: ${computableWindowToOutputs.keys.mkString(", ")}
143+
|""".stripMargin
144+
145+
// If there are no dates to run, also throw that error
146+
require(
147+
computableWindowToOutputs.nonEmpty,
148+
s"""$baseLogString
149+
|
150+
|There are no partitions that we can run the label join for. At least one window must be computable.
134151
|
152+
|Exiting.
135153
|""".stripMargin
136154
)
137155

138-
require(
139-
computableWindowToOutputs.isEmpty,
140-
"No valid windows to compute labels for given the existing join output range." +
141-
s"Consider backfilling the join output table for the following days: ${missingWindowToOutputs.values.map(_.joinDsAsRange.start)}."
156+
// Else log what we are running, but warn about missing windows
157+
logger.warn(
158+
s"""$baseLogString
159+
|
160+
|Proceeding with valid windows: ${computableWindowToOutputs.keys.mkString(", ")}
161+
|
162+
|""".stripMargin
142163
)
143164
}
144165

@@ -150,7 +171,11 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
150171
// Each unique window is an output partition in the joined table
151172
// Each window may contain a subset of the joinParts and their columns
152173
computableWindowToOutputs.foreach { case (windowLength, joinOutputInfo) =>
153-
computeOutputForWindow(windowLength, joinOutputInfo, existingLabelTableOutputPartitions, windowToLabelOutputInfos)
174+
computeOutputForWindow(windowLength,
175+
joinOutputInfo,
176+
existingLabelTableOutputPartitions,
177+
windowToLabelOutputInfos,
178+
labelDsAsPartitionRange)
154179
}
155180

156181
val allOutputDfs = computableWindowToOutputs.values
@@ -171,16 +196,17 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
171196
private def computeOutputForWindow(windowLength: Int,
172197
joinOutputInfo: AllLabelOutputInfo,
173198
existingLabelTableOutputPartitions: Seq[String],
174-
windowToLabelOutputInfos: Map[Int, AllLabelOutputInfo]): Unit = {
199+
windowToLabelOutputInfos: Map[Int, AllLabelOutputInfo],
200+
labelDsAsPartitionRange: PartitionRange): Unit = {
175201
logger.info(
176-
s"Computing labels for window: $windowLength days on labelDs: $labelDs \n" +
202+
s"Computing labels for window: $windowLength days on labelDs: ${labelDsAsPartitionRange.start} \n" +
177203
s"Includes the following joinParts and output cols: ${joinOutputInfo.labelPartOutputInfos
178204
.map(x => s"${x.labelPart.groupBy.metaData.name} -> ${x.outputColumnNames.mkString(", ")}")
179205
.mkString("\n")}")
180206

181207
val startMillis = System.currentTimeMillis()
182208
// This is the join output ds that we're working with
183-
val joinDsAsRange = labelDsAsRange.shift(windowLength * -1)
209+
val joinDsAsRange = labelDsAsPartitionRange.shift(windowLength * -1)
184210

185211
val joinBaseDf = if (existingLabelTableOutputPartitions.contains(joinDsAsRange.start)) {
186212
// If the existing join table has the partition, then we should use it, because another label column
@@ -208,7 +234,7 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
208234

209235
val snapshotQuery = Builders.Query(selects = selectCols)
210236
val snapshotTable = labelJoinPart.groupBy.metaData.outputTable
211-
val snapshotDf = tableUtils.scanDf(snapshotQuery, snapshotTable, range = Some(labelDsAsRange))
237+
val snapshotDf = tableUtils.scanDf(snapshotQuery, snapshotTable, range = Some(labelDsAsPartitionRange))
212238

213239
(labelJoinPart, snapshotDf)
214240
}

spark/src/test/scala/ai/chronon/spark/test/batch/LabelJoinV2Test.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class LabelJoinV2Test extends AnyFlatSpec {
102102
tableUtils.sql(s"SELECT * FROM $labelGbOutputTable").show()
103103

104104
// Now compute the label join for thirty three days ago (label ds)
105-
val labelJoin = new LabelJoinV2(joinConf, tableUtils, thirtyThreeDaysAgo)
105+
val labelDateRange = new api.DateRange(thirtyThreeDaysAgo, thirtyThreeDaysAgo)
106+
val labelJoin = new LabelJoinV2(joinConf, tableUtils, labelDateRange)
106107
val labelComputed = labelJoin.compute()
107108
println("Label computed::")
108109
labelComputed.show()
@@ -219,7 +220,8 @@ class LabelJoinV2Test extends AnyFlatSpec {
219220
tableUtils.sql(s"SELECT * FROM $labelGbOutputTable2").show()
220221

221222
// Now compute the label join for thirty three days ago (label ds)
222-
val labelJoin = new LabelJoinV2(joinConf, tableUtils, thirtyThreeDaysAgo)
223+
val labelDateRange = new api.DateRange(thirtyThreeDaysAgo, thirtyThreeDaysAgo)
224+
val labelJoin = new LabelJoinV2(joinConf, tableUtils, labelDateRange)
223225
val labelComputed = labelJoin.compute()
224226
println("Label computed::")
225227
labelComputed.show()
@@ -296,7 +298,8 @@ class LabelJoinV2Test extends AnyFlatSpec {
296298
// Should get appended (i.e. the 10d column goes from all null to having values without losing the 7d values)
297299

298300
// compute the label join for thirty days ago (label ds)
299-
val labelJoin2 = new LabelJoinV2(joinConf, tableUtils, monthAgo)
301+
val labelDateRange2 = new api.DateRange(monthAgo, monthAgo)
302+
val labelJoin2 = new LabelJoinV2(joinConf, tableUtils, labelDateRange2)
300303
val labelComputed2 = labelJoin2.compute()
301304
println("Label computed (second run)::")
302305
labelComputed2.show()

0 commit comments

Comments
 (0)