Skip to content

Commit ea25fc7

Browse files
feat: support providing additional confs as yaml file for Driver.scala (#164)
## Summary - We want to be able to configure the spark jobs with additional confs when running them via `Driver.scala`. Let's thread through some conf files. ## Checklist - [x] Added Unit Tests - [x] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Dependencies** - Updated Google Cloud Dataproc dependency to 4.52.0 - Added JSON processing and YAML support dependencies - **Configuration** - Added support for additional configuration files - Introduced new configuration options for Spark table format and partition format - **Testing** - Enhanced test cases for configuration parsing - Added tests for Google Cloud runtime classes - Improved BigQuery catalog and Dataproc submitter tests - **Code Refactoring** - Renamed `SparkSubmitter` to `JobSubmitter` - Renamed `SparkAuth` to `JobAuth` - Updated configuration loading mechanisms - Streamlined visibility and access of methods in various classes <!-- end of auto-generated comment: release notes by coderabbit.ai --> <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` --> --------- Co-authored-by: Thomas Chow <[email protected]>
1 parent a745bdf commit ea25fc7

File tree

11 files changed

+153
-69
lines changed

11 files changed

+153
-69
lines changed

build.sbt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ lazy val spark = project
189189
libraryDependencies += "jakarta.servlet" % "jakarta.servlet-api" % "4.0.3",
190190
libraryDependencies += "com.google.guava" % "guava" % "33.3.1-jre",
191191
libraryDependencies ++= log4j2,
192-
libraryDependencies ++= delta.map(_ % "provided")
192+
libraryDependencies ++= delta.map(_ % "provided"),
193+
libraryDependencies += "org.json4s" % "json4s-jackson_2.12" % "3.7.0-M11", // This version is pinned to the one spark uses in 3.X.X - see: https://github.com/apache/spark/pull/45838
194+
libraryDependencies += "org.json4s" %% "json4s-native" % "3.7.0-M11",
195+
libraryDependencies += "org.json4s" %% "json4s-core" % "3.7.0-M11",
196+
libraryDependencies += "org.yaml" % "snakeyaml" % "2.3"
193197
)
194198

195199
lazy val flink = project
@@ -211,15 +215,16 @@ lazy val cloud_gcp = project
211215
libraryDependencies += "com.google.cloud" % "google-cloud-bigquery" % "2.42.0",
212216
libraryDependencies += "com.google.cloud" % "google-cloud-bigtable" % "2.41.0",
213217
libraryDependencies += "com.google.cloud" % "google-cloud-pubsub" % "1.131.0",
214-
libraryDependencies += "com.google.cloud" % "google-cloud-dataproc" % "4.51.0",
215-
libraryDependencies += "com.google.cloud.bigdataoss" % "gcs-connector" % "3.0.3", // it's what's on the cluster
218+
libraryDependencies += "com.google.cloud" % "google-cloud-dataproc" % "4.52.0",
216219
libraryDependencies += "com.google.cloud.bigdataoss" % "gcs-connector" % "hadoop3-2.2.26",
217220
libraryDependencies += "com.google.cloud.bigdataoss" % "gcsio" % "3.0.3", // need it for https://github.com/GoogleCloudDataproc/hadoop-connectors/blob/master/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/GoogleCloudStorageFileSystem.java
218221
libraryDependencies += "com.google.cloud.bigdataoss" % "util-hadoop" % "3.0.0", // need it for https://github.com/GoogleCloudDataproc/hadoop-connectors/blob/master/util-hadoop/src/main/java/com/google/cloud/hadoop/util/HadoopConfigurationProperty.java
219-
libraryDependencies += "io.circe" %% "circe-yaml" % "1.15.0",
220-
libraryDependencies += "com.google.cloud.spark" %% s"spark-bigquery-with-dependencies" % "0.41.0",
222+
libraryDependencies += "com.google.cloud.spark" %% "spark-bigquery-with-dependencies" % "0.41.0",
223+
libraryDependencies += "org.json4s" % "json4s-jackson_2.12" % "3.7.0-M11", // This version is pinned to the one spark uses in 3.X.X - see: https://github.com/apache/spark/pull/45838
224+
libraryDependencies += "org.json4s" %% "json4s-native" % "3.7.0-M11",
225+
libraryDependencies += "org.json4s" %% "json4s-core" % "3.7.0-M11",
226+
libraryDependencies += "org.yaml" % "snakeyaml" % "2.3",
221227
libraryDependencies += "com.google.cloud.bigtable" % "bigtable-hbase-2.x" % "2.14.2",
222-
libraryDependencies ++= circe,
223228
libraryDependencies ++= avro,
224229
libraryDependencies ++= spark_all_provided,
225230
dependencyOverrides ++= jackson,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
spark.chronon.table.format_provider.class: "ai.chronon.integrations.cloud_gcp.GcpFormatProvider"
2+
spark.chronon.partition.format: "yyyy-MM-dd"

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package ai.chronon.integrations.cloud_gcp
2-
import ai.chronon.spark.SparkAuth
3-
import ai.chronon.spark.SparkSubmitter
2+
import ai.chronon.spark.JobAuth
3+
import ai.chronon.spark.JobSubmitter
44
import com.google.api.gax.rpc.ApiException
55
import com.google.cloud.dataproc.v1._
6-
import io.circe.generic.auto._
7-
import io.circe.yaml.parser
6+
import org.json4s._
7+
import org.json4s.jackson.JsonMethods._
8+
import org.yaml.snakeyaml.Yaml
89

