Skip to content

chore: split out expensive spark tests to parallelize #382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions .github/workflows/test_scala_spark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,133 @@ jobs:
--google_credentials=bazel-cache-key.json \
--test_env=JAVA_OPTS="-Xmx8G -Xms2G" \
//spark:tests

fetcher_tests:
runs-on: ubuntu-8_cores-32_gb
container:
image: ghcr.io/${{ github.repository }}-ci:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
defaults:
run:
working-directory: ${{ github.workspace }}

steps:
- uses: actions/checkout@v4

- name: Setup Bazel cache credentials
run: |
echo "${{ secrets.BAZEL_CACHE_CREDENTIALS }}" | base64 -d > bazel-cache-key.json

- name: Run Fetcher tests
run: |
bazel test \
--remote_cache=https://storage.googleapis.com/zipline-bazel-cache \
--google_credentials=bazel-cache-key.json \
--test_env=JAVA_OPTS="-Xmx16G -Xms8G" \
//spark:fetcher_test

join_tests:
runs-on: ubuntu-8_cores-32_gb
container:
image: ghcr.io/${{ github.repository }}-ci:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
defaults:
run:
working-directory: ${{ github.workspace }}

steps:
- uses: actions/checkout@v4

- name: Setup Bazel cache credentials
run: |
echo "${{ secrets.BAZEL_CACHE_CREDENTIALS }}" | base64 -d > bazel-cache-key.json

- name: Run Join tests
run: |
bazel test \
--remote_cache=https://storage.googleapis.com/zipline-bazel-cache \
--google_credentials=bazel-cache-key.json \
--test_env=JAVA_OPTS="-Xmx16G -Xms8G" \
//spark:join_test

groupby_tests:
runs-on: ubuntu-8_cores-32_gb
container:
image: ghcr.io/${{ github.repository }}-ci:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
defaults:
run:
working-directory: ${{ github.workspace }}

steps:
- uses: actions/checkout@v4

- name: Setup Bazel cache credentials
run: |
echo "${{ secrets.BAZEL_CACHE_CREDENTIALS }}" | base64 -d > bazel-cache-key.json

- name: Run GroupBy tests
run: |
bazel test \
--remote_cache=https://storage.googleapis.com/zipline-bazel-cache \
--google_credentials=bazel-cache-key.json \
--test_env=JAVA_OPTS="-Xmx16G -Xms8G" \
//spark:groupby_test

analyzer_tests:
runs-on: ubuntu-8_cores-32_gb
container:
image: ghcr.io/${{ github.repository }}-ci:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
defaults:
run:
working-directory: ${{ github.workspace }}

steps:
- uses: actions/checkout@v4

- name: Setup Bazel cache credentials
run: |
echo "${{ secrets.BAZEL_CACHE_CREDENTIALS }}" | base64 -d > bazel-cache-key.json

- name: Run Analyzer tests
run: |
bazel test \
--remote_cache=https://storage.googleapis.com/zipline-bazel-cache \
--google_credentials=bazel-cache-key.json \
--test_env=JAVA_OPTS="-Xmx16G -Xms8G" \
//spark:analyzer_test

streaming_tests:
runs-on: ubuntu-8_cores-32_gb
container:
image: ghcr.io/${{ github.repository }}-ci:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
defaults:
run:
working-directory: ${{ github.workspace }}

steps:
- uses: actions/checkout@v4

- name: Setup Bazel cache credentials
run: |
echo "${{ secrets.BAZEL_CACHE_CREDENTIALS }}" | base64 -d > bazel-cache-key.json

- name: Run Streaming tests
run: |
bazel test \
--remote_cache=https://storage.googleapis.com/zipline-bazel-cache \
--google_credentials=bazel-cache-key.json \
--test_env=JAVA_OPTS="-Xmx16G -Xms8G" \
//spark:streaming_test
58 changes: 56 additions & 2 deletions spark/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,67 @@ scala_library(
name = "test_lib",
srcs = glob(["src/test/**/*.scala"]),
format = True,
visibility = ["//visibility:public"],
deps = test_deps,
)

