Skip to content

Commit 2ac01a7

Browse files
chore: split out expensive spark tests to parallelize (#382)
## Summary ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced dedicated testing workflows covering multiple system components to enhance overall reliability. - Added new test suites for various components to enhance testing granularity. - **Refactor** - Streamlined code organization with improved package structures and consolidated imports across test modules. - **Chores** - Upgraded automated testing configurations with optimized resource settings for improved performance and stability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` --> --------- Co-authored-by: Thomas Chow <[email protected]>
1 parent b40ae8b commit 2ac01a7

17 files changed

+247
-265
lines changed

spark/BUILD.bazel

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,67 @@ scala_library(
9090
name = "test_lib",
9191
srcs = glob(["src/test/**/*.scala"]),
9292
format = True,
93-
visibility = ["//visibility:public"],
9493
deps = test_deps,
9594
)
9695

9796
scala_test_suite(
9897
name = "tests",
99-
srcs = glob(["src/test/**/*.scala"]),
98+
tags = ["large"],
99+
srcs = glob(["src/test/scala/ai/chronon/spark/test/*.scala",
100+
"src/test/scala/ai/chronon/spark/test/udafs/*.scala",
101+
"src/test/scala/ai/chronon/spark/test/stats/drift/*.scala",
102+
"src/test/scala/ai/chronon/spark/test/bootstrap/*.scala"]),
103+
data = glob(["spark/src/test/resources/**/*"]),
104+
# defined in prelude_bazel file
105+
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
106+
visibility = ["//visibility:public"],
107+
deps = test_deps + [":test_lib"],
108+
)
109+
110+
scala_test_suite(
111+
name = "fetcher_test",
112+
srcs = glob(["src/test/scala/ai/chronon/spark/test/fetcher/*.scala"]),
113+
resources = ["//spark/src/test/resources:test-resources"],
114+
# defined in prelude_bazel file
115+
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
116+
visibility = ["//visibility:public"],
117+
deps = test_deps + [":test_lib"],
118+
)
119+
120+
scala_test_suite(
121+
name = "groupby_test",
122+
srcs = glob(["src/test/scala/ai/chronon/spark/test/groupby/*.scala"]),
123+
data = glob(["spark/src/test/resources/**/*"]),
124+
# defined in prelude_bazel file
125+
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
126+
visibility = ["//visibility:public"],
127+
deps = test_deps + [":test_lib"],
128+
)
129+
130+
scala_test_suite(
131+
name = "join_test",
132+
srcs = glob(["src/test/scala/ai/chronon/spark/test/join/*.scala"]),
133+
tags = ["large"],
134+
data = glob(["spark/src/test/resources/**/*"]),
135+
# defined in prelude_bazel file
136+
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
137+
visibility = ["//visibility:public"],
138+
deps = test_deps + [":test_lib"],
139+
)
140+
141+
scala_test_suite(
142+
name = "analyzer_test",
143+
srcs = glob(["src/test/scala/ai/chronon/spark/test/analyzer/*.scala"]),
144+
data = glob(["spark/src/test/resources/**/*"]),
145+
# defined in prelude_bazel file
146+
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
147+
visibility = ["//visibility:public"],
148+
deps = test_deps + [":test_lib"],
149+
)
150+
151+
scala_test_suite(
152+
name = "streaming_test",
153+
srcs = glob(["src/test/scala/ai/chronon/spark/test/streaming/*.scala"]),
100154
data = glob(["spark/src/test/resources/**/*"]),
101155
# defined in prelude_bazel file
102156
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,

spark/src/test/scala/ai/chronon/spark/test/AnalyzerTest.scala renamed to spark/src/test/scala/ai/chronon/spark/test/analyzer/AnalyzerTest.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,21 @@
1414
* limitations under the License.
1515
*/
1616

17-
package ai.chronon.spark.test
17+
package ai.chronon.spark.test.analyzer
1818

1919
import ai.chronon.aggregator.test.Column
2020
import ai.chronon.api
2121
import ai.chronon.api._
22-
import ai.chronon.spark.Analyzer
2322
import ai.chronon.spark.Extensions._
24-
import ai.chronon.spark.Join
25-
import ai.chronon.spark.SparkSessionBuilder
26-
import ai.chronon.spark.TableUtils
23+
import ai.chronon.spark.{Analyzer, Join, SparkSessionBuilder, TableUtils}
24+
import ai.chronon.spark.test.DataFrameGen
2725
import org.apache.spark.sql.SparkSession
28-
import org.apache.spark.sql.functions.col
29-
import org.apache.spark.sql.functions.lit
26+
import org.apache.spark.sql.functions.{col, lit}
3027
import org.junit.Assert.assertTrue
3128
import org.scalatest.BeforeAndAfter
3229
import org.scalatest.flatspec.AnyFlatSpec
3330
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
34-
import org.slf4j.Logger
35-
import org.slf4j.LoggerFactory
31+
import org.slf4j.{Logger, LoggerFactory}
3632

3733
class AnalyzerTest extends AnyFlatSpec with BeforeAndAfter {
3834
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)

spark/src/test/scala/ai/chronon/spark/test/bootstrap/DerivationTest.scala renamed to spark/src/test/scala/ai/chronon/spark/test/analyzer/DerivationTest.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package ai.chronon.spark.test.bootstrap
17+
package ai.chronon.spark.test.analyzer
1818

1919
import ai.chronon.api.Builders.Derivation
2020
import ai.chronon.api.Extensions._
@@ -24,17 +24,14 @@ import ai.chronon.online.Fetcher.Request
2424
import ai.chronon.online.MetadataStore
2525
import ai.chronon.spark.Extensions.DataframeOps
2626
import ai.chronon.spark._
27-
import ai.chronon.spark.test.OnlineUtils
28-
import ai.chronon.spark.test.SchemaEvolutionUtils
27+
import ai.chronon.spark.test.{OnlineUtils, SchemaEvolutionUtils}
28+
import ai.chronon.spark.test.bootstrap.BootstrapUtils
2929
import ai.chronon.spark.utils.MockApi
3030
import org.apache.spark.sql.SparkSession
3131
import org.apache.spark.sql.functions._
32-
import org.junit.Assert.assertEquals
33-
import org.junit.Assert.assertFalse
34-
import org.junit.Assert.assertTrue
32+
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
3533
import org.scalatest.flatspec.AnyFlatSpec
36-
import org.slf4j.Logger
37-
import org.slf4j.LoggerFactory
34+
import org.slf4j.{Logger, LoggerFactory}
3835

3936
import scala.concurrent.Await
4037
import scala.concurrent.duration.Duration

spark/src/test/scala/ai/chronon/spark/test/ChainingFetcherTest.scala renamed to spark/src/test/scala/ai/chronon/spark/test/fetcher/ChainingFetcherTest.scala

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,26 @@
1414
* limitations under the License.
1515
*/
1616

17-
package ai.chronon.spark.test
17+
package ai.chronon.spark.test.fetcher
1818

1919
import ai.chronon.aggregator.windowing.TsUtils
2020
import ai.chronon.api
2121
import ai.chronon.api.Constants.MetadataDataset
22-
import ai.chronon.api.Extensions.JoinOps
23-
import ai.chronon.api.Extensions.MetadataOps
22+
import ai.chronon.api.Extensions.{JoinOps, MetadataOps}
2423
import ai.chronon.api.ScalaJavaConversions._
2524
import ai.chronon.api._
2625
import ai.chronon.online.Fetcher.Request
27-
import ai.chronon.online.MetadataStore
28-
import ai.chronon.online.SparkConversions
26+
import ai.chronon.online.{MetadataStore, SparkConversions}
2927
import ai.chronon.spark.Extensions._
28+
import ai.chronon.spark.test.{OnlineUtils, TestUtils}
3029
import ai.chronon.spark.utils.MockApi
3130
import ai.chronon.spark.{Join => _, _}
32-
import org.apache.spark.sql.DataFrame
33-
import org.apache.spark.sql.Row
34-
import org.apache.spark.sql.SparkSession
3531
import org.apache.spark.sql.catalyst.expressions.GenericRow
3632
import org.apache.spark.sql.functions.lit
37-
import org.junit.Assert.assertEquals
38-
import org.junit.Assert.assertTrue
33+
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
34+
import org.junit.Assert.{assertEquals, assertTrue}
3935
import org.scalatest.flatspec.AnyFlatSpec
40-
import org.slf4j.Logger
41-
import org.slf4j.LoggerFactory
36+
import org.slf4j.{Logger, LoggerFactory}
4237

4338
import java.lang
4439
import java.util.TimeZone

spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala renamed to spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTest.scala

Lines changed: 12 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -14,58 +14,37 @@
1414
* limitations under the License.
1515
*/
1616

17-
package ai.chronon.spark.test
17+
package ai.chronon.spark.test.fetcher
1818

1919
import ai.chronon.aggregator.test.Column
2020
import ai.chronon.aggregator.windowing.TsUtils
2121
import ai.chronon.api
2222
import ai.chronon.api.Constants.MetadataDataset
23-
import ai.chronon.api.Extensions.JoinOps
24-
import ai.chronon.api.Extensions.MetadataOps
23+
import ai.chronon.api.Extensions.{JoinOps, MetadataOps}
2524
import ai.chronon.api.ScalaJavaConversions._
2625
import ai.chronon.api._
27-
import ai.chronon.online.Fetcher.Request
28-
import ai.chronon.online.Fetcher.Response
29-
import ai.chronon.online.Fetcher.StatsRequest
30-
import ai.chronon.online.FlagStore
31-
import ai.chronon.online.FlagStoreConstants
32-
import ai.chronon.online.JavaRequest
26+
import ai.chronon.online.Fetcher.{Request, StatsRequest}
3327
import ai.chronon.online.KVStore.GetRequest
34-
import ai.chronon.online.LoggableResponseBase64
35-
import ai.chronon.online.MetadataDirWalker
36-
import ai.chronon.online.MetadataEndPoint
37-
import ai.chronon.online.MetadataStore
38-
import ai.chronon.online.SparkConversions
28+
import ai.chronon.online._
3929
import ai.chronon.spark.Extensions._
4030
import ai.chronon.spark.stats.ConsistencyJob
31+
import ai.chronon.spark.test.{DataFrameGen, OnlineUtils, SchemaEvolutionUtils}
4132
import ai.chronon.spark.utils.MockApi
4233
import ai.chronon.spark.{Join => _, _}
4334
import com.google.gson.GsonBuilder
44-
import org.apache.spark.sql.DataFrame
45-
import org.apache.spark.sql.Row
46-
import org.apache.spark.sql.SparkSession
4735
import org.apache.spark.sql.catalyst.expressions.GenericRow
48-
import org.apache.spark.sql.functions.avg
49-
import org.apache.spark.sql.functions.col
50-
import org.apache.spark.sql.functions.lit
51-
import org.junit.Assert.assertEquals
52-
import org.junit.Assert.assertFalse
53-
import org.junit.Assert.assertTrue
36+
import org.apache.spark.sql.functions.{avg, col, lit}
37+
import org.apache.spark.sql.{Row, SparkSession}
38+
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
5439
import org.scalatest.flatspec.AnyFlatSpec
55-
import org.slf4j.Logger
56-
import org.slf4j.LoggerFactory
40+
import org.slf4j.{Logger, LoggerFactory}
5741

58-
import java.lang
59-
import java.util
6042
import java.util.TimeZone
6143
import java.util.concurrent.Executors
44+
import java.{lang, util}
6245
import scala.collection.Seq
63-
import scala.compat.java8.FutureConverters
64-
import scala.concurrent.Await
65-
import scala.concurrent.ExecutionContext
66-
import scala.concurrent.Future
6746
import scala.concurrent.duration.Duration
68-
import scala.concurrent.duration.SECONDS
47+
import scala.concurrent.{Await, ExecutionContext, Future}
6948
import scala.io.Source
7049

7150
class FetcherTest extends AnyFlatSpec {
@@ -86,8 +65,8 @@ class FetcherTest extends AnyFlatSpec {
8665

8766
val joinPath = "joins/team/example_join.v1"
8867
val confResource = getClass.getResource(s"/$joinPath")
68+
val src = Source.fromResource(joinPath)
8969
println(s"conf resource path for dir walker: ${confResource.getPath}")
90-
val src = Source.fromFile(confResource.getPath)
9170

9271
val expected = {
9372
try src.mkString
@@ -785,102 +764,3 @@ class FetcherTest extends AnyFlatSpec {
785764
assertTrue(responseMap.keys.forall(_.endsWith("_exception")))
786765
}
787766
}
788-
789-
object FetcherTestUtil {
790-
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
791-
def joinResponses(spark: SparkSession,
792-
requests: Array[Request],
793-
mockApi: MockApi,
794-
useJavaFetcher: Boolean = false,
795-
runCount: Int = 1,
796-
samplePercent: Double = -1,
797-
logToHive: Boolean = false,
798-
debug: Boolean = false)(implicit ec: ExecutionContext): (List[Response], DataFrame) = {
799-
val chunkSize = 100
800-
@transient lazy val fetcher = mockApi.buildFetcher(debug)
801-
@transient lazy val javaFetcher = mockApi.buildJavaFetcher()
802-
803-
def fetchOnce = {
804-
var latencySum: Long = 0
805-
var latencyCount = 0
806-
val blockStart = System.currentTimeMillis()
807-
val result = requests.iterator
808-
.grouped(chunkSize)
809-
.map { oldReqs =>
810-
// deliberately mis-type a few keys
811-
val r = oldReqs
812-
.map(r =>
813-
r.copy(keys = r.keys.mapValues { v =>
814-
if (v.isInstanceOf[java.lang.Long]) v.toString else v
815-
}.toMap))
816-
val responses = if (useJavaFetcher) {
817-
// Converting to java request and using the toScalaRequest functionality to test conversion
818-
val convertedJavaRequests = r.map(new JavaRequest(_)).toJava
819-
val javaResponse = javaFetcher.fetchJoin(convertedJavaRequests)
820-
FutureConverters
821-
.toScala(javaResponse)
822-
.map(
823-
_.toScala.map(jres =>
824-
Response(
825-
Request(jres.request.name, jres.request.keys.toScala.toMap, Option(jres.request.atMillis)),
826-
jres.values.toScala.map(_.toScala)
827-
)))
828-
} else {
829-
fetcher.fetchJoin(r)
830-
}
831-
832-
// fix mis-typed keys in the request
833-
val fixedResponses =
834-
responses.map(resps => resps.zip(oldReqs).map { case (resp, req) => resp.copy(request = req) })
835-
System.currentTimeMillis() -> fixedResponses
836-
}
837-
.flatMap { case (start, future) =>
838-
val result = Await.result(future, Duration(10000, SECONDS)) // todo: change back to millis
839-
val latency = System.currentTimeMillis() - start
840-
latencySum += latency
841-
latencyCount += 1
842-
result
843-
}
844-
.toList
845-
val latencyMillis = latencySum.toFloat / latencyCount.toFloat
846-
val qps = (requests.length * 1000.0) / (System.currentTimeMillis() - blockStart).toFloat
847-
(latencyMillis, qps, result)
848-
}
849-
850-
// to overwhelm the profiler with fetching code path
851-
// so as to make it prominent in the flamegraph & collect enough stats
852-
853-
var latencySum = 0.0
854-
var qpsSum = 0.0
855-
var loggedValues: Seq[LoggableResponseBase64] = null
856-
var result: List[Response] = null
857-
(0 until runCount).foreach { _ =>
858-
val (latency, qps, resultVal) = fetchOnce
859-
result = resultVal
860-
loggedValues = mockApi.flushLoggedValues
861-
latencySum += latency
862-
qpsSum += qps
863-
}
864-
val fetcherNameString = if (useJavaFetcher) "Java" else "Scala"
865-
866-
logger.info(s"""
867-
|Averaging fetching stats for $fetcherNameString Fetcher over ${requests.length} requests $runCount times
868-
|with batch size: $chunkSize
869-
|average qps: ${qpsSum / runCount}
870-
|average latency: ${latencySum / runCount}
871-
|""".stripMargin)
872-
val loggedDf = mockApi.loggedValuesToDf(loggedValues, spark)
873-
if (logToHive) {
874-
TableUtils(spark).insertPartitions(
875-
loggedDf,
876-
mockApi.logTable,
877-
partitionColumns = Seq("ds", "name")
878-
)
879-
}
880-
if (samplePercent > 0) {
881-
logger.info(s"logged count: ${loggedDf.count()}")
882-
loggedDf.show()
883-
}
884-
result -> loggedDf
885-
}
886-
}

0 commit comments

Comments
 (0)