Skip to content

Add Flink validation job + expose verb in streaming #495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d648220
Python updates to add validate verb
piyush-zlai Mar 10, 2025
f479da6
Spark expr eval + CU updates to support spark df based eval
piyush-zlai Mar 6, 2025
d450471
Flink basic validation job support
piyush-zlai Mar 6, 2025
46b03c7
Rework validate a bit + add integ test
piyush-zlai Mar 7, 2025
744231c
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 7, 2025
034bc23
Add validation unit tests
piyush-zlai Mar 10, 2025
df92152
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 10, 2025
a8a782c
Drop useCatalyst flag
piyush-zlai Mar 10, 2025
23aec0e
Reworked to use bulk apis
piyush-zlai Mar 10, 2025
de964d1
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 10, 2025
c9254f0
Add logs
piyush-zlai Mar 10, 2025
a3ab0a1
Lower //ism
piyush-zlai Mar 10, 2025
10ae03b
Add bounds to kafka src
piyush-zlai Mar 10, 2025
d3197bd
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 10, 2025
bcf1d9d
Rework bulk spark eval method
piyush-zlai Mar 11, 2025
e67b7af
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 11, 2025
4a372c2
Try and rework things
piyush-zlai Mar 11, 2025
f1810a6
Try a limit on flink main
piyush-zlai Mar 11, 2025
0663f83
Revert Kafka bounds
piyush-zlai Mar 11, 2025
d9b2eb5
Bump validateRows limit
piyush-zlai Mar 11, 2025
d9626e7
Improved failure toString
piyush-zlai Mar 11, 2025
9e7a479
Reinstate commons text exclusion
piyush-zlai Mar 11, 2025
2c76796
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 11, 2025
30744ff
Clean up expr eval code + comments
piyush-zlai Mar 11, 2025
1adef2d
Add more comments
piyush-zlai Mar 11, 2025
cdc256c
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 11, 2025
a92bbb8
Chain the validate job with flink job based on flag
piyush-zlai Mar 11, 2025
fff7aa1
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 11, 2025
13f8792
Move commons-text exclusion around to make test pass
piyush-zlai Mar 11, 2025
960486e
Scala 2.13 related test fixes & updates
piyush-zlai Mar 11, 2025
80d7ce9
style: Apply scalafix and scalafmt changes
piyush-zlai Mar 11, 2025
578aa20
Merge branch 'main' of https://github.com/zipline-ai/chronon into piy…
piyush-zlai Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/py/ai/chronon/repo/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self, args, jar_path):
self.kafka_bootstrap = args.get("kafka_bootstrap")
self.mock_source = args.get("mock_source")
self.savepoint_uri = args.get("savepoint_uri")
self.validate = args.get("validate")
self.validate_rows = args.get("validate_rows")

valid_jar = args["online_jar"] and os.path.exists(args["online_jar"])

Expand Down
6 changes: 5 additions & 1 deletion api/py/ai/chronon/repo/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,13 @@ def run_dataproc_flink_streaming(self):
"-ZGCP_PROJECT_ID": GcpRunner.get_gcp_project_id(),
"-ZGCP_BIGTABLE_INSTANCE_ID": GcpRunner.get_gcp_bigtable_instance_id(),
"--savepoint-uri": self.savepoint_uri,
"--validate-rows": self.validate_rows,
}

flag_args = {"--mock-source": self.mock_source}
flag_args = {
"--mock-source": self.mock_source,
"--validate": self.validate
}
flag_args_str = " ".join(key for key, value in flag_args.items() if value)

user_args_str = " ".join(
Expand Down
6 changes: 6 additions & 0 deletions api/py/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def set_defaults(ctx):
help="Use a mocked data source instead of a real source for groupby-streaming Flink.",
)
@click.option("--savepoint-uri", help="Savepoint URI for Flink streaming job")
@click.option("--validate", is_flag=True,
help="Validate the catalyst util Spark expression evaluation logic")
@click.option("--validate-rows", default="10000",
help="Number of rows to run the validation on")
@click.pass_context
def main(
ctx,
Expand Down Expand Up @@ -190,6 +194,8 @@ def main(
kafka_bootstrap,
mock_source,
savepoint_uri,
validate,
validate_rows
):
unknown_args = ctx.args
click.echo("Running with args: {}".format(ctx.params))
Expand Down
11 changes: 11 additions & 0 deletions api/src/main/scala/ai/chronon/api/ScalaJavaConversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ object ScalaJavaConversions {
}
}

implicit class IterableOps[T](iterable: java.lang.Iterable[T]) {
def toScala: Iterable[T] = {
iterable.asScala
}
}
implicit class JIterableOps[T](iterable: Iterable[T]) {
def toJava: java.lang.Iterable[T] = {
iterable.asJava
}
}

