Skip to content

refactor: split fetcher logic into multiple files #425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions api/src/main/scala/ai/chronon/api/SerdeUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package ai.chronon.api

import ai.chronon.api.thrift.protocol.{TBinaryProtocol, TCompactProtocol}
import ai.chronon.api.thrift.{TDeserializer, TSerializer}

object SerdeUtils {
@transient
lazy val compactSerializer: ThreadLocal[TSerializer] = new ThreadLocal[TSerializer] {
override def initialValue(): TSerializer = new TSerializer(new TCompactProtocol.Factory())
}

@transient
lazy val compactDeserializer: ThreadLocal[TDeserializer] = new ThreadLocal[TDeserializer] {
override def initialValue(): TDeserializer = new TDeserializer(new TCompactProtocol.Factory())
}
}
22 changes: 3 additions & 19 deletions api/src/main/scala/ai/chronon/api/TilingUtils.scala
Original file line number Diff line number Diff line change
@@ -1,37 +1,21 @@
package ai.chronon.api

import ai.chronon.api.thrift.TDeserializer
import ai.chronon.api.thrift.TSerializer
import ai.chronon.api.thrift.protocol.TBinaryProtocol
import ai.chronon.api.thrift.protocol.TProtocolFactory
import ai.chronon.fetcher.TileKey
import ai.chronon.api.SerdeUtils

import java.io.Serializable
import scala.jdk.CollectionConverters._

