Skip to content

Commit 110808e

Browse files
authored
Tweak spark test setup to tags and run tests appropriately (#56)
## Summary As of today our spark tests CI action isn't running the right set of Spark tests. The testOnly option seems to only include and not exclude tests. To get around this, I've set up a [SuiteMixin](https://www.scalatest.org/scaladoc/3.0.6/org/scalatest/SuiteMixin.html) which we can use to run the tests in a suite if there is a tag the sbt tests have been invoked with. Else we skip them all. This allows us to: * Trigger `sbt test` or `sbt spark/test` and run all the tests barring the ones that include this suite mixin. * Selectively run these tests using an incantation like: `sbt "spark/testOnly -- -n jointest"`. This allows us to run really long running tests like the Join / Fetcher / Mutations test separately in different CI jvms in parallel to keep our build times short. There's a couple of other alternative options we can pursue to wire up our tests: * Trigger all Spark tests at once using "sbt spark/test" (this will probably bring our test runtime to ~1 hour) * Set up per test [Tags](https://www.scalatest.org/scaladoc/3.0.6/org/scalatest/Tag.html) - we could do something like either set up individual tags for the JoinTests, MutationTests, FetcherTests OR just create a "Slow" test tag and mark the Join, Mutations and Fetcher tests to it. Seems like this requires the tags to be in Java but it's a viable option. ## Checklist - [] Added Unit Tests - [X] Covered by existing CI - [ ] Integration tested - [ ] Documentation update Verified that our other Spark tests run a bunch now (and now our CI takes ~30-40 mins thanks to that :-) ): ``` [info] All tests passed. [info] Passed: Total 127, Failed 0, Errors 0, Passed 127 [success] Total time: 2040 s (34:00), completed Oct 30, 2024, 11:27:39 PM ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new `TaggedFilterSuite` trait for selective test execution based on specified tags. - Enhanced Spark test execution commands for better manageability. - **Refactor** - Transitioned multiple test classes from JUnit to ScalaTest, improving readability and consistency. - Updated test methods to utilize ScalaTest's syntax and structure. - **Bug Fixes** - Improved test logic and assertions in the `FetcherTest`, `JoinTest`, and `MutationsTest` classes to ensure expected behavior. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent cf2cb3d commit 110808e

File tree

5 files changed

+98
-70
lines changed

5 files changed

+98
-70
lines changed

.github/workflows/test_scala_and_python.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
- name: Run other spark tests
6666
run: |
6767
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
68-
sbt "spark/testOnly -- -l ai.chronon.spark.JoinTest -l ai.chronon.spark.test.MutationsTest -l ai.chronon.spark.test.FetcherTest"
68+
sbt "spark/testOnly"
6969
7070
join_spark_tests:
7171
runs-on: ubuntu-latest
@@ -84,7 +84,7 @@ jobs:
8484
- name: Run other spark tests
8585
run: |
8686
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
87-
sbt "spark/testOnly ai.chronon.spark.JoinTest"
87+
sbt "spark/testOnly -- -n jointest"
8888
8989
mutation_spark_tests:
9090
runs-on: ubuntu-latest
@@ -103,7 +103,7 @@ jobs:
103103
- name: Run other spark tests
104104
run: |
105105
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
106-
sbt "spark/testOnly ai.chronon.spark.test.MutationsTest"
106+
sbt "spark/testOnly -- -n mutationstest"
107107
108108
fetcher_spark_tests:
109109
runs-on: ubuntu-latest
@@ -122,7 +122,7 @@ jobs:
122122
- name: Run other spark tests
123123
run: |
124124
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
125-
sbt "spark/testOnly ai.chronon.spark.test.FetcherTest"
125+
sbt "spark/testOnly -- -n fetchertest"
126126
127127
scala_compile_fmt_fix :
128128
runs-on: ubuntu-latest

spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ import ai.chronon.spark.Extensions._
3737
import ai.chronon.spark.stats.ConsistencyJob
3838
import ai.chronon.spark.{Join => _, _}
3939
import com.google.gson.GsonBuilder
40-
import junit.framework.TestCase
4140
import org.apache.spark.sql.DataFrame
4241
import org.apache.spark.sql.Row
4342
import org.apache.spark.sql.SparkSession
@@ -48,6 +47,7 @@ import org.apache.spark.sql.functions.lit
4847
import org.junit.Assert.assertEquals
4948
import org.junit.Assert.assertFalse
5049
import org.junit.Assert.assertTrue
50+
import org.scalatest.funsuite.AnyFunSuite
5151
import org.slf4j.Logger
5252
import org.slf4j.LoggerFactory
5353

@@ -64,7 +64,11 @@ import scala.concurrent.duration.SECONDS
6464
import scala.io.Source
6565
import scala.util.ScalaJavaConversions._
6666

67-
class FetcherTest extends TestCase {
67+
// Run as follows: sbt "spark/testOnly -- -n fetchertest"
68+
class FetcherTest extends AnyFunSuite with TaggedFilterSuite {
69+
70+
override def tagName: String = "fetchertest"
71+
6872
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
6973
val sessionName = "FetcherTest"
7074
val spark: SparkSession = SparkSessionBuilder.build(sessionName, local = true)
@@ -74,7 +78,7 @@ class FetcherTest extends TestCase {
7478
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
7579
private val yesterday = tableUtils.partitionSpec.before(today)
7680

77-
def testMetadataStore(): Unit = {
81+
test("test metadata store") {
7882
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
7983
implicit val tableUtils: TableUtils = TableUtils(spark)
8084

@@ -114,7 +118,8 @@ class FetcherTest extends TestCase {
114118
val directoryDataSetDataSet = ChrononMetadataKey + "_directory_test"
115119
val directoryMetadataStore = new MetadataStore(inMemoryKvStore, directoryDataSetDataSet, timeoutMillis = 10000)
116120
inMemoryKvStore.create(directoryDataSetDataSet)
117-
val directoryDataDirWalker = new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
121+
val directoryDataDirWalker =
122+
new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
118123
val directoryDataKvMap = directoryDataDirWalker.run
119124
val directoryPut = directoryDataKvMap.toSeq.map {
120125
case (_, kvMap) => directoryMetadataStore.put(kvMap, directoryDataSetDataSet)
@@ -385,9 +390,8 @@ class FetcherTest extends TestCase {
385390
sources = Seq(Builders.Source.entities(query = Builders.Query(), snapshotTable = creditTable)),
386391
keyColumns = Seq("vendor_id"),
387392
aggregations = Seq(
388-
Builders.Aggregation(operation = Operation.SUM,
389-
inputColumn = "credit",
390-
windows = Seq(new Window(3, TimeUnit.DAYS)))),
393+
Builders
394+
.Aggregation(operation = Operation.SUM, inputColumn = "credit", windows = Seq(new Window(3, TimeUnit.DAYS)))),
391395
metaData = Builders.MetaData(name = "unit_test/vendor_credit_derivation", namespace = namespace),
392396
derivations = Seq(
393397
Builders.Derivation("credit_sum_3d_test_rename", "credit_sum_3d"),
@@ -527,7 +531,6 @@ class FetcherTest extends TestCase {
527531
println("saved all data hand written for fetcher test")
528532

529533
val startPartition = "2021-04-07"
530-
531534

532535
val leftSource =
533536
Builders.Source.events(
@@ -717,13 +720,13 @@ class FetcherTest extends TestCase {
717720
assertEquals(0, diff.count())
718721
}
719722

720-
def testTemporalFetchJoinDeterministic(): Unit = {
723+
test("test temporal fetch join deterministic") {
721724
val namespace = "deterministic_fetch"
722725
val joinConf = generateMutationData(namespace)
723726
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
724727
}
725728

726-
def testTemporalFetchJoinGenerated(): Unit = {
729+
test("test temporal fetch join generated") {
727730
val namespace = "generated_fetch"
728731
val joinConf = generateRandomData(namespace)
729732
compareTemporalFetch(joinConf,
@@ -733,14 +736,14 @@ class FetcherTest extends TestCase {
733736
dropDsOnWrite = false)
734737
}
735738

736-
def testTemporalTiledFetchJoinDeterministic(): Unit = {
739+
test("test temporal tiled fetch join deterministic") {
737740
val namespace = "deterministic_tiled_fetch"
738741
val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}"))
739742
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
740743
}
741744

742745
// test soft-fail on missing keys
743-
def testEmptyRequest(): Unit = {
746+
test("test empty request") {
744747
val namespace = "empty_request"
745748
val joinConf = generateRandomData(namespace, 5, 5)
746749
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))

spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ import org.apache.spark.sql.types.StructType
4040
import org.apache.spark.sql.types._
4141
import org.apache.spark.sql.types.{StringType => SparkStringType}
4242
import org.junit.Assert._
43-
import org.junit.Test
44-
import org.scalatest.Assertions.intercept
43+
import org.scalatest.funsuite.AnyFunSuite
4544

4645
import scala.collection.JavaConverters._
4746
import scala.util.ScalaJavaConversions.ListOps
4847

49-
class JoinTest {
48+
// Run as follows: sbt "spark/testOnly -- -n jointest"
49+
class JoinTest extends AnyFunSuite with TaggedFilterSuite {
5050

5151
val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true)
5252
private implicit val tableUtils = TableUtils(spark)
@@ -59,8 +59,9 @@ class JoinTest {
5959
private val namespace = "test_namespace_jointest"
6060
tableUtils.createDatabase(namespace)
6161

62-
@Test
63-
def testEventsEntitiesSnapshot(): Unit = {
62+
override def tagName: String = "jointest"
63+
64+
test("test events entities snapshot") {
6465
val dollarTransactions = List(
6566
Column("user", StringType, 100),
6667
Column("user_name", api.StringType, 100),
@@ -263,8 +264,7 @@ class JoinTest {
263264
assertEquals(0, diff2.count())
264265
}
265266

266-
@Test
267-
def testEntitiesEntities(): Unit = {
267+
test("test entities entities") {
268268
// untimed/unwindowed entities on right
269269
// right side
270270
val weightSchema = List(
@@ -384,8 +384,7 @@ class JoinTest {
384384
*/
385385
}
386386

387-
@Test
388-
def testEntitiesEntitiesNoHistoricalBackfill(): Unit = {
387+
test("test entities entities no historical backfill") {
389388
// Only backfill latest partition if historical_backfill is turned off
390389
val weightSchema = List(
391390
Column("user", api.StringType, 1000),
@@ -438,8 +437,7 @@ class JoinTest {
438437
assertEquals(allPartitions.toList(0), end)
439438
}
440439

441-
@Test
442-
def testEventsEventsSnapshot(): Unit = {
440+
test("test events events snapshot") {
443441
val viewsSchema = List(
444442
Column("user", api.StringType, 10000),
445443
Column("item", api.StringType, 100),
@@ -508,8 +506,7 @@ class JoinTest {
508506
assertEquals(diff.count(), 0)
509507
}
510508

511-
@Test
512-
def testEventsEventsTemporal(): Unit = {
509+
test("test events events temporal") {
513510

514511
val joinConf = getEventsEventsTemporal("temporal")
515512
val viewsSchema = List(
@@ -586,8 +583,7 @@ class JoinTest {
586583
assertEquals(diff.count(), 0)
587584
}
588585

589-
@Test
590-
def testEventsEventsCumulative(): Unit = {
586+
test("test events events cumulative") {
591587
// Create a cumulative source GroupBy
592588
val viewsTable = s"$namespace.view_cumulative"
593589
val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true)
@@ -686,8 +682,7 @@ class JoinTest {
686682

687683
}
688684

689-
@Test
690-
def testNoAgg(): Unit = {
685+
test("test no agg") {
691686
// Left side entities, right side entities no agg
692687
// Also testing specific select statement (rather than select *)
693688
val namesSchema = List(
@@ -767,8 +762,7 @@ class JoinTest {
767762
assertEquals(diff.count(), 0)
768763
}
769764

770-
@Test
771-
def testVersioning(): Unit = {
765+
test("test versioning") {
772766
val joinConf = getEventsEventsTemporal("versioning")
773767

774768
// Run the old join to ensure that tables exist
@@ -922,8 +916,7 @@ class JoinTest {
922916

923917
}
924918

925-
@Test
926-
def testEndPartitionJoin(): Unit = {
919+
test("test end partition join") {
927920
val join = getEventsEventsTemporal("end_partition_test")
928921
val start = join.getLeft.query.startPartition
929922
val end = tableUtils.partitionSpec.after(start)
@@ -940,12 +933,11 @@ class JoinTest {
940933
assertTrue(ds.first().getString(0) < today)
941934
}
942935

943-
@Test
944-
def testSkipBloomFilterJoinBackfill(): Unit = {
945-
val testSpark: SparkSession = SparkSessionBuilder.build(
946-
"JoinTest",
947-
local = true,
948-
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
936+
test("test skip bloom filter join backfill") {
937+
val testSpark: SparkSession =
938+
SparkSessionBuilder.build("JoinTest",
939+
local = true,
940+
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
949941
val testTableUtils = TableUtils(testSpark)
950942
val viewsSchema = List(
951943
Column("user", api.StringType, 10000),
@@ -990,8 +982,7 @@ class JoinTest {
990982
assertEquals(leftSideCount, skipBloomComputed.count())
991983
}
992984

993-
@Test
994-
def testStructJoin(): Unit = {
985+
test("test struct join") {
995986
val nameSuffix = "_struct_test"
996987
val itemQueries = List(Column("item", api.StringType, 100))
997988
val itemQueriesTable = s"$namespace.item_queries_$nameSuffix"
@@ -1049,8 +1040,7 @@ class JoinTest {
10491040
new SummaryJob(spark, join, today).dailyRun(stepDays = Some(30))
10501041
}
10511042

1052-
@Test
1053-
def testMigration(): Unit = {
1043+
test("test migration") {
10541044

10551045
// Left
10561046
val itemQueriesTable = s"$namespace.item_queries"
@@ -1099,8 +1089,7 @@ class JoinTest {
10991089
assertEquals(0, join.tablesToDrop(productionHashV2).length)
11001090
}
11011091

1102-
@Test
1103-
def testKeyMappingOverlappingFields(): Unit = {
1092+
test("testKeyMappingOverlappingFields") {
11041093
// test the scenario when a key_mapping is a -> b, (right key b is mapped to left key a) and
11051094
// a happens to be another field in the same group by
11061095

@@ -1158,8 +1147,7 @@ class JoinTest {
11581147
* Run computeJoin().
11591148
* Check if the selected join part is computed and the other join parts are not computed.
11601149
*/
1161-
@Test
1162-
def testSelectedJoinParts(): Unit = {
1150+
test("test selected join parts") {
11631151
// Left
11641152
val itemQueries = List(
11651153
Column("item", api.StringType, 100),

0 commit comments

Comments
 (0)