Skip to content

Commit d9378ee

Browse files
authored
Swap out AvroCodecOutput + KV store output types from case classes to PoJos (#294)
## Summary We use case classes for many of our intermediate types in Flink and a few of these are persisted in Flink state. In the tiled setup the TimestampIR as an example. In untiled & tiled the Input and Output types of the AsyncKVWriter are persisted to state. In the future if we do need to update these intermediate types to include additional fields (like we're thinking of doing to support tiling) it will not be possible as Flink doesn't support [state schema evolution for case classes](https://nightlies.apache.org/flink/flink-docs-release-1.17/docs/dev/datastream/fault-tolerance/serialization/schema_evolution/). Due to that we'll need to do a time consuming migration where we spin up parallel operators with the new types, dual write and then cut over in a subsequent job. Instead we can try and set this up before we're in prod to hopefully minimize running into this issue down the line. This PR essentially swaps the case classes for Scala PoJo types for the tiled aggregations, AvroCodecOutput (as that feeds into KV store writer) and the KV store write response. In a subsequent PR we can update the TimestampTile to include the startTs of the tile (and plumb the latestTs through to the sink so we can track e2e lag). (Choose the Scala PoJo route here as there's a bit of interplay with some of our aggregator libs and other Scala code related to these classes - e.g. the FlinkRowAggregatorFunction and there's a bunch of casting needed to interop. The Flink 2.0 migration will likely be a decent sized chunk given all our Flink code is in Scala and I think we can bite it off then) ## Checklist - [ ] Added Unit Tests - [X] Covered by existing CI - [X] Integration tested - Confirmed that this works by kicking off our TestFlinkJob on the cluster - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced new data types (`AvroCodecOutput`, `TimestampedTile`, `TimestampedIR`, `WriteResponse`) for improved data handling in Flink jobs - Enhanced schema evolution support for stateful data processing - **Refactor** - Replaced `PutRequest` with `AvroCodecOutput` across multiple Flink processing components - Updated method signatures and data type handling in various Flink-related classes - Simplified timestamp retrieval and data structure interactions - **Chores** - Reorganized type imports and package structure - Updated test cases to align with new data type implementations <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 8dba7ef commit d9378ee

File tree

10 files changed

+199
-81
lines changed

10 files changed

+199
-81
lines changed

flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ai.chronon.flink
22

3+
import ai.chronon.flink.types.AvroCodecOutput
4+
import ai.chronon.flink.types.WriteResponse
35
import ai.chronon.online.Api
46
import ai.chronon.online.KVStore
57
import ai.chronon.online.KVStore.PutRequest
@@ -19,14 +21,12 @@ import scala.concurrent.Future
1921
import scala.util.Failure
2022
import scala.util.Success
2123

22-
case class WriteResponse(putRequest: PutRequest, status: Boolean)
23-
2424
object AsyncKVStoreWriter {
2525
private val kvStoreConcurrency = 10
2626
private val defaultTimeoutMillis = 1000L
2727

28-
def withUnorderedWaits(inputDS: DataStream[PutRequest],
29-
kvStoreWriterFn: RichAsyncFunction[PutRequest, WriteResponse],
28+
def withUnorderedWaits(inputDS: DataStream[AvroCodecOutput],
29+
kvStoreWriterFn: RichAsyncFunction[AvroCodecOutput, WriteResponse],
3030
featureGroupName: String,
3131
timeoutMillis: Long = defaultTimeoutMillis,
3232
capacity: Int = kvStoreConcurrency): DataStream[WriteResponse] = {
@@ -69,7 +69,7 @@ object AsyncKVStoreWriter {
6969
* @param featureGroupName Name of the FG we're writing to
7070
*/
7171
class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String)
72-
extends RichAsyncFunction[PutRequest, WriteResponse] {
72+
extends RichAsyncFunction[AvroCodecOutput, WriteResponse] {
7373
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
7474

7575
@transient private var kvStore: KVStore = _
@@ -96,14 +96,17 @@ class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String)
9696
kvStore = getKVStore
9797
}
9898

99-
override def timeout(input: PutRequest, resultFuture: ResultFuture[WriteResponse]): Unit = {
99+
override def timeout(input: AvroCodecOutput, resultFuture: ResultFuture[WriteResponse]): Unit = {
100100
logger.error(s"Timed out writing to KV Store for object: $input")
101101
errorCounter.inc()
102-
resultFuture.complete(util.Arrays.asList[WriteResponse](WriteResponse(input, status = false)))
102+
resultFuture.complete(
103+
util.Arrays.asList[WriteResponse](
104+
new WriteResponse(input.keyBytes, input.valueBytes, input.dataset, input.tsMillis, status = false)))
103105
}
104106

105-
override def asyncInvoke(input: PutRequest, resultFuture: ResultFuture[WriteResponse]): Unit = {
106-
val resultFutureRequested: Future[Seq[Boolean]] = kvStore.multiPut(Seq(input))
107+
override def asyncInvoke(input: AvroCodecOutput, resultFuture: ResultFuture[WriteResponse]): Unit = {
108+
val putRequest = PutRequest(input.keyBytes, input.valueBytes, input.dataset, Some(input.tsMillis))
109+
val resultFutureRequested: Future[Seq[Boolean]] = kvStore.multiPut(Seq(putRequest))
107110
resultFutureRequested.onComplete {
108111
case Success(l) =>
109112
val succeeded = l.forall(identity)
@@ -113,14 +116,18 @@ class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String)
113116
errorCounter.inc()
114117
logger.error(s"Failed to write to KVStore for object: $input")
115118
}
116-
resultFuture.complete(util.Arrays.asList[WriteResponse](WriteResponse(input, status = succeeded)))
119+
resultFuture.complete(
120+
util.Arrays.asList[WriteResponse](
121+
new WriteResponse(input.keyBytes, input.valueBytes, input.dataset, input.tsMillis, status = succeeded)))
117122
case Failure(exception) =>
118123
// this should be rare and indicates we have an uncaught exception
119124
// in the KVStore - we log the exception and skip the object to
120125
// not fail the app
121126
errorCounter.inc()
122127
logger.error(s"Caught exception writing to KVStore for object: $input", exception)
123-
resultFuture.complete(util.Arrays.asList[WriteResponse](WriteResponse(input, status = false)))
128+
resultFuture.complete(
129+
util.Arrays.asList[WriteResponse](
130+
new WriteResponse(input.keyBytes, input.valueBytes, input.dataset, input.tsMillis, status = false)))
124131
}
125132
}
126133
}

flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ import ai.chronon.api.DataModel
55
import ai.chronon.api.Extensions.GroupByOps
66
import ai.chronon.api.Query
77
import ai.chronon.api.{StructType => ChrononStructType}
8-
import ai.chronon.flink.window.TimestampedTile
8+
import ai.chronon.flink.types.AvroCodecOutput
9+
import ai.chronon.flink.types.TimestampedTile
910
import ai.chronon.online.AvroConversions
1011
import ai.chronon.online.GroupByServingInfoParsed
11-
import ai.chronon.online.KVStore.PutRequest
1212
import org.apache.flink.api.common.functions.RichFlatMapFunction
1313
import org.apache.flink.configuration.Configuration
1414
import org.apache.flink.metrics.Counter
@@ -89,7 +89,7 @@ sealed abstract class BaseAvroCodecFn[IN, OUT] extends RichFlatMapFunction[IN, O
8989
* @tparam T The input data type
9090
*/
9191
case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
92-
extends BaseAvroCodecFn[Map[String, Any], PutRequest] {
92+
extends BaseAvroCodecFn[Map[String, Any], AvroCodecOutput] {
9393

9494
override def open(configuration: Configuration): Unit = {
9595
super.open(configuration)
@@ -101,7 +101,7 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
101101

102102
override def close(): Unit = super.close()
103103

104-
override def flatMap(value: Map[String, Any], out: Collector[PutRequest]): Unit =
104+
override def flatMap(value: Map[String, Any], out: Collector[AvroCodecOutput]): Unit =
105105
try {
106106
out.collect(avroConvertMapToPutRequest(value))
107107
} catch {
@@ -113,11 +113,11 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
113113
avroConversionErrorCounter.inc()
114114
}
115115

116-
def avroConvertMapToPutRequest(in: Map[String, Any]): PutRequest = {
116+
def avroConvertMapToPutRequest(in: Map[String, Any]): AvroCodecOutput = {
117117
val tsMills = in(timeColumnAlias).asInstanceOf[Long]
118118
val keyBytes = keyToBytes(keyColumns.map(in(_)))
119119
val valueBytes = valueToBytes(valueColumns.map(in(_)))
120-
PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills))
120+
new AvroCodecOutput(keyBytes, valueBytes, streamingDataset, tsMills)
121121
}
122122
}
123123

@@ -129,7 +129,7 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
129129
* @tparam T The input data type
130130
*/
131131
case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
132-
extends BaseAvroCodecFn[TimestampedTile, PutRequest] {
132+
extends BaseAvroCodecFn[TimestampedTile, AvroCodecOutput] {
133133
override def open(configuration: Configuration): Unit = {
134134
super.open(configuration)
135135
val metricsGroup = getRuntimeContext.getMetricGroup
@@ -140,7 +140,7 @@ case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParse
140140
}
141141
override def close(): Unit = super.close()
142142

143-
override def flatMap(value: TimestampedTile, out: Collector[PutRequest]): Unit =
143+
override def flatMap(value: TimestampedTile, out: Collector[AvroCodecOutput]): Unit =
144144
try {
145145
out.collect(avroConvertTileToPutRequest(value))
146146
} catch {
@@ -152,7 +152,7 @@ case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParse
152152
avroConversionErrorCounter.inc()
153153
}
154154

155-
def avroConvertTileToPutRequest(in: TimestampedTile): PutRequest = {
155+
def avroConvertTileToPutRequest(in: TimestampedTile): AvroCodecOutput = {
156156
val tsMills = in.latestTsMillis
157157

158158
// 'keys' is a map of (key name in schema -> key value), e.g. Map("card_number" -> "4242-4242-4242-4242")
@@ -170,6 +170,6 @@ case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParse
170170
|streamingDataset=$streamingDataset""".stripMargin
171171
)
172172

173-
PutRequest(keyBytes, valueBytes, streamingDataset, Some(tsMills))
173+
new AvroCodecOutput(keyBytes, valueBytes, streamingDataset, tsMills)
174174
}
175175
}

flink/src/main/scala/ai/chronon/flink/FlinkJob.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ import ai.chronon.api.DataType
77
import ai.chronon.api.Extensions.GroupByOps
88
import ai.chronon.api.Extensions.SourceOps
99
import ai.chronon.flink.FlinkJob.watermarkStrategy
10+
import ai.chronon.flink.types.AvroCodecOutput
11+
import ai.chronon.flink.types.TimestampedTile
12+
import ai.chronon.flink.types.WriteResponse
1013
import ai.chronon.flink.window.AlwaysFireOnElementTrigger
1114
import ai.chronon.flink.window.FlinkRowAggProcessFunction
1215
import ai.chronon.flink.window.FlinkRowAggregationFunction
1316
import ai.chronon.flink.window.KeySelector
14-
import ai.chronon.flink.window.TimestampedTile
1517
import ai.chronon.online.Api
1618
import ai.chronon.online.GroupByServingInfoParsed
17-
import ai.chronon.online.KVStore.PutRequest
1819
import ai.chronon.online.MetadataStore
1920
import ai.chronon.online.SparkConversions
2021
import ai.chronon.online.TopicInfo
@@ -58,7 +59,7 @@ import scala.concurrent.duration.FiniteDuration
5859
* @tparam T - The input data type
5960
*/
6061
class FlinkJob[T](eventSrc: FlinkSource[T],
61-
sinkFn: RichAsyncFunction[PutRequest, WriteResponse],
62+
sinkFn: RichAsyncFunction[AvroCodecOutput, WriteResponse],
6263
groupByServingInfoParsed: GroupByServingInfoParsed,
6364
encoder: Encoder[T],
6465
parallelism: Int) {
@@ -117,7 +118,7 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
117118
.name(s"Spark expression eval with timestamps for $groupByName")
118119
.setParallelism(sourceStream.parallelism)
119120

120-
val putRecordDS: DataStream[PutRequest] = sparkExprEvalDSWithWatermarks
121+
val putRecordDS: DataStream[AvroCodecOutput] = sparkExprEvalDSWithWatermarks
121122
.flatMap(AvroCodecFn[T](groupByServingInfoParsed))
122123
.uid(s"avro-conversion-$groupByName")
123124
.name(s"Avro conversion for $groupByName")
@@ -221,7 +222,7 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
221222
.name(s"Tiling Side Output Late Data for $groupByName")
222223
.setParallelism(sourceStream.parallelism)
223224

224-
val putRecordDS: DataStream[PutRequest] = tilingDS
225+
val putRecordDS: DataStream[AvroCodecOutput] = tilingDS
225226
.flatMap(new TiledAvroCodecFn[T](groupByServingInfoParsed))
226227
.uid(s"avro-conversion-01-$groupByName")
227228
.name(s"Avro conversion for $groupByName")

flink/src/main/scala/ai/chronon/flink/MetricsSink.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
package ai.chronon.flink
2+
import ai.chronon.flink.types.WriteResponse
23
import com.codahale.metrics.ExponentiallyDecayingReservoir
34
import org.apache.flink.configuration.Configuration
45
import org.apache.flink.dropwizard.metrics.DropwizardHistogramWrapper
@@ -30,7 +31,7 @@ class MetricsSink(groupByName: String) extends RichSinkFunction[WriteResponse] {
3031
}
3132

3233
override def invoke(value: WriteResponse, context: SinkFunction.Context): Unit = {
33-
val eventCreatedToSinkTime = System.currentTimeMillis() - value.putRequest.tsMillis.get
34+
val eventCreatedToSinkTime = System.currentTimeMillis() - value.tsMillis
3435
eventCreatedToSinkTimeHistogram.update(eventCreatedToSinkTime)
3536
}
3637
}

flink/src/main/scala/ai/chronon/flink/TestFlinkJob.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@ import ai.chronon.api.StringType
1616
import ai.chronon.api.TimeUnit
1717
import ai.chronon.api.Window
1818
import ai.chronon.api.{StructType => ApiStructType}
19+
import ai.chronon.flink.types.WriteResponse
1920
import ai.chronon.online.Api
2021
import ai.chronon.online.AvroCodec
2122
import ai.chronon.online.AvroConversions
2223
import ai.chronon.online.Extensions.StructTypeOps
2324
import ai.chronon.online.GroupByServingInfoParsed
25+
import org.apache.flink.api.common.serialization.DeserializationSchema
26+
import org.apache.flink.api.common.serialization.SerializationSchema
2427
import org.apache.flink.api.scala.createTypeInformation
28+
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup
2529
import org.apache.flink.streaming.api.functions.sink.SinkFunction
2630
import org.apache.flink.streaming.api.functions.source.SourceFunction
2731
import org.apache.flink.streaming.api.scala.DataStream
2832
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
33+
import org.apache.flink.util.SimpleUserCodeClassLoader
34+
import org.apache.flink.util.UserCodeClassLoader
2935
import org.apache.spark.sql.Row
3036
import org.apache.spark.sql.avro.AvroDeserializationSupport
3137
import org.apache.spark.sql.types.StructType
@@ -71,11 +77,20 @@ class PrintSink extends SinkFunction[WriteResponse] {
7177
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
7278

7379
override def invoke(value: WriteResponse, context: SinkFunction.Context): Unit = {
74-
val elapsedTime = System.currentTimeMillis() - value.putRequest.tsMillis.get
80+
val elapsedTime = System.currentTimeMillis() - value.tsMillis
7581
logger.info(s"Received write response with status ${value.status}; elapsedTime = $elapsedTime ms")
7682
}
7783
}
7884

85+
class DummyInitializationContext
86+
extends SerializationSchema.InitializationContext
87+
with DeserializationSchema.InitializationContext {
88+
override def getMetricGroup = new UnregisteredMetricsGroup
89+
90+
override def getUserCodeClassLoader: UserCodeClassLoader =
91+
SimpleUserCodeClassLoader.create(classOf[DummyInitializationContext].getClassLoader)
92+
}
93+
7994
object TestFlinkJob {
8095
val fields: Array[(String, DataType)] = Array(
8196
"id" -> StringType,
@@ -92,6 +107,7 @@ object TestFlinkJob {
92107

93108
def makeSource(mockPartitionCount: Int): FlinkSource[Row] = {
94109
val avroCodec = AvroCodec.of(e2eTestEventAvroSchema)
110+
avroDeserializationSchema.open(new DummyInitializationContext)
95111
val startTs = System.currentTimeMillis()
96112
val elements: Seq[Map[String, AnyRef]] = (0 until 10).map(i =>
97113
Map("id" -> s"test$i",
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package ai.chronon.flink.types
2+
3+
import java.util
4+
import java.util.Objects
5+
6+
// This file contains PoJo classes that are persisted while taking checkpoints in Chronon's Flink jobs. This falls primarily
7+
// in two buckets - tiled state and KV store incoming / outgoing records. The classes used in these cases need to allow for state
8+
// schema evolution (https://nightlies.apache.org/flink/flink-docs-release-1.17/docs/dev/datastream/fault-tolerance/serialization/schema_evolution/)
9+
// This allows us to add / remove fields without requiring us to migrate the state using dual write / read patterns.
10+
11+
/**
12+
* Combines the IR (intermediate result) with the timestamp of the event being processed.
13+
* We need the timestamp of the event processed so we can calculate processing lag down the line.
14+
*
15+
* Example: for a GroupBy with 2 windows, we'd have TimestampedTile( [IR for window 1, IR for window 2], timestamp ).
16+
*
17+
* @param ir the array of partial aggregates
18+
* @param latestTsMillis timestamp of the current event being processed
19+
*/
20+
class TimestampedIR(var ir: Array[Any], var latestTsMillis: Option[Long]) {
21+
def this() = this(Array(), None)
22+
23+
override def toString: String =
24+
s"TimestampedIR(ir=${ir.mkString(", ")}, latestTsMillis=$latestTsMillis)"
25+
26+
override def hashCode(): Int =
27+
Objects.hash(ir.deep, latestTsMillis)
28+
29+
override def equals(other: Any): Boolean =
30+
other match {
31+
case e: TimestampedIR =>
32+
util.Arrays.deepEquals(ir.asInstanceOf[Array[AnyRef]],
33+
e.ir.asInstanceOf[Array[AnyRef]]) && latestTsMillis == e.latestTsMillis
34+
case _ => false
35+
}
36+
}
37+
38+
/**
39+
* Combines the entity keys, the encoded IR (intermediate result), and the timestamp of the event being processed.
40+
*
41+
* We need the timestamp of the event processed so we can calculate processing lag down the line.
42+
*
43+
* @param keys the GroupBy entity keys
44+
* @param tileBytes encoded tile IR
45+
* @param latestTsMillis timestamp of the current event being processed
46+
*/
47+
class TimestampedTile(var keys: List[Any], var tileBytes: Array[Byte], var latestTsMillis: Long) {
48+
def this() = this(List(), Array(), 0L)
49+
50+
override def toString: String =
51+
s"TimestampedTile(keys=${keys.mkString(", ")}, tileBytes=${java.util.Base64.getEncoder
52+
.encodeToString(tileBytes)}, latestTsMillis=$latestTsMillis)"
53+
54+
override def hashCode(): Int =
55+
Objects.hash(keys.toArray.deep, tileBytes, latestTsMillis.asInstanceOf[java.lang.Long])
56+
57+
override def equals(other: Any): Boolean =
58+
other match {
59+
case e: TimestampedTile =>
60+
util.Arrays.deepEquals(keys.toArray.asInstanceOf[Array[AnyRef]], e.keys.toArray.asInstanceOf[Array[AnyRef]]) &&
61+
util.Arrays.equals(tileBytes, e.tileBytes) &&
62+
latestTsMillis == e.latestTsMillis
63+
case _ => false
64+
}
65+
}
66+
67+
/**
68+
* Output emitted by the AvroCodecFn operator. This is fed into the Async KV store writer and objects of this type are persisted
69+
* while taking checkpoints.
70+
*/
71+
class AvroCodecOutput(var keyBytes: Array[Byte], var valueBytes: Array[Byte], var dataset: String, var tsMillis: Long) {
72+
def this() = this(Array(), Array(), "", 0L)
73+
74+
override def hashCode(): Int =
75+
Objects.hash(
76+
keyBytes,
77+
valueBytes,
78+
dataset,
79+
tsMillis.asInstanceOf[java.lang.Long]
80+
)
81+
82+
override def equals(other: Any): Boolean =
83+
other match {
84+
case o: AvroCodecOutput =>
85+
util.Arrays.equals(keyBytes, o.keyBytes) &&
86+
util.Arrays.equals(valueBytes, o.valueBytes) &&
87+
dataset == o.dataset &&
88+
tsMillis == o.tsMillis
89+
case _ => false
90+
}
91+
}
92+
93+
/**
94+
* Output records emitted by the AsyncKVStoreWriter. Objects of this type are persisted while taking checkpoints.
95+
*/
96+
class WriteResponse(var keyBytes: Array[Byte],
97+
var valueBytes: Array[Byte],
98+
var dataset: String,
99+
var tsMillis: Long,
100+
var status: Boolean) {
101+
def this() = this(Array(), Array(), "", 0L, false)
102+
103+
override def hashCode(): Int =
104+
Objects.hash(keyBytes,
105+
valueBytes,
106+
dataset,
107+
tsMillis.asInstanceOf[java.lang.Long],
108+
status.asInstanceOf[java.lang.Boolean])
109+
110+
override def equals(other: Any): Boolean =
111+
other match {
112+
case o: WriteResponse =>
113+
util.Arrays.equals(keyBytes, o.keyBytes) &&
114+
util.Arrays.equals(valueBytes, o.valueBytes) &&
115+
dataset == o.dataset &&
116+
tsMillis == o.tsMillis &&
117+
status == o.status
118+
case _ => false
119+
}
120+
}

0 commit comments

Comments
 (0)