Skip to content

Commit 896e2fc

Browse files
david-zlainikhil-zlaitchow-zlai
authored
Use execution spark configs from compiled confs in JobSubmitter (#549)
## Summary ## Cheour clientslist - [ ] Added Unit Tests - [x] Covered by existing CI - [x] Integration tested - [ ] Documentation update Can see the spark properties set in the Configuration tab of this dataproc job: https://console.cloud.google.com/dataproc/jobs/3878fd24-6ca8-41df-8507-ecc7adcead91/configuration?region=us-central1&invt=Abt0gg&project=canary-443022 - removed additional-confs.yaml integration <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced configuration metadata with improved categorization. - Introduced dedicated submission properties for more robust cloud job configuration. - Added a new optional field `confType` to the metadata structure. - Introduced a new enumeration `ConfType` for configuration types. - **Refactor** - Streamlined job submission workflows and command argument handling across platforms. - Simplified Spark session setup by removing extraneous configuration file processing. - **Chores** - Updated logging dependencies for the cloud GCP target. - Modified artifact upload paths to streamline deployment processes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Nikhil Simha <[email protected]> Co-authored-by: Thomas Chow <[email protected]>
1 parent e23a096 commit 896e2fc

File tree

15 files changed

+258
-142
lines changed

15 files changed

+258
-142
lines changed

api/python/ai/chronon/cli/compile/parse_configs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import os
55
from typing import List
66

7+
from ai.chronon.api.common.ttypes import ConfType
8+
from ai.chronon.api.ttypes import GroupBy, Join, Model, StagingQuery
79
from ai.chronon.cli.compile import parse_teams, serializer
810
from ai.chronon.cli.compile.compile_context import CompileContext
911
from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
@@ -24,6 +26,16 @@ def from_folder(
2426

2527
results = []
2628

29+
conf_type = None
30+
if cls == GroupBy:
31+
conf_type = ConfType.GROUP_BYS
32+
elif cls == Join:
33+
conf_type = ConfType.JOINS
34+
elif cls == Model:
35+
conf_type = ConfType.MODELS
36+
elif cls == StagingQuery:
37+
conf_type = ConfType.STAGING_QUERIES
38+
2739
for f in python_files:
2840

2941
try:
@@ -32,6 +44,7 @@ def from_folder(
3244
for name, obj in results_dict.items():
3345
parse_teams.update_metadata(obj, compile_context.teams_dict)
3446
obj.metaData.sourceFile = f
47+
obj.metaData.confType = conf_type
3548

3649
tjson = serializer.thrift_simple_json(obj)
3750

api/python/ai/chronon/repo/aws.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ def generate_emr_submitter_args(
194194
job_type=job_type.value,
195195
main_class=main_class,
196196
)
197-
+ f" --additional-conf-path={EMR_MOUNT_FILE_PREFIX}additional-confs.yaml --files={s3_file_args}"
197+
+ f" --additional-conf-path={EMR_MOUNT_FILE_PREFIX}additional-confs.yaml"
198+
f" --files={s3_file_args}"
198199
)
199200
else:
200201
raise ValueError(f"Invalid job type: {job_type}")

api/python/ai/chronon/repo/default_runner.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,23 @@
1212
ROUTES,
1313
SPARK_MODES,
1414
UNIVERSAL_ROUTES,
15+
RunMode,
1516
)
1617

1718

1819
class Runner:
1920
def __init__(self, args, jar_path):
2021
self.repo = args["repo"]
2122
self.conf = args["conf"]
23+
self.local_abs_conf_path = os.path.realpath(os.path.join(self.repo, self.conf))
2224
self.sub_help = args["sub_help"]
2325
self.mode = args["mode"]
2426
self.online_jar = args.get(ONLINE_JAR_ARG)
2527
self.online_class = args.get(ONLINE_CLASS_ARG)
2628

27-
self.conf_type = args.get("conf_type", "").replace("-", "_") # in case user sets dash instead of underscore
29+
self.conf_type = args.get("conf_type", "").replace(
30+
"-", "_"
31+
) # in case user sets dash instead of underscore
2832

2933
# streaming flink
3034
self.groupby_name = args.get("groupby_name")
@@ -37,32 +41,35 @@ def __init__(self, args, jar_path):
3741
valid_jar = args["online_jar"] and os.path.exists(args["online_jar"])
3842

3943
# fetch online jar if necessary
40-
if (self.mode in ONLINE_MODES) and (not args["sub_help"]) and not valid_jar and (
41-
args.get("online_jar_fetch")):
44+
if (
45+
(self.mode in ONLINE_MODES)
46+
and (not args["sub_help"])
47+
and not valid_jar
48+
and (args.get("online_jar_fetch"))
49+
):
4250
print("Downloading online_jar")
43-
self.online_jar = utils.cheour clients_output("{}".format(args["online_jar_fetch"])).decode(
44-
"utf-8"
45-
)
51+
self.online_jar = utils.cheour clients_output(
52+
"{}".format(args["online_jar_fetch"])
53+
).decode("utf-8")
4654
os.environ["CHRONON_ONLINE_JAR"] = self.online_jar
4755
print("Downloaded jar to {}".format(self.online_jar))
4856

4957
if self.conf:
5058
try:
51-
self.context, self.conf_type, self.team, _ = self.conf.split(
52-
"/")[-4:]
59+
self.context, self.conf_type, self.team, _ = self.conf.split("/")[-4:]
5360
except Exception as e:
5461
logging.error(
5562
"Invalid conf path: {}, please ensure to supply the relative path to zipline/ folder".format(
5663
self.conf
5764
)
5865
)
5966
raise e
60-
possible_modes = list(
61-
ROUTES[self.conf_type].keys()) + UNIVERSAL_ROUTES
67+
possible_modes = list(ROUTES[self.conf_type].keys()) + UNIVERSAL_ROUTES
6268
assert (
63-
args["mode"] in possible_modes), ("Invalid mode:{} for conf:{} of type:{}, please choose from {}"
64-
.format(args["mode"], self.conf, self.conf_type, possible_modes
65-
))
69+
args["mode"] in possible_modes
70+
), "Invalid mode:{} for conf:{} of type:{}, please choose from {}".format(
71+
args["mode"], self.conf, self.conf_type, possible_modes
72+
)
6673

