Skip to content

Commit 675d0f7

Browse files
authored
Tweak spark test setup to tags and run tests appropriately (#56)
## Summary As of today our spark tests CI action isn't running the right set of Spark tests. The testOnly option seems to only include and not exclude tests. To get around this, I've set up a [SuiteMixin](https://www.scalatest.org/scaladoc/3.0.6/org/scalatest/SuiteMixin.html) which we can use to run the tests in a suite if there is a tag the sbt tests have been invoked with. Else we skip them all. This allows us to: * Trigger `sbt test` or `sbt spark/test` and run all the tests barring the ones that include this suite mixin. * Selectively run these tests using an incantation like: `sbt "spark/testOnly -- -n jointest"`. This allows us to run really long running tests like the Join / Fetcher / Mutations test separately in different CI jvms in parallel to keep our build times short. There's a couple of other alternative options we can pursue to wire up our tests: * Trigger all Spark tests at once using "sbt spark/test" (this will probably bring our test runtime to ~1 hour) * Set up per test [Tags](https://www.scalatest.org/scaladoc/3.0.6/org/scalatest/Tag.html) - we could do something like either set up individual tags for the JoinTests, MutationTests, FetcherTests OR just create a "Slow" test tag and mark the Join, Mutations and Fetcher tests to it. Seems like this requires the tags to be in Java but it's a viable option. ## Checklist - [] Added Unit Tests - [X] Covered by existing CI - [ ] Integration tested - [ ] Documentation update Verified that our other Spark tests run a bunch now (and now our CI takes ~30-40 mins thanks to that :-) ): ``` [info] All tests passed. [info] Passed: Total 127, Failed 0, Errors 0, Passed 127 [success] Total time: 2040 s (34:00), completed Oct 30, 2024, 11:27:39 PM ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new `TaggedFilterSuite` trait for selective test execution based on specified tags. - Enhanced Spark test execution commands for better manageability. - **Refactor** - Transitioned multiple test classes from JUnit to ScalaTest, improving readability and consistency. - Updated test methods to utilize ScalaTest's syntax and structure. - **Bug Fixes** - Improved test logic and assertions in the `FetcherTest`, `JoinTest`, and `MutationsTest` classes to ensure expected behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0fc3d8d commit 675d0f7

File tree

4 files changed

+94
-66
lines changed

4 files changed

+94
-66
lines changed

spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ import ai.chronon.spark.Extensions._
3737
import ai.chronon.spark.stats.ConsistencyJob
3838
import ai.chronon.spark.{Join => _, _}
3939
import com.google.gson.GsonBuilder
40-
import junit.framework.TestCase
4140
import org.apache.spark.sql.DataFrame
4241
import org.apache.spark.sql.Row
4342
import org.apache.spark.sql.SparkSession
@@ -48,6 +47,7 @@ import org.apache.spark.sql.functions.lit
4847
import org.junit.Assert.assertEquals
4948
import org.junit.Assert.assertFalse
5049
import org.junit.Assert.assertTrue
50+
import org.scalatest.funsuite.AnyFunSuite
5151
import org.slf4j.Logger
5252
import org.slf4j.LoggerFactory
5353

@@ -64,7 +64,11 @@ import scala.concurrent.duration.SECONDS
6464
import scala.io.Source
6565
import scala.util.ScalaJavaConversions._
6666

67-
class FetcherTest extends TestCase {
67+
// Run as follows: sbt "spark/testOnly -- -n fetchertest"
68+
class FetcherTest extends AnyFunSuite with TaggedFilterSuite {
69+
70+
override def tagName: String = "fetchertest"
71+
6872
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
6973
val sessionName = "FetcherTest"
7074
val spark: SparkSession = SparkSessionBuilder.build(sessionName, local = true)
@@ -74,7 +78,7 @@ class FetcherTest extends TestCase {
7478
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
7579
private val yesterday = tableUtils.partitionSpec.before(today)
7680

77-
def testMetadataStore(): Unit = {
81+
test("test metadata store") {
7882
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
7983
implicit val tableUtils: TableUtils = TableUtils(spark)
8084

@@ -114,7 +118,8 @@ class FetcherTest extends TestCase {
114118
val directoryDataSetDataSet = ChrononMetadataKey + "_directory_test"
115119
val directoryMetadataStore = new MetadataStore(inMemoryKvStore, directoryDataSetDataSet, timeoutMillis = 10000)
116120
inMemoryKvStore.create(directoryDataSetDataSet)
117-
val directoryDataDirWalker = new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
121+
val directoryDataDirWalker =
122+
new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
118123
val directoryDataKvMap = directoryDataDirWalker.run
119124
val directoryPut = directoryDataKvMap.toSeq.map {
120125
case (_, kvMap) => directoryMetadataStore.put(kvMap, directoryDataSetDataSet)
@@ -385,9 +390,8 @@ class FetcherTest extends TestCase {
385390
sources = Seq(Builders.Source.entities(query = Builders.Query(), snapshotTable = creditTable)),
386391
keyColumns = Seq("vendor_id"),
387392
aggregations = Seq(
388-
Builders.Aggregation(operation = Operation.SUM,
389-
inputColumn = "credit",
390-
windows = Seq(new Window(3, TimeUnit.DAYS)))),
393+
Builders
394+
.Aggregation(operation = Operation.SUM, inputColumn = "credit", windows = Seq(new Window(3, TimeUnit.DAYS)))),
391395
metaData = Builders.MetaData(name = "unit_test/vendor_credit_derivation", namespace = namespace),
392396
derivations = Seq(
393397
Builders.Derivation("credit_sum_3d_test_rename", "credit_sum_3d"),
@@ -527,7 +531,6 @@ class FetcherTest extends TestCase {
527531
println("saved all data hand written for fetcher test")
528532

529533
val startPartition = "2021-04-07"
530-
531534

532535
val leftSource =
533536
Builders.Source.events(
@@ -717,13 +720,13 @@ class FetcherTest extends TestCase {
717720
assertEquals(0, diff.count())
718721
}
719722

720-
def testTemporalFetchJoinDeterministic(): Unit = {
723+
test("test temporal fetch join deterministic") {
721724
val namespace = "deterministic_fetch"
722725
val joinConf = generateMutationData(namespace)
723726
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
724727
}
725728

726-
def testTemporalFetchJoinGenerated(): Unit = {
729+
test("test temporal fetch join generated") {
727730
val namespace = "generated_fetch"
728731
val joinConf = generateRandomData(namespace)
729732
compareTemporalFetch(joinConf,
@@ -733,14 +736,14 @@ class FetcherTest extends TestCase {
733736
dropDsOnWrite = false)
734737
}
735738

736-
def testTemporalTiledFetchJoinDeterministic(): Unit = {
739+
test("test temporal tiled fetch join deterministic") {
737740
val namespace = "deterministic_tiled_fetch"
738741
val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}"))
739742
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
740743
}
741744

742745
// test soft-fail on missing keys
743-
def testEmptyRequest(): Unit = {
746+
test("test empty request") {
744747
val namespace = "empty_request"
745748
val joinConf = generateRandomData(namespace, 5, 5)
746749
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))

spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ import org.apache.spark.sql.types.StructType
4040
import org.apache.spark.sql.types._
4141
import org.apache.spark.sql.types.{StringType => SparkStringType}
4242
import org.junit.Assert._
43-
import org.junit.Test
44-
import org.scalatest.Assertions.intercept
43+
import org.scalatest.funsuite.AnyFunSuite
4544

4645
import scala.collection.JavaConverters._
4746
import scala.util.ScalaJavaConversions.ListOps
4847

49-
class JoinTest {
48+
// Run as follows: sbt "spark/testOnly -- -n jointest"
49+
class JoinTest extends AnyFunSuite with TaggedFilterSuite {
5050

5151
val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true)
5252
private implicit val tableUtils = TableUtils(spark)
@@ -59,8 +59,9 @@ class JoinTest {
5959
private val namespace = "test_namespace_jointest"
6060
tableUtils.createDatabase(namespace)
6161

62-
@Test
63-
def testEventsEntitiesSnapshot(): Unit = {
62+
override def tagName: String = "jointest"
63+
64+
test("test events entities snapshot") {
6465
val dollarTransactions = List(
6566
Column("user", StringType, 100),
6667
Column("user_name", api.StringType, 100),
@@ -263,8 +264,7 @@ class JoinTest {
263264
assertEquals(0, diff2.count())
264265
}
265266

266-
@Test
267-
def testEntitiesEntities(): Unit = {
267+
test("test entities entities") {
268268
// untimed/unwindowed entities on right
269269
// right side
270270
val weightSchema = List(
@@ -384,8 +384,7 @@ class JoinTest {
384384
*/
385385
}
386386

387-
@Test
388-
def testEntitiesEntitiesNoHistoricalBackfill(): Unit = {
387+
test("test entities entities no historical backfill") {
389388
// Only backfill latest partition if historical_backfill is turned off
390389
val weightSchema = List(
391390
Column("user", api.StringType, 1000),
@@ -438,8 +437,7 @@ class JoinTest {
438437
assertEquals(allPartitions.toList(0), end)
439438
}
440439

441-
@Test
442-
def testEventsEventsSnapshot(): Unit = {
440+
test("test events events snapshot") {
443441
val viewsSchema = List(
444442
Column("user", api.StringType, 10000),
445443
Column("item", api.StringType, 100),
@@ -508,8 +506,7 @@ class JoinTest {
508506
assertEquals(diff.count(), 0)
509507
}
510508

511-
@Test
512-
def testEventsEventsTemporal(): Unit = {
509+
test("test events events temporal") {
513510

514511
val joinConf = getEventsEventsTemporal("temporal")
515512
val viewsSchema = List(
@@ -586,8 +583,7 @@ class JoinTest {
586583
assertEquals(diff.count(), 0)
587584
}
588585

589-
@Test
590-
def testEventsEventsCumulative(): Unit = {
586+
test("test events events cumulative") {
591587
// Create a cumulative source GroupBy
592588
val viewsTable = s"$namespace.view_cumulative"
593589
val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true)
@@ -686,8 +682,7 @@ class JoinTest {
686682

687683
}
688684

689-
@Test
690-
def testNoAgg(): Unit = {
685+
test("test no agg") {
691686
// Left side entities, right side entities no agg
692687
// Also testing specific select statement (rather than select *)
693688
val namesSchema = List(
@@ -767,8 +762,7 @@ class JoinTest {
767762
assertEquals(diff.count(), 0)
768763
}
769764

770-
@Test
771-
def testVersioning(): Unit = {
765+
test("test versioning") {
772766
val joinConf = getEventsEventsTemporal("versioning")
773767

774768
// Run the old join to ensure that tables exist
@@ -922,8 +916,7 @@ class JoinTest {
922916

923917
}
924918

925-
@Test
926-
def testEndPartitionJoin(): Unit = {
919+
test("test end partition join") {
927920
val join = getEventsEventsTemporal("end_partition_test")
928921
val start = join.getLeft.query.startPartition
929922
val end = tableUtils.partitionSpec.after(start)
@@ -940,12 +933,11 @@ class JoinTest {
940933
assertTrue(ds.first().getString(0) < today)
941934
}
942935

943-
@Test
944-
def testSkipBloomFilterJoinBackfill(): Unit = {
945-
val testSpark: SparkSession = SparkSessionBuilder.build(
946-
"JoinTest",
947-
local = true,
948-
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
936+
test("test skip bloom filter join backfill") {
937+
val testSpark: SparkSession =
938+
SparkSessionBuilder.build("JoinTest",
939+
local = true,
940+
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
949941
val testTableUtils = TableUtils(testSpark)
950942
val viewsSchema = List(
951943
Column("user", api.StringType, 10000),
@@ -990,8 +982,7 @@ class JoinTest {
990982
assertEquals(leftSideCount, skipBloomComputed.count())
991983
}
992984

993-
@Test
994-
def testStructJoin(): Unit = {
985+
test("test struct join") {
995986
val nameSuffix = "_struct_test"
996987
val itemQueries = List(Column("item", api.StringType, 100))
997988
val itemQueriesTable = s"$namespace.item_queries_$nameSuffix"
@@ -1049,8 +1040,7 @@ class JoinTest {
10491040
new SummaryJob(spark, join, today).dailyRun(stepDays = Some(30))
10501041
}
10511042

1052-
@Test
1053-
def testMigration(): Unit = {
1043+
test("test migration") {
10541044

10551045
// Left
10561046
val itemQueriesTable = s"$namespace.item_queries"
@@ -1099,8 +1089,7 @@ class JoinTest {
10991089
assertEquals(0, join.tablesToDrop(productionHashV2).length)
11001090
}
11011091

1102-
@Test
1103-
def testKeyMappingOverlappingFields(): Unit = {
1092+
test("testKeyMappingOverlappingFields") {
11041093
// test the scenario when a key_mapping is a -> b, (right key b is mapped to left key a) and
11051094
// a happens to be another field in the same group by
11061095

@@ -1158,8 +1147,7 @@ class JoinTest {
11581147
* Run computeJoin().
11591148
* Check if the selected join part is computed and the other join parts are not computed.
11601149
*/
1161-
@Test
1162-
def testSelectedJoinParts(): Unit = {
1150+
test("test selected join parts") {
11631151
// Left
11641152
val itemQueries = List(
11651153
Column("item", api.StringType, 100),

spark/src/test/scala/ai/chronon/spark/test/MutationsTest.scala

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,25 @@ import org.apache.spark.sql.types.LongType
3838
import org.apache.spark.sql.types.StringType
3939
import org.apache.spark.sql.types.StructField
4040
import org.apache.spark.sql.types.StructType
41-
import org.junit.Test
41+
import org.scalatest.funsuite.AnyFunSuite
4242
import org.slf4j.Logger
4343
import org.slf4j.LoggerFactory
4444

4545
/** Tests for the temporal join of entities.
4646
* Left is an event source with definite ts.
4747
* Right is an entity with snapshots and mutation values through the day.
4848
* Join is the events and the entity value at the exact timestamp of the ts.
49+
* To run: sbt "spark/testOnly -- -n mutationstest"
4950
*/
50-
class MutationsTest {
51+
class MutationsTest extends AnyFunSuite with TaggedFilterSuite {
5152
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
5253

53-
val spark: SparkSession = SparkSessionBuilder.build("MutationsTest", local = true) //, additionalConfig = Some(Map("spark.chronon.backfill.validation.enabled" -> "false")))
54+
override def tagName: String = "mutationstest"
55+
56+
val spark: SparkSession =
57+
SparkSessionBuilder.build("MutationsTest",
58+
local = true
59+
) //, additionalConfig = Some(Map("spark.chronon.backfill.validation.enabled" -> "false")))
5460
private implicit val tableUtils: TableUtils = TableUtils(spark)
5561

5662
private def namespace(suffix: String) = s"test_mutations_$suffix"
@@ -443,8 +449,7 @@ class MutationsTest {
443449
*
444450
* Compute Join for when mutations are just insert on values.
445451
*/
446-
@Test
447-
def testSimplestCase(): Unit = {
452+
test("test simplest case") {
448453
val suffix = "simple"
449454
val leftData = Seq(
450455
// {listing_id, some_col, ts, ds}
@@ -502,8 +507,7 @@ class MutationsTest {
502507
*
503508
* Compute Join when mutations have an update on values.
504509
*/
505-
@Test
506-
def testUpdateValueCase(): Unit = {
510+
test("test update value case") {
507511
val suffix = "update_value"
508512
val leftData = Seq(
509513
// {listing_id, ts, event, ds}
@@ -554,8 +558,7 @@ class MutationsTest {
554558
*
555559
* Compute Join when mutations have an update on keys.
556560
*/
557-
@Test
558-
def testUpdateKeyCase(): Unit = {
561+
test("test update key case") {
559562
val suffix = "update_key"
560563
val leftData = Seq(
561564
Row(1, 1, millis("2021-04-10 01:00:00"), "2021-04-10"),
@@ -612,8 +615,7 @@ class MutationsTest {
612615
* For this test we request a value for id 2, w/ mutations happening in the day before and after the time requested.
613616
* The consistency constraint here is that snapshot 4/8 + mutations 4/8 = snapshot 4/9
614617
*/
615-
@Test
616-
def testInconsistentTsLeftCase(): Unit = {
618+
test("test inconsistent ts left case") {
617619
val suffix = "inconsistent_ts"
618620
val leftData = Seq(
619621
Row(1, 1, millis("2021-04-10 01:00:00"), "2021-04-10"),
@@ -682,8 +684,7 @@ class MutationsTest {
682684
* Compute Join, the snapshot aggregation should decay, this is the main reason to have
683685
* resolution in snapshot IR
684686
*/
685-
@Test
686-
def testDecayedWindowCase(): Unit = {
687+
test("test decayed window case") {
687688
val suffix = "decayed"
688689
val leftData = Seq(
689690
Row(2, 1, millis("2021-04-09 01:30:00"), "2021-04-10"),
@@ -754,8 +755,7 @@ class MutationsTest {
754755
* Compute Join, the snapshot aggregation should decay.
755756
* When there are no mutations returning the collapsed is not enough depending on the time.
756757
*/
757-
@Test
758-
def testDecayedWindowCaseNoMutation(): Unit = {
758+
test("test decayed window case no mutation") {
759759
val suffix = "decayed_v2"
760760
val leftData = Seq(
761761
Row(2, 1, millis("2021-04-10 01:00:00"), "2021-04-10"),
@@ -803,8 +803,7 @@ class MutationsTest {
803803
* Compute Join, the snapshot aggregation should decay.
804804
* When there's no snapshot the value would depend only on mutations of the day.
805805
*/
806-
@Test
807-
def testNoSnapshotJustMutation(): Unit = {
806+
test("test no snapshot just mutation") {
808807
val suffix = "no_mutation"
809808
val leftData = Seq(
810809
Row(2, 1, millis("2021-04-10 00:07:00"), "2021-04-10"),
@@ -844,8 +843,7 @@ class MutationsTest {
844843
assert(compareResult(result, expected))
845844
}
846845

847-
@Test
848-
def testWithGeneratedData(): Unit = {
846+
test("test with generated data") {
849847
val suffix = "generated"
850848
val reviews = List(
851849
Column("listing_id", api.StringType, 10),

0 commit comments

Comments
 (0)