scala_test_suite(
name = "tests",
srcs = glob(["src/test/**/*.scala"]),
tags = ["large"],
srcs = glob(["src/test/scala/ai/chronon/spark/test/*.scala",
"src/test/scala/ai/chronon/spark/test/udafs/*.scala",
"src/test/scala/ai/chronon/spark/test/stats/drift/*.scala",
"src/test/scala/ai/chronon/spark/test/bootstrap/*.scala"]),
data = glob(["spark/src/test/resources/**/*"]),
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
)

scala_test_suite(
name = "fetcher_test",
srcs = glob(["src/test/scala/ai/chronon/spark/test/fetcher/*.scala"]),
resources = ["//spark/src/test/resources:test-resources"],
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
)

scala_test_suite(
name = "groupby_test",
srcs = glob(["src/test/scala/ai/chronon/spark/test/groupby/*.scala"]),
data = glob(["spark/src/test/resources/**/*"]),
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
)

scala_test_suite(
name = "join_test",
srcs = glob(["src/test/scala/ai/chronon/spark/test/join/*.scala"]),
tags = ["large"],
data = glob(["spark/src/test/resources/**/*"]),
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
)

scala_test_suite(
name = "analyzer_test",
srcs = glob(["src/test/scala/ai/chronon/spark/test/analyzer/*.scala"]),
data = glob(["spark/src/test/resources/**/*"]),
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
)

scala_test_suite(
name = "streaming_test",
srcs = glob(["src/test/scala/ai/chronon/spark/test/streaming/*.scala"]),
data = glob(["spark/src/test/resources/**/*"]),
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,21 @@
* limitations under the License.
*/

package ai.chronon.spark.test
package ai.chronon.spark.test.analyzer

import ai.chronon.aggregator.test.Column
import ai.chronon.api
import ai.chronon.api._
import ai.chronon.spark.Analyzer
import ai.chronon.spark.Extensions._
import ai.chronon.spark.Join
import ai.chronon.spark.SparkSessionBuilder
import ai.chronon.spark.TableUtils
import ai.chronon.spark.{Analyzer, Join, SparkSessionBuilder, TableUtils}
import ai.chronon.spark.test.DataFrameGen
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.{col, lit}
import org.junit.Assert.assertTrue
import org.scalatest.BeforeAndAfter
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.{Logger, LoggerFactory}

class AnalyzerTest extends AnyFlatSpec with BeforeAndAfter {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package ai.chronon.spark.test.bootstrap
package ai.chronon.spark.test.analyzer

import ai.chronon.api.Builders.Derivation
import ai.chronon.api.Extensions._
Expand All @@ -24,17 +24,14 @@ import ai.chronon.online.Fetcher.Request
import ai.chronon.online.MetadataStore
import ai.chronon.spark.Extensions.DataframeOps
import ai.chronon.spark._
import ai.chronon.spark.test.OnlineUtils
import ai.chronon.spark.test.SchemaEvolutionUtils
import ai.chronon.spark.test.{OnlineUtils, SchemaEvolutionUtils}
import ai.chronon.spark.test.bootstrap.BootstrapUtils
import ai.chronon.spark.utils.MockApi
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.{Logger, LoggerFactory}

import scala.concurrent.Await
import scala.concurrent.duration.Duration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,26 @@
* limitations under the License.
*/

package ai.chronon.spark.test
package ai.chronon.spark.test.fetcher

import ai.chronon.aggregator.windowing.TsUtils
import ai.chronon.api
import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.Extensions.JoinOps
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api.Extensions.{JoinOps, MetadataOps}
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.api._
import ai.chronon.online.Fetcher.Request
import ai.chronon.online.MetadataStore
import ai.chronon.online.SparkConversions
import ai.chronon.online.{MetadataStore, SparkConversions}
import ai.chronon.spark.Extensions._
import ai.chronon.spark.test.{OnlineUtils, TestUtils}
import ai.chronon.spark.utils.MockApi
import ai.chronon.spark.{Join => _, _}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.functions.lit
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.junit.Assert.{assertEquals, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.{Logger, LoggerFactory}

import java.lang
import java.util.TimeZone
Expand Down
Loading
Loading