Skip to content

Wire up Flink DataProc job submission #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 15, 2025
Merged
32 changes: 30 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ val circe = Seq(
val flink_all = Seq(
"org.apache.flink" %% "flink-streaming-scala",
"org.apache.flink" % "flink-metrics-dropwizard",
"org.apache.flink" % "flink-clients"
"org.apache.flink" % "flink-clients",
"org.apache.flink" % "flink-yarn"
).map(_ % flink_1_17)

val vertx_java = Seq(
Expand Down Expand Up @@ -213,6 +214,22 @@ lazy val flink = project
.settings(
libraryDependencies ++= spark_all,
libraryDependencies ++= flink_all,
assembly / assemblyMergeStrategy := {
case PathList("META-INF", "services", xs @ _*) => MergeStrategy.concat
case "reference.conf" => MergeStrategy.concat
case "application.conf" => MergeStrategy.concat
case PathList("META-INF", xs @ _*) => MergeStrategy.discard
case _ => MergeStrategy.first
},
// Exclude Hadoop & Guava from the assembled JAR
// Else we hit an error - IllegalAccessError: class org.apache.hadoop.hdfs.web.HftpFileSystem cannot access its
// superinterface org.apache.hadoop.hdfs.web.TokenAspect$TokenManagementDelegator
// Or: java.lang.NoSuchMethodError: com.google.common.base.Preconditions.checkArgument(...)
// Or: 'com/google/protobuf/MapField' is not assignable to 'com/google/protobuf/MapFieldReflectionAccessor'
assembly / assemblyExcludedJars := {
val cp = (assembly / fullClasspath).value
cp filter { jar => jar.data.getName.startsWith("hadoop-") || jar.data.getName.startsWith("guava") || jar.data.getName.startsWith("protobuf")}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so pretty much we don't trust any of the transitives matching these prefixes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah for each of those there's a submit / runtime failure of the jobs (even with user jars first..)

},
libraryDependencies += "org.apache.flink" % "flink-test-utils" % flink_1_17 % Test excludeAll (
ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-api"),
ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-core"),
Expand All @@ -236,13 +253,24 @@ lazy val cloud_gcp = project
libraryDependencies += "org.json4s" %% "json4s-native" % "3.7.0-M11",
libraryDependencies += "org.json4s" %% "json4s-core" % "3.7.0-M11",
libraryDependencies += "org.yaml" % "snakeyaml" % "2.3",
libraryDependencies += "io.grpc" % "grpc-netty-shaded" % "1.62.2",
libraryDependencies ++= avro,
libraryDependencies ++= spark_all_provided,
dependencyOverrides ++= jackson,
// assembly merge settings to allow Flink jobs to kick off
assembly / assemblyMergeStrategy := {
case PathList("META-INF", "services", xs @ _*) => MergeStrategy.concat // Add to include channel provider
case PathList("META-INF", xs @ _*) => MergeStrategy.discard
case "reference.conf" => MergeStrategy.concat
case "application.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
},
libraryDependencies += "org.mockito" % "mockito-core" % "5.12.0" % Test,
libraryDependencies += "com.google.cloud" % "google-cloud-bigtable-emulator" % "0.178.0" % Test,
// force a newer version of reload4j to sidestep: https://security.snyk.io/vuln/SNYK-JAVA-CHQOSRELOAD4J-5731326
dependencyOverrides += "ch.qos.reload4j" % "reload4j" % "1.2.25"
dependencyOverrides ++= Seq(
"ch.qos.reload4j" % "reload4j" % "1.2.25",
)
)

lazy val cloud_gcp_submitter = project
Expand Down
2 changes: 0 additions & 2 deletions cloud_gcp/src/main/resources/dataproc-submitter-conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,3 @@
projectId: "canary-443022"
region: "us-central1"
clusterName: "canary-2"
jarUri: "gs://zipline-jars/cloud_gcp-assembly-0.1.0-SNAPSHOT.jar"
mainClass: "ai.chronon.spark.Driver"
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package ai.chronon.integrations.cloud_gcp
import ai.chronon.spark.JobAuth
import ai.chronon.spark.JobSubmitter
import ai.chronon.spark.JobSubmitterConstants.FlinkMainJarURI
import ai.chronon.spark.JobSubmitterConstants.JarURI
import ai.chronon.spark.JobSubmitterConstants.MainClass
import ai.chronon.spark.JobType
import ai.chronon.spark.{FlinkJob => TypeFlinkJob}
import ai.chronon.spark.{SparkJob => TypeSparkJob}
import com.google.api.gax.rpc.ApiException
import com.google.cloud.dataproc.v1._
import org.json4s._
Expand All @@ -14,9 +20,7 @@ import collection.JavaConverters._
case class SubmitterConf(
projectId: String,
region: String,
clusterName: String,
jarUri: String,
mainClass: String
clusterName: String
) {

def endPoint: String = s"${region}-dataproc.googleapis.com:443"
Expand Down Expand Up @@ -49,38 +53,67 @@ class DataprocSubmitter(jobControllerClient: JobControllerClient, conf: Submitte
job.getDone
}

override def submit(files: List[String], args: String*): String = {

val sparkJob = SparkJob
.newBuilder()
.setMainClass(conf.mainClass)
.addJarFileUris(conf.jarUri)
.addAllFileUris(files.asJava)
.addAllArgs(args.toIterable.asJava)
.build()
override def submit(jobType: JobType,
jobProperties: Map[String, String],
files: List[String],
args: String*): String = {
val mainClass = jobProperties.getOrElse(MainClass, throw new RuntimeException("Main class not found"))
val jarUri = jobProperties.getOrElse(JarURI, throw new RuntimeException("Jar URI not found"))

val jobBuilder = jobType match {
case TypeSparkJob => buildSparkJob(mainClass, jarUri, files, args: _*)
case TypeFlinkJob =>
val mainJarUri =
jobProperties.getOrElse(FlinkMainJarURI, throw new RuntimeException(s"Missing expected $FlinkMainJarURI"))
buildFlinkJob(mainClass, mainJarUri, jarUri, args: _*)
}

val jobPlacement = JobPlacement
.newBuilder()
.setClusterName(conf.clusterName)
.build()

try {
val job = Job
.newBuilder()
val job = jobBuilder
.setReference(jobReference)
.setPlacement(jobPlacement)
.setSparkJob(sparkJob)
.build()

val submittedJob = jobControllerClient.submitJob(conf.projectId, conf.region, job)
submittedJob.getReference.getJobId

} catch {
case e: ApiException =>
throw new RuntimeException(s"Failed to submit job: ${e.getMessage}")
throw new RuntimeException(s"Failed to submit job: ${e.getMessage}", e)
}
}

private def buildSparkJob(mainClass: String, jarUri: String, files: List[String], args: String*): Job.Builder = {
val sparkJob = SparkJob
.newBuilder()
.setMainClass(mainClass)
.addJarFileUris(jarUri)
.addAllFileUris(files.asJava)
.addAllArgs(args.toIterable.asJava)
.build()
Job.newBuilder().setSparkJob(sparkJob)
}

private def buildFlinkJob(mainClass: String, mainJarUri: String, jarUri: String, args: String*): Job.Builder = {
val envProps =
Map("jobmanager.memory.process.size" -> "4G", "yarn.classpath.include-user-jar" -> "FIRST")

Comment on lines +103 to +105
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make Flink environment properties configurable

Move hardcoded values to configuration.

-    val envProps =
-      Map("jobmanager.memory.process.size" -> "4G", "yarn.classpath.include-user-jar" -> "FIRST")
+    val envProps = jobProperties.getOrElse("flink.properties", 
+      Map("jobmanager.memory.process.size" -> "4G", "yarn.classpath.include-user-jar" -> "FIRST"))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
val envProps =
Map("jobmanager.memory.process.size" -> "4G", "yarn.classpath.include-user-jar" -> "FIRST")
val envProps = jobProperties.getOrElse("flink.properties",
Map("jobmanager.memory.process.size" -> "4G", "yarn.classpath.include-user-jar" -> "FIRST"))

val flinkJob = FlinkJob
.newBuilder()
.setMainClass(mainClass)
.setMainJarFileUri(mainJarUri)
.putAllProperties(envProps.asJava)
.addJarFileUris(jarUri)
.addAllArgs(args.toIterable.asJava)
.build()
Job.newBuilder().setFlinkJob(flinkJob)
}

def jobReference: JobReference = JobReference.newBuilder().build()
}

Expand Down Expand Up @@ -146,14 +179,14 @@ object DataprocSubmitter {
val submitterConf = SubmitterConf(
projectId,
region,
clusterName,
chrononJarUri,
"ai.chronon.spark.Driver"
clusterName
)

val a = DataprocSubmitter(submitterConf)

val jobId = a.submit(
TypeSparkJob,
Map(MainClass -> "ai.chronon.spark.Driver", JarURI -> chrononJarUri),
gcsFiles.toList,
userArgs: _*
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package ai.chronon.integrations.cloud_gcp

import ai.chronon.spark
import ai.chronon.spark.JobSubmitterConstants.FlinkMainJarURI
import ai.chronon.spark.JobSubmitterConstants.JarURI
import ai.chronon.spark.JobSubmitterConstants.MainClass
import com.google.api.gax.rpc.UnaryCallable
import com.google.cloud.dataproc.v1._
import com.google.cloud.dataproc.v1.stub.JobControllerStub
Expand Down Expand Up @@ -37,21 +41,39 @@ class DataprocSubmitterTest extends AnyFunSuite with MockitoSugar {

val submitter = new DataprocSubmitter(
mockJobControllerClient,
SubmitterConf("test-project", "test-region", "test-cluster", "test-jar-uri", "test-main-class"))
SubmitterConf("test-project", "test-region", "test-cluster"))

val submittedJobId = submitter.submit(List.empty)
val submittedJobId = submitter.submit(spark.SparkJob, Map(MainClass -> "test-main-class", JarURI -> "test-jar-uri"), List.empty)
assertEquals(submittedJobId, jobId)
}

test("Verify classpath with spark-bigquery-connector") {
BigQueryUtilScala.validateScalaVersionCompatibility()
}

ignore("test flink job locally") {
val submitter = DataprocSubmitter()
val submittedJobId =
submitter.submit(spark.FlinkJob,
Map(MainClass -> "ai.chronon.flink.FlinkJob",
FlinkMainJarURI -> "gs://zipline-jars/flink-assembly-0.1.0-SNAPSHOT.jar",
JarURI -> "gs://zipline-jars/cloud_gcp_bigtable.jar"),
List.empty,
"--online-class=ai.chronon.integrations.cloud_gcp.GcpApiImpl",
"--groupby-name=e2e-count",
"-ZGCP_PROJECT_ID=bigtable-project-id",
"-ZGCP_INSTANCE_ID=bigtable-instance-id")
println(submittedJobId)
}

ignore("Used to iterate locally. Do not enable this in CI/CD!") {

val submitter = DataprocSubmitter()
val submittedJobId =
submitter.submit(
spark.SparkJob,
Map(MainClass -> "ai.chronon.spark.Driver",
JarURI -> "gs://zipline-jars/cloud_gcp-assembly-0.1.0-SNAPSHOT.jar"),
List("gs://zipline-jars/training_set.v1",
"gs://zipline-jars/dataproc-submitter-conf.yaml",
"gs://zipline-jars/additional-confs.yaml"),
Expand All @@ -67,7 +89,11 @@ class DataprocSubmitterTest extends AnyFunSuite with MockitoSugar {

val submitter = DataprocSubmitter()
val submittedJobId =
submitter.submit(List.empty,
submitter.submit(
spark.SparkJob,
Map(MainClass -> "ai.chronon.spark.Driver",
JarURI -> "gs://zipline-jars/cloud_gcp-assembly-0.1.0-SNAPSHOT.jar"),
List.empty,
"groupby-upload-bulk-load",
"-ZGCP_PROJECT_ID=bigtable-project-id",
"-ZGCP_INSTANCE_ID=bigtable-instance-id",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class AsyncKVStoreWriter(onlineImpl: Api, featureGroupName: String)
// in the KVStore - we log the exception and skip the object to
// not fail the app
errorCounter.inc()
logger.error(s"Caught exception writing to KVStore for object: $input - $exception")
logger.error(s"Caught exception writing to KVStore for object: $input", exception)
resultFuture.complete(util.Arrays.asList[WriteResponse](WriteResponse(input, status = false)))
}
}
Expand Down
2 changes: 1 addition & 1 deletion flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
case e: Exception =>
// To improve availability, we don't rethrow the exception. We just drop the event
// and track the errors in a metric. Alerts should be set up on this metric.
logger.error(s"Error converting to Avro bytes - $e")
logger.error("Error converting to Avro bytes", e)
eventProcessingErrorCounter.inc()
avroConversionErrorCounter.inc()
}
Expand Down
57 changes: 57 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import ai.chronon.flink.window.FlinkRowAggProcessFunction
import ai.chronon.flink.window.FlinkRowAggregationFunction
import ai.chronon.flink.window.KeySelector
import ai.chronon.flink.window.TimestampedTile
import ai.chronon.online.Api
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.online.KVStore.PutRequest
import ai.chronon.online.SparkConversions
Expand All @@ -22,6 +23,9 @@ import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.apache.spark.sql.Encoder
import org.rogach.scallop.ScallopConf
import org.rogach.scallop.ScallopOption
import org.rogach.scallop.Serialization
import org.slf4j.LoggerFactory

/**
Expand Down Expand Up @@ -196,3 +200,56 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
)
}
}

object FlinkJob {
// Pull in the Serialization trait to sidestep: https://github.com/scallop/scallop/issues/137
class JobArgs(args: Seq[String]) extends ScallopConf(args) with Serialization {
val onlineClass: ScallopOption[String] =
opt[String](required = true,
descr = "Fully qualified Online.Api based class. We expect the jar to be on the class path")
val groupbyName: ScallopOption[String] =
opt[String](required = true, descr = "The name of the groupBy to process")
val mockSource: ScallopOption[Boolean] =
opt[Boolean](required = false, descr = "Use a mocked data source instead of a real source", default = Some(true))

val apiProps: Map[String, String] = props[String]('Z', descr = "Props to configure API / KV Store")

verify()
}

def main(args: Array[String]): Unit = {
val jobArgs = new JobArgs(args)
jobArgs.groupbyName()
val onlineClassName = jobArgs.onlineClass()
val props = jobArgs.apiProps.map(identity)
val useMockedSource = jobArgs.mockSource()

val api = buildApi(onlineClassName, props)
val flinkJob =
if (useMockedSource) {
// We will yank this conditional block when we wire up our real sources etc.
TestFlinkJob.buildTestFlinkJob(api)
} else {
// TODO - what we need to do when we wire this up for real
// lookup groupByServingInfo by groupByName from the kv store
// based on the topic type (e.g. kafka / pubsub) and the schema class name:
// 1. lookup schema object using SchemaProvider (e.g SchemaRegistry / Jar based)
// 2. Create the appropriate Encoder for the given schema type
// 3. Invoke the appropriate source provider to get the source, encoder, parallelism
throw new IllegalArgumentException("We don't support non-mocked sources like Kafka / PubSub yet!")
}

val env = StreamExecutionEnvironment.getExecutionEnvironment
// TODO add useful configs
flinkJob.runGroupByJob(env).addSink(new PrintSink) // TODO wire up a metrics sink / such
env.execute(s"${flinkJob.groupByName}")
Comment on lines +242 to +245
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add essential Flink configurations.

Missing critical Flink settings:

  • Checkpointing
  • Restart strategy
  • State backend
  • Watermark strategy

}

def buildApi(onlineClass: String, props: Map[String, String]): Api = {
val cl = Thread.currentThread().getContextClassLoader // Use Flink's classloader
val cls = cl.loadClass(onlineClass)
val constructor = cls.getConstructors.apply(0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Constructor lookup could be more robust.

Using apply(0) assumes only one constructor exists.

-    val constructor = cls.getConstructors.apply(0)
+    val constructor = cls.getConstructors.find(_.getParameterCount == 1)
+      .getOrElse(throw new IllegalArgumentException(s"No suitable constructor found for $onlineClass"))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
val constructor = cls.getConstructors.apply(0)
val constructor = cls.getConstructors.find(_.getParameterCount == 1)
.getOrElse(throw new IllegalArgumentException(s"No suitable constructor found for $onlineClass"))

val onlineImpl = constructor.newInstance(props)
onlineImpl.asInstanceOf[Api]
}
Comment on lines +248 to +254
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for class loading.

Wrap class loading operations in try-catch to handle ClassNotFoundException and InstantiationException.

 def buildApi(onlineClass: String, props: Map[String, String]): Api = {
+  try {
     val cl = Thread.currentThread().getContextClassLoader
     val cls = cl.loadClass(onlineClass)
     val constructor = cls.getConstructors.apply(0)
     val onlineImpl = constructor.newInstance(props)
     onlineImpl.asInstanceOf[Api]
+  } catch {
+    case e: ClassNotFoundException => 
+      throw new IllegalArgumentException(s"Class $onlineClass not found", e)
+    case e: InstantiationException => 
+      throw new IllegalArgumentException(s"Failed to instantiate $onlineClass", e)
+  }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def buildApi(onlineClass: String, props: Map[String, String]): Api = {
val cl = Thread.currentThread().getContextClassLoader // Use Flink's classloader
val cls = cl.loadClass(onlineClass)
val constructor = cls.getConstructors.apply(0)
val onlineImpl = constructor.newInstance(props)
onlineImpl.asInstanceOf[Api]
}
def buildApi(onlineClass: String, props: Map[String, String]): Api = {
try {
val cl = Thread.currentThread().getContextClassLoader // Use Flink's classloader
val cls = cl.loadClass(onlineClass)
val constructor = cls.getConstructors.apply(0)
val onlineImpl = constructor.newInstance(props)
onlineImpl.asInstanceOf[Api]
} catch {
case e: ClassNotFoundException =>
throw new IllegalArgumentException(s"Class $onlineClass not found", e)
case e: InstantiationException =>
throw new IllegalArgumentException(s"Failed to instantiate $onlineClass", e)
}
}

}
23 changes: 23 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/SourceProvider.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package ai.chronon.flink

import ai.chronon.online.GroupByServingInfoParsed
import org.apache.spark.sql.Encoder

/**
* SourceProvider is an abstract class that provides a way to build a source for a Flink job.
* It takes the groupByServingInfo as an argument and based on the configured GB details, configures
* the Flink source (e.g. Kafka or PubSub) with the right parallelism etc.
*/
abstract class SourceProvider[T](maybeGroupByServingInfoParsed: Option[GroupByServingInfoParsed]) {
// Returns a tuple of the source, parallelism
def buildSource(): (FlinkSource[T], Int)
}

/**
* EncoderProvider is an abstract class that provides a way to build an Spark encoder for a Flink job.
* These encoders are used in the SparkExprEval Flink function to convert the incoming stream into types
* that are amenable for tiled / untiled processing.
*/
abstract class EncoderProvider[T] {
def buildEncoder(): Encoder[T]
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends Ri
case e: Exception =>
// To improve availability, we don't rethrow the exception. We just drop the event
// and track the errors in a metric. Alerts should be set up on this metric.
logger.error(s"Error evaluating Spark expression - $e")
logger.error("Error evaluating Spark expression", e)
exprEvalErrorCounter.inc()
}
}
Expand Down
Loading
Loading