6774
self.ds = args["end_ds"] if "end_ds" in args and args["end_ds"] else args["ds"]
6875
self.start_ds = (
@@ -124,7 +131,9 @@ def run_spark_streaming(self):
124131
)
125132
)
126133
if self.mode == "streaming":
127-
assert (len(filtered_apps) == 1), "More than one found, please kill them all"
134+
assert (
135+
len(filtered_apps) == 1
136+
), "More than one found, please kill them all"
128137
print("All good. No need to start a new app.")
129138
return
130139
elif self.mode == "streaming-client":
@@ -139,9 +148,7 @@ def run_spark_streaming(self):
139148
jar=self.jar_path,
140149
subcommand=ROUTES[self.conf_type][self.mode],
141150
args=self._gen_final_args(),
142-
additional_args=os.environ.get(
143-
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
144-
),
151+
additional_args=os.environ.get("CHRONON_CONFIG_ADDITIONAL_ARGS", ""),
145152
)
146153
return command
147154

@@ -182,23 +189,22 @@ def run(self):
182189
)
183190
for start_ds, end_ds in date_ranges:
184191
command = (
185-
"bash {script} --class ai.chronon.spark.Driver " +
186-
"{jar} {subcommand} {args} {additional_args}"
192+
"bash {script} --class ai.chronon.spark.Driver "
193+
+ "{jar} {subcommand} {args} {additional_args}"
187194
).format(
188195
script=self.spark_submit,
189196
jar=self.jar_path,
190197
subcommand=ROUTES[self.conf_type][self.mode],
191-
args=self._gen_final_args(
192-
start_ds=start_ds, end_ds=end_ds),
198+
args=self._gen_final_args(start_ds=start_ds, end_ds=end_ds),
193199
additional_args=os.environ.get(
194200
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
195201
),
196202
)
197203
command_list.append(command)
198204
else:
199205
command = (
200-
"bash {script} --class ai.chronon.spark.Driver "
201-
+ "{jar} {subcommand} {args} {additional_args}"
206+
"bash {script} --class ai.chronon.spark.Driver "
207+
+ "{jar} {subcommand} {args} {additional_args}"
202208
).format(
203209
script=self.spark_submit,
204210
jar=self.jar_path,
@@ -222,21 +228,39 @@ def run(self):
222228
elif len(command_list) == 1:
223229
utils.cheour clients_call(command_list[0])
224230

225-
def _gen_final_args(self, start_ds=None, end_ds=None, override_conf_path=None, **kwargs):
231+
def _gen_final_args(
232+
self, start_ds=None, end_ds=None, override_conf_path=None, **kwargs
233+
):
226234
base_args = MODE_ARGS[self.mode].format(
227235
conf_path=override_conf_path if override_conf_path else self.conf,
228236
ds=end_ds if end_ds else self.ds,
229237
online_jar=self.online_jar,
230-
online_class=self.online_class
238+
online_class=self.online_class,
231239
)
232-
base_args = base_args + f" --conf-type={self.conf_type} " if self.conf_type else base_args
240+
241+
base_args = (
242+
base_args + f" --conf-type={self.conf_type} "
243+
if self.conf_type
244+
else base_args
245+
)
246+
247+
if self.mode != RunMode.FETCH:
248+
base_args += " --local-conf-path={conf}".format(
249+
conf=self.local_abs_conf_path
250+
) + " --original-mode={mode}".format(mode=self.mode)
233251

234252
override_start_partition_arg = (
235253
"--start-partition-override=" + start_ds if start_ds else ""
236254
)
237255

238-
additional_args = " ".join(f"--{key.replace('_', '-')}={value}" for key, value in kwargs.items() if value)
256+
additional_args = " ".join(
257+
f"--{key.replace('_', '-')}={value}"
258+
for key, value in kwargs.items()
259+
if value
260+
)
239261

240-
final_args = " ".join([base_args, str(self.args), override_start_partition_arg, additional_args])
262+
final_args = " ".join(
263+
[base_args, str(self.args), override_start_partition_arg, additional_args]
264+
)
241265

242266
return final_args

api/python/ai/chronon/repo/gcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ def generate_dataproc_submitter_args(
260260
jar_uri=jar_uri,
261261
job_type=job_type.value,
262262
main_class=main_class,
263-
)
264-
+ f" --additional-conf-path=additional-confs.yaml --gcs-files={gcs_file_args}"
263+
) + f" --files={gcs_file_args}"
264+
265265
)
266266
else:
267267
raise ValueError(f"Invalid job type: {job_type}")

