Skip to content

feat: support providing additional confs as yaml file for Driver.scala #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ lazy val spark = project
libraryDependencies += "jakarta.servlet" % "jakarta.servlet-api" % "4.0.3",
libraryDependencies += "com.google.guava" % "guava" % "33.3.1-jre",
libraryDependencies ++= log4j2,
libraryDependencies ++= delta.map(_ % "provided")
libraryDependencies ++= delta.map(_ % "provided"),
libraryDependencies += "org.json4s" % "json4s-jackson_2.12" % "3.7.0-M11", // This version is pinned to the one spark uses in 3.X.X - see: https://github.com/apache/spark/pull/45838
libraryDependencies += "org.json4s" %% "json4s-native" % "3.7.0-M11",
libraryDependencies += "org.json4s" %% "json4s-core" % "3.7.0-M11",
libraryDependencies += "org.yaml" % "snakeyaml" % "2.3"
)

lazy val flink = project
Expand All @@ -211,15 +215,16 @@ lazy val cloud_gcp = project
libraryDependencies += "com.google.cloud" % "google-cloud-bigquery" % "2.42.0",
libraryDependencies += "com.google.cloud" % "google-cloud-bigtable" % "2.41.0",
libraryDependencies += "com.google.cloud" % "google-cloud-pubsub" % "1.131.0",
libraryDependencies += "com.google.cloud" % "google-cloud-dataproc" % "4.51.0",
libraryDependencies += "com.google.cloud.bigdataoss" % "gcs-connector" % "3.0.3", // it's what's on the cluster
libraryDependencies += "com.google.cloud" % "google-cloud-dataproc" % "4.52.0",
libraryDependencies += "com.google.cloud.bigdataoss" % "gcs-connector" % "hadoop3-2.2.26",
libraryDependencies += "com.google.cloud.bigdataoss" % "gcsio" % "3.0.3", // need it for https://github.com/GoogleCloudDataproc/hadoop-connectors/blob/master/gcsio/src/main/java/com/google/cloud/hadoop/gcsio/GoogleCloudStorageFileSystem.java
libraryDependencies += "com.google.cloud.bigdataoss" % "util-hadoop" % "3.0.0", // need it for https://github.com/GoogleCloudDataproc/hadoop-connectors/blob/master/util-hadoop/src/main/java/com/google/cloud/hadoop/util/HadoopConfigurationProperty.java
libraryDependencies += "io.circe" %% "circe-yaml" % "1.15.0",
libraryDependencies += "com.google.cloud.spark" %% s"spark-bigquery-with-dependencies" % "0.41.0",
libraryDependencies += "com.google.cloud.spark" %% "spark-bigquery-with-dependencies" % "0.41.0",
libraryDependencies += "org.json4s" % "json4s-jackson_2.12" % "3.7.0-M11", // This version is pinned to the one spark uses in 3.X.X - see: https://github.com/apache/spark/pull/45838
libraryDependencies += "org.json4s" %% "json4s-native" % "3.7.0-M11",
libraryDependencies += "org.json4s" %% "json4s-core" % "3.7.0-M11",
libraryDependencies += "org.yaml" % "snakeyaml" % "2.3",
libraryDependencies += "com.google.cloud.bigtable" % "bigtable-hbase-2.x" % "2.14.2",
libraryDependencies ++= circe,
libraryDependencies ++= avro,
libraryDependencies ++= spark_all_provided,
dependencyOverrides ++= jackson,
Expand Down
2 changes: 2 additions & 0 deletions cloud_gcp/src/main/resources/additional-confs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
spark.chronon.table.format_provider.class: "ai.chronon.integrations.cloud_gcp.GcpFormatProvider"
spark.chronon.partition.format: "yyyy-MM-dd"
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package ai.chronon.integrations.cloud_gcp
import ai.chronon.spark.SparkAuth
import ai.chronon.spark.SparkSubmitter
import ai.chronon.spark.JobAuth
import ai.chronon.spark.JobSubmitter
import com.google.api.gax.rpc.ApiException
import com.google.cloud.dataproc.v1._
import io.circe.generic.auto._
import io.circe.yaml.parser
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.yaml.snakeyaml.Yaml

import scala.io.Source

Expand All @@ -27,7 +28,7 @@ case class GeneralJob(
mainClass: String
)

