|
| 1 | +package ai.chronon.spark |
| 2 | + |
| 3 | +import ai.chronon.api |
| 4 | +import ai.chronon.api.DataModel.Events |
| 5 | +import ai.chronon.api.Extensions._ |
| 6 | +import ai.chronon.api._ |
| 7 | +import ai.chronon.online.Metrics |
| 8 | +import ai.chronon.spark.Extensions._ |
| 9 | +import org.apache.spark.sql.DataFrame |
| 10 | +import org.apache.spark.sql.functions.lit |
| 11 | +import org.apache.spark.sql.types.DataType |
| 12 | +import org.slf4j.{Logger, LoggerFactory} |
| 13 | + |
| 14 | +import scala.collection.JavaConverters._ |
| 15 | +import scala.collection.Seq |
| 16 | + |
| 17 | +// let's say we are running the label join on `ds`, we want to modify partitions of the join output table |
| 18 | +// that are `ds - windowLength` days old. window sizes could repeat across different label join parts |
| 19 | +// |
| 20 | +// so we create a struct to map which partitions of join output table to modify for each window size (AllLabelOutputInfo) |
| 21 | +// and for each label join part which columns have that particular window size. (LabelPartOutputInfo) |
| 22 | +case class LabelPartOutputInfo(labelPart: JoinPart, outputColumnNames: Seq[String]) |
| 23 | +case class AllLabelOutputInfo(joinDsAsRange: PartitionRange, labelPartOutputInfos: Seq[LabelPartOutputInfo]) |
| 24 | + |
| 25 | +class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) { |
| 26 | + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) |
| 27 | + implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec |
| 28 | + assert(Option(joinConf.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null") |
| 29 | + |
| 30 | + val metrics: Metrics.Context = Metrics.Context(Metrics.Environment.LabelJoin, joinConf) |
| 31 | + private val outputLabelTable = joinConf.metaData.outputLabelTableV2 |
| 32 | + private val labelJoinConf = joinConf.labelParts |
| 33 | + private val confTableProps = Option(joinConf.metaData.tableProperties) |
| 34 | + .map(_.asScala.toMap) |
| 35 | + .getOrElse(Map.empty[String, String]) |
| 36 | + private val labelColumnPrefix = "label_" |
| 37 | + private val labelDsAsRange = PartitionRange(labelDs, labelDs) |
| 38 | + |
| 39 | + private def getLabelColSchema(labelOutputs: Seq[AllLabelOutputInfo]): Seq[(String, DataType)] = { |
| 40 | + val labelPartToOutputCols = labelOutputs |
| 41 | + .flatMap(_.labelPartOutputInfos) |
| 42 | + .groupBy(_.labelPart) |
| 43 | + .mapValues(_.flatMap(_.outputColumnNames)) |
| 44 | + |
| 45 | + labelPartToOutputCols.flatMap { case (labelPart, outputCols) => |
| 46 | + val labelPartSchema = tableUtils.scanDf(null, labelPart.groupBy.metaData.outputTable).schema |
| 47 | + outputCols.map(col => (col, labelPartSchema(col).dataType)) |
| 48 | + }.toSeq |
| 49 | + } |
| 50 | + |
| 51 | + private def runAssertions(): Unit = { |
| 52 | + assert(joinConf.left.dataModel == Events, |
| 53 | + s"join.left.dataMode needs to be Events for label join ${joinConf.metaData.name}") |
| 54 | + |
| 55 | + assert(Option(joinConf.metaData.team).nonEmpty, |
| 56 | + s"join.metaData.team needs to be set for join ${joinConf.metaData.name}") |
| 57 | + |
| 58 | + labelJoinConf.labels.asScala.foreach { jp => |
| 59 | + assert(jp.groupBy.dataModel == Events, |
| 60 | + s"groupBy.dataModel must be Events for label join with aggregations ${jp.groupBy.metaData.name}") |
| 61 | + |
| 62 | + assert(Option(jp.groupBy.aggregations).isDefined, |
| 63 | + s"aggregations must be defined for label join ${jp.groupBy.metaData.name}") |
| 64 | + |
| 65 | + val windows = jp.groupBy.aggregations.asScala.flatMap(_.windows.asScala).filter(_.timeUnit == TimeUnit.DAYS) |
| 66 | + |
| 67 | + assert(windows.nonEmpty, |
| 68 | + s"at least one aggregation with a daily window must be defined for label join ${jp.groupBy.metaData.name}") |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + private def getWindowToLabelOutputInfos: Map[Int, AllLabelOutputInfo] = { |
| 73 | + // Create a map of window to LabelOutputInfo |
| 74 | + // Each window could be shared across multiple labelJoinParts |
| 75 | + val labelJoinParts = labelJoinConf.labels.asScala |
| 76 | + |
| 77 | + labelJoinParts |
| 78 | + .flatMap { labelJoinPart => |
| 79 | + labelJoinPart.groupBy.aggregations.asScala |
| 80 | + .flatMap { agg => |
| 81 | + agg.windows.asScala.filter(_.timeUnit == TimeUnit.DAYS).map { w => |
| 82 | + // TODO -- support buckets |
| 83 | + assert(Option(agg.buckets).isEmpty, "Buckets as labels are not yet supported in LabelJoinV2") |
| 84 | + val aggPart = Builders.AggregationPart(agg.operation, agg.inputColumn, w) |
| 85 | + (w.length, aggPart.outputColumnName) |
| 86 | + } |
| 87 | + } |
| 88 | + .groupBy(_._1) |
| 89 | + .map { case (window, windowAndOutputCols) => |
| 90 | + (window, LabelPartOutputInfo(labelJoinPart, windowAndOutputCols.map(_._2))) |
| 91 | + } |
| 92 | + } |
| 93 | + .groupBy(_._1) // Flatten map and combine into one map with window as key |
| 94 | + .mapValues(_.map(_._2)) // Drop the duplicate window |
| 95 | + .map { case (window, labelPartOutputInfos) => |
| 96 | + // The labelDs is a lookback from the labelSnapshot partition back to the join output table |
| 97 | + val joinPartitionDsAsRange = labelDsAsRange.shift(window * -1) |
| 98 | + window -> AllLabelOutputInfo(joinPartitionDsAsRange, labelPartOutputInfos) |
| 99 | + } |
| 100 | + .toMap |
| 101 | + } |
| 102 | + |
| 103 | + def compute(): DataFrame = { |
| 104 | + logger.info(s"Running LabelJoinV2 for $labelDs") |
| 105 | + |
| 106 | + runAssertions() |
| 107 | + |
| 108 | + // First get a map of window to LabelOutputInfo |
| 109 | + val windowToLabelOutputInfos = getWindowToLabelOutputInfos |
| 110 | + |
| 111 | + // Find existing partition in the join table |
| 112 | + val joinTable = joinConf.metaData.outputTable |
| 113 | + val existingJoinPartitions = tableUtils.partitions(joinTable) |
| 114 | + |
| 115 | + // Split the windows into two groups, one that has a corresponding partition in the join table and one that doesn't |
| 116 | + // If a partition is missing, we can't compute the labels for that window, but the job will proceed with the rest |
| 117 | + val (computableWindowToOutputs, missingWindowToOutputs) = windowToLabelOutputInfos.partition { |
| 118 | + case (_, labelOutputInfo) => |
| 119 | + existingJoinPartitions.contains(labelOutputInfo.joinDsAsRange.start) |
| 120 | + } |
| 121 | + |
| 122 | + if (missingWindowToOutputs.nonEmpty) { |
| 123 | + logger.info( |
| 124 | + s"""Missing following partitions from $joinTable: ${missingWindowToOutputs.values |
| 125 | + .map(_.joinDsAsRange.start) |
| 126 | + .mkString(", ")} |
| 127 | + | |
| 128 | + |Found existing partitions: ${existingJoinPartitions.mkString(", ")} |
| 129 | + | |
| 130 | + |Therefore unable to compute the labels for ${missingWindowToOutputs.keys.mkString(", ")} |
| 131 | + | |
| 132 | + |For requested ds: $labelDs |
| 133 | + | |
| 134 | + |Proceeding with valid windows: ${computableWindowToOutputs.keys.mkString(", ")} |
| 135 | + | |
| 136 | + |""".stripMargin |
| 137 | + ) |
| 138 | + |
| 139 | + require( |
| 140 | + computableWindowToOutputs.isEmpty, |
| 141 | + "No valid windows to compute labels for given the existing join output range." + |
| 142 | + s"Consider backfilling the join output table for the following days: ${missingWindowToOutputs.values.map(_.joinDsAsRange.start)}." |
| 143 | + ) |
| 144 | + } |
| 145 | + |
| 146 | + // Find existing partition in the outputLabelTable (different from the join output table used above) |
| 147 | + // This is used below in computing baseJoinDf |
| 148 | + val existingLabelTableOutputPartitions = tableUtils.partitions(outputLabelTable) |
| 149 | + logger.info(s"Found existing partitions in Label Table: ${existingLabelTableOutputPartitions.mkString(", ")}") |
| 150 | + |
| 151 | + // Each unique window is an output partition in the joined table |
| 152 | + // Each window may contain a subset of the joinParts and their columns |
| 153 | + computableWindowToOutputs.foreach { case (windowLength, joinOutputInfo) => |
| 154 | + computeOutputForWindow(windowLength, joinOutputInfo, existingLabelTableOutputPartitions, windowToLabelOutputInfos) |
| 155 | + } |
| 156 | + |
| 157 | + val allOutputDfs = computableWindowToOutputs.values |
| 158 | + .map(_.joinDsAsRange) |
| 159 | + .map { range => |
| 160 | + tableUtils.scanDf(null, outputLabelTable, range = Some(range)) |
| 161 | + } |
| 162 | + .toSeq |
| 163 | + |
| 164 | + if (allOutputDfs.length == 1) { |
| 165 | + allOutputDfs.head |
| 166 | + } else { |
| 167 | + allOutputDfs.reduce(_ union _) |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + // Writes out a single partition of the label table with all labels for the corresponding window |
| 172 | + private def computeOutputForWindow(windowLength: Int, |
| 173 | + joinOutputInfo: AllLabelOutputInfo, |
| 174 | + existingLabelTableOutputPartitions: Seq[String], |
| 175 | + windowToLabelOutputInfos: Map[Int, AllLabelOutputInfo]): Unit = { |
| 176 | + logger.info( |
| 177 | + s"Computing labels for window: $windowLength days on labelDs: $labelDs \n" + |
| 178 | + s"Includes the following joinParts and output cols: ${joinOutputInfo.labelPartOutputInfos |
| 179 | + .map(x => s"${x.labelPart.groupBy.metaData.name} -> ${x.outputColumnNames.mkString(", ")}") |
| 180 | + .mkString("\n")}") |
| 181 | + |
| 182 | + val startMillis = System.currentTimeMillis() |
| 183 | + // This is the join output ds that we're working with |
| 184 | + val joinDsAsRange = labelDsAsRange.shift(windowLength * -1) |
| 185 | + |
| 186 | + val joinBaseDf = if (existingLabelTableOutputPartitions.contains(joinDsAsRange.start)) { |
| 187 | + // If the existing join table has the partition, then we should use it, because another label column |
| 188 | + // may have landed for this date, otherwise we can use the base join output and |
| 189 | + logger.info(s"Found existing partition in Label Table: ${joinDsAsRange.start}") |
| 190 | + tableUtils.scanDf(null, outputLabelTable, range = Some(joinDsAsRange)) |
| 191 | + } else { |
| 192 | + // Otherwise we need to use the join output, but pad the schema to include other label columns that might |
| 193 | + // be on the schema |
| 194 | + logger.info(s"Did not find existing partition in Label Table, querying from Join Output: ${joinDsAsRange.start}") |
| 195 | + val joinOutputDf = tableUtils.scanDf(null, joinConf.metaData.outputTable, range = Some(joinDsAsRange)) |
| 196 | + val allLabelCols = getLabelColSchema(windowToLabelOutputInfos.values.toSeq) |
| 197 | + allLabelCols.foldLeft(joinOutputDf) { case (currentDf, (colName, dataType)) => |
| 198 | + val prefixedColName = s"${labelColumnPrefix}_$colName" |
| 199 | + currentDf.withColumn(prefixedColName, lit(null).cast(dataType)) |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + val joinPartsAndDfs = joinOutputInfo.labelPartOutputInfos.map { labelOutputInfo => |
| 204 | + val labelJoinPart = labelOutputInfo.labelPart |
| 205 | + |
| 206 | + val outputColumnNames = labelOutputInfo.outputColumnNames |
| 207 | + val selectCols: Map[String, String] = |
| 208 | + (labelJoinPart.rightToLeft.keys ++ outputColumnNames).map(x => x -> x).toMap |
| 209 | + |
| 210 | + val snapshotQuery = Builders.Query(selects = selectCols) |
| 211 | + val snapshotTable = labelJoinPart.groupBy.metaData.outputTable |
| 212 | + val snapshotDf = tableUtils.scanDf(snapshotQuery, snapshotTable, range = Some(labelDsAsRange)) |
| 213 | + |
| 214 | + (labelJoinPart, snapshotDf) |
| 215 | + } |
| 216 | + |
| 217 | + val joined = joinPartsAndDfs.foldLeft(joinBaseDf) { case (left, (joinPart, rightDf)) => |
| 218 | + joinWithLeft(left, rightDf, joinPart) |
| 219 | + } |
| 220 | + |
| 221 | + val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000) |
| 222 | + |
| 223 | + metrics.gauge(Metrics.Name.LatencyMinutes, elapsedMins) |
| 224 | + |
| 225 | + joined.save(outputLabelTable, confTableProps, Seq(tableUtils.partitionColumn), autoExpand = true) |
| 226 | + |
| 227 | + logger.info(s"Wrote to table $outputLabelTable, into partitions: ${joinDsAsRange.start} in $elapsedMins mins") |
| 228 | + } |
| 229 | + |
| 230 | + def joinWithLeft(leftDf: DataFrame, rightDf: DataFrame, joinPart: JoinPart): DataFrame = { |
| 231 | + val partLeftKeys = joinPart.rightToLeft.values.toArray |
| 232 | + |
| 233 | + // apply key-renaming to key columns |
| 234 | + val keyRenamedRight = joinPart.rightToLeft.foldLeft(rightDf) { case (updatedRight, (rightKey, leftKey)) => |
| 235 | + updatedRight.withColumnRenamed(rightKey, leftKey) |
| 236 | + } |
| 237 | + |
| 238 | + val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn, |
| 239 | + tableUtils.partitionColumn, |
| 240 | + Constants.TimePartitionColumn, |
| 241 | + Constants.LabelPartitionColumn) |
| 242 | + val valueColumns = rightDf.schema.names.filterNot(nonValueColumns.contains) |
| 243 | + |
| 244 | + // In this case, since we're joining with the full-schema dataframe, |
| 245 | + // we need to drop the columns that we're attempting to overwrite |
| 246 | + val cleanLeftDf = valueColumns.foldLeft(leftDf)((df, colName) => df.drop(s"${labelColumnPrefix}_$colName")) |
| 247 | + |
| 248 | + val prefixedRight = keyRenamedRight.prefixColumnNames(labelColumnPrefix, valueColumns) |
| 249 | + |
| 250 | + val partName = joinPart.groupBy.metaData.name |
| 251 | + |
| 252 | + logger.info(s"""Join keys for $partName: ${partLeftKeys.mkString(", ")} |
| 253 | + |Left Schema: |
| 254 | + |${leftDf.schema.pretty} |
| 255 | + | |
| 256 | + |Right Schema: |
| 257 | + |${prefixedRight.schema.pretty} |
| 258 | + | |
| 259 | + |""".stripMargin) |
| 260 | + |
| 261 | + cleanLeftDf.validateJoinKeys(prefixedRight, partLeftKeys) |
| 262 | + cleanLeftDf.join(prefixedRight, partLeftKeys, "left_outer") |
| 263 | + } |
| 264 | +} |
0 commit comments