Skip to content

Commit e2e82f7

Browse files
author
Haozhen Ding
committed
add UT
1 parent 875603e commit e2e82f7

File tree

2 files changed

+89
-20
lines changed

2 files changed

+89
-20
lines changed

online/src/main/scala/ai/chronon/online/MetadataStore.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class MetadataStore(kvStore: KVStore,
134134
if (result.isSuccess) Metrics.Context(Metrics.Environment.MetaDataFetching, result.get.join)
135135
else Metrics.Context(Metrics.Environment.MetaDataFetching, join = name)
136136
// Throw exception after metrics. No join metadata is bound to be a critical failure.
137+
// This will ensure that a Failure is never cached in the getJoinConf TTLCache
137138
if (result.isFailure) {
138139
context.withSuffix("join").incrementException(result.failed.get)
139140
throw result.failed.get

spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616

1717
package ai.chronon.spark.test
1818

19-
import org.slf4j.LoggerFactory
2019
import ai.chronon.aggregator.test.Column
2120
import ai.chronon.aggregator.windowing.TsUtils
2221
import ai.chronon.api
2322
import ai.chronon.api.Constants.ChrononMetadataKey
2423
import ai.chronon.api.Extensions.{JoinOps, MetadataOps}
2524
import ai.chronon.api._
2625
import ai.chronon.online.Fetcher.{Request, Response, StatsRequest}
27-
import ai.chronon.online.{JavaRequest, KVStore, LoggableResponseBase64, MetadataStore, SparkConversions}
26+
import ai.chronon.online.{JavaRequest, LoggableResponseBase64, MetadataStore, SparkConversions}
2827
import ai.chronon.spark.Extensions._
2928
import ai.chronon.spark.stats.ConsistencyJob
3029
import ai.chronon.spark.{Join => _, _}
@@ -34,16 +33,19 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow
3433
import org.apache.spark.sql.functions.{avg, col, lit}
3534
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
3635
import org.junit.Assert.{assertEquals, assertTrue}
36+
import org.mockito.ArgumentMatchers.{any, anyString}
37+
import org.mockito.Mockito.{reset, spy, when}
38+
import org.slf4j.LoggerFactory
3739

3840
import java.lang
3941
import java.util.TimeZone
4042
import java.util.concurrent.Executors
4143
import scala.collection.Seq
4244
import scala.compat.java8.FutureConverters
4345
import scala.concurrent.duration.{Duration, SECONDS}
44-
import scala.concurrent.{Await, ExecutionContext, Future}
45-
import scala.util.Random
46+
import scala.concurrent.{Await, ExecutionContext}
4647
import scala.util.ScalaJavaConversions._
48+
import scala.util.{Failure, Random, Try}
4749

4850
class FetcherTest extends TestCase {
4951
@transient lazy val logger = LoggerFactory.getLogger(getClass)
@@ -55,12 +57,15 @@ class FetcherTest extends TestCase {
5557
private val today = dummyTableUtils.partitionSpec.at(System.currentTimeMillis())
5658
private val yesterday = dummyTableUtils.partitionSpec.before(today)
5759

60+
private def createSparkSession(): SparkSession = {
61+
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
62+
}
63+
5864
/**
5965
* Generate deterministic data for testing and checkpointing IRs and streaming data.
6066
*/
61-
def generateMutationData(namespace: String): api.Join = {
62-
val spark: SparkSession =
63-
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
67+
def generateMutationData(namespace: String, sparkOpt: Option[SparkSession] = None): api.Join = {
68+
val spark: SparkSession = sparkOpt.getOrElse(createSparkSession())
6469
val tableUtils = TableUtils(spark)
6570
tableUtils.createDatabase(namespace)
6671
def toTs(arg: String): Long = TsUtils.datetimeToTs(arg)
@@ -203,8 +208,7 @@ class FetcherTest extends TestCase {
203208
}
204209

205210
def generateRandomData(namespace: String, keyCount: Int = 10, cardinality: Int = 100): api.Join = {
206-
val spark: SparkSession =
207-
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
211+
val spark: SparkSession = createSparkSession()
208212
val tableUtils = TableUtils(spark)
209213
tableUtils.createDatabase(namespace)
210214
val rowCount = cardinality * keyCount
@@ -390,8 +394,7 @@ class FetcherTest extends TestCase {
390394
}
391395

392396
def generateEventOnlyData(namespace: String, groupByCustomJson: Option[String] = None): api.Join = {
393-
val spark: SparkSession =
394-
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
397+
val spark: SparkSession = createSparkSession()
395398
val tableUtils = TableUtils(spark)
396399
tableUtils.createDatabase(namespace)
397400
def toTs(arg: String): Long = TsUtils.datetimeToTs(arg)
@@ -519,8 +522,7 @@ class FetcherTest extends TestCase {
519522
consistencyCheck: Boolean,
520523
dropDsOnWrite: Boolean): Unit = {
521524
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
522-
val spark: SparkSession =
523-
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
525+
val spark: SparkSession = createSparkSession()
524526
val tableUtils = TableUtils(spark)
525527
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest")
526528
val inMemoryKvStore = kvStoreFunc()
@@ -696,8 +698,7 @@ class FetcherTest extends TestCase {
696698

697699
// test soft-fail on missing keys
698700
def testEmptyRequest(): Unit = {
699-
val spark: SparkSession =
700-
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
701+
val spark: SparkSession = createSparkSession()
701702
val namespace = "empty_request"
702703
val joinConf = generateRandomData(namespace, 5, 5)
703704
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
@@ -723,10 +724,9 @@ class FetcherTest extends TestCase {
723724

724725
def testTemporalFetchGroupByNonExistKey(): Unit = {
725726
val namespace = "non_exist_key_group_by_fetch"
726-
val joinConf = generateMutationData(namespace)
727+
val spark: SparkSession = createSparkSession()
728+
val joinConf = generateMutationData(namespace, Some(spark))
727729
val endDs = "2021-04-10"
728-
val spark: SparkSession =
729-
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
730730
val tableUtils = TableUtils(spark)
731731
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest")
732732
val inMemoryKvStore = kvStoreFunc()
@@ -749,8 +749,7 @@ class FetcherTest extends TestCase {
749749

750750
def testKVStorePartialFailure(): Unit = {
751751

752-
val spark: SparkSession =
753-
SparkSessionBuilder.build(sessionName + "_" + Random.alphanumeric.take(6).mkString, local = true)
752+
val spark: SparkSession = createSparkSession()
754753
val namespace = "test_kv_store_partial_failure"
755754
val joinConf = generateRandomData(namespace, 5, 5)
756755
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
@@ -777,6 +776,75 @@ class FetcherTest extends TestCase {
777776
exceptionKeys.foreach(k => assertTrue(responseMap.contains(k)))
778777
}
779778

779+
def testGroupByServingInfoTtlCacheRefresh(): Unit = {
780+
val namespace = "test_group_by_serving_info_ttl_cache_refresh"
781+
val spark: SparkSession = createSparkSession()
782+
val joinConf = generateMutationData(namespace, Some(spark))
783+
val groupByConf = joinConf.joinParts.toScala.head.groupBy
784+
val endDs = "2021-04-10"
785+
val tableUtils = TableUtils(spark)
786+
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest")
787+
OnlineUtils.serve(tableUtils, kvStoreFunc(), kvStoreFunc, namespace, endDs, groupByConf, dropDsOnWrite = true)
788+
789+
val spyKvStore = spy(kvStoreFunc())
790+
val mockApi = new MockApi(() => spyKvStore, namespace)
791+
@transient lazy val fetcher = mockApi.buildFetcher()
792+
793+
/* 1st request: kv store failure */
794+
when(spyKvStore.getString(anyString(), anyString(), any()))
795+
.thenReturn(Failure(new Exception("kvstore error")))
796+
val request = Seq(Request(groupByConf.metaData.name, Map("listing_id" -> 1L.asInstanceOf[AnyRef])))
797+
def fetch(): Response = Await.result(fetcher.fetchGroupBys(request), Duration(10, SECONDS)).head
798+
val response1 = fetch()
799+
assertTrue(response1.values.isFailure)
800+
801+
Thread.sleep(10000) // Wait for ttl cache refresh interval to expire
802+
803+
/* kv store recovers, 2nd request still fails, but will run async refresh */
804+
reset(spyKvStore)
805+
val response2 = fetch()
806+
assertTrue(response2.values.isFailure)
807+
808+
Thread.sleep(10000) // Wait for ttl cache async update to finish
809+
810+
/* 3rd request uses the refreshed kvstore result */
811+
val response3 = fetch()
812+
assertTrue(response3.values.isSuccess)
813+
}
814+
815+
def testJoinConfTtlCacheRefresh(): Unit = {
816+
val namespace = "test_join_conf_ttl_cache_refresh"
817+
val spark: SparkSession = createSparkSession()
818+
val joinConf = generateMutationData(namespace, Some(spark))
819+
val endDs = "2021-04-10"
820+
val tableUtils = TableUtils(spark)
821+
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherTest")
822+
val inMemoryKvStore = kvStoreFunc()
823+
joinConf.joinParts.toScala.foreach(jp =>
824+
OnlineUtils.serve(tableUtils, inMemoryKvStore, kvStoreFunc, namespace, endDs, jp.groupBy, dropDsOnWrite = true))
825+
val metadataStore = new MetadataStore(inMemoryKvStore, timeoutMillis = 10000)
826+
inMemoryKvStore.create(ChrononMetadataKey)
827+
metadataStore.putJoinConf(joinConf)
828+
829+
val spyKvStore = spy(inMemoryKvStore)
830+
val mockApi = new MockApi(() => spyKvStore, namespace)
831+
@transient lazy val fetcher = mockApi.buildFetcher()
832+
833+
/* 1st request: kv store failure. */
834+
when(spyKvStore.getString(anyString(), anyString(), any()))
835+
.thenReturn(Failure(new Exception("kvstore error")))
836+
val request = Seq(Request(joinConf.metaData.name, Map("listing_id" -> 1L.asInstanceOf[AnyRef])))
837+
def fetch(): Try[Response] = Try(Await.result(fetcher.fetchJoin(request), Duration(10, SECONDS)).head)
838+
val response1 = fetch()
839+
assertTrue(response1.isFailure)
840+
841+
Thread.sleep(10000) // Wait for ttl cache refresh interval to expire
842+
843+
/* kv store recovers, 2nd request should immediately succeed, since getJoinConf never caches Failure */
844+
reset(spyKvStore)
845+
val response2 = fetch()
846+
assertTrue(response2.isSuccess)
847+
}
780848
}
781849

782850
object FetcherTestUtil {

0 commit comments

Comments
 (0)