Skip to content

Commit e8dcb0e

Browse files
varant-zlaiezvz
andauthored
cherry pick Don't request null keys from KVStore #774 (#297)
## Summary Cherry pick: https://github.com/airbnb/chronon/pull/774/files ## Checklist - [x] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced group-by request processing with improved key validation - Added a new method for parsing group-by responses with more robust error handling - **Tests** - Added comprehensive test cases for new response parsing method - Verified handling of scenarios with null keys and missing keys <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: ezvz <[email protected]>
1 parent be363f6 commit e8dcb0e

File tree

2 files changed

+133
-70
lines changed

2 files changed

+133
-70
lines changed

online/src/main/scala/ai/chronon/online/FetcherBase.scala

Lines changed: 89 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -323,57 +323,63 @@ class FetcherBase(kvStore: KVStore,
323323
// 4. Finally converted to outputSchema
324324
def fetchGroupBys(requests: scala.collection.Seq[Request]): Future[scala.collection.Seq[Response]] = {
325325
// split a groupBy level request into its kvStore level requests
326-
val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = requests.iterator.map { request =>
327-
val groupByRequestMetaTry: Try[GroupByRequestMeta] = getGroupByServingInfo(request.name)
328-
.map { groupByServingInfo =>
329-
val context =
330-
request.context.getOrElse(Metrics.Context(Metrics.Environment.GroupByFetching, groupByServingInfo.groupBy))
331-
context.increment("group_by_request.count")
332-
var batchKeyBytes: Array[Byte] = null
333-
var streamingKeyBytes: Array[Byte] = null
334-
try {
335-
// The formats of key bytes for batch requests and key bytes for streaming requests may differ based
336-
// on the KVStore implementation, so we encode each distinctly.
337-
batchKeyBytes =
338-
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
339-
streamingKeyBytes =
340-
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.streamingDataset)
341-
} catch {
342-
// TODO: only gets hit in cli path - make this code path just use avro schema to decode keys directly in cli
343-
// TODO: Remove this code block
344-
case ex: Exception =>
345-
val castedKeys = groupByServingInfo.keyChrononSchema.fields.map {
346-
case StructField(name, typ) => name -> ColumnAggregator.castTo(request.keys.getOrElse(name, null), typ)
347-
}.toMap
348-
try {
349-
batchKeyBytes =
350-
kvStore.createKeyBytes(castedKeys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
351-
streamingKeyBytes =
352-
kvStore.createKeyBytes(castedKeys, groupByServingInfo, groupByServingInfo.groupByOps.streamingDataset)
353-
} catch {
354-
case exInner: Exception =>
355-
exInner.addSuppressed(ex)
356-
throw new RuntimeException("Couldn't encode request keys or casted keys", exInner)
357-
}
358-
}
359-
val batchRequest = GetRequest(batchKeyBytes, groupByServingInfo.groupByOps.batchDataset)
360-
val streamingRequestOpt = groupByServingInfo.groupByOps.inferredAccuracy match {
361-
// fetch batch(ir) and streaming(input) and aggregate
362-
case Accuracy.TEMPORAL =>
363-
Some(
364-
GetRequest(streamingKeyBytes,
365-
groupByServingInfo.groupByOps.streamingDataset,
366-
Some(groupByServingInfo.batchEndTsMillis)))
367-
// no further aggregation is required - the value in KvStore is good as is
368-
case Accuracy.SNAPSHOT => None
326+
val groupByRequestToKvRequest: Seq[(Request, Try[GroupByRequestMeta])] = requests.iterator
327+
.filter(r => r.keys == null || r.keys.values == null || r.keys.values.exists(_ != null))
328+
.map { request =>
329+
val groupByRequestMetaTry: Try[GroupByRequestMeta] = getGroupByServingInfo(request.name)
330+
.map { groupByServingInfo =>
331+
val context =
332+
request.context.getOrElse(
333+
Metrics.Context(Metrics.Environment.GroupByFetching, groupByServingInfo.groupBy))
334+
context.increment("group_by_request.count")
335+
var batchKeyBytes: Array[Byte] = null
336+
var streamingKeyBytes: Array[Byte] = null
337+
try {
338+
// The formats of key bytes for batch requests and key bytes for streaming requests may differ based
339+
// on the KVStore implementation, so we encode each distinctly.
340+
batchKeyBytes =
341+
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
342+
streamingKeyBytes =
343+
kvStore.createKeyBytes(request.keys, groupByServingInfo, groupByServingInfo.groupByOps.streamingDataset)
344+
} catch {
345+
// TODO: only gets hit in cli path - make this code path just use avro schema to decode keys directly in cli
346+
// TODO: Remove this code block
347+
case ex: Exception =>
348+
val castedKeys = groupByServingInfo.keyChrononSchema.fields.map {
349+
case StructField(name, typ) =>
350+
name -> ColumnAggregator.castTo(request.keys.getOrElse(name, null), typ)
351+
}.toMap
352+
try {
353+
batchKeyBytes =
354+
kvStore.createKeyBytes(castedKeys, groupByServingInfo, groupByServingInfo.groupByOps.batchDataset)
355+
streamingKeyBytes = kvStore.createKeyBytes(castedKeys,
356+
groupByServingInfo,
357+
groupByServingInfo.groupByOps.streamingDataset)
358+
} catch {
359+
case exInner: Exception =>
360+
exInner.addSuppressed(ex)
361+
throw new RuntimeException("Couldn't encode request keys or casted keys", exInner)
362+
}
363+
}
364+
val batchRequest = GetRequest(batchKeyBytes, groupByServingInfo.groupByOps.batchDataset)
365+
val streamingRequestOpt = groupByServingInfo.groupByOps.inferredAccuracy match {
366+
// fetch batch(ir) and streaming(input) and aggregate
367+
case Accuracy.TEMPORAL =>
368+
Some(
369+
GetRequest(streamingKeyBytes,
370+
groupByServingInfo.groupByOps.streamingDataset,
371+
Some(groupByServingInfo.batchEndTsMillis)))
372+
// no further aggregation is required - the value in KvStore is good as is
373+
case Accuracy.SNAPSHOT => None
374+
}
375+
GroupByRequestMeta(groupByServingInfo, batchRequest, streamingRequestOpt, request.atMillis, context)
369376
}
370-
GroupByRequestMeta(groupByServingInfo, batchRequest, streamingRequestOpt, request.atMillis, context)
377+
if (groupByRequestMetaTry.isFailure) {
378+
request.context.foreach(_.increment("group_by_serving_info_failure.count"))
371379
}
372-
if (groupByRequestMetaTry.isFailure) {
373-
request.context.foreach(_.increment("group_by_serving_info_failure.count"))
380+
request -> groupByRequestMetaTry
374381
}
375-
request -> groupByRequestMetaTry
376-
}.toSeq
382+
.toSeq
377383

378384
// If caching is enabled, we check if any of the GetRequests are already cached. If so, we store them in a Map
379385
// and avoid the work of re-fetching them. It is mainly for batch data requests.
@@ -583,28 +589,8 @@ class FetcherBase(kvStore: KVStore,
583589
case Right(keyMissingException) => {
584590
Map(keyMissingException.requestName + "_exception" -> keyMissingException.getMessage)
585591
}
586-
case Left(PrefixedRequest(prefix, groupByRequest)) => {
587-
responseMap
588-
.getOrElse(groupByRequest,
589-
Failure(new IllegalStateException(
590-
s"Couldn't find a groupBy response for $groupByRequest in response map")))
591-
.map { valueMap =>
592-
if (valueMap != null) {
593-
valueMap.map { case (aggName, aggValue) => prefix + "_" + aggName -> aggValue }
594-
} else {
595-
Map.empty[String, AnyRef]
596-
}
597-
}
598-
// prefix feature names
599-
.recover { // capture exception as a key
600-
case ex: Throwable =>
601-
if (debug || Math.random() < 0.001) {
602-
logger.error(s"Failed to fetch $groupByRequest", ex)
603-
}
604-
Map(groupByRequest.name + "_exception" -> ex.traceString)
605-
}
606-
.get
607-
}
592+
case Left(PrefixedRequest(prefix, groupByRequest)) =>
593+
parseGroupByResponse(prefix, groupByRequest, responseMap)
608594
}.toMap
609595
}
610596
joinValuesTry match {
@@ -624,6 +610,39 @@ class FetcherBase(kvStore: KVStore,
624610
}
625611
}
626612

613+
def parseGroupByResponse(prefix: String,
614+
groupByRequest: Request,
615+
responseMap: Map[Request, Try[Map[String, AnyRef]]]): Map[String, AnyRef] = {
616+
// Group bys with all null keys won't be requested from the KV store and we don't expect a response.
617+
val isRequiredRequest = groupByRequest.keys.values.exists(_ != null) || groupByRequest.keys.isEmpty
618+
619+
val response: Try[Map[String, AnyRef]] = responseMap.get(groupByRequest) match {
620+
case Some(value) => value
621+
case None =>
622+
if (isRequiredRequest)
623+
Failure(new IllegalStateException(s"Couldn't find a groupBy response for $groupByRequest in response map"))
624+
else Success(null)
625+
}
626+
627+
response
628+
.map { valueMap =>
629+
if (valueMap != null) {
630+
valueMap.map { case (aggName, aggValue) => prefix + "_" + aggName -> aggValue }
631+
} else {
632+
Map.empty[String, AnyRef]
633+
}
634+
}
635+
// prefix feature names
636+
.recover { // capture exception as a key
637+
case ex: Throwable =>
638+
if (debug || Math.random() < 0.001) {
639+
println(s"Failed to fetch $groupByRequest with \n${ex.traceString}")
640+
}
641+
Map(groupByRequest.name + "_exception" -> ex.traceString)
642+
}
643+
.get
644+
}
645+
627646
/**
628647
* Fetch method to simulate a random access interface for Chronon
629648
* by distributing requests to relevant GroupBys. This is a batch

online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import ai.chronon.online.Fetcher.Response
2727
import ai.chronon.online.FetcherCache.BatchResponses
2828
import ai.chronon.online.KVStore.TimedValue
2929
import ai.chronon.online._
30+
import org.junit.Assert.assertEquals
3031
import org.junit.Assert.assertFalse
3132
import org.junit.Assert.assertTrue
3233
import org.mockito.Answers
@@ -224,4 +225,47 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M
224225
assertFalse(fetcherBaseWithFlagStore.isCachingEnabled(buildGroupByWithCustomJson("test_groupby_2")))
225226
assertTrue(fetcherBaseWithFlagStore.isCachingEnabled(buildGroupByWithCustomJson("test_groupby_3")))
226227
}
228+
229+
it should "fetch in the happy case" in {
230+
val baseFetcher = new FetcherBase(mock[KVStore])
231+
val request = Request(name = "name", keys = Map("email" -> "email"), atMillis = None, context = None)
232+
val response: Map[Request, Try[Map[String, AnyRef]]] = Map(
233+
request -> Success(Map(
234+
"key" -> "value"
235+
))
236+
)
237+
238+
val result = baseFetcher.parseGroupByResponse("prefix", request, response)
239+
assertEquals(result, Map("prefix_key" -> "value"))
240+
}
241+
242+
it should "Not fetch with null keys" in {
243+
val baseFetcher = new FetcherBase(mock[KVStore])
244+
val request = Request(name = "name", keys = Map("email" -> null), atMillis = None, context = None)
245+
val request2 = Request(name = "name2", keys = Map("email" -> null), atMillis = None, context = None)
246+
247+
val response: Map[Request, Try[Map[String, AnyRef]]] = Map(
248+
request2 -> Success(Map(
249+
"key" -> "value"
250+
))
251+
)
252+
253+
val result = baseFetcher.parseGroupByResponse("prefix", request, response)
254+
result shouldBe Map()
255+
}
256+
257+
it should "parse with missing keys" in {
258+
val baseFetcher = new FetcherBase(mock[KVStore])
259+
val request = Request(name = "name", keys = Map("email" -> "email"), atMillis = None, context = None)
260+
val request2 = Request(name = "name2", keys = Map("email" -> "email"), atMillis = None, context = None)
261+
262+
val response: Map[Request, Try[Map[String, AnyRef]]] = Map(
263+
request2 -> Success(Map(
264+
"key" -> "value"
265+
))
266+
)
267+
268+
val result = baseFetcher.parseGroupByResponse("prefix", request, response)
269+
result.keySet shouldBe Set("name_exception")
270+
}
227271
}

0 commit comments

Comments
 (0)