api/thrift/api.thrift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ struct MetaData {
290290

291291
# information that needs to be present on every physical node
292292
204: optional common.ExecutionInfo executionInfo
293+
294+
205: optional common.ConfType confType
293295
}
294296

295297
// Equivalent to a FeatureSet in chronon terms

api/thrift/common.thrift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,11 @@ struct ExecutionInfo {
132132
# note that batch jobs could in theory also depend on model training runs
133133
# in which case we will be polling
134134
# in the future we will add other types of dependencies
135+
}
136+
137+
enum ConfType {
138+
JOINS = 0
139+
GROUP_BYS = 1
140+
MODELS = 2
141+
STAGING_QUERIES = 3
135142
}

cloud_aws/src/main/scala/ai/chronon/integrations/aws/EmrSubmitter.scala

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,23 @@ class EmrSubmitter(customerId: String, emrClient: EmrClient) extends JobSubmitte
158158
}
159159

160160
override def submit(jobType: JobType,
161+
submissionProperties: Map[String, String],
161162
jobProperties: Map[String, String],
162163
files: List[String],
163164
args: String*): String = {
164-
if (jobProperties.get(ShouldCreateCluster).exists(_.toBoolean)) {
165+
if (submissionProperties.get(ShouldCreateCluster).exists(_.toBoolean)) {
165166
// create cluster
166167
val runJobFlowBuilder = createClusterRequestBuilder(
167-
emrReleaseLabel = jobProperties.getOrElse(EmrReleaseLabel, DefaultEmrReleaseLabel),
168-
clusterIdleTimeout = jobProperties.getOrElse(ClusterIdleTimeout, DefaultClusterIdleTimeout.toString).toInt,
169-
masterInstanceType = jobProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType),
170-
slaveInstanceType = jobProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType),
171-
instanceCount = jobProperties.getOrElse(ClusterInstanceCount, DefaultClusterInstanceCount.toString).toInt
168+
emrReleaseLabel = submissionProperties.getOrElse(EmrReleaseLabel, DefaultEmrReleaseLabel),
169+
clusterIdleTimeout =
170+
submissionProperties.getOrElse(ClusterIdleTimeout, DefaultClusterIdleTimeout.toString).toInt,
171+
masterInstanceType = submissionProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType),
172+
slaveInstanceType = submissionProperties.getOrElse(ClusterInstanceType, DefaultClusterInstanceType),
173+
instanceCount = submissionProperties.getOrElse(ClusterInstanceCount, DefaultClusterInstanceCount.toString).toInt
172174
)
173175

174-
runJobFlowBuilder.steps(createStepConfig(files, jobProperties(MainClass), jobProperties(JarURI), args: _*))
176+
runJobFlowBuilder.steps(
177+
createStepConfig(files, submissionProperties(MainClass), submissionProperties(JarURI), args: _*))
175178

176179
val responseJobId = emrClient.runJobFlow(runJobFlowBuilder.build()).jobFlowId()
177180
println("EMR job id: " + responseJobId)
@@ -181,11 +184,11 @@ class EmrSubmitter(customerId: String, emrClient: EmrClient) extends JobSubmitte
181184

