Skip to content

Commit b45d1c2

Browse files
hzding621Haozhen Ding
and
Haozhen Ding
authored
fix: propagate kvstore multi get exception properly (#955)
* squash multi-get failure * UT --------- Co-authored-by: Haozhen Ding <[email protected]>
1 parent 352391c commit b45d1c2

File tree

6 files changed

+88
-49
lines changed

6 files changed

+88
-49
lines changed

online/src/main/scala/ai/chronon/online/Api.scala

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,29 @@ trait KVStore {
5858

5959
// helper method to blocking read a string - used for fetching metadata & not in hotpath.
6060
def getString(key: String, dataset: String, timeoutMillis: Long): Try[String] = {
61-
val response = getResponse(key, dataset, timeoutMillis)
62-
if (response.values.isFailure) {
63-
Failure(new RuntimeException(s"Request for key ${key} in dataset ${dataset} failed", response.values.failed.get))
64-
} else {
65-
Success(new String(response.latest.get.bytes, Constants.UTF8))
66-
}
61+
val bytesTry = getResponse(key, dataset, timeoutMillis)
62+
bytesTry.map(bytes => new String(bytes, Constants.UTF8))
6763
}
6864

6965
def getStringArray(key: String, dataset: String, timeoutMillis: Long): Try[Seq[String]] = {
70-
val response = getResponse(key, dataset, timeoutMillis)
71-
if (response.values.isFailure) {
72-
Failure(new RuntimeException(s"Request for key ${key} in dataset ${dataset} failed", response.values.failed.get))
73-
} else {
74-
Success(StringArrayConverter.bytesToStrings(response.latest.get.bytes))
75-
}
66+
val bytesTry = getResponse(key, dataset, timeoutMillis)
67+
bytesTry.map(bytes => StringArrayConverter.bytesToStrings(bytes))
7668
}
7769

78-
private def getResponse(key: String, dataset: String, timeoutMillis: Long): GetResponse = {
70+
private def getResponse(key: String, dataset: String, timeoutMillis: Long): Try[Array[Byte]] = {
7971
val fetchRequest = KVStore.GetRequest(key.getBytes(Constants.UTF8), dataset)
8072
val responseFutureOpt = get(fetchRequest)
81-
Await.result(responseFutureOpt, Duration(timeoutMillis, MILLISECONDS))
73+
def buildException(e: Throwable) = new RuntimeException(s"Request for key ${key} in dataset ${dataset} failed", e)
74+
Try(Await.result(responseFutureOpt, Duration(timeoutMillis, MILLISECONDS))) match {
75+
case Failure(e) =>
76+
Failure(buildException(e))
77+
case Success(resp) =>
78+
if (resp.values.isFailure) {
79+
Failure(buildException(resp.values.failed.get))
80+
} else {
81+
Success(resp.latest.get.bytes)
82+
}
83+
}
8284
}
8385
def get(request: GetRequest): Future[GetResponse] = {
8486
multiGet(Seq(request))

online/src/main/scala/ai/chronon/online/FetcherBase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ class FetcherBase(kvStore: KVStore,
670670
if (debug || Math.random() < 0.001) {
671671
logger.error(s"Failed to fetch $groupByRequest", ex)
672672
}
673-
Map(groupByRequest.name + "_exception" -> ex.traceString)
673+
Map(prefix + "_exception" -> ex.traceString)
674674
}
675675
.get
676676
}

spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import ai.chronon.api.Constants.ChrononMetadataKey
2424
import ai.chronon.api.Extensions.{JoinOps, MetadataOps}
2525
import ai.chronon.api._
2626
import ai.chronon.online.Fetcher.{Request, Response, StatsRequest}
27-
import ai.chronon.online.{JavaRequest, LoggableResponseBase64, MetadataStore, SparkConversions}
27+
import ai.chronon.online.{JavaRequest, KVStore, LoggableResponseBase64, MetadataStore, SparkConversions}
2828
import ai.chronon.spark.Extensions._
2929
import ai.chronon.spark.stats.ConsistencyJob
3030
import ai.chronon.spark.{Join => _, _}
@@ -41,7 +41,7 @@ import java.util.concurrent.Executors
4141
import scala.collection.Seq
4242
import scala.compat.java8.FutureConverters
4343
import scala.concurrent.duration.{Duration, SECONDS}
44-
import scala.concurrent.{Await, ExecutionContext}
44+
import scala.concurrent.{Await, ExecutionContext, Future}
4545
import scala.util.Random
4646
import scala.util.ScalaJavaConversions._
4747

@@ -55,12 +55,12 @@ class FetcherTest extends TestCase {
5555
private val today = dummyTableUtils.partitionSpec.at(System.currentTimeMillis())
5656
private val yesterday = dummyTableUtils.partitionSpec.before(today)
5757

58-
5958
/**
6059
* Generate deterministic data for testing and checkpointing IRs and streaming data.
6160
*/
6261
def generateMutationData(namespace: String): api.Join = {
63-
val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
62+
val spark: SparkSession =
63+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
6464
val tableUtils = TableUtils(spark)
6565
tableUtils.createDatabase(namespace)
6666
def toTs(arg: String): Long = TsUtils.datetimeToTs(arg)
@@ -188,7 +188,7 @@ class FetcherTest extends TestCase {
188188
),
189189
accuracy = Accuracy.TEMPORAL,
190190
metaData = Builders.MetaData(name = "unit_test/fetcher_mutations_gb", namespace = namespace, team = "chronon"),
191-
derivations=Seq(
191+
derivations = Seq(
192192
Builders.Derivation(name = "*", expression = "*"),
193193
Builders.Derivation(name = "rating_average_1d_same", expression = "rating_average_1d")
194194
)
@@ -203,7 +203,8 @@ class FetcherTest extends TestCase {
203203
}
204204

205205
def generateRandomData(namespace: String, keyCount: Int = 10, cardinality: Int = 100): api.Join = {
206-
val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
206+
val spark: SparkSession =
207+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
207208
val tableUtils = TableUtils(spark)
208209
tableUtils.createDatabase(namespace)
209210
val rowCount = cardinality * keyCount
@@ -312,9 +313,8 @@ class FetcherTest extends TestCase {
312313
sources = Seq(Builders.Source.entities(query = Builders.Query(), snapshotTable = creditTable)),
313314
keyColumns = Seq("vendor_id"),
314315
aggregations = Seq(
315-
Builders.Aggregation(operation = Operation.SUM,
316-
inputColumn = "credit",
317-
windows = Seq(new Window(3, TimeUnit.DAYS)))),
316+
Builders
317+
.Aggregation(operation = Operation.SUM, inputColumn = "credit", windows = Seq(new Window(3, TimeUnit.DAYS)))),
318318
metaData = Builders.MetaData(name = "unit_test/vendor_credit_derivation", namespace = namespace),
319319
derivations = Seq(
320320
Builders.Derivation("credit_sum_3d_test_rename", "credit_sum_3d"),
@@ -390,7 +390,8 @@ class FetcherTest extends TestCase {
390390
}
391391

392392
def generateEventOnlyData(namespace: String, groupByCustomJson: Option[String] = None): api.Join = {
393-
val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
393+
val spark: SparkSession =
394+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
394395
val tableUtils = TableUtils(spark)
395396
tableUtils.createDatabase(namespace)
396397
def toTs(arg: String): Long = TsUtils.datetimeToTs(arg)
@@ -518,7 +519,8 @@ class FetcherTest extends TestCase {
518519
consistencyCheck: Boolean,
519520
dropDsOnWrite: Boolean): Unit = {
520521
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
521-
val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
522+
val spark: SparkSession =
523+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
522524
val tableUtils = TableUtils(spark)
523525
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest")
524526
val inMemoryKvStore = kvStoreFunc()
@@ -655,8 +657,10 @@ class FetcherTest extends TestCase {
655657
def testTemporalFetchJoinDerivation(): Unit = {
656658
val namespace = "derivation_fetch"
657659
val joinConf = generateMutationData(namespace)
658-
val derivations = Seq(Builders.Derivation(name = "*", expression = "*"),
659-
Builders.Derivation(name = "unit_test_fetcher_mutations_gb_rating_sum_plus", expression = "unit_test_fetcher_mutations_gb_rating_sum + 1"),
660+
val derivations = Seq(
661+
Builders.Derivation(name = "*", expression = "*"),
662+
Builders.Derivation(name = "unit_test_fetcher_mutations_gb_rating_sum_plus",
663+
expression = "unit_test_fetcher_mutations_gb_rating_sum + 1"),
660664
Builders.Derivation(name = "listing_id_renamed", expression = "listing_id")
661665
)
662666
joinConf.setDerivations(derivations.toJava)
@@ -668,14 +672,12 @@ class FetcherTest extends TestCase {
668672
val namespace = "derivation_fetch_rename_only"
669673
val joinConf = generateMutationData(namespace)
670674
val derivations = Seq(Builders.Derivation(name = "*", expression = "*"),
671-
Builders.Derivation(name = "listing_id_renamed", expression = "listing_id")
672-
)
675+
Builders.Derivation(name = "listing_id_renamed", expression = "listing_id"))
673676
joinConf.setDerivations(derivations.toJava)
674677

675678
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
676679
}
677680

678-
679681
def testTemporalFetchJoinGenerated(): Unit = {
680682
val namespace = "generated_fetch"
681683
val joinConf = generateRandomData(namespace)
@@ -694,7 +696,8 @@ class FetcherTest extends TestCase {
694696

695697
// test soft-fail on missing keys
696698
def testEmptyRequest(): Unit = {
697-
val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
699+
val spark: SparkSession =
700+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
698701
val namespace = "empty_request"
699702
val joinConf = generateRandomData(namespace, 5, 5)
700703
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
@@ -722,33 +725,58 @@ class FetcherTest extends TestCase {
722725
val namespace = "non_exist_key_group_by_fetch"
723726
val joinConf = generateMutationData(namespace)
724727
val endDs = "2021-04-10"
725-
val spark: SparkSession = SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
728+
val spark: SparkSession =
729+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
726730
val tableUtils = TableUtils(spark)
727731
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest")
728732
val inMemoryKvStore = kvStoreFunc()
729733
val mockApi = new MockApi(kvStoreFunc, namespace)
730-
@transient lazy val fetcher = mockApi.buildFetcher(debug=false)
734+
@transient lazy val fetcher = mockApi.buildFetcher(debug = false)
731735

732736
joinConf.joinParts.toScala.foreach(jp =>
733-
OnlineUtils.serve(tableUtils,
734-
inMemoryKvStore,
735-
kvStoreFunc,
736-
namespace,
737-
endDs,
738-
jp.groupBy,
739-
dropDsOnWrite = true))
737+
OnlineUtils.serve(tableUtils, inMemoryKvStore, kvStoreFunc, namespace, endDs, jp.groupBy, dropDsOnWrite = true))
740738

741739
// a random key that doesn't exist
742740
val nonExistKey = 123L
743-
val request = Request("unit_test/fetcher_mutations_gb",
744-
Map("listing_id" -> nonExistKey.asInstanceOf[AnyRef]))
741+
val request = Request("unit_test/fetcher_mutations_gb", Map("listing_id" -> nonExistKey.asInstanceOf[AnyRef]))
745742
val response = fetcher.fetchGroupBys(Seq(request))
746743
val result = Await.result(response, Duration(10, SECONDS))
747744

748745
// result should be "null" if the key is not found
749746
val expected: Map[String, AnyRef] = Map("rating_average_1d_same" -> null)
750747
assertEquals(expected, result.head.values.get)
751748
}
749+
750+
def testKVStorePartialFailure(): Unit = {
751+
752+
val spark: SparkSession =
753+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
754+
val namespace = "test_kv_store_partial_failure"
755+
val joinConf = generateRandomData(namespace, 5, 5)
756+
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
757+
758+
val kvStoreFunc = () =>
759+
OnlineUtils.buildInMemoryKVStore("FetcherTest#test_kv_store_partial_failure", hardFailureOnInvalidDataset = true)
760+
val inMemoryKvStore = kvStoreFunc()
761+
val mockApi = new MockApi(kvStoreFunc, namespace)
762+
763+
val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000)
764+
inMemoryKvStore.create(ChrononMetadataKey)
765+
metadataStore.putJoinConf(joinConf)
766+
767+
val keys = joinConf.leftKeyCols
768+
val keyData = spark.table(s"$namespace.queries_table").select(keys.map(col): _*).head
769+
val keyMap = keys.indices.map { idx =>
770+
keys(idx) -> keyData.get(idx).asInstanceOf[AnyRef]
771+
}.toMap
772+
773+
val request = Request(joinConf.metaData.nameToFilePath, keyMap)
774+
val (responses, _) = FetcherTestUtil.joinResponses(spark, Array(request), mockApi)
775+
val responseMap = responses.head.values.get
776+
val exceptionKeys = joinConf.joinPartOps.map(jp => jp.fullPrefix + "_exception")
777+
exceptionKeys.foreach(k => assertTrue(responseMap.contains(k)))
778+
}
779+
752780
}
753781

754782
object FetcherTestUtil {

spark/src/test/scala/ai/chronon/spark/test/InMemoryKvStore.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ import scala.collection.mutable
2929
import scala.concurrent.Future
3030
import scala.util.Try
3131

32-
class InMemoryKvStore(tableUtils: () => TableUtils) extends KVStore with Serializable {
32+
class InMemoryKvStore(tableUtils: () => TableUtils, hardFailureOnInvalidDataset: Boolean = false)
33+
extends KVStore
34+
with Serializable {
3335
//type aliases for readability
3436
type Key = String
3537
type Data = Array[Byte]
@@ -47,6 +49,9 @@ class InMemoryKvStore(tableUtils: () => TableUtils) extends KVStore with Seriali
4749
// emulate IO latency
4850
Thread.sleep(4)
4951
requests.map { req =>
52+
if (!database.containsKey(req.dataset) && hardFailureOnInvalidDataset) {
53+
throw new RuntimeException(s"Invalid dataset: ${req.dataset}")
54+
}
5055
val values = Try {
5156
database
5257
.get(req.dataset) // table
@@ -144,13 +149,15 @@ object InMemoryKvStore {
144149
// We would like to create one instance of InMemoryKVStore per executors, but share SparkContext
145150
// across them. Since SparkContext is not serializable, we wrap TableUtils that has SparkContext
146151
// in a closure and pass it around.
147-
def build(testName: String, tableUtils: () => TableUtils): InMemoryKvStore = {
152+
def build(testName: String,
153+
tableUtils: () => TableUtils,
154+
hardFailureOnInvalidDataset: Boolean = false): InMemoryKvStore = {
148155
stores.computeIfAbsent(
149156
testName,
150157
new function.Function[String, InMemoryKvStore] {
151158
override def apply(name: String): InMemoryKvStore = {
152159
logger.info(s"Missing in-memory store for name: $name. Creating one")
153-
new InMemoryKvStore(tableUtils)
160+
new InMemoryKvStore(tableUtils, hardFailureOnInvalidDataset)
154161
}
155162
}
156163
)

spark/src/test/scala/ai/chronon/spark/test/JavaFetcherTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public class JavaFetcherTest {
4141
String namespace = "java_fetcher_test";
4242
SparkSession session = SparkSessionBuilder.build(namespace, true, scala.Option.apply(null), scala.Option.apply(null), true);
4343
TableUtils tu = new TableUtils(session);
44-
InMemoryKvStore kvStore = new InMemoryKvStore(func(() -> tu));
44+
InMemoryKvStore kvStore = new InMemoryKvStore(func(() -> tu), false);
4545
MockApi mockApi = new MockApi(func(() -> kvStore), "java_fetcher_test");
4646
JavaFetcher fetcher = mockApi.buildJavaFetcher();
4747

spark/src/test/scala/ai/chronon/spark/test/OnlineUtils.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ object OnlineUtils {
190190
inMemoryKvStore.bulkPut(joinConf.metaData.consistencyUploadTable, Constants.ConsistencyMetricsDataset, null)
191191
}
192192

193-
def buildInMemoryKVStore(sessionName: String): InMemoryKvStore = {
194-
InMemoryKvStore.build(sessionName, { () => TableUtils(SparkSessionBuilder.build(sessionName, local = true)) })
193+
def buildInMemoryKVStore(sessionName: String, hardFailureOnInvalidDataset: Boolean = false): InMemoryKvStore = {
194+
InMemoryKvStore.build(sessionName,
195+
{ () => TableUtils(SparkSessionBuilder.build(sessionName, local = true)) },
196+
hardFailureOnInvalidDataset)
195197
}
196198
}

0 commit comments

Comments
 (0)