Skip to content

Tweak spark test setup to tags and run tests appropriately #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/test_scala_and_python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly -- -l ai.chronon.spark.JoinTest -l ai.chronon.spark.test.MutationsTest -l ai.chronon.spark.test.FetcherTest"
sbt "spark/testOnly"

join_spark_tests:
runs-on: ubuntu-latest
Expand All @@ -84,7 +84,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly ai.chronon.spark.JoinTest"
sbt "spark/testOnly -- -n jointest"

mutation_spark_tests:
runs-on: ubuntu-latest
Expand All @@ -103,7 +103,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly ai.chronon.spark.test.MutationsTest"
sbt "spark/testOnly -- -n mutationstest"

fetcher_spark_tests:
runs-on: ubuntu-latest
Expand All @@ -122,7 +122,7 @@ jobs:
- name: Run other spark tests
run: |
export SBT_OPTS="-Xmx8G -Xms2G --add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
sbt "spark/testOnly ai.chronon.spark.test.FetcherTest"
sbt "spark/testOnly -- -n fetchertest"

scala_compile_fmt_fix :
runs-on: ubuntu-latest
Expand Down
27 changes: 15 additions & 12 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import ai.chronon.spark.Extensions._
import ai.chronon.spark.stats.ConsistencyJob
import ai.chronon.spark.{Join => _, _}
import com.google.gson.GsonBuilder
import junit.framework.TestCase
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
Expand All @@ -48,6 +47,7 @@ import org.apache.spark.sql.functions.lit
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.scalatest.funsuite.AnyFunSuite
import org.slf4j.Logger
import org.slf4j.LoggerFactory

Expand All @@ -64,7 +64,11 @@ import scala.concurrent.duration.SECONDS
import scala.io.Source
import scala.util.ScalaJavaConversions._

