@@ -24,7 +24,7 @@ import ai.chronon.api.Constants.ChrononMetadataKey
24
24
import ai .chronon .api .Extensions .{JoinOps , MetadataOps }
25
25
import ai .chronon .api ._
26
26
import ai .chronon .online .Fetcher .{Request , Response , StatsRequest }
27
- import ai .chronon .online .{JavaRequest , LoggableResponseBase64 , MetadataStore , SparkConversions }
27
+ import ai .chronon .online .{JavaRequest , KVStore , LoggableResponseBase64 , MetadataStore , SparkConversions }
28
28
import ai .chronon .spark .Extensions ._
29
29
import ai .chronon .spark .stats .ConsistencyJob
30
30
import ai .chronon .spark .{Join => _ , _ }
@@ -41,7 +41,7 @@ import java.util.concurrent.Executors
41
41
import scala .collection .Seq
42
42
import scala .compat .java8 .FutureConverters
43
43
import scala .concurrent .duration .{Duration , SECONDS }
44
- import scala .concurrent .{Await , ExecutionContext }
44
+ import scala .concurrent .{Await , ExecutionContext , Future }
45
45
import scala .util .Random
46
46
import scala .util .ScalaJavaConversions ._
47
47
@@ -55,12 +55,12 @@ class FetcherTest extends TestCase {
55
55
private val today = dummyTableUtils.partitionSpec.at(System .currentTimeMillis())
56
56
private val yesterday = dummyTableUtils.partitionSpec.before(today)
57
57
58
-
59
58
/**
60
59
* Generate deterministic data for testing and checkpointing IRs and streaming data.
61
60
*/
62
61
def generateMutationData (namespace : String ): api.Join = {
63
- val spark : SparkSession = SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
62
+ val spark : SparkSession =
63
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
64
64
val tableUtils = TableUtils (spark)
65
65
tableUtils.createDatabase(namespace)
66
66
def toTs (arg : String ): Long = TsUtils .datetimeToTs(arg)
@@ -188,7 +188,7 @@ class FetcherTest extends TestCase {
188
188
),
189
189
accuracy = Accuracy .TEMPORAL ,
190
190
metaData = Builders .MetaData (name = " unit_test/fetcher_mutations_gb" , namespace = namespace, team = " chronon" ),
191
- derivations= Seq (
191
+ derivations = Seq (
192
192
Builders .Derivation (name = " *" , expression = " *" ),
193
193
Builders .Derivation (name = " rating_average_1d_same" , expression = " rating_average_1d" )
194
194
)
@@ -203,7 +203,8 @@ class FetcherTest extends TestCase {
203
203
}
204
204
205
205
def generateRandomData (namespace : String , keyCount : Int = 10 , cardinality : Int = 100 ): api.Join = {
206
- val spark : SparkSession = SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
206
+ val spark : SparkSession =
207
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
207
208
val tableUtils = TableUtils (spark)
208
209
tableUtils.createDatabase(namespace)
209
210
val rowCount = cardinality * keyCount
@@ -312,9 +313,8 @@ class FetcherTest extends TestCase {
312
313
sources = Seq (Builders .Source .entities(query = Builders .Query (), snapshotTable = creditTable)),
313
314
keyColumns = Seq (" vendor_id" ),
314
315
aggregations = Seq (
315
- Builders .Aggregation (operation = Operation .SUM ,
316
- inputColumn = " credit" ,
317
- windows = Seq (new Window (3 , TimeUnit .DAYS )))),
316
+ Builders
317
+ .Aggregation (operation = Operation .SUM , inputColumn = " credit" , windows = Seq (new Window (3 , TimeUnit .DAYS )))),
318
318
metaData = Builders .MetaData (name = " unit_test/vendor_credit_derivation" , namespace = namespace),
319
319
derivations = Seq (
320
320
Builders .Derivation (" credit_sum_3d_test_rename" , " credit_sum_3d" ),
@@ -390,7 +390,8 @@ class FetcherTest extends TestCase {
390
390
}
391
391
392
392
def generateEventOnlyData (namespace : String , groupByCustomJson : Option [String ] = None ): api.Join = {
393
- val spark : SparkSession = SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
393
+ val spark : SparkSession =
394
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
394
395
val tableUtils = TableUtils (spark)
395
396
tableUtils.createDatabase(namespace)
396
397
def toTs (arg : String ): Long = TsUtils .datetimeToTs(arg)
@@ -518,7 +519,8 @@ class FetcherTest extends TestCase {
518
519
consistencyCheck : Boolean ,
519
520
dropDsOnWrite : Boolean ): Unit = {
520
521
implicit val executionContext : ExecutionContext = ExecutionContext .fromExecutor(Executors .newFixedThreadPool(1 ))
521
- val spark : SparkSession = SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
522
+ val spark : SparkSession =
523
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
522
524
val tableUtils = TableUtils (spark)
523
525
val kvStoreFunc = () => OnlineUtils .buildInMemoryKVStore(" FetcherTest" )
524
526
val inMemoryKvStore = kvStoreFunc()
@@ -655,8 +657,10 @@ class FetcherTest extends TestCase {
655
657
def testTemporalFetchJoinDerivation (): Unit = {
656
658
val namespace = " derivation_fetch"
657
659
val joinConf = generateMutationData(namespace)
658
- val derivations = Seq (Builders .Derivation (name = " *" , expression = " *" ),
659
- Builders .Derivation (name = " unit_test_fetcher_mutations_gb_rating_sum_plus" , expression = " unit_test_fetcher_mutations_gb_rating_sum + 1" ),
660
+ val derivations = Seq (
661
+ Builders .Derivation (name = " *" , expression = " *" ),
662
+ Builders .Derivation (name = " unit_test_fetcher_mutations_gb_rating_sum_plus" ,
663
+ expression = " unit_test_fetcher_mutations_gb_rating_sum + 1" ),
660
664
Builders .Derivation (name = " listing_id_renamed" , expression = " listing_id" )
661
665
)
662
666
joinConf.setDerivations(derivations.toJava)
@@ -668,14 +672,12 @@ class FetcherTest extends TestCase {
668
672
val namespace = " derivation_fetch_rename_only"
669
673
val joinConf = generateMutationData(namespace)
670
674
val derivations = Seq (Builders .Derivation (name = " *" , expression = " *" ),
671
- Builders .Derivation (name = " listing_id_renamed" , expression = " listing_id" )
672
- )
675
+ Builders .Derivation (name = " listing_id_renamed" , expression = " listing_id" ))
673
676
joinConf.setDerivations(derivations.toJava)
674
677
675
678
compareTemporalFetch(joinConf, " 2021-04-10" , namespace, consistencyCheck = false , dropDsOnWrite = true )
676
679
}
677
680
678
-
679
681
def testTemporalFetchJoinGenerated (): Unit = {
680
682
val namespace = " generated_fetch"
681
683
val joinConf = generateRandomData(namespace)
@@ -694,7 +696,8 @@ class FetcherTest extends TestCase {
694
696
695
697
// test soft-fail on missing keys
696
698
def testEmptyRequest (): Unit = {
697
- val spark : SparkSession = SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
699
+ val spark : SparkSession =
700
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
698
701
val namespace = " empty_request"
699
702
val joinConf = generateRandomData(namespace, 5 , 5 )
700
703
implicit val executionContext : ExecutionContext = ExecutionContext .fromExecutor(Executors .newFixedThreadPool(1 ))
@@ -722,33 +725,58 @@ class FetcherTest extends TestCase {
722
725
val namespace = " non_exist_key_group_by_fetch"
723
726
val joinConf = generateMutationData(namespace)
724
727
val endDs = " 2021-04-10"
725
- val spark : SparkSession = SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
728
+ val spark : SparkSession =
729
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
726
730
val tableUtils = TableUtils (spark)
727
731
val kvStoreFunc = () => OnlineUtils .buildInMemoryKVStore(" FetcherTest" )
728
732
val inMemoryKvStore = kvStoreFunc()
729
733
val mockApi = new MockApi (kvStoreFunc, namespace)
730
- @ transient lazy val fetcher = mockApi.buildFetcher(debug= false )
734
+ @ transient lazy val fetcher = mockApi.buildFetcher(debug = false )
731
735
732
736
joinConf.joinParts.toScala.foreach(jp =>
733
- OnlineUtils .serve(tableUtils,
734
- inMemoryKvStore,
735
- kvStoreFunc,
736
- namespace,
737
- endDs,
738
- jp.groupBy,
739
- dropDsOnWrite = true ))
737
+ OnlineUtils .serve(tableUtils, inMemoryKvStore, kvStoreFunc, namespace, endDs, jp.groupBy, dropDsOnWrite = true ))
740
738
741
739
// a random key that doesn't exist
742
740
val nonExistKey = 123L
743
- val request = Request (" unit_test/fetcher_mutations_gb" ,
744
- Map (" listing_id" -> nonExistKey.asInstanceOf [AnyRef ]))
741
+ val request = Request (" unit_test/fetcher_mutations_gb" , Map (" listing_id" -> nonExistKey.asInstanceOf [AnyRef ]))
745
742
val response = fetcher.fetchGroupBys(Seq (request))
746
743
val result = Await .result(response, Duration (10 , SECONDS ))
747
744
748
745
// result should be "null" if the key is not found
749
746
val expected : Map [String , AnyRef ] = Map (" rating_average_1d_same" -> null )
750
747
assertEquals(expected, result.head.values.get)
751
748
}
749
+
750
+ def testKVStorePartialFailure (): Unit = {
751
+
752
+ val spark : SparkSession =
753
+ SparkSessionBuilder .build(sessionName + " _" + Random .alphanumeric.take(6 ).mkString, local = true )
754
+ val namespace = " test_kv_store_partial_failure"
755
+ val joinConf = generateRandomData(namespace, 5 , 5 )
756
+ implicit val executionContext : ExecutionContext = ExecutionContext .fromExecutor(Executors .newFixedThreadPool(1 ))
757
+
758
+ val kvStoreFunc = () =>
759
+ OnlineUtils .buildInMemoryKVStore(" FetcherTest#test_kv_store_partial_failure" , hardFailureOnInvalidDataset = true )
760
+ val inMemoryKvStore = kvStoreFunc()
761
+ val mockApi = new MockApi (kvStoreFunc, namespace)
762
+
763
+ val metadataStore = new MetadataStore (inMemoryKvStore, timeoutMillis = 10000 )
764
+ inMemoryKvStore.create(ChrononMetadataKey )
765
+ metadataStore.putJoinConf(joinConf)
766
+
767
+ val keys = joinConf.leftKeyCols
768
+ val keyData = spark.table(s " $namespace.queries_table " ).select(keys.map(col): _* ).head
769
+ val keyMap = keys.indices.map { idx =>
770
+ keys(idx) -> keyData.get(idx).asInstanceOf [AnyRef ]
771
+ }.toMap
772
+
773
+ val request = Request (joinConf.metaData.nameToFilePath, keyMap)
774
+ val (responses, _) = FetcherTestUtil .joinResponses(spark, Array (request), mockApi)
775
+ val responseMap = responses.head.values.get
776
+ val exceptionKeys = joinConf.joinPartOps.map(jp => jp.fullPrefix + " _exception" )
777
+ exceptionKeys.foreach(k => assertTrue(responseMap.contains(k)))
778
+ }
779
+
752
780
}
753
781
754
782
object FetcherTestUtil {
0 commit comments