182185
} else {
183186
// use existing cluster
184-
val existingJobId = jobProperties.getOrElse(ClusterId, throw new RuntimeException("JobFlowId not found"))
187+
val existingJobId = submissionProperties.getOrElse(ClusterId, throw new RuntimeException("JobFlowId not found"))
185188
val request = AddJobFlowStepsRequest
186189
.builder()
187190
.jobFlowId(existingJobId)
188-
.steps(createStepConfig(files, jobProperties(MainClass), jobProperties(JarURI), args: _*))
191+
.steps(createStepConfig(files, submissionProperties(MainClass), submissionProperties(JarURI), args: _*))
189192
.build()
190193

191194
val responseStepId = emrClient.addJobFlowSteps(request).stepIds().get(0)
@@ -230,40 +233,35 @@ object EmrSubmitter {
230233
def main(args: Array[String]): Unit = {
231234
// List of args that are not application args
232235
val internalArgs = Set(
233-
JarUriArgKeyword,
234-
JobTypeArgKeyword,
235-
MainClassKeyword,
236-
FlinkMainJarUriArgKeyword,
237-
FlinkSavepointUriArgKeyword,
238236
ClusterInstanceTypeArgKeyword,
239237
ClusterInstanceCountArgKeyword,
240238
ClusterIdleTimeoutArgKeyword,
241-
FilesArgKeyword,
242239
CreateClusterArgKeyword
243-
)
240+
) ++ SharedInternalArgs
244241

245242
val userArgs = args.filter(arg => !internalArgs.exists(arg.startsWith))
246243

247-
val jarUri =
248-
args.find(_.startsWith(JarUriArgKeyword)).map(_.split("=")(1)).getOrElse(throw new Exception("Jar URI not found"))
249-
val mainClass = args
250-
.find(_.startsWith(MainClassKeyword))
251-
.map(_.split("=")(1))
252-
.getOrElse(throw new Exception("Main class not found"))
253-
val jobTypeValue = args
254-
.find(_.startsWith(JobTypeArgKeyword))
255-
.map(_.split("=")(1))
256-
.getOrElse(throw new Exception("Job type not found"))
257-
val clusterInstanceType =
258-
args.find(_.startsWith(ClusterInstanceTypeArgKeyword)).map(_.split("=")(1)).getOrElse(DefaultClusterInstanceType)
259-
val clusterInstanceCount = args
260-
.find(_.startsWith(ClusterInstanceCountArgKeyword))
261-
.map(_.split("=")(1))
244+
val jarUri = JobSubmitter
245+
.getArgValue(args, JarUriArgKeyword)
246+
.getOrElse(throw new Exception("Missing required argument: " + JarUriArgKeyword))
247+
val mainClass = JobSubmitter
248+
.getArgValue(args, MainClassKeyword)
249+
.getOrElse(throw new Exception("Missing required argument: " + MainClassKeyword))
250+
val jobTypeValue =
251+
JobSubmitter
252+
.getArgValue(args, JobTypeArgKeyword)
253+
.getOrElse(throw new Exception("Missing required argument: " + JobTypeArgKeyword))
254+
255+
val clusterInstanceType = JobSubmitter
256+
.getArgValue(args, ClusterInstanceTypeArgKeyword)
257+
.getOrElse(DefaultClusterInstanceType)
258+
val clusterInstanceCount = JobSubmitter
259+
.getArgValue(args, ClusterInstanceCountArgKeyword)
262260
.getOrElse(DefaultClusterInstanceCount.toString)
263-
val clusterIdleTimeout = args
264-
.find(_.startsWith(ClusterIdleTimeoutArgKeyword))
265-
.map(_.split("=")(1))
261+
val clusterIdleTimeout = JobSubmitter
262+
.getArgValue(args, ClusterIdleTimeoutArgKeyword)
266263
.getOrElse(DefaultClusterIdleTimeout.toString)
264+
267265
val createCluster = args.exists(_.startsWith(CreateClusterArgKeyword))
268266

269267
val clusterId = sys.env.get("EMR_CLUSTER_ID")
@@ -278,7 +276,7 @@ object EmrSubmitter {
278276
filesArgs(0).split("=")(1).split(",")
279277
}
280278

281-
val (jobType, jobProps) = jobTypeValue.toLowerCase match {
279+
val (jobType, submissionProps) = jobTypeValue.toLowerCase match {
282280
case "spark" => {
283281
val baseProps = Map(
284282
MainClass -> mainClass,
@@ -299,13 +297,15 @@ object EmrSubmitter {
299297
case _ => throw new Exception("Invalid job type")
300298
}
301299

302-
val finalArgs = userArgs
300+
val finalArgs = userArgs.toSeq
301+
val modeConfigProperties = JobSubmitter.getModeConfigProperties(args)
303302

304303
val emrSubmitter = EmrSubmitter()
305304
emrSubmitter.submit(
306-
jobType,
307-
jobProps,
308-
files.toList,
305+
jobType = jobType,
306+
submissionProperties = submissionProps,
307+
jobProperties = modeConfigProperties.getOrElse(Map.empty),
308+
files = files.toList,
309309
finalArgs: _*
310310
)
311311
}

cloud_aws/src/main/scala/ai/chronon/integrations/aws/LivySubmitter.scala

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)