class DataprocSubmitter(jobControllerClient: JobControllerClient, conf: SubmitterConf) extends SparkSubmitter {
class DataprocSubmitter(jobControllerClient: JobControllerClient, conf: SubmitterConf) extends JobSubmitter {

override def status(jobId: String): Unit = {
try {
Expand Down Expand Up @@ -87,18 +88,23 @@ object DataprocSubmitter {
new DataprocSubmitter(jobControllerClient, conf)
}

def loadConfig: SubmitterConf = {
val is = getClass.getClassLoader.getResourceAsStream("dataproc-submitter-conf.yaml")
val confStr = Source.fromInputStream(is).mkString
val res: Either[io.circe.Error, SubmitterConf] = parser
.parse(confStr)
.flatMap(_.as[SubmitterConf])
res match {
private[cloud_gcp] def loadConfig: SubmitterConf = {
val inputStreamOption = Option(getClass.getClassLoader.getResourceAsStream("dataproc-submitter-conf.yaml"))
val yamlLoader = new Yaml()
implicit val formats: Formats = DefaultFormats
inputStreamOption
.map(Source.fromInputStream)
.map((is) =>
try { is.mkString }
finally { is.close })
.map(yamlLoader.load(_).asInstanceOf[java.util.Map[String, Any]])
.map((jMap) => Extraction.decompose(jMap.asScala.toMap))
.map((jVal) => render(jVal))
.map(compact)
.map(parse(_).extract[SubmitterConf])
.getOrElse(throw new IllegalArgumentException("Yaml conf not found or invalid yaml"))

case Right(v) => v
case Left(e) => throw e
}
}
}

object DataprocAuth extends SparkAuth {}
object DataprocAuth extends JobAuth {}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import ai.chronon.spark.TableUtils
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem
import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystemConfiguration
import com.google.cloud.hadoop.util.HadoopConfigurationProperty
import com.google.cloud.hadoop.fs.gcs.HadoopConfigurationProperty
import com.google.cloud.hadoop.gcsio.GoogleCloudStorageFileSystem
import org.apache.spark.sql.SparkSession
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
Expand Down Expand Up @@ -37,6 +38,7 @@ class BigQueryCatalogTest extends AnyFunSuite with MockitoSugar {
assertTrue(GoogleHadoopFileSystemConfiguration.BLOCK_SIZE.isInstanceOf[HadoopConfigurationProperty[Long]])
assertCompiles("classOf[GoogleHadoopFileSystem]")
assertCompiles("classOf[GoogleHadoopFS]")
assertCompiles("classOf[GoogleCloudStorageFileSystem]")

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ class BigTableKVStoreTest {
@Before
def setup(): Unit = {
// Configure settings to use emulator
val dataSettings = BigtableDataSettings.newBuilderForEmulator(bigtableEmulator.getPort)
val dataSettings = BigtableDataSettings
.newBuilderForEmulator(bigtableEmulator.getPort)
.setProjectId(projectId)
.setInstanceId(instanceId)
.setCredentialsProvider(NoCredentialsProvider.create())
.build()

val adminSettings = BigtableTableAdminSettings.newBuilderForEmulator(bigtableEmulator.getPort)
val adminSettings = BigtableTableAdminSettings
.newBuilderForEmulator(bigtableEmulator.getPort)
.setProjectId(projectId)
.setInstanceId(instanceId)
.setCredentialsProvider(NoCredentialsProvider.create())
Expand Down Expand Up @@ -153,11 +155,10 @@ class BigTableKVStoreTest {
val kvStore = new BigTableKVStoreImpl(dataClient, adminClient)
kvStore.create(dataset)

val putReqs = (0 until 100).map {
i =>
val key = s"key-$i"
val value = s"""{"name": "name-$i", "age": $i}"""
PutRequest(key.getBytes, value.getBytes, dataset, None)
val putReqs = (0 until 100).map { i =>
val key = s"key-$i"
val value = s"""{"name": "name-$i", "age": $i}"""
PutRequest(key.getBytes, value.getBytes, dataset, None)
}

val putResults = Await.result(kvStore.multiPut(putReqs), 1.second)
Expand Down Expand Up @@ -185,7 +186,9 @@ class BigTableKVStoreTest {

// lets collect all the keys and confirm we got everything
val allKeys = (listValues1 ++ listValues2).map(v => new String(v.keyBytes, StandardCharsets.UTF_8))
allKeys.toSet shouldBe putReqs.map(r => new String(buildRowKey(r.keyBytes, r.dataset), StandardCharsets.UTF_8)).toSet
allKeys.toSet shouldBe putReqs
.map(r => new String(buildRowKey(r.keyBytes, r.dataset), StandardCharsets.UTF_8))
.toSet
}

@Test
Expand Down Expand Up @@ -227,7 +230,8 @@ class BigTableKVStoreTest {

when(mockDataClient.readRowsCallable()).thenReturn(serverStreamingCallable)
when(serverStreamingCallable.all()).thenReturn(unaryCallable)
val failedFuture = ApiFutures.immediateFailedFuture[util.List[Row]](new RuntimeException("some BT exception on read"))
val failedFuture =
ApiFutures.immediateFailedFuture[util.List[Row]](new RuntimeException("some BT exception on read"))
when(unaryCallable.futureCall(any[Query])).thenReturn(failedFuture)

val getResult = Await.result(kvStoreWithMocks.multiGet(Seq(getReq1, getReq2)), 1.second)
Expand Down Expand Up @@ -323,11 +327,15 @@ class BigTableKVStoreTest {
val getResult1 = Await.result(kvStore.multiGet(Seq(getRequest1)), 1.second)
getResult1.size shouldBe 1
// we expect results to only cover the time range where we have data
val expectedTimeSeriesPoints = (queryStartsTs until dataEndTs by 1.hour.toMillis).toSeq
val expectedTimeSeriesPoints = (queryStartsTs until dataEndTs by 1.hour.toMillis).toSeq
validateTimeSeriesValueExpectedPayload(getResult1.head, expectedTimeSeriesPoints, fakePayload)
}

private def writeGeneratedTimeSeriesData(kvStore: BigTableKVStoreImpl, dataset: String, key: String, tsRange: Seq[Long], payload: String): Unit = {
private def writeGeneratedTimeSeriesData(kvStore: BigTableKVStoreImpl,
dataset: String,
key: String,
tsRange: Seq[Long],
payload: String): Unit = {
val points = Seq.fill(tsRange.size)(payload)
val putRequests = tsRange.zip(points).map {
case (ts, point) =>
Expand All @@ -350,10 +358,10 @@ class BigTableKVStoreTest {
}
}

private def validateTimeSeriesValueExpectedPayload(response: GetResponse, expectedTimestamps: Seq[Long], expectedPayload: String): Unit = {
for (
tSeq <- response.values
) {
private def validateTimeSeriesValueExpectedPayload(response: GetResponse,
expectedTimestamps: Seq[Long],
expectedPayload: String): Unit = {
for (tSeq <- response.values) {
tSeq.map(_.millis).toSet shouldBe expectedTimestamps.toSet
tSeq.map(v => new String(v.bytes, StandardCharsets.UTF_8)).foreach(v => v shouldBe expectedPayload)
tSeq.length shouldBe expectedTimestamps.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,15 @@ class DataprocSubmitterTest extends AnyFunSuite with MockitoSugar {

val submitter = DataprocSubmitter()
val submittedJobId =
submitter.submit(List("gs://dataproc-temp-us-central1-703996152583-pqtvfptb/jars/training_set.v1"),
"join",
"--end-date=2024-12-10",
"--conf-path=training_set.v1")
submitter.submit(
List("gs://zipline-jars/training_set.v1",
"gs://zipline-jars/dataproc-submitter-conf.yaml",
"gs://zipline-jars/additional-confs.yaml"),
"join",
"--end-date=2024-12-10",
"--additional-conf-path=additional-confs.yaml",
"--conf-path=training_set.v1"
)
println(submittedJobId)
}

Expand Down
49 changes: 33 additions & 16 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryStartedEvent
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryTerminatedEvent
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.rogach.scallop.ScallopConf
import org.rogach.scallop.ScallopOption
import org.rogach.scallop.Subcommand
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.yaml.snakeyaml.Yaml

import java.io.File
import java.nio.file.Files
Expand Down Expand Up @@ -87,6 +90,9 @@ object Driver {
this: ScallopConf =>
val confPath: ScallopOption[String] = opt[String](required = true, descr = "Path to conf")

val additionalConfPath: ScallopOption[String] =
opt[String](required = false, descr = "Path to additional driver job configurations")

val runFirstHole: ScallopOption[Boolean] =
opt[Boolean](required = false,
default = Some(false),
Expand Down Expand Up @@ -144,33 +150,44 @@ object Driver {

def subcommandName(): String

def isLocal: Boolean = localTableMapping.nonEmpty || localDataPath.isDefined
protected def isLocal: Boolean = localTableMapping.nonEmpty || localDataPath.isDefined

protected def buildSparkSession(): SparkSession = {
implicit val formats: Formats = DefaultFormats
val yamlLoader = new Yaml()
val additionalConfs = additionalConfPath.toOption
.map(Source.fromFile)
.map((src) =>
try { src.mkString }
finally { src.close })
.map(yamlLoader.load(_).asInstanceOf[java.util.Map[String, Any]])
.map((map) => Extraction.decompose(map.asScala.toMap))
.map((v) => render(v))
.map(compact)
.map((str) => parse(str).extract[Map[String, String]])

// We use the KryoSerializer for group bys and joins since we serialize the IRs.
// But since staging query is fairly freeform, it's better to stick to the java serializer.
val session =
SparkSessionBuilder.build(
subcommandName(),
local = isLocal,
localWarehouseLocation = localWarehouseLocation.toOption,
enforceKryoSerializer = !subcommandName().contains("staging_query"),
additionalConfig = additionalConfs
)
if (localTableMapping.nonEmpty) {
val localSession = SparkSessionBuilder.build(subcommandName(),
local = true,
localWarehouseLocation = localWarehouseLocation.toOption)
localTableMapping.foreach {
case (table, filePath) =>
val file = new File(filePath)
LocalDataLoader.loadDataFileAsTable(file, localSession, table)
LocalDataLoader.loadDataFileAsTable(file, session, table)
}
localSession
} else if (localDataPath.isDefined) {
val dir = new File(localDataPath())
assert(dir.exists, s"Provided local data path: ${localDataPath()} doesn't exist")
val localSession =
SparkSessionBuilder.build(subcommandName(),
local = true,
localWarehouseLocation = localWarehouseLocation.toOption)
LocalDataLoader.loadDataRecursively(dir, localSession)
localSession
} else {
// We use the KryoSerializer for group bys and joins since we serialize the IRs.
// But since staging query is fairly freeform, it's better to stick to the java serializer.
SparkSessionBuilder.build(subcommandName(), enforceKryoSerializer = !subcommandName().contains("staging_query"))
LocalDataLoader.loadDataRecursively(dir, session)
}
session
}

def buildTableUtils(): TableUtils = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package ai.chronon.spark

trait SparkSubmitter {
trait JobSubmitter {

def submit(files: List[String], args: String*): String

Expand All @@ -9,6 +9,6 @@ trait SparkSubmitter {
def kill(jobId: String): Unit
}

abstract class SparkAuth {
abstract class JobAuth {
def token(): Unit = {}
}
1 change: 1 addition & 0 deletions spark/src/test/resources/test-driver-additional-confs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test.yaml.key: "test_yaml_key"
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ class MigrationCompareTest {

//--------------------------------Staging Query-----------------------------
val stagingQueryConf = Builders.StagingQuery(
query = s"select * from ${joinConf.metaData.outputTable} WHERE ds BETWEEN '{{ start_date }}' AND '{{ end_date }}'",
query =
s"select * from ${joinConf.metaData.outputTable} WHERE ds BETWEEN '{{ start_date }}' AND '{{ end_date }}'",
startPartition = ninetyDaysAgo,
metaData = Builders.MetaData(name = "test.item_snapshot_features_sq_3",
namespace = namespace,
tableProperties = Map("key" -> "val"))
namespace = namespace,
tableProperties = Map("key" -> "val"))
)

(joinConf, stagingQueryConf)
Expand All @@ -113,8 +114,8 @@ class MigrationCompareTest {
query = s"select item, ts, ds from ${joinConf.metaData.outputTable}",
startPartition = ninetyDaysAgo,
metaData = Builders.MetaData(name = "test.item_snapshot_features_sq_4",
namespace = namespace,
tableProperties = Map("key" -> "val"))
namespace = namespace,
tableProperties = Map("key" -> "val"))
)

val (compareDf, metricsDf, metrics: DataMetrics) =
Expand All @@ -141,8 +142,8 @@ class MigrationCompareTest {
query = s"select * from ${joinConf.metaData.outputTable} where ds BETWEEN '${monthAgo}' AND '${today}'",
startPartition = ninetyDaysAgo,
metaData = Builders.MetaData(name = "test.item_snapshot_features_sq_5",
namespace = namespace,
tableProperties = Map("key" -> "val"))
namespace = namespace,
tableProperties = Map("key" -> "val"))
)

val (compareDf, metricsDf, metrics: DataMetrics) =
Expand Down
Loading
Loading