Skip to content

Logging fix -- make root cause more clear if label job has misaligned dates #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion api/python/ai/chronon/cli/compile/parse_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

logger = get_logger()


def from_folder(
cls: type, input_dir: str, compile_context: CompileContext
) -> List[CompiledObj]:
Expand Down
21 changes: 20 additions & 1 deletion api/python/ai/chronon/cli/compile/parse_teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EnvironmentVariables,
ExecutionInfo,
)
from ai.chronon.api.ttypes import Team
from ai.chronon.api.ttypes import Join, MetaData, Team
from ai.chronon.cli.logger import get_logger

logger = get_logger()
Expand Down Expand Up @@ -88,6 +88,25 @@ def update_metadata(obj: Any, team_dict: Dict[str, Team]):

metadata.outputNamespace = team_dict[team].outputNamespace

if isinstance(obj, Join):
join_namespace = obj.metaData.outputNamespace
# set the metadata for each join part and labelParts
def set_join_part_metadata(jp, output_namespace):
if jp.groupBy is not None:
if jp.groupBy.metaData and not jp.groupBy.metaData.outputNamespace:
jp.groupBy.metaData.outputNamespace = output_namespace
else:
jp.groupBy.metaData = MetaData()
jp.groupBy.metaData.outputNamespace = output_namespace

if obj.joinParts:
for jp in (obj.joinParts or []):
set_join_part_metadata(jp, join_namespace)

if obj.labelParts:
for lb in (obj.labelParts.labels or []):
set_join_part_metadata(lb, join_namespace)

if metadata.executionInfo is None:
metadata.executionInfo = ExecutionInfo()