910
import scala.io.Source
1011

@@ -27,7 +28,7 @@ case class GeneralJob(
2728
mainClass: String
2829
)
2930

30-
class DataprocSubmitter(jobControllerClient: JobControllerClient, conf: SubmitterConf) extends SparkSubmitter {
31+
class DataprocSubmitter(jobControllerClient: JobControllerClient, conf: SubmitterConf) extends JobSubmitter {
3132

3233
override def status(jobId: String): Unit = {
3334
try {
@@ -87,18 +88,23 @@ object DataprocSubmitter {
8788
new DataprocSubmitter(jobControllerClient, conf)
8889
}
8990

90-
def loadConfig: SubmitterConf = {
91-
val is = getClass.getClassLoader.getResourceAsStream("dataproc-submitter-conf.yaml")
92-
val confStr = Source.fromInputStream(is).mkString
93-
val res: Either[io.circe.Error, SubmitterConf] = parser
94-
.parse(confStr)
95-
.flatMap(_.as[SubmitterConf])
96-
res match {
91+
private[cloud_gcp] def loadConfig: SubmitterConf = {
92+
val inputStreamOption = Option(getClass.getClassLoader.getResourceAsStream("dataproc-submitter-conf.yaml"))
93+
val yamlLoader = new Yaml()
94+
implicit val formats: Formats = DefaultFormats
95+
inputStreamOption
96+
.map(Source.fromInputStream)
97+
.map((is) =>
98+
try { is.mkString }
99+
finally { is.close })
100+
.map(yamlLoader.load(_).asInstanceOf[java.util.Map[String, Any]])
101+
.map((jMap) => Extraction.decompose(jMap.asScala.toMap))
102+
.map((jVal) => render(jVal))
103+
.map(compact)
104+
.map(parse(_).extract[SubmitterConf])
105+
.getOrElse(throw new IllegalArgumentException("Yaml conf not found or invalid yaml"))
97106

98-
case Right(v) => v
99-
case Left(e) => throw e
100-
}
101107
}
102108
}
103109

104-
object DataprocAuth extends SparkAuth {}
110+
object DataprocAuth extends JobAuth {}

cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import ai.chronon.spark.TableUtils
55
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS
66
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem
77
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystemConfiguration
8-
import com.google.cloud.hadoop.util.HadoopConfigurationProperty
8+
import com.google.cloud.hadoop.fs.gcs.HadoopConfigurationProperty
9+
import com.google.cloud.hadoop.gcsio.GoogleCloudStorageFileSystem
910
import org.apache.spark.sql.SparkSession
1011
import org.junit.Assert.assertEquals
1112
import org.junit.Assert.assertTrue
@@ -37,6 +38,7 @@ class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar {
3738
assertTrue(GoogleHadoopFileSystemConfiguration.BLOCK_SIZE.isInstanceOf[HadoopConfigurationProperty[Long]])
3839
assertCompiles("classOf[GoogleHadoopFileSystem]")
3940
assertCompiles("classOf[GoogleHadoopFS]")
41+
assertCompiles("classOf[GoogleCloudStorageFileSystem]")
4042

4143
}
4244

cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreTest.scala

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ class BigTableKVStoreTest {
5252
@Before
5353
def setup(): Unit = {
5454
// Configure settings to use emulator
55-
val dataSettings = BigtableDataSettings.newBuilderForEmulator(bigtableEmulator.getPort)
55+
val dataSettings = BigtableDataSettings
56+
.newBuilderForEmulator(bigtableEmulator.getPort)
5657
.setProjectId(projectId)
5758
.setInstanceId(instanceId)
5859
.setCredentialsProvider(NoCredentialsProvider.create())
5960
.build()
6061

61-
val adminSettings = BigtableTableAdminSettings.newBuilderForEmulator(bigtableEmulator.getPort)
62+
val adminSettings = BigtableTableAdminSettings
63+
.newBuilderForEmulator(bigtableEmulator.getPort)
6264
.setProjectId(projectId)
6365
.setInstanceId(instanceId)
6466
.setCredentialsProvider(NoCredentialsProvider.create())
@@ -153,11 +155,10 @@ class BigTableKVStoreTest {
153155
val kvStore = new BigTableKVStoreImpl(dataClient, adminClient)
154156
kvStore.create(dataset)
155157

156-
val putReqs = (0 until 100).map {
157-
i =>
158-
val key = s"key-$i"
159-
val value = s"""{"name": "name-$i", "age": $i}"""
160-
PutRequest(key.getBytes, value.getBytes, dataset, None)
158+
val putReqs = (0 until 100).map { i =>
159+
val key = s"key-$i"
160+
val value = s"""{"name": "name-$i", "age": $i}"""
161+
PutRequest(key.getBytes, value.getBytes, dataset, None)
161162
}
162163

163164
val putResults = Await.result(kvStore.multiPut(putReqs), 1.second)
@@ -185,7 +186,9 @@ class BigTableKVStoreTest {
185186

186187
// lets collect all the keys and confirm we got everything
187188
val allKeys = (listValues1 ++ listValues2).map(v => new String(v.keyBytes, StandardCharsets.UTF_8))
188-
allKeys.toSet shouldBe putReqs.map(r => new String(buildRowKey(r.keyBytes, r.dataset), StandardCharsets.UTF_8)).toSet
189+
allKeys.toSet shouldBe putReqs
190+
.map(r => new String(buildRowKey(r.keyBytes, r.dataset), StandardCharsets.UTF_8))
191+
.toSet
189192
}
190193

191194
@Test
@@ -227,7 +230,8 @@ class BigTableKVStoreTest {
227230

228231
when(mockDataClient.readRowsCallable()).thenReturn(serverStreamingCallable)
229232
when(serverStreamingCallable.all()).thenReturn(unaryCallable)
230-
val failedFuture = ApiFutures.immediateFailedFuture[util.List[Row]](new RuntimeException("some BT exception on read"))
233+
val failedFuture =
234+
ApiFutures.immediateFailedFuture[util.List[Row]](new RuntimeException("some BT exception on read"))
231235
when(unaryCallable.futureCall(any[Query])).thenReturn(failedFuture)
232236

233237
val getResult = Await.result(kvStoreWithMocks.multiGet(Seq(getReq1, getReq2)), 1.second)
@@ -323,11 +327,15 @@ class BigTableKVStoreTest {
323327
val getResult1 = Await.result(kvStore.multiGet(Seq(getRequest1)), 1.second)
324328
getResult1.size shouldBe 1
325329
// we expect results to only cover the time range where we have data
326-
val expectedTimeSeriesPoints = (queryStartsTs until dataEndTs by 1.hour.toMillis).toSeq
330+
val expectedTimeSeriesPoints = (queryStartsTs until dataEndTs by 1.hour.toMillis).toSeq
327331
validateTimeSeriesValueExpectedPayload(getResult1.head, expectedTimeSeriesPoints, fakePayload)
328332
}
329333

330-
private def writeGeneratedTimeSeriesData(kvStore: BigTableKVStoreImpl, dataset: String, key: String, tsRange: Seq[Long], payload: String): Unit = {
334+
private def writeGeneratedTimeSeriesData(kvStore: BigTableKVStoreImpl,
335+
dataset: String,
336+
key: String,
337+
tsRange: Seq[Long],
338+
payload: String): Unit = {
331339
val points = Seq.fill(tsRange.size)(payload)
332340
val putRequests = tsRange.zip(points).map {
333341
case (ts, point) =>
@@ -350,10 +358,10 @@ class BigTableKVStoreTest {
350358
}
351359
}
352360

353-
private def validateTimeSeriesValueExpectedPayload(response: GetResponse, expectedTimestamps: Seq[Long], expectedPayload: String): Unit = {
354-
for (
355-
tSeq <- response.values
356-
) {
361+
private def validateTimeSeriesValueExpectedPayload(response: GetResponse,
362+
expectedTimestamps: Seq[Long],
363+
expectedPayload: String): Unit = {
364+
for (tSeq <- response.values) {
357365
tSeq.map(_.millis).toSet shouldBe expectedTimestamps.toSet
358366
tSeq.map(v => new String(v.bytes, StandardCharsets.UTF_8)).foreach(v => v shouldBe expectedPayload)
359367
tSeq.length shouldBe expectedTimestamps.length

cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitterTest.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,15 @@ class DataprocSubmitterTest extends AnyFunSuite with MockitoSugar {
5151

5252
val submitter = DataprocSubmitter()
5353
val submittedJobId =
54-
submitter.submit(List("gs://dataproc-temp-us-central1-703996152583-pqtvfptb/jars/training_set.v1"),
55-
"join",
56-
"--end-date=2024-12-10",
57-
"--conf-path=training_set.v1")
54+
submitter.submit(
55+
List("gs://zipline-jars/training_set.v1",
56+
"gs://zipline-jars/dataproc-submitter-conf.yaml",
57+
"gs://zipline-jars/additional-confs.yaml"),
58+
"join",
59+
"--end-date=2024-12-10",
60+
"--additional-conf-path=additional-confs.yaml",
61+
"--conf-path=training_set.v1"
62+
)
5863
println(submittedJobId)
5964
}
6065

spark/src/main/scala/ai/chronon/spark/Driver.scala

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
4949
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
5050
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent
5151
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryTerminatedEvent
52+
import org.json4s._
53+
import org.json4s.jackson.JsonMethods._
5254
import org.rogach.scallop.ScallopConf
5355
import org.rogach.scallop.ScallopOption
5456
import org.rogach.scallop.Subcommand
5557
import org.slf4j.Logger
5658
import org.slf4j.LoggerFactory
59+
import org.yaml.snakeyaml.Yaml
5760

5861
import java.io.File
5962
import java.nio.file.Files
@@ -87,6 +90,9 @@ object Driver {
8790
this: ScallopConf =>
8891
val confPath: ScallopOption[String] = opt[String](required = true, descr = "Path to conf")
8992

93+
val additionalConfPath: ScallopOption[String] =
94+
opt[String](required = false, descr = "Path to additional driver job configurations")
95+
9096
val runFirstHole: ScallopOption[Boolean] =
9197
opt[Boolean](required = false,
9298
default = Some(false),
@@ -144,33 +150,44 @@ object Driver {
144150

145151
def subcommandName(): String
146152

147-
def isLocal: Boolean = localTableMapping.nonEmpty || localDataPath.isDefined
153+
protected def isLocal: Boolean = localTableMapping.nonEmpty || localDataPath.isDefined
148154

149155
protected def buildSparkSession(): SparkSession = {
156+
implicit val formats: Formats = DefaultFormats
157+
val yamlLoader = new Yaml()
158+
val additionalConfs = additionalConfPath.toOption
159+
.map(Source.fromFile)
160+
.map((src) =>
161+
try { src.mkString }
162+
finally { src.close })
163+
.map(yamlLoader.load(_).asInstanceOf[java.util.Map[String, Any]])
164+
.map((map) => Extraction.decompose(map.asScala.toMap))
165+
.map((v) => render(v))
166+
.map(compact)
167+
.map((str) => parse(str).extract[Map[String, String]])
168+
169+
// We use the KryoSerializer for group bys and joins since we serialize the IRs.
170+
// But since staging query is fairly freeform, it's better to stick to the java serializer.
171+
val session =
172+
SparkSessionBuilder.build(
173+
subcommandName(),
174+
local = isLocal,
175+
localWarehouseLocation = localWarehouseLocation.toOption,
176+
enforceKryoSerializer = !subcommandName().contains("staging_query"),
177+
additionalConfig = additionalConfs
178+
)
150179
if (localTableMapping.nonEmpty) {
151-
val localSession = SparkSessionBuilder.build(subcommandName(),
152-
local = true,
153-
localWarehouseLocation = localWarehouseLocation.toOption)
154180
localTableMapping.foreach {
155181
case (table, filePath) =>
156182
val file = new File(filePath)
157-
LocalDataLoader.loadDataFileAsTable(file, localSession, table)
183+
LocalDataLoader.loadDataFileAsTable(file, session, table)
158184
}
159-
localSession
160185
} else if (localDataPath.isDefined) {
161186
val dir = new File(localDataPath())
162187
assert(dir.exists, s"Provided local data path: ${localDataPath()} doesn't exist")
163-
val localSession =
164-
SparkSessionBuilder.build(subcommandName(),
165-
local = true,
166-
localWarehouseLocation = localWarehouseLocation.toOption)
167-
LocalDataLoader.loadDataRecursively(dir, localSession)
168-
localSession
169-
} else {
170-
// We use the KryoSerializer for group bys and joins since we serialize the IRs.
171-
// But since staging query is fairly freeform, it's better to stick to the java serializer.
172-
SparkSessionBuilder.build(subcommandName(), enforceKryoSerializer = !subcommandName().contains("staging_query"))
188+
LocalDataLoader.loadDataRecursively(dir, session)
173189
}
190+
session
174191
}
175192

176193
def buildTableUtils(): TableUtils = {
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package ai.chronon.spark
22

3-
trait SparkSubmitter {
3+
trait JobSubmitter {
44

55
def submit(files: List[String], args: String*): String
66

@@ -9,6 +9,6 @@ trait SparkSubmitter {
99
def kill(jobId: String): Unit
1010
}
1111

12-
abstract class SparkAuth {
12+
abstract class JobAuth {
1313
def token(): Unit = {}
1414
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
test.yaml.key: "test_yaml_key"

spark/src/test/scala/ai/chronon/spark/test/MigrationCompareTest.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,12 @@ class MigrationCompareTest {
8484

8585
//--------------------------------Staging Query-----------------------------
8686
val stagingQueryConf = Builders.StagingQuery(
87-
query = s"select * from ${joinConf.metaData.outputTable} WHERE ds BETWEEN '{{ start_date }}' AND '{{ end_date }}'",
87+
query =
88+
s"select * from ${joinConf.metaData.outputTable} WHERE ds BETWEEN '{{ start_date }}' AND '{{ end_date }}'",
8889
startPartition = ninetyDaysAgo,
8990
metaData = Builders.MetaData(name = "test.item_snapshot_features_sq_3",
90-
namespace = namespace,
91-
tableProperties = Map("key" -> "val"))
91+
namespace = namespace,
92+
tableProperties = Map("key" -> "val"))
9293
)
9394

9495
(joinConf, stagingQueryConf)
@@ -113,8 +114,8 @@ class MigrationCompareTest {
113114
query = s"select item, ts, ds from ${joinConf.metaData.outputTable}",
114115
startPartition = ninetyDaysAgo,
115116
metaData = Builders.MetaData(name = "test.item_snapshot_features_sq_4",
116-
namespace = namespace,
117-
tableProperties = Map("key" -> "val"))
117+
namespace = namespace,
118+
tableProperties = Map("key" -> "val"))
118119
)
119120

120121
val (compareDf, metricsDf, metrics: DataMetrics) =
@@ -141,8 +142,8 @@ class MigrationCompareTest {
141142
query = s"select * from ${joinConf.metaData.outputTable} where ds BETWEEN '${monthAgo}' AND '${today}'",
142143
startPartition = ninetyDaysAgo,
143144
metaData = Builders.MetaData(name = "test.item_snapshot_features_sq_5",
144-
namespace = namespace,
145-
tableProperties = Map("key" -> "val"))
145+
namespace = namespace,
146+
tableProperties = Map("key" -> "val"))
146147
)
147148

148149
val (compareDf, metricsDf, metrics: DataMetrics) =

0 commit comments

Comments
 (0)