class FetcherTest extends TestCase {
// Run as follows: sbt "spark/testOnly -- -n fetchertest"
class FetcherTest extends AnyFunSuite with TaggedFilterSuite {

override def tagName: String = "fetchertest"

@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
val sessionName = "FetcherTest"
val spark: SparkSession = SparkSessionBuilder.build(sessionName, local = true)
Expand All @@ -74,7 +78,7 @@ class FetcherTest extends TestCase {
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
private val yesterday = tableUtils.partitionSpec.before(today)

def testMetadataStore(): Unit = {
test("test metadata store") {
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
implicit val tableUtils: TableUtils = TableUtils(spark)

Expand Down Expand Up @@ -114,7 +118,8 @@ class FetcherTest extends TestCase {
val directoryDataSetDataSet = ChrononMetadataKey + "_directory_test"
val directoryMetadataStore = new MetadataStore(inMemoryKvStore, directoryDataSetDataSet, timeoutMillis = 10000)
inMemoryKvStore.create(directoryDataSetDataSet)
val directoryDataDirWalker = new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
val directoryDataDirWalker =
new MetadataDirWalker(confResource.getPath.replace(s"/$joinPath", ""), acceptedEndPoints)
val directoryDataKvMap = directoryDataDirWalker.run
val directoryPut = directoryDataKvMap.toSeq.map {
case (_, kvMap) => directoryMetadataStore.put(kvMap, directoryDataSetDataSet)
Expand Down Expand Up @@ -385,9 +390,8 @@ class FetcherTest extends TestCase {
sources = Seq(Builders.Source.entities(query = Builders.Query(), snapshotTable = creditTable)),
keyColumns = Seq("vendor_id"),
aggregations = Seq(
Builders.Aggregation(operation = Operation.SUM,
inputColumn = "credit",
windows = Seq(new Window(3, TimeUnit.DAYS)))),
Builders
.Aggregation(operation = Operation.SUM, inputColumn = "credit", windows = Seq(new Window(3, TimeUnit.DAYS)))),
metaData = Builders.MetaData(name = "unit_test/vendor_credit_derivation", namespace = namespace),
derivations = Seq(
Builders.Derivation("credit_sum_3d_test_rename", "credit_sum_3d"),
Expand Down Expand Up @@ -527,7 +531,6 @@ class FetcherTest extends TestCase {
println("saved all data hand written for fetcher test")

val startPartition = "2021-04-07"


val leftSource =
Builders.Source.events(
Expand Down Expand Up @@ -717,13 +720,13 @@ class FetcherTest extends TestCase {
assertEquals(0, diff.count())
}

def testTemporalFetchJoinDeterministic(): Unit = {
test("test temporal fetch join deterministic") {
val namespace = "deterministic_fetch"
val joinConf = generateMutationData(namespace)
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
}

def testTemporalFetchJoinGenerated(): Unit = {
test("test temporal fetch join generated") {
val namespace = "generated_fetch"
val joinConf = generateRandomData(namespace)
compareTemporalFetch(joinConf,
Expand All @@ -733,14 +736,14 @@ class FetcherTest extends TestCase {
dropDsOnWrite = false)
}

def testTemporalTiledFetchJoinDeterministic(): Unit = {
test("test temporal tiled fetch join deterministic") {
val namespace = "deterministic_tiled_fetch"
val joinConf = generateEventOnlyData(namespace, groupByCustomJson = Some("{\"enable_tiling\": true}"))
compareTemporalFetch(joinConf, "2021-04-10", namespace, consistencyCheck = false, dropDsOnWrite = true)
}

// test soft-fail on missing keys
def testEmptyRequest(): Unit = {
test("test empty request") {
val namespace = "empty_request"
val joinConf = generateRandomData(namespace, 5, 5)
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
Expand Down
58 changes: 23 additions & 35 deletions spark/src/test/scala/ai/chronon/spark/test/JoinTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.{StringType => SparkStringType}
import org.junit.Assert._
import org.junit.Test
import org.scalatest.Assertions.intercept
import org.scalatest.funsuite.AnyFunSuite

import scala.collection.JavaConverters._
import scala.util.ScalaJavaConversions.ListOps

class JoinTest {
// Run as follows: sbt "spark/testOnly -- -n jointest"
class JoinTest extends AnyFunSuite with TaggedFilterSuite {

val spark: SparkSession = SparkSessionBuilder.build("JoinTest", local = true)
private implicit val tableUtils = TableUtils(spark)
Expand All @@ -59,8 +59,9 @@ class JoinTest {
private val namespace = "test_namespace_jointest"
tableUtils.createDatabase(namespace)

@Test
def testEventsEntitiesSnapshot(): Unit = {
override def tagName: String = "jointest"

test("test events entities snapshot") {
val dollarTransactions = List(
Column("user", StringType, 100),
Column("user_name", api.StringType, 100),
Expand Down Expand Up @@ -263,8 +264,7 @@ class JoinTest {
assertEquals(0, diff2.count())
}

@Test
def testEntitiesEntities(): Unit = {
test("test entities entities") {
// untimed/unwindowed entities on right
// right side
val weightSchema = List(
Expand Down Expand Up @@ -384,8 +384,7 @@ class JoinTest {
*/
}

@Test
def testEntitiesEntitiesNoHistoricalBackfill(): Unit = {
test("test entities entities no historical backfill") {
// Only backfill latest partition if historical_backfill is turned off
val weightSchema = List(
Column("user", api.StringType, 1000),
Expand Down Expand Up @@ -438,8 +437,7 @@ class JoinTest {
assertEquals(allPartitions.toList(0), end)
}

@Test
def testEventsEventsSnapshot(): Unit = {
test("test events events snapshot") {
val viewsSchema = List(
Column("user", api.StringType, 10000),
Column("item", api.StringType, 100),
Expand Down Expand Up @@ -508,8 +506,7 @@ class JoinTest {
assertEquals(diff.count(), 0)
}

@Test
def testEventsEventsTemporal(): Unit = {
test("test events events temporal") {

val joinConf = getEventsEventsTemporal("temporal")
val viewsSchema = List(
Expand Down Expand Up @@ -586,8 +583,7 @@ class JoinTest {
assertEquals(diff.count(), 0)
}

@Test
def testEventsEventsCumulative(): Unit = {
test("test events events cumulative") {
// Create a cumulative source GroupBy
val viewsTable = s"$namespace.view_cumulative"
val viewsGroupBy = getViewsGroupBy(suffix = "cumulative", makeCumulative = true)
Expand Down Expand Up @@ -686,8 +682,7 @@ class JoinTest {

}

@Test
def testNoAgg(): Unit = {
test("test no agg") {
// Left side entities, right side entities no agg
// Also testing specific select statement (rather than select *)
val namesSchema = List(
Expand Down Expand Up @@ -767,8 +762,7 @@ class JoinTest {
assertEquals(diff.count(), 0)
}

@Test
def testVersioning(): Unit = {
test("test versioning") {
val joinConf = getEventsEventsTemporal("versioning")

// Run the old join to ensure that tables exist
Expand Down Expand Up @@ -922,8 +916,7 @@ class JoinTest {

}

@Test
def testEndPartitionJoin(): Unit = {
test("test end partition join") {
val join = getEventsEventsTemporal("end_partition_test")
val start = join.getLeft.query.startPartition
val end = tableUtils.partitionSpec.after(start)
Expand All @@ -940,12 +933,11 @@ class JoinTest {
assertTrue(ds.first().getString(0) < today)
}

@Test
def testSkipBloomFilterJoinBackfill(): Unit = {
val testSpark: SparkSession = SparkSessionBuilder.build(
"JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
test("test skip bloom filter join backfill") {
val testSpark: SparkSession =
SparkSessionBuilder.build("JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
Comment on lines +937 to +940
Copy link

@coderabbitai coderabbitai bot Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding error handling for SparkSession creation.

The SparkSession creation could benefit from explicit error handling to gracefully handle initialization failures.

Consider wrapping the SparkSession creation in a try-catch block:

-    val testSpark: SparkSession =
-      SparkSessionBuilder.build("JoinTest",
-                                local = true,
-                                additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
+    val testSpark: SparkSession = try {
+      SparkSessionBuilder.build("JoinTest",
+                                local = true,
+                                additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
+    } catch {
+      case e: Exception =>
+        throw new RuntimeException("Failed to initialize SparkSession for testing", e)
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
val testSpark: SparkSession =
SparkSessionBuilder.build("JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
val testSpark: SparkSession = try {
SparkSessionBuilder.build("JoinTest",
local = true,
additionalConfig = Some(Map("spark.chronon.backfill.bloomfilter.threshold" -> "100")))
} catch {
case e: Exception =>
throw new RuntimeException("Failed to initialize SparkSession for testing", e)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing code, leaving as-is

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piyush-zlai, understood, leaving the existing code as-is.

(_/)
(._.)
/ \

val testTableUtils = TableUtils(testSpark)
val viewsSchema = List(
Column("user", api.StringType, 10000),
Expand Down Expand Up @@ -990,8 +982,7 @@ class JoinTest {
assertEquals(leftSideCount, skipBloomComputed.count())
}

@Test
def testStructJoin(): Unit = {
test("test struct join") {
val nameSuffix = "_struct_test"
val itemQueries = List(Column("item", api.StringType, 100))
val itemQueriesTable = s"$namespace.item_queries_$nameSuffix"
Expand Down Expand Up @@ -1049,8 +1040,7 @@ class JoinTest {
new SummaryJob(spark, join, today).dailyRun(stepDays = Some(30))
}

@Test
def testMigration(): Unit = {
test("test migration") {

// Left
val itemQueriesTable = s"$namespace.item_queries"
Expand Down Expand Up @@ -1099,8 +1089,7 @@ class JoinTest {
assertEquals(0, join.tablesToDrop(productionHashV2).length)
}

@Test
def testKeyMappingOverlappingFields(): Unit = {
test("testKeyMappingOverlappingFields") {
// test the scenario when a key_mapping is a -> b, (right key b is mapped to left key a) and
// a happens to be another field in the same group by

Expand Down Expand Up @@ -1158,8 +1147,7 @@ class JoinTest {
* Run computeJoin().
* Check if the selected join part is computed and the other join parts are not computed.
*/
@Test
def testSelectedJoinParts(): Unit = {
test("test selected join parts") {
// Left
val itemQueries = List(
Column("item", api.StringType, 100),
Expand Down
Loading
Loading