// Convenience functions for working with tiling
object TilingUtils {
class SerializableSerializer(factory: TProtocolFactory) extends TSerializer(factory) with Serializable

// crazy bug in compact protocol - do not change to compact

@transient
lazy val binarySerializer: ThreadLocal[TSerializer] = new ThreadLocal[TSerializer] {
override def initialValue(): TSerializer = new TSerializer(new TBinaryProtocol.Factory())
}

@transient
lazy val binaryDeserializer: ThreadLocal[TDeserializer] = new ThreadLocal[TDeserializer] {
override def initialValue(): TDeserializer = new TDeserializer(new TBinaryProtocol.Factory())
}

def serializeTileKey(key: TileKey): Array[Byte] = {
binarySerializer.get().serialize(key)
SerdeUtils.compactSerializer.get().serialize(key)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no bugs here - just needed thread local

}

def deserializeTileKey(bytes: Array[Byte]): TileKey = {
val key = new TileKey()
binaryDeserializer.get().deserialize(key, bytes)
SerdeUtils.compactDeserializer.get().deserialize(key, bytes)
key
}

Expand Down
10 changes: 5 additions & 5 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ import ai.chronon.flink.window.KeySelectorBuilder
import ai.chronon.online.Api
import ai.chronon.online.FlagStoreConstants
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.online.MetadataStore
import ai.chronon.online.SparkConversions
import ai.chronon.online.TopicInfo
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner
import org.apache.flink.api.common.eventtime.WatermarkStrategy
import org.apache.flink.configuration.CheckpointingOptions
Expand Down Expand Up @@ -243,7 +243,7 @@ object FlinkJob {
// we set an explicit max parallelism to ensure if we do make parallelism setting updates, there's still room
// to restore the job from prior state. Number chosen does have perf ramifications if too high (can impact rocksdb perf)
// so we've chosen one that should allow us to scale to jobs in the 10K-50K events / s range.
val MaxParallelism = 1260 // highly composite number
val MaxParallelism: Int = 1260 // highly composite number

// We choose to checkpoint frequently to ensure the incremental checkpoints are small in size
// as well as ensuring the catch-up backlog is fairly small in case of failures
Expand All @@ -254,11 +254,11 @@ object FlinkJob {
val CheckpointTimeout: FiniteDuration = 5.minutes

// We use incremental checkpoints and we cap how many we keep around
val MaxRetainedCheckpoints = 10
val MaxRetainedCheckpoints: Int = 10

// how many consecutive checkpoint failures can we tolerate - default is 0, we choose a more lenient value
// to allow us a few tries before we give up
val TolerableCheckpointFailures = 5
val TolerableCheckpointFailures: Int = 5

// Keep windows open for a bit longer before closing to ensure we don't lose data due to late arrivals (needed in case of
// tiling implementation)
Expand Down Expand Up @@ -306,7 +306,7 @@ object FlinkJob {
val kafkaBootstrap = jobArgs.kafkaBootstrap.toOption

val api = buildApi(onlineClassName, props)
val metadataStore = new MetadataStore(api.genKvStore, MetadataDataset, timeoutMillis = 10000)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10k is the default - there was a compiler warning

val metadataStore = new MetadataStore(FetchContext(api.genKvStore, MetadataDataset))

val flinkJob =
if (useMockedSource) {
Expand Down
2 changes: 1 addition & 1 deletion flink/src/main/scala/ai/chronon/flink/TestFlinkJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import ai.chronon.api.{StructType => ApiStructType}
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.flink.types.WriteResponse
import ai.chronon.online.Api
import ai.chronon.online.AvroCodec
import ai.chronon.online.AvroConversions
import ai.chronon.online.Extensions.StructTypeOps
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.online.serde.AvroCodec
import org.apache.flink.api.common.serialization.DeserializationSchema
import org.apache.flink.api.common.serialization.SerializationSchema
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import ai.chronon.api.Row
import ai.chronon.api.ScalaJavaConversions.ListOps
import ai.chronon.flink.types.TimestampedIR
import ai.chronon.flink.types.TimestampedTile
import ai.chronon.online.ArrayRow
import ai.chronon.online.TileCodec
import ai.chronon.online.serde.ArrayRow
import org.apache.flink.api.common.functions.AggregateFunction
import org.apache.flink.configuration.Configuration
import org.apache.flink.metrics.Counter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package org.apache.spark.sql.avro

import ai.chronon.api.{StructType => ChrononStructType}
import ai.chronon.flink.test.UserAvroSchema
import ai.chronon.online.{AvroCodec, AvroConversions, CatalystUtil}
import ai.chronon.online.serde.AvroCodec
import ai.chronon.online.{AvroConversions, CatalystUtil}
import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.flink.api.common.serialization.DeserializationSchema
Expand Down
37 changes: 22 additions & 15 deletions online/src/main/scala/ai/chronon/online/Api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,33 @@ trait KVStore {

// helper method to blocking read a string - used for fetching metadata & not in hotpath.
def getString(key: String, dataset: String, timeoutMillis: Long): Try[String] = {
val response = getResponse(key, dataset, timeoutMillis)
if (response.values.isFailure) {
Failure(new RuntimeException(s"Request for key ${key} in dataset ${dataset} failed", response.values.failed.get))
} else {
response.values.get.length match {
case 0 => {
Failure(new RuntimeException(s"Empty response from KVStore for key=${key} in dataset=${dataset}."))
}
case _ => Success(new String(response.latest.get.bytes, Constants.UTF8))

getResponse(key, dataset, timeoutMillis).values
.recoverWith { case ex =>
// wrap with more info
Failure(new RuntimeException(s"Request for key $key in dataset $dataset failed", ex))
}
.flatMap { values =>
if (values.isEmpty)
Failure(new RuntimeException(s"Empty response from KVStore for key=$key in dataset=$dataset."))
else
Success(new String(values.maxBy(_.millis).bytes, Constants.UTF8))
}
}
}

def getStringArray(key: String, dataset: String, timeoutMillis: Long): Try[Seq[String]] = {
val response = getResponse(key, dataset, timeoutMillis)
if (response.values.isFailure) {
Failure(new RuntimeException(s"Request for key ${key} in dataset ${dataset} failed", response.values.failed.get))
} else {
Success(StringArrayConverter.bytesToStrings(response.latest.get.bytes))
}

response.values
.map { values =>
val latestBytes = values.maxBy(_.millis).bytes
StringArrayConverter.bytesToStrings(latestBytes)
}
.recoverWith { case ex =>
// Wrap with more info
Failure(new RuntimeException(s"Request for key $key in dataset $dataset failed", ex))
}

}

private def getResponse(key: String, dataset: String, timeoutMillis: Long): GetResponse = {
Expand Down
4 changes: 2 additions & 2 deletions online/src/main/scala/ai/chronon/online/AvroConversions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ object AvroConversions {
}

def encodeBytes(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => Array[Byte] = {
val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true));
val codec: serde.AvroCodec = new serde.AvroCodec(fromChrononSchema(schema).toString(true));
{ data: Any =>
val record =
fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record]
Expand All @@ -193,7 +193,7 @@ object AvroConversions {
}

def encodeJson(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => String = {
val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true));
val codec: serde.AvroCodec = new serde.AvroCodec(fromChrononSchema(schema).toString(true));
{ data: Any =>
val record =
fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record]
Expand Down
32 changes: 0 additions & 32 deletions online/src/main/scala/ai/chronon/online/CompatParColls.scala

This file was deleted.

2 changes: 1 addition & 1 deletion online/src/main/scala/ai/chronon/online/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ object Extensions {

def toAvroSchema(name: String = null): Schema = AvroConversions.fromChrononSchema(toChrononSchema(name))

def toAvroCodec(name: String = null): AvroCodec = new AvroCodec(toAvroSchema(name).toString())
def toAvroCodec(name: String = null): serde.AvroCodec = new serde.AvroCodec(toAvroSchema(name).toString())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
package ai.chronon.online

import ai.chronon.api.Constants
import ai.chronon.online.fetcher.Fetcher.Request
import ai.chronon.online.fetcher.Fetcher.Response
import ai.chronon.online.fetcher.Fetcher.{Request, Response}

import scala.collection.Seq
import scala.collection.mutable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ import ai.chronon.api.Constants.ReversalField
import ai.chronon.api.Constants.TimeField
import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api.ScalaJavaConversions.ListOps
import ai.chronon.api._
import ai.chronon.online.OnlineDerivationUtil.DerivationFunc
import ai.chronon.online.OnlineDerivationUtil.buildDerivationFunction
import ai.chronon.online.serde.AvroCodec

import org.apache.avro.Schema

import scala.collection.JavaConverters.asScalaBufferConverter
Expand All @@ -43,7 +46,7 @@ class GroupByServingInfoParsed(val groupByServingInfo: GroupByServingInfo, parti

lazy val aggregator: SawtoothOnlineAggregator = {
new SawtoothOnlineAggregator(batchEndTsMillis,
groupByServingInfo.groupBy.aggregations.asScala.toSeq,
groupByServingInfo.groupBy.aggregations.toScala,
valueChrononSchema.fields.map(sf => (sf.name, sf.fieldType)))
}

Expand Down Expand Up @@ -77,11 +80,11 @@ class GroupByServingInfoParsed(val groupByServingInfo: GroupByServingInfo, parti
AvroConversions.fromChrononSchema(valueChrononSchema).toString()
}

def valueAvroCodec: AvroCodec = AvroCodec.of(valueAvroSchema)
def selectedCodec: AvroCodec = AvroCodec.of(selectedAvroSchema)
def valueAvroCodec: serde.AvroCodec = serde.AvroCodec.of(valueAvroSchema)
def selectedCodec: serde.AvroCodec = serde.AvroCodec.of(selectedAvroSchema)
lazy val irAvroSchema: String = AvroConversions.fromChrononSchema(irChrononSchema).toString()
def irCodec: AvroCodec = AvroCodec.of(irAvroSchema)
def outputCodec: AvroCodec = AvroCodec.of(outputAvroSchema)
def irCodec: serde.AvroCodec = serde.AvroCodec.of(irAvroSchema)
def outputCodec: serde.AvroCodec = serde.AvroCodec.of(outputAvroSchema)

// Start tiling specific variables

Expand All @@ -90,9 +93,12 @@ class GroupByServingInfoParsed(val groupByServingInfo: GroupByServingInfo, parti

// End tiling specific variables

def outputChrononSchema: StructType = {
StructType.from(s"${groupBy.metaData.cleanName}_OUTPUT", aggregator.windowedAggregator.outputSchema)
}
def outputChrononSchema: StructType =
if (groupByServingInfo.groupBy.aggregations == null) {
selectedChrononSchema
} else {
StructType.from(s"${groupBy.metaData.cleanName}_OUTPUT", aggregator.windowedAggregator.outputSchema)
}

lazy val outputAvroSchema: String = { AvroConversions.fromChrononSchema(outputChrononSchema).toString() }

Expand All @@ -118,7 +124,7 @@ class GroupByServingInfoParsed(val groupByServingInfo: GroupByServingInfo, parti
AvroConversions.toChrononSchema(parser.parse(mutationValueAvroSchema)).asInstanceOf[StructType]
}

def mutationValueAvroCodec: AvroCodec = AvroCodec.of(mutationValueAvroSchema)
def mutationValueAvroCodec: serde.AvroCodec = serde.AvroCodec.of(mutationValueAvroSchema)

// Schema for data consumed by the streaming job.
// Needs consistency with mutationDf Schema for backfill group by. (Shared queries)
Expand Down
8 changes: 5 additions & 3 deletions online/src/main/scala/ai/chronon/online/JoinCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ import ai.chronon.online.OnlineDerivationUtil.DerivationFunc
import ai.chronon.online.OnlineDerivationUtil.buildDerivationFunction
import ai.chronon.online.OnlineDerivationUtil.buildDerivedFields
import ai.chronon.online.OnlineDerivationUtil.buildRenameOnlyDerivationFunction
import ai.chronon.online.serde.AvroCodec

import com.google.gson.Gson

case class JoinCodec(conf: JoinOps,
keySchema: StructType,
baseValueSchema: StructType,
keyCodec: AvroCodec,
baseValueCodec: AvroCodec)
keyCodec: serde.AvroCodec,
baseValueCodec: serde.AvroCodec)
extends Serializable {

@transient lazy val valueSchema: StructType = {
Expand Down Expand Up @@ -87,7 +89,7 @@ case class JoinCodec(conf: JoinOps,

object JoinCodec {

def buildLoggingSchema(joinName: String, keyCodec: AvroCodec, valueCodec: AvroCodec): String = {
def buildLoggingSchema(joinName: String, keyCodec: serde.AvroCodec, valueCodec: serde.AvroCodec): String = {
val schemaMap = Map(
"join_name" -> joinName,
"key_schema" -> keyCodec.schemaStr,
Expand Down
1 change: 1 addition & 0 deletions online/src/main/scala/ai/chronon/online/TTLCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class TTLCache[I, O](f: I => O,

case class Entry(value: O, updatedAtMillis: Long, var markedForUpdate: AtomicBoolean = new AtomicBoolean(false))
@transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass)

private val updateWhenNull =
new function.BiFunction[I, Entry, Entry] {
override def apply(t: I, u: Entry): Entry = {
Expand Down
1 change: 1 addition & 0 deletions online/src/main/scala/ai/chronon/online/TileCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import ai.chronon.api.Extensions.WindowUtils
import ai.chronon.api.GroupBy
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.api.StructType
import ai.chronon.online.serde.AvroCodec
import org.apache.avro.generic.GenericData

import scala.collection.JavaConverters._
Expand Down
36 changes: 36 additions & 0 deletions online/src/main/scala/ai/chronon/online/fetcher/FetchContext.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package ai.chronon.online.fetcher
import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.ScalaJavaConversions.JMapOps
import ai.chronon.online.{FlagStore, FlagStoreConstants, FlexibleExecutionContext, KVStore}

import scala.concurrent.ExecutionContext

case class FetchContext(kvStore: KVStore,
metadataDataset: String = MetadataDataset,
timeoutMillis: Long = 10000,
debug: Boolean = false,
flagStore: FlagStore = null,
disableErrorThrows: Boolean = false,
executionContextOverride: ExecutionContext = null) {

def isTilingEnabled: Boolean = {
Option(flagStore)
.map(_.isSet(FlagStoreConstants.TILING_ENABLED, Map.empty[String, String].toJava))
.exists(_.asInstanceOf[Boolean])
}

def isCachingEnabled(groupByName: String): Boolean = {
Option(flagStore)
.exists(_.isSet("enable_fetcher_batch_ir_cache", Map("group_by_streaming_dataset" -> groupByName).toJava))
}

def shouldStreamingDecodeThrow(groupByName: String): Boolean = {
Option(flagStore)
.exists(
_.isSet("disable_streaming_decoding_error_throws", Map("group_by_streaming_dataset" -> groupByName).toJava))
}

def getOrCreateExecutionContext: ExecutionContext = {
Option(executionContextOverride).getOrElse(FlexibleExecutionContext.buildExecutionContext)
}
}
Loading