Expand Down
12 changes: 11 additions & 1 deletion spark/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,17 @@ scala_test_suite(
scala_test_suite(
name = "streaming_test",
srcs = glob(["src/test/scala/ai/chronon/spark/test/streaming/*.scala"]),
data = glob(["spark/src/test/resources/**/*"]),
data = glob(["src/test/resources/**/*"]),
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
)

scala_test_suite(
name = "submission_test",
srcs = glob(["src/test/scala/ai/chronon/spark/test/submission/*.scala"]),
data = ["//spark/src/test/resources:test-resources"],
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
Expand Down
49 changes: 35 additions & 14 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ai.chronon.spark

import ai.chronon.api
import ai.chronon.api.Constants
import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.Extensions.{GroupByOps, JoinPartOps, MetadataOps, SourceOps}
import ai.chronon.api.planner.RelevantLeftForJoinPart
Expand All @@ -27,30 +28,39 @@ import ai.chronon.online.{Api, MetadataDirWalker, MetadataEndPoint, TopicChecker
import ai.chronon.orchestration.{JoinMergeNode, JoinPartNode}
import ai.chronon.spark.batch._
import ai.chronon.spark.format.Format
import ai.chronon.spark.stats.drift.{Summarizer, SummaryPacker, SummaryUploader}
import ai.chronon.spark.stats.{CompareBaseJob, CompareJob, ConsistencyJob}
import ai.chronon.spark.stats.CompareBaseJob
import ai.chronon.spark.stats.CompareJob
import ai.chronon.spark.stats.ConsistencyJob
import ai.chronon.spark.stats.drift.Summarizer
import ai.chronon.spark.stats.drift.SummaryPacker
import ai.chronon.spark.stats.drift.SummaryUploader
import ai.chronon.spark.streaming.JoinSourceRunner
import org.apache.commons.io.FileUtils
import org.apache.spark.SparkFiles
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener.{
QueryProgressEvent,
QueryStartedEvent,
QueryTerminatedEvent
}
import org.apache.spark.sql.{DataFrame, SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryTerminatedEvent
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.rogach.scallop.{ScallopConf, ScallopOption, Subcommand}
import org.slf4j.{Logger, LoggerFactory}
import org.rogach.scallop.ScallopConf
import org.rogach.scallop.ScallopOption
import org.rogach.scallop.Subcommand
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.yaml.snakeyaml.Yaml

import java.io.File
import java.nio.file.{Files, Paths}
import java.nio.file.Files
import java.nio.file.Paths
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, Future}
import scala.concurrent.Await
import scala.concurrent.Future
import scala.reflect.ClassTag
import scala.reflect.internal.util.ScalaClassLoader

Expand Down Expand Up @@ -397,10 +407,18 @@ object Driver {

def run(args: Args): Unit = {
val tableUtils = args.buildTableUtils()

// Use startPartitionOverride if provided, otherwise use endDate for both (single day)
val startDate = args.startPartitionOverride.toOption.getOrElse(args.endDate())
val endDate = args.endDate()

// Create a DateRange with start and end dates
val dateRange = new api.DateRange(startDate, endDate)

val labelJoin = new LabelJoinV2(
args.joinConf,
tableUtils,
args.endDate()
dateRange
)
labelJoin.compute()

Expand Down Expand Up @@ -990,7 +1008,10 @@ object Driver {
val partitionNames = args.partitionNames()
val tablesToPartitionSpec = partitionNames.map((p) =>
p.split("/").toList match {
case fullTableName :: partitionSpec :: Nil => (fullTableName, Format.parseHiveStylePartition(partitionSpec))
case fullTableName :: partitionParts if partitionParts.nonEmpty =>
// Join all partition parts with "/" and parse as one combined partition spec.
val partitionSpec = partitionParts.mkString("/")
(fullTableName, Format.parseHiveStylePartition(partitionSpec))
case fullTableName :: Nil =>
throw new IllegalArgumentException(
s"A partition spec must be specified for ${fullTableName}. ${helpNamingConvention}")
Expand Down
66 changes: 46 additions & 20 deletions spark/src/main/scala/ai/chronon/spark/batch/LabelJoinV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.Seq
case class LabelPartOutputInfo(labelPart: JoinPart, outputColumnNames: Seq[String])
case class AllLabelOutputInfo(joinDsAsRange: PartitionRange, labelPartOutputInfos: Seq[LabelPartOutputInfo])

class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDateRange: api.DateRange) {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
implicit val partitionSpec: PartitionSpec = tableUtils.partitionSpec
assert(Option(joinConf.metaData.outputNamespace).nonEmpty, "output namespace could not be empty or null")
Expand All @@ -33,7 +33,7 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
.map(_.asScala.toMap)
.getOrElse(Map.empty[String, String])
private val labelColumnPrefix = "label_"
private val labelDsAsRange = PartitionRange(labelDs, labelDs)
private val labelDsAsPartitionRange = labelDateRange.toPartitionRange

private def getLabelColSchema(labelOutputs: Seq[AllLabelOutputInfo]): Seq[(String, DataType)] = {
val labelPartToOutputCols = labelOutputs
Expand Down Expand Up @@ -93,13 +93,21 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
.mapValues(_.map(_._2)) // Drop the duplicate window
.map { case (window, labelPartOutputInfos) =>
// The labelDs is a lookback from the labelSnapshot partition back to the join output table
val joinPartitionDsAsRange = labelDsAsRange.shift(window * -1)
val joinPartitionDsAsRange = labelDsAsPartitionRange.shift(window * -1)
window -> AllLabelOutputInfo(joinPartitionDsAsRange, labelPartOutputInfos)
}
.toMap
}

def compute(): DataFrame = {
val resultDfsPerDay = labelDsAsPartitionRange.steps(days = 1).map { dayStep =>
computeDay(dayStep.start)
}

resultDfsPerDay.tail.foldLeft(resultDfsPerDay.head)((acc, df) => acc.union(df))
}

def computeDay(labelDs: String): DataFrame = {
logger.info(s"Running LabelJoinV2 for $labelDs")

runAssertions()
Expand All @@ -119,26 +127,39 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
}

if (missingWindowToOutputs.nonEmpty) {
logger.info(
s"""Missing following partitions from $joinTable: ${missingWindowToOutputs.values
.map(_.joinDsAsRange.start)
.mkString(", ")}

// Always log this no matter what.
val baseLogString = s"""Missing following partitions from $joinTable: ${missingWindowToOutputs.values
.map(_.joinDsAsRange.start)
.mkString(", ")}
|
|Found existing partitions: ${existingJoinPartitions.mkString(", ")}
|Found existing partitions of join output: ${existingJoinPartitions.mkString(", ")}
|
|Therefore unable to compute the labels for ${missingWindowToOutputs.keys.mkString(", ")}
|Required dates are computed based on label date (the run date) - window for distinct windows that are used in label parts.
|
|For requested ds: $labelDs
|In this case, the run date is: $labelDs, and given the existing partitions we are unable to compute the labels for the following windows: ${missingWindowToOutputs.keys
.mkString(", ")} (days).
|
|Proceeding with valid windows: ${computableWindowToOutputs.keys.mkString(", ")}
|""".stripMargin

// If there are no dates to run, also throw that error
require(
computableWindowToOutputs.nonEmpty,
s"""$baseLogString
|
|There are no partitions that we can run the label join for. At least one window must be computable.
|
|Exiting.
|""".stripMargin
)

require(
computableWindowToOutputs.isEmpty,
"No valid windows to compute labels for given the existing join output range." +
s"Consider backfilling the join output table for the following days: ${missingWindowToOutputs.values.map(_.joinDsAsRange.start)}."
// Else log what we are running, but warn about missing windows
logger.warn(
s"""$baseLogString
|
|Proceeding with valid windows: ${computableWindowToOutputs.keys.mkString(", ")}
|
|""".stripMargin
)
}

Expand All @@ -150,7 +171,11 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
// Each unique window is an output partition in the joined table
// Each window may contain a subset of the joinParts and their columns
computableWindowToOutputs.foreach { case (windowLength, joinOutputInfo) =>
computeOutputForWindow(windowLength, joinOutputInfo, existingLabelTableOutputPartitions, windowToLabelOutputInfos)
computeOutputForWindow(windowLength,
joinOutputInfo,
existingLabelTableOutputPartitions,
windowToLabelOutputInfos,
labelDsAsPartitionRange)
}

val allOutputDfs = computableWindowToOutputs.values
Expand All @@ -171,16 +196,17 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {
private def computeOutputForWindow(windowLength: Int,
joinOutputInfo: AllLabelOutputInfo,
existingLabelTableOutputPartitions: Seq[String],
windowToLabelOutputInfos: Map[Int, AllLabelOutputInfo]): Unit = {
windowToLabelOutputInfos: Map[Int, AllLabelOutputInfo],
labelDsAsPartitionRange: PartitionRange): Unit = {
logger.info(
s"Computing labels for window: $windowLength days on labelDs: $labelDs \n" +
s"Computing labels for window: $windowLength days on labelDs: ${labelDsAsPartitionRange.start} \n" +
s"Includes the following joinParts and output cols: ${joinOutputInfo.labelPartOutputInfos
.map(x => s"${x.labelPart.groupBy.metaData.name} -> ${x.outputColumnNames.mkString(", ")}")
.mkString("\n")}")

val startMillis = System.currentTimeMillis()
// This is the join output ds that we're working with
val joinDsAsRange = labelDsAsRange.shift(windowLength * -1)
val joinDsAsRange = labelDsAsPartitionRange.shift(windowLength * -1)

val joinBaseDf = if (existingLabelTableOutputPartitions.contains(joinDsAsRange.start)) {
// If the existing join table has the partition, then we should use it, because another label column
Expand Down Expand Up @@ -208,7 +234,7 @@ class LabelJoinV2(joinConf: api.Join, tableUtils: TableUtils, labelDs: String) {

val snapshotQuery = Builders.Query(selects = selectCols)
val snapshotTable = labelJoinPart.groupBy.metaData.outputTable
val snapshotDf = tableUtils.scanDf(snapshotQuery, snapshotTable, range = Some(labelDsAsRange))
val snapshotDf = tableUtils.scanDf(snapshotQuery, snapshotTable, range = Some(labelDsAsPartitionRange))

(labelJoinPart, snapshotDf)
}
Expand Down
41 changes: 22 additions & 19 deletions spark/src/main/scala/ai/chronon/spark/submission/JobSubmitter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,36 @@ object JobSubmitter {
val confTypeValue = getArgValue(args, ConfTypeArgKeyword)

val modeConfigProperties = if (localConfPathValue.isDefined && confTypeValue.isDefined) {
val executionInfo = confTypeValue.get match {
case "joins" => parseConf[api.Join](localConfPathValue.get).metaData.executionInfo
case "group_bys" => parseConf[api.GroupBy](localConfPathValue.get).metaData.executionInfo
case "staging_queries" => parseConf[api.StagingQuery](localConfPathValue.get).metaData.executionInfo
case "models" => parseConf[api.Model](localConfPathValue.get).metaData.executionInfo
val metadata = confTypeValue.get match {
case "joins" => parseConf[api.Join](localConfPathValue.get).metaData
case "group_bys" => parseConf[api.GroupBy](localConfPathValue.get).metaData
case "staging_queries" => parseConf[api.StagingQuery](localConfPathValue.get).metaData
case "models" => parseConf[api.Model](localConfPathValue.get).metaData
case _ => throw new Exception("Invalid conf type")
}

val originalMode = getArgValue(args, OriginalModeArgKeyword)
val executionInfo = Option(metadata.getExecutionInfo)

(Option(executionInfo.conf), originalMode) match {
case (Some(conf), Some(mode)) =>
Option(conf.getModeConfigs).map(modeConfigs => {
if (modeConfigs.containsKey(mode)) {
modeConfigs.get(mode).toScala
if (executionInfo.isEmpty) {
None
} else {
val originalMode = getArgValue(args, OriginalModeArgKeyword)

(Option(executionInfo.get.conf), originalMode) match {
case (Some(conf), Some(mode)) =>
val modeConfs = if (conf.isSetModeConfigs && conf.getModeConfigs.containsKey(mode)) {
conf.getModeConfigs.get(mode).toScala
} else if (conf.isSetCommon) {
conf.getCommon.toScala
} else {
// check common
if (conf.isSetCommon) {
conf.getCommon.toScala
} else {
Map[String, String]()
}
Map[String, String]()
}
})
case _ => None
Option(modeConfs)
case _ => None
}
}
} else None

modeConfigProperties
}
}
Expand Down
Loading