16
16
17
17
package ai .chronon .spark .test
18
18
19
- import org .slf4j .LoggerFactory
20
19
import ai .chronon .aggregator .test .Column
21
20
import ai .chronon .aggregator .windowing .TsUtils
22
21
import ai .chronon .api
23
22
import ai .chronon .api .Constants .ChrononMetadataKey
24
23
import ai .chronon .api .Extensions .{JoinOps , MetadataOps }
25
24
import ai .chronon .api ._
26
25
import ai .chronon .online .Fetcher .{Request , Response , StatsRequest }
27
- import ai .chronon .online .{JavaRequest , KVStore , LoggableResponseBase64 , MetadataStore , SparkConversions }
26
+ import ai .chronon .online .{JavaRequest , LoggableResponseBase64 , MetadataStore , SparkConversions }
28
27
import ai .chronon .spark .Extensions ._
29
28
import ai .chronon .spark .stats .ConsistencyJob
30
29
import ai .chronon .spark .{Join => _ , _ }
@@ -34,16 +33,19 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow
34
33
import org .apache .spark .sql .functions .{avg , col , lit }
35
34
import org .apache .spark .sql .{DataFrame , Row , SparkSession }
36
35
import org .junit .Assert .{assertEquals , assertTrue }
36
+ import org .mockito .ArgumentMatchers .{any , anyString }
37
+ import org .mockito .Mockito .{reset , spy , when }
38
+ import org .slf4j .LoggerFactory
37
39
38
40
import java .lang
39
41
import java .util .TimeZone
40
42
import java .util .concurrent .Executors
41
43
import scala .collection .Seq
42
44
import scala .compat .java8 .FutureConverters
43
45
import scala .concurrent .duration .{Duration , SECONDS }
44
- import scala .concurrent .{Await , ExecutionContext , Future }
45
- import scala .util .Random
46
+ import scala .concurrent .{Await , ExecutionContext }
46
47
import scala .util .ScalaJavaConversions ._
48
+ import scala .util .{Failure , Random , Try }
47
49
48
50
class FetcherTest extends TestCase {
49
51
@ transient lazy val logger = LoggerFactory .getLogger(getClass)
@@ -55,12 +57,15 @@ class FetcherTest extends TestCase {
55
57
private val today = dummyTableUtils.partitionSpec.at(System .currentTimeMillis())
56
58
private val yesterday = dummyTableUtils.partitionSpec.before(today)
57
59
60
+ private def createSparkSession (): SparkSession = {
61
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
62
+ }
63
+
58
64
/**
59
65
* Generate deterministic data for testing and checkpointing IRs and streaming data.
60
66
*/
61
- def generateMutationData (namespace : String ): api.Join = {
62
- val spark : SparkSession =
63
- SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
67
+ def generateMutationData (namespace : String , sparkOpt : Option [SparkSession ] = None ): api.Join = {
68
+ val spark : SparkSession = sparkOpt.getOrElse(createSparkSession())
64
69
val tableUtils = TableUtils (spark)
65
70
tableUtils.createDatabase(namespace)
66
71
def toTs (arg : String ): Long = TsUtils .datetimeToTs(arg)
@@ -203,8 +208,7 @@ class FetcherTest extends TestCase {
203
208
}
204
209
205
210
def generateRandomData (namespace : String , keyCount : Int = 10 , cardinality : Int = 100 ): api.Join = {
206
- val spark : SparkSession =
207
- SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
211
+ val spark : SparkSession = createSparkSession()
208
212
val tableUtils = TableUtils (spark)
209
213
tableUtils.createDatabase(namespace)
210
214
val rowCount = cardinality * keyCount
@@ -390,8 +394,7 @@ class FetcherTest extends TestCase {
390
394
}
391
395
392
396
def generateEventOnlyData (namespace : String , groupByCustomJson : Option [String ] = None ): api.Join = {
393
- val spark : SparkSession =
394
- SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
397
+ val spark : SparkSession = createSparkSession()
395
398
val tableUtils = TableUtils (spark)
396
399
tableUtils.createDatabase(namespace)
397
400
def toTs (arg : String ): Long = TsUtils .datetimeToTs(arg)
@@ -519,8 +522,7 @@ class FetcherTest extends TestCase {
519
522
consistencyCheck : Boolean ,
520
523
dropDsOnWrite : Boolean ): Unit = {
521
524
implicit val executionContext : ExecutionContext = ExecutionContext .fromExecutor(Executors .newFixedThreadPool(1 ))
522
- val spark : SparkSession =
523
- SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
525
+ val spark : SparkSession = createSparkSession()
524
526
val tableUtils = TableUtils (spark)
525
527
val kvStoreFunc = () => OnlineUtils .buildInMemoryKVStore(" FetcherTest" )
526
528
val inMemoryKvStore = kvStoreFunc()
@@ -696,8 +698,7 @@ class FetcherTest extends TestCase {
696
698
697
699
// test soft-fail on missing keys
698
700
def testEmptyRequest (): Unit = {
699
- val spark : SparkSession =
700
- SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
701
+ val spark : SparkSession = createSparkSession()
701
702
val namespace = " empty_request"
702
703
val joinConf = generateRandomData(namespace, 5 , 5 )
703
704
implicit val executionContext : ExecutionContext = ExecutionContext .fromExecutor(Executors .newFixedThreadPool(1 ))
@@ -723,10 +724,9 @@ class FetcherTest extends TestCase {
723
724
724
725
def testTemporalFetchGroupByNonExistKey (): Unit = {
725
726
val namespace = " non_exist_key_group_by_fetch"
726
- val joinConf = generateMutationData(namespace)
727
+ val spark : SparkSession = createSparkSession()
728
+ val joinConf = generateMutationData(namespace, Some (spark))
727
729
val endDs = " 2021-04-10"
728
- val spark : SparkSession =
729
- SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
730
730
val tableUtils = TableUtils (spark)
731
731
val kvStoreFunc = () => OnlineUtils .buildInMemoryKVStore(" FetcherTest" )
732
732
val inMemoryKvStore = kvStoreFunc()
@@ -749,8 +749,7 @@ class FetcherTest extends TestCase {
749
749
750
750
def testKVStorePartialFailure (): Unit = {
751
751
752
- val spark : SparkSession =
753
- SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
752
+ val spark : SparkSession = createSparkSession()
754
753
val namespace = " test_kv_store_partial_failure"
755
754
val joinConf = generateRandomData(namespace, 5 , 5 )
756
755
implicit val executionContext : ExecutionContext = ExecutionContext .fromExecutor(Executors .newFixedThreadPool(1 ))
@@ -777,6 +776,75 @@ class FetcherTest extends TestCase {
777
776
exceptionKeys.foreach(k => assertTrue(responseMap.contains(k)))
778
777
}
779
778
779
+ def testGroupByServingInfoTtlCacheRefresh (): Unit = {
780
+ val namespace = " test_group_by_serving_info_ttl_cache_refresh"
781
+ val spark : SparkSession = createSparkSession()
782
+ val joinConf = generateMutationData(namespace, Some (spark))
783
+ val groupByConf = joinConf.joinParts.toScala.head.groupBy
784
+ val endDs = " 2021-04-10"
785
+ val tableUtils = TableUtils (spark)
786
+ val kvStoreFunc = () => OnlineUtils .buildInMemoryKVStore(" FetcherTest" )
787
+ OnlineUtils .serve(tableUtils, kvStoreFunc(), kvStoreFunc, namespace, endDs, groupByConf, dropDsOnWrite = true )
788
+
789
+ val spyKvStore = spy(kvStoreFunc())
790
+ val mockApi = new MockApi (() => spyKvStore, namespace)
791
+ @ transient lazy val fetcher = mockApi.buildFetcher()
792
+
793
+ /* 1st request: kv store failure */
794
+ when(spyKvStore.getString(anyString(), anyString(), any()))
795
+ .thenReturn(Failure (new Exception (" kvstore error" )))
796
+ val request = Seq (Request (groupByConf.metaData.name, Map (" listing_id" -> 1L .asInstanceOf [AnyRef ])))
797
+ def fetch (): Response = Await .result(fetcher.fetchGroupBys(request), Duration (10 , SECONDS )).head
798
+ val response1 = fetch()
799
+ assertTrue(response1.values.isFailure)
800
+
801
+ Thread .sleep(10000 ) // Wait for ttl cache refresh interval to expire
802
+
803
+ /* kv store recovers, 2nd request still fails, but will run async refresh */
804
+ reset(spyKvStore)
805
+ val response2 = fetch()
806
+ assertTrue(response2.values.isFailure)
807
+
808
+ Thread .sleep(10000 ) // Wait for ttl cache async update to finish
809
+
810
+ /* 3rd request uses the refreshed kvstore result */
811
+ val response3 = fetch()
812
+ assertTrue(response3.values.isSuccess)
813
+ }
814
+
815
+ def testJoinConfTtlCacheRefresh (): Unit = {
816
+ val namespace = " test_join_conf_ttl_cache_refresh"
817
+ val spark : SparkSession = createSparkSession()
818
+ val joinConf = generateMutationData(namespace, Some (spark))
819
+ val endDs = " 2021-04-10"
820
+ val tableUtils = TableUtils (spark)
821
+ val kvStoreFunc = () => OnlineUtils .buildInMemoryKVStore(" FetcherTest" )
822
+ val inMemoryKvStore = kvStoreFunc()
823
+ joinConf.joinParts.toScala.foreach(jp =>
824
+ OnlineUtils .serve(tableUtils, inMemoryKvStore, kvStoreFunc, namespace, endDs, jp.groupBy, dropDsOnWrite = true ))
825
+ val metadataStore = new MetadataStore (inMemoryKvStore, timeoutMillis = 10000 )
826
+ inMemoryKvStore.create(ChrononMetadataKey )
827
+ metadataStore.putJoinConf(joinConf)
828
+
829
+ val spyKvStore = spy(inMemoryKvStore)
830
+ val mockApi = new MockApi (() => spyKvStore, namespace)
831
+ @ transient lazy val fetcher = mockApi.buildFetcher()
832
+
833
+ /* 1st request: kv store failure. */
834
+ when(spyKvStore.getString(anyString(), anyString(), any()))
835
+ .thenReturn(Failure (new Exception (" kvstore error" )))
836
+ val request = Seq (Request (joinConf.metaData.name, Map (" listing_id" -> 1L .asInstanceOf [AnyRef ])))
837
+ def fetch (): Try [Response ] = Try (Await .result(fetcher.fetchJoin(request), Duration (10 , SECONDS )).head)
838
+ val response1 = fetch()
839
+ assertTrue(response1.isFailure)
840
+
841
+ Thread .sleep(10000 ) // Wait for ttl cache refresh interval to expire
842
+
843
+ /* kv store recovers, 2nd request should immediately succeed, since getJoinConf never caches Failure */
844
+ reset(spyKvStore)
845
+ val response2 = fetch()
846
+ assertTrue(response2.isSuccess)
847
+ }
780
848
}
781
849
782
850
object FetcherTestUtil {
0 commit comments