implicit class IteratorOps[T](iterator: java.util.Iterator[T]) {
def toScala: Iterator[T] = {
iterator.asScala
Expand Down
1 change: 1 addition & 0 deletions flink/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ scala_library(
maven_artifact("org.apache.hadoop:hadoop-common"),
maven_artifact("org.apache.hadoop:hadoop-client-api"),
maven_artifact("org.apache.hadoop:hadoop-yarn-api"),
maven_artifact("org.apache.commons:commons-lang3"),
],
)

Expand Down
20 changes: 20 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import ai.chronon.flink.SchemaRegistrySchemaProvider.RegistryHostKey
import ai.chronon.flink.types.AvroCodecOutput
import ai.chronon.flink.types.TimestampedTile
import ai.chronon.flink.types.WriteResponse
import ai.chronon.flink.validation.ValidationFlinkJob
import ai.chronon.flink.window.AlwaysFireOnElementTrigger
import ai.chronon.flink.window.FlinkRowAggProcessFunction
import ai.chronon.flink.window.FlinkRowAggregationFunction
Expand Down Expand Up @@ -291,6 +292,12 @@ object FlinkJob {
// Kafka config is optional as we can support other sources in the future
val kafkaBootstrap: ScallopOption[String] =
opt[String](required = false, descr = "Kafka bootstrap server in host:port format")
// Run in validate mode - We read rows using Kafka and run them through Spark Df and compare against CatalystUtil output
val validate: ScallopOption[Boolean] =
opt[Boolean](required = false, descr = "Run in validate mode", default = Some(false))
// Number of rows to use for validation
val validateRows: ScallopOption[Int] =
opt[Int](required = false, descr = "Number of rows to use for validation", default = Some(10000))

val apiProps: Map[String, String] = props[String]('Z', descr = "Props to configure API / KV Store")

Expand All @@ -304,10 +311,23 @@ object FlinkJob {
val props = jobArgs.apiProps.map(identity)
val useMockedSource = jobArgs.mockSource()
val kafkaBootstrap = jobArgs.kafkaBootstrap.toOption
val validateMode = jobArgs.validate()
val validateRows = jobArgs.validateRows()

val api = buildApi(onlineClassName, props)
val metadataStore = new MetadataStore(FetchContext(api.genKvStore, MetadataDataset))

if (validateMode) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would pull this out as a function.

val validationResults = ValidationFlinkJob.run(metadataStore, kafkaBootstrap, groupByName, validateRows)
if (validationResults.map(_.totalMismatches).sum > 0) {
val validationSummary = s"Total records: ${validationResults.map(_.totalRecords).sum}, " +
s"Total matches: ${validationResults.map(_.totalMatches).sum}, " +
s"Total mismatches: ${validationResults.map(_.totalMismatches).sum}"
throw new IllegalStateException(
s"Spark DF vs Catalyst util validation failed. Validation summary: $validationSummary")
}
}

val flinkJob =
if (useMockedSource) {
// We will yank this conditional block when we wire up our real sources etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ import org.apache.flink.dropwizard.metrics.DropwizardHistogramWrapper
import org.apache.flink.metrics.Counter
import org.apache.flink.metrics.Histogram
import org.apache.flink.util.Collector
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import scala.collection.Seq

/** A Flink function that uses Chronon's CatalystUtil to evaluate the Spark SQL expression in a GroupBy.
* This function is instantiated for a given type T (specific case class object, Thrift / Proto object).
Expand Down Expand Up @@ -99,6 +102,7 @@ class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends Ri
rowSerTimeHistogram.update(serFinish - start)

val maybeRow = catalystUtil.performSql(row)

exprEvalTimeHistogram.update(System.currentTimeMillis() - serFinish)
maybeRow.foreach(out.collect)
exprEvalSuccessCounter.inc()
Expand All @@ -115,4 +119,69 @@ class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends Ri
super.close()
CatalystUtil.session.close()
}

// Utility method to help with result validation. This method is used to match results of the core catalyst util based
// eval against Spark DF based eval. To do the Spark Df based eval, we:
// 1. Create a df with the events + record_id tacked on
// 2. Apply the projections and filters based on how we've set up the CatalystUtil instance based on the input groupBy.
// 3. Collect the results and group them by record_id
def runSparkSQLBulk(idToRecords: Seq[(String, Row)]): Map[String, Seq[Map[String, Any]]] = {

val idField = StructField("__record_id", StringType, false)
val fullSchema = StructType(idField +: encoder.schema.fields)
val fullRows = idToRecords.map { case (id, row) =>
// Create a new Row with id as the first field, followed by all fields from the original row
Row.fromSeq(id +: row.toSeq)
}

val rowsRdd = CatalystUtil.session.sparkContext.parallelize(fullRows.toSeq)

val eventDfs = CatalystUtil.session
.createDataFrame(rowsRdd, fullSchema)

// Apply filtering if needed
val filteredDf = catalystUtil.whereClauseOpt match {
case Some(whereClause) => eventDfs.where(whereClause)
case None => eventDfs
}

// Apply projections while preserving the index
val projectedDf = filteredDf.selectExpr(
// Include the index column and all the select clauses
Array("__record_id") ++ catalystUtil.selectClauses: _*
)

// Collect the results
val results = projectedDf.collect()

// Group results by record ID
val resultsByRecordId = results.groupBy(row => row.getString(0))

// Map back to the original record order
idToRecords.map { record =>
val recordId = record._1
val resultRows = resultsByRecordId.getOrElse(recordId, Array.empty)

val maps = resultRows.map { row =>
val columnNames = projectedDf.columns.tail // Skip the record ID column
columnNames.zipWithIndex.map { case (colName, i) =>
(colName, row.get(i + 1)) // +1 to skip the record ID column
}.toMap
}.toSeq

(recordId, maps)
}.toMap
}

// Utility method to help with result validation. This method is used to match results of the core catalyst util based
// eval against Spark DF based eval. This method iterates over the input records and hits the catalyst performSql method
// to collect results.
def runCatalystBulk(records: Seq[(String, T)]): Map[String, Seq[Map[String, Any]]] = {
records.map { record =>
val recordId = record._1
val row = rowSerializer(record._2)
val maybeRow = catalystUtil.performSql(row)
(recordId, maybeRow)
}.toMap
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package ai.chronon.flink.validation

import org.apache.commons.lang3.builder.EqualsBuilder

import scala.collection.immutable.SortedMap
import scala.collection.mutable

case class ComparisonResult(recordId: String,
isMatch: Boolean,
catalystResult: Seq[Map[String, Any]],
sparkDfResult: Seq[Map[String, Any]],
differences: Map[String, (Any, Any)]) {
override def toString: String = {
s"""
|RecordId: $recordId
|Is Match: $isMatch
|Catalyst Result: $catalystResult
|Spark DF Result: $sparkDfResult
|Differences (diff_type -> (catalystValue, sparkDfValue) ) : $differences
|""".stripMargin
}
}

object SparkExprEvalComparisonFn {

/** Utility function to compare the results of Catalyst and Spark DataFrame evaluation
* for a given recordId.
* At a high level comparison is done as follows:
* 1. If the number of rows in the catalyst vs spark df result is different, the results are considered different ("result_count" -> (catalystSize, sparkDfSize))
* 2. As the rows in the result can be in any order (which is ok from a Catalyst perspective), we sort the rows prior to comparing.
* 3. For each row, we compare the key-value pairs in the maps.
* If the size of the maps is different, the results are considered different ("result_row_size_$i" -> (catalystSize, sparkDfSize))
* If the values are different, the results are considered different ("result_row_value_${i}_$k" -> (catalystValue, sparkDfValue))
*/
private[validation] def compareResultRows(recordId: String,
catalystResult: Seq[Map[String, Any]],
sparkDfResult: Seq[Map[String, Any]]): ComparisonResult = {
if (catalystResult.size != sparkDfResult.size) {
return ComparisonResult(
recordId = recordId,
isMatch = false,
catalystResult = catalystResult,
sparkDfResult = sparkDfResult,
differences = Map("result_count" -> (catalystResult.size, sparkDfResult.size))
)
}

// We can expect multiple rows in the result (e.g. for explode queries) and these rows
// might be ordered differently. We need to compare the rows in a way that is order-agnostic.
val sortedCatalystResult = catalystResult.map(m => SortedMap[String, Any]() ++ m).sortBy(_.toString)
val sortedSparkDfResult = sparkDfResult.map(m => SortedMap[String, Any]() ++ m).sortBy(_.toString)
Comment on lines +50 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we make dataframes out of these - we have a nice Compare.sideBySide method that prints only the differing rows.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will be able to kill the rest of the code under.

// Compare each pair of maps
val differences = mutable.Map[String, (Any, Any)]()

for (i <- sortedCatalystResult.indices) {
val map1 = sortedCatalystResult(i)
val map2 = sortedSparkDfResult(i)

if (map1.size != map2.size) {
differences += s"result_row_size_$i" -> (map1.size, map2.size)
} else {
map1.foreach { case (k, v1) =>
val v2 = map2.getOrElse(k, null)

if (!deepEquals(v1, v2)) {
differences += s"result_row_value_${i}_$k" -> (v1, v2)
}
}
}
}

if (differences.isEmpty) {
ComparisonResult(
recordId = recordId,
isMatch = true,
catalystResult = catalystResult,
sparkDfResult = sparkDfResult,
differences = Map.empty
)
} else {
ComparisonResult(
recordId = recordId,
isMatch = false,
catalystResult = catalystResult,
sparkDfResult = sparkDfResult,
differences = differences.toMap
)
}
}

// Helper method for deep equality - primarily used to special case types like Maps that don't match correctly
// in EqualsBuilder.reflectionEquals across scala versions 2.12 and 2.13.
private def deepEquals(a: Any, b: Any): Boolean = (a, b) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat!

case (null, null) => true
case (null, _) | (_, null) => false
case (a: Map[_, _], b: Map[_, _]) =>
a.size == b.size && a.asInstanceOf[Map[Any, Any]].forall { case (k, v) =>
b.asInstanceOf[Map[Any, Any]].get(k) match {
case Some(bValue) => deepEquals(v, bValue)
case None => false
}
}
case _ => EqualsBuilder.reflectionEquals(a, b)
}
}
Loading