Skip to content

Commit db26796

Browse files
authored
Cherrypick OSS fetcher failure handling PRs - #932 and #964 (#706)
## Summary Pull in PRs - airbnb/chronon#964 and airbnb/chronon#932. We hit issues related to 964 in some of our tests at Etsy - groupByServingInfo lookups against BT timed out and we end up caching the failure response. 964 addresses this and it depends on 932 so pulling that in as well. ## Checklist - [ ] Added Unit Tests - [X] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Improved error handling and reporting for partial failures in join operations and key-value store lookups. - Enhanced cache refresh mechanisms for join configurations and metadata, improving system robustness during failures. - Added a configurable option to control strictness on invalid dataset references in the in-memory key-value store. - **Bug Fixes** - Exceptions and partial failures are now more accurately surfaced in fetch responses, ensuring clearer diagnostics for end-users. - Updated error key naming for consistency in response maps. - **Tests** - Added a new test to verify correct handling and reporting of partial failures in key-value store operations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent d6cabc5 commit db26796

File tree

11 files changed

+190
-98
lines changed

11 files changed

+190
-98
lines changed

online/src/main/scala/ai/chronon/online/Api.scala

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -73,44 +73,31 @@ trait KVStore {
7373

7474
// helper method to blocking read a string - used for fetching metadata & not in hotpath.
7575
def getString(key: String, dataset: String, timeoutMillis: Long): Try[String] = {
76-
77-
getResponse(key, dataset, timeoutMillis).values
78-
.recoverWith { case ex =>
79-
// wrap with more info
80-
Failure(new RuntimeException(s"Request for key $key in dataset $dataset failed", ex))
81-
}
82-
.flatMap { values =>
83-
if (values.isEmpty)
84-
Failure(new RuntimeException(s"Empty response from KVStore for key=$key in dataset=$dataset."))
85-
else
86-
Success(new String(values.maxBy(_.millis).bytes, Constants.UTF8))
87-
}
76+
val bytesTry = getResponse(key, dataset, timeoutMillis)
77+
bytesTry.map(bytes => new String(bytes, Constants.UTF8))
8878
}
8979

9080
def getStringArray(key: String, dataset: String, timeoutMillis: Long): Try[Seq[String]] = {
91-
val response = getResponse(key, dataset, timeoutMillis)
92-
93-
response.values
94-
.map { values =>
95-
val latestBytes = values.maxBy(_.millis).bytes
96-
StringArrayConverter.bytesToStrings(latestBytes)
97-
}
98-
.recoverWith { case ex =>
99-
// Wrap with more info
100-
Failure(new RuntimeException(s"Request for key $key in dataset $dataset failed", ex))
101-
}
102-
81+
val bytesTry = getResponse(key, dataset, timeoutMillis)
82+
bytesTry.map(bytes => StringArrayConverter.bytesToStrings(bytes))
10383
}
10484

105-
private def getResponse(key: String, dataset: String, timeoutMillis: Long): GetResponse = {
106-
try {
107-
val fetchRequest = KVStore.GetRequest(key.getBytes(Constants.UTF8), dataset)
108-
val responseFutureOpt = get(fetchRequest)
109-
Await.result(responseFutureOpt, Duration(timeoutMillis, MILLISECONDS))
110-
} catch {
111-
case ex: Exception =>
112-
ex.printStackTrace()
113-
throw ex
85+
private def getResponse(key: String, dataset: String, timeoutMillis: Long): Try[Array[Byte]] = {
86+
val fetchRequest = KVStore.GetRequest(key.getBytes(Constants.UTF8), dataset)
87+
val responseFutureOpt = get(fetchRequest)
88+
89+
def buildException(e: Throwable) =
90+
new RuntimeException(s"Request for key ${key} in dataset ${dataset} failed", e)
91+
92+
Try(Await.result(responseFutureOpt, Duration(timeoutMillis, MILLISECONDS))) match {
93+
case Failure(e) =>
94+
Failure(buildException(e))
95+
case Success(resp) =>
96+
if (resp.values.isFailure) {
97+
Failure(buildException(resp.values.failed.get))
98+
} else {
99+
Success(resp.latest.get.bytes)
100+
}
114101
}
115102
}
116103

online/src/main/scala/ai/chronon/online/JoinCodec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ case class JoinCodec(conf: JoinOps,
3535
keySchema: StructType,
3636
baseValueSchema: StructType,
3737
keyCodec: AvroCodec,
38-
baseValueCodec: AvroCodec)
38+
baseValueCodec: AvroCodec,
39+
hasPartialFailure: Boolean = false)
3940
extends Serializable {
4041

4142
@transient lazy val valueSchema: StructType = {

online/src/main/scala/ai/chronon/online/fetcher/Fetcher.scala

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,16 @@ class Fetcher(val kvStore: KVStore,
241241
ctx.distribution("derivation.latency.millis", requestEndTs - derivationStartTs)
242242
ctx.distribution("request.latency.millis", requestEndTs - ts)
243243

244-
ResponseWithContext(request, finalizedDerivedMap, baseMap)
244+
val response = ResponseWithContext(request, finalizedDerivedMap, baseMap)
245+
// Refresh joinCodec if it has partial failure
246+
if (joinCodec.hasPartialFailure) {
247+
joinCodecCache.refresh(joinName)
248+
}
249+
response
245250

246251
case Failure(exception) =>
247252
// more validation logic will be covered in compile.py to avoid this case
253+
joinCodecCache.refresh(joinName)
248254
ctx.incrementException(exception)
249255
ResponseWithContext(request, Map("join_codec_fetch_exception" -> exception.traceString), Map.empty)
250256

@@ -293,14 +299,15 @@ class Fetcher(val kvStore: KVStore,
293299

294300
val joinCodecTry = joinCodecCache(resp.request.name)
295301

296-
val loggingTry: Try[Unit] = joinCodecTry.map(codec => {
297-
val metaData = codec.conf.join.metaData
298-
val samplePercent = if (metaData.isSetSamplePercent) metaData.getSamplePercent else 0
302+
val loggingTry: Try[Unit] = joinCodecTry
303+
.map(codec => {
304+
val metaData = codec.conf.join.metaData
305+
val samplePercent = if (metaData.isSetSamplePercent) metaData.getSamplePercent else 0
299306

300-
if (samplePercent > 0)
301-
encodeAndPublishLog(resp, ts, codec, samplePercent)
307+
if (samplePercent > 0)
308+
encodeAndPublishLog(resp, ts, codec, samplePercent)
302309

303-
})
310+
})
304311

305312
loggingTry.failed.map { exception =>
306313
// to handle GroupByServingInfo staleness that results in encoding failure
@@ -310,6 +317,10 @@ class Fetcher(val kvStore: KVStore,
310317
_.incrementException(new RuntimeException(s"Logging failed due to: ${exception.traceString}", exception)))
311318
}
312319

320+
if (joinCodecTry.isSuccess && joinCodecTry.get.hasPartialFailure) {
321+
joinCodecCache.refresh(resp.request.name)
322+
}
323+
313324
Response(resp.request, Success(resp.derivedValues))
314325
}
315326

@@ -390,6 +401,7 @@ class Fetcher(val kvStore: KVStore,
390401
val joinName = request.name
391402
val joinConfTry: Try[JoinOps] = metadataStore.getJoinConf(request.name)
392403
if (joinConfTry.isFailure) {
404+
metadataStore.getJoinConf.refresh(request.name)
393405
resultMap.update(
394406
request,
395407
Failure(
@@ -412,6 +424,10 @@ class Fetcher(val kvStore: KVStore,
412424
// step-2 dedup external requests across joins
413425
val externalToJoinRequests: Seq[ExternalToJoinRequest] = validRequests
414426
.flatMap { joinRequest =>
427+
val joinConf = metadataStore.getJoinConf(joinRequest.name)
428+
if (joinConf.isFailure) {
429+
metadataStore.getJoinConf.refresh(joinRequest.name)
430+
}
415431
val parts =
416432
metadataStore
417433
.getJoinConf(joinRequest.name)
@@ -498,18 +514,22 @@ class Fetcher(val kvStore: KVStore,
498514

499515
val joinSchemaResponse = joinCodecTry
500516
.map { joinCodec =>
501-
JoinSchemaResponse(joinName,
502-
joinCodec.keyCodec.schemaStr,
503-
joinCodec.valueCodec.schemaStr,
504-
joinCodec.loggingSchemaHash)
517+
val response = JoinSchemaResponse(joinName,
518+
joinCodec.keyCodec.schemaStr,
519+
joinCodec.valueCodec.schemaStr,
520+
joinCodec.loggingSchemaHash)
521+
if (joinCodec.hasPartialFailure) {
522+
joinCodecCache.refresh(joinName)
523+
}
524+
ctx.distribution("response.latency.millis", System.currentTimeMillis() - startTime)
525+
response
505526
}
506-
.recover { case exception: Throwable =>
527+
.recover { case exception =>
507528
logger.error(s"Failed to fetch join schema for $joinName", exception)
508529
ctx.incrementException(exception)
509530
throw exception
510531
}
511532

512-
joinSchemaResponse.foreach(_ => ctx.distribution("response.latency.millis", System.currentTimeMillis() - startTime))
513533
joinSchemaResponse
514534
}
515535

online/src/main/scala/ai/chronon/online/fetcher/GroupByFetcher.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ class GroupByFetcher(fetchContext: FetchContext, metadataStore: MetadataStore)
4343
*/
4444
private def toLambdaKvRequest(request: Fetcher.Request): Try[LambdaKvRequest] = metadataStore
4545
.getGroupByServingInfo(request.name)
46+
.recover { case ex: Throwable =>
47+
metadataStore.getGroupByServingInfo.refresh(request.name)
48+
logger.error(s"Couldn't fetch GroupByServingInfo for ${request.name}", ex)
49+
request.context.foreach(_.incrementException(ex))
50+
throw ex
51+
}
4652
.map { groupByServingInfo =>
4753
val context =
4854
request.context.getOrElse(

online/src/main/scala/ai/chronon/online/fetcher/JoinPartFetcher.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,16 @@ class JoinPartFetcher(fetchContext: FetchContext, metadataStore: MetadataStore)
6161
val joinDecomposed: Seq[(Request, Try[Seq[Either[PrefixedRequest, KeyMissingException]]])] =
6262
requests.map { request =>
6363
// use passed-in join or fetch one
64-
import ai.chronon.online.metrics
65-
val joinTry: Try[JoinOps] = joinConf
66-
.map(conf => Success(JoinOps(conf)))
67-
.getOrElse(metadataStore.getJoinConf(request.name))
64+
val joinTry: Try[JoinOps] = if (joinConf.isEmpty) {
65+
val joinConfTry = metadataStore.getJoinConf(request.name)
66+
if (joinConfTry.isFailure) {
67+
metadataStore.getJoinConf.refresh(request.name)
68+
}
69+
joinConfTry
70+
} else {
71+
logger.debug(s"Using passed in join configuration: ${joinConf.get.metaData.getName}")
72+
Success(JoinOps(joinConf.get))
73+
}
6874

6975
var joinContext: Option[metrics.Metrics.Context] = None
7076

@@ -163,7 +169,7 @@ class JoinPartFetcher(fetchContext: FetchContext, metadataStore: MetadataStore)
163169
if (fetchContext.debug || Math.random() < 0.001) {
164170
println(s"Failed to fetch $groupByRequest with \n${ex.traceString}")
165171
}
166-
Map(groupByRequest.name + "_exception" -> ex.traceString)
172+
Map(prefix + "_exception" -> ex.traceString)
167173
}
168174
.get
169175
}

online/src/main/scala/ai/chronon/online/fetcher/MetadataStore.scala

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ class MetadataStore(fetchContext: FetchContext) {
160160
if (result.isSuccess) metrics.Metrics.Context(metrics.Metrics.Environment.MetaDataFetching, result.get.join)
161161
else metrics.Metrics.Context(metrics.Metrics.Environment.MetaDataFetching, join = name)
162162
// Throw exception after metrics. No join metadata is bound to be a critical failure.
163+
// This will ensure that a Failure is never cached in the getJoinConf TTLCache
163164
if (result.isFailure) {
164-
import ai.chronon.online.metrics
165-
context.withSuffix("join").increment(metrics.Metrics.Name.Exception)
165+
context.withSuffix("join").incrementException(result.failed.get)
166166
throw result.failed.get
167167
}
168168
context
@@ -239,20 +239,56 @@ class MetadataStore(fetchContext: FetchContext) {
239239
doRetrieveAllListConfs(new mutable.ArrayBuffer[String]())
240240
}
241241

242+
private def buildJoinPartCodec(
243+
joinPart: JoinPartOps,
244+
servingInfo: GroupByServingInfoParsed): (Iterable[StructField], Iterable[StructField]) = {
245+
val keySchema = servingInfo.keyCodec.chrononSchema.asInstanceOf[StructType]
246+
val joinKeyFields = joinPart.leftToRight
247+
.map { case (leftKey, rightKey) =>
248+
StructField(leftKey, keySchema.fields.find(_.name == rightKey).get.fieldType)
249+
}
250+
251+
val baseValueSchema: StructType = if (servingInfo.groupBy.aggregations == null) {
252+
servingInfo.selectedChrononSchema
253+
} else {
254+
servingInfo.outputChrononSchema
255+
}
256+
val valueFields = if (!servingInfo.groupBy.hasDerivations) {
257+
baseValueSchema.fields
258+
} else {
259+
buildDerivedFields(servingInfo.groupBy.derivationsScala, keySchema, baseValueSchema).toArray
260+
}
261+
val joinValueFields = valueFields.map(joinPart.constructJoinPartSchema)
262+
263+
(joinKeyFields, joinValueFields)
264+
}
265+
242266
// key and value schemas
243267
def buildJoinCodecCache(onCreateFunc: Option[Try[JoinCodec] => Unit]): TTLCache[String, Try[JoinCodec]] = {
244268

245269
val codecBuilder = { joinName: String =>
246-
getJoinConf(joinName)
247-
.map(_.join)
248-
.map(buildJoinCodec)
249-
.recoverWith { case th: Throwable =>
250-
Failure(
251-
new RuntimeException(
252-
s"Couldn't fetch joinName = ${joinName} or build join codec due to ${th.traceString}",
253-
th
254-
))
270+
val startTimeMs = System.currentTimeMillis()
271+
val result: Try[JoinCodec] =
272+
try {
273+
getJoinConf(joinName)
274+
.map(_.join)
275+
.map(join => buildJoinCodec(join, refreshOnFail = true))
276+
} catch {
277+
case th: Throwable =>
278+
getJoinConf.refresh(joinName)
279+
Failure(
280+
new RuntimeException(
281+
s"Couldn't fetch joinName = ${joinName} or build join codec due to ${th.traceString}",
282+
th
283+
))
255284
}
285+
val context = Metrics.Context(Metrics.Environment.MetaDataFetching, join = joinName).withSuffix("join_codec")
286+
if (result.isFailure) {
287+
context.incrementException(result.failed.get)
288+
} else {
289+
context.distribution(Metrics.Name.LatencyMillis, System.currentTimeMillis() - startTimeMs)
290+
}
291+
result
256292
}
257293

258294
new TTLCache[String, Try[JoinCodec]](
@@ -265,38 +301,32 @@ class MetadataStore(fetchContext: FetchContext) {
265301
)
266302
}
267303

268-
def buildJoinCodec(joinConf: Join): JoinCodec = {
304+
def buildJoinCodec(joinConf: Join, refreshOnFail: Boolean): JoinCodec = {
269305
val keyFields = new mutable.LinkedHashSet[StructField]
270306
val valueFields = new mutable.ListBuffer[StructField]
307+
var hasPartialFailure = false
271308
// collect keyFields and valueFields from joinParts/GroupBys
272309
joinConf.joinPartOps.foreach { joinPart =>
273-
val servingInfoTry = getGroupByServingInfo(joinPart.groupBy.metaData.getName)
274-
servingInfoTry
310+
getGroupByServingInfo(joinPart.groupBy.metaData.getName)
275311
.map { servingInfo =>
276-
val keySchema = servingInfo.keyCodec.chrononSchema.asInstanceOf[StructType]
277-
joinPart.leftToRight
278-
.mapValues(right => keySchema.fields.find(_.name == right).get.fieldType)
279-
.foreach { case (name, dType) =>
280-
val keyField = StructField(name, dType)
281-
keyFields.add(keyField)
282-
}
283-
val groupBySchemaBeforeDerivation: StructType =
284-
if (servingInfo.groupBy.aggregations == null) {
285-
servingInfo.selectedChrononSchema
312+
val (keys, values) = buildJoinPartCodec(joinPart, servingInfo)
313+
keys.foreach(k => keyFields.add(k))
314+
values.foreach(v => valueFields.append(v))
315+
}
316+
.recoverWith {
317+
case exception: Throwable => {
318+
if (refreshOnFail) {
319+
getGroupByServingInfo.refresh(joinPart.groupBy.metaData.getName)
320+
hasPartialFailure = true
321+
Success(())
286322
} else {
287-
servingInfo.outputChrononSchema
323+
Failure(new Exception(
324+
s"Failure to build join codec for join ${joinConf.metaData.name} due to bad groupBy serving info for ${joinPart.groupBy.metaData.name}",
325+
exception))
288326
}
289-
val baseValueSchema: StructType = if (!servingInfo.groupBy.hasDerivations) {
290-
groupBySchemaBeforeDerivation
291-
} else {
292-
val fields =
293-
buildDerivedFields(servingInfo.groupBy.derivationsScala, keySchema, groupBySchemaBeforeDerivation)
294-
StructType(s"groupby_derived_${servingInfo.groupBy.metaData.cleanName}", fields.toArray)
295-
}
296-
baseValueSchema.fields.foreach { sf =>
297-
valueFields.append(joinPart.constructJoinPartSchema(sf))
298327
}
299328
}
329+
.get
300330
}
301331

302332
// gather key schema and value schema from external sources.
@@ -325,8 +355,7 @@ class MetadataStore(fetchContext: FetchContext) {
325355
val keyCodec = AvroCodec.of(AvroConversions.fromChrononSchema(keySchema).toString)
326356
val baseValueSchema = StructType(s"${joinName.sanitize}_value", valueFields.toArray)
327357
val baseValueCodec = serde.AvroCodec.of(AvroConversions.fromChrononSchema(baseValueSchema).toString)
328-
val joinCodec = JoinCodec(joinConf, keySchema, baseValueSchema, keyCodec, baseValueCodec)
329-
joinCodec
358+
JoinCodec(joinConf, keySchema, baseValueSchema, keyCodec, baseValueCodec, hasPartialFailure)
330359
}
331360

332361
def getSchemaFromKVStore(dataset: String, key: String): serde.AvroCodec = {
@@ -366,6 +395,10 @@ class MetadataStore(fetchContext: FetchContext) {
366395
}
367396
logger.info(s"Fetched ${Constants.GroupByServingInfoKey} from : $batchDataset")
368397
if (metaData.isFailure) {
398+
Metrics
399+
.Context(Metrics.Environment.MetaDataFetching, groupBy = name)
400+
.withSuffix("group_by")
401+
.incrementException(metaData.failed.get)
369402
Failure(
370403
new RuntimeException(s"Couldn't fetch group by serving info for $batchDataset, " +
371404
"please make sure a batch upload was successful",

online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
244244
)
245245

246246
val result = baseFetcher.parseGroupByResponse("prefix", request, response)
247-
result.keySet shouldBe Set("name_exception")
247+
result.keySet shouldBe Set("prefix_exception")
248248
}
249249

250250
it should "check late batch data is handled correctly" in {

spark/src/main/scala/ai/chronon/spark/streaming/JoinSourceRunner.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,12 @@ class JoinSourceRunner(groupByConf: api.GroupBy, conf: Map[String, String] = Map
193193
val leftSourceSchema: StructType = outputSchema(leftStreamSchema, enrichQuery(left.query)) // apply same thing
194194

195195
// joinSchema = leftSourceSchema ++ joinCodec.valueSchema
196-
val joinCodec: JoinCodec = apiImpl.buildFetcher(debug).metadataStore.buildJoinCodec(joinSource.getJoin)
196+
val joinCodec: JoinCodec =
197+
apiImpl
198+
.buildFetcher(debug)
199+
.metadataStore
200+
// immediately fails if the codec has partial error to avoid using stale codec
201+
.buildJoinCodec(joinSource.getJoin, refreshOnFail = false)
197202
val joinValueSchema: StructType = SparkConversions.fromChrononSchema(joinCodec.valueSchema)
198203
val joinSchema: StructType = StructType(leftSourceSchema ++ joinValueSchema)
199204
val joinSourceSchema: StructType = outputSchema(joinSchema, enrichQuery(joinSource.query))

0 commit comments

Comments
 (0)