Skip to content

Commit d0b3a34

Browse files
committed
Connect GroupByUploadToKVBulkLoad from Driver.scala to run.py
1 parent 6369d55 commit d0b3a34

File tree

7 files changed

+214
-107
lines changed

7 files changed

+214
-107
lines changed

api/py/ai/chronon/repo/run.py

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"backfill-left",
4848
"backfill-final",
4949
"upload",
50+
"upload-to-kv",
5051
"streaming",
5152
"streaming-client",
5253
"consistency-metrics-compute",
@@ -62,13 +63,15 @@
6263

6364
# Constants for supporting multiple spark versions.
6465
SUPPORTED_SPARK = ["2.4.0", "3.1.1", "3.2.1", "3.5.1"]
65-
SCALA_VERSION_FOR_SPARK = {"2.4.0": "2.11", "3.1.1": "2.12", "3.2.1": "2.13", "3.5.1": "2.12"}
66+
SCALA_VERSION_FOR_SPARK = {"2.4.0": "2.11",
67+
"3.1.1": "2.12", "3.2.1": "2.13", "3.5.1": "2.12"}
6668

6769
MODE_ARGS = {
6870
"backfill": OFFLINE_ARGS,
6971
"backfill-left": OFFLINE_ARGS,
7072
"backfill-final": OFFLINE_ARGS,
7173
"upload": OFFLINE_ARGS,
74+
"upload-to-kv": ONLINE_WRITE_ARGS,
7275
"stats-summary": OFFLINE_ARGS,
7376
"log-summary": OFFLINE_ARGS,
7477
"analyze": OFFLINE_ARGS,
@@ -88,6 +91,7 @@
8891
ROUTES = {
8992
"group_bys": {
9093
"upload": "group-by-upload",
94+
"upload-to-kv": "groupby-upload-bulk-load",
9195
"backfill": "group-by-backfill",
9296
"streaming": "group-by-streaming",
9397
"metadata-upload": "metadata-upload",
@@ -123,7 +127,10 @@
123127
APP_NAME_TEMPLATE = "chronon_{conf_type}_{mode}_{context}_{name}"
124128
RENDER_INFO_DEFAULT_SCRIPT = "scripts/render_info.py"
125129

130+
# GCP DATAPROC SPECIFIC CONSTANTS
126131
DATAPROC_ENTRY = "ai.chronon.integrations.cloud_gcp.DataprocSubmitter"
132+
ZIPLINE_ONLINE_JAR_DEFAULT = "cloud_gcp-assembly-0.1.0-SNAPSHOT.jar"
133+
ZIPLINE_ONLINE_CLASS_DEFAULT = "ai.chronon.integrations.cloud_gcp.GcpApiImpl"
127134

128135

129136
def retry_decorator(retries=3, backoff=20):
@@ -175,7 +182,8 @@ def download_only_once(url, path, skip_download=False):
175182
path = path.strip()
176183
if os.path.exists(path):
177184
content_output = check_output("curl -sI " + url).decode("utf-8")
178-
content_length = re.search("(content-length:\\s)(\\d+)", content_output.lower())
185+
content_length = re.search(
186+
"(content-length:\\s)(\\d+)", content_output.lower())
179187
remote_size = int(content_length.group().split()[-1])
180188
local_size = int(check_output("wc -c " + path).split()[0])
181189
print(
@@ -189,7 +197,8 @@ def download_only_once(url, path, skip_download=False):
189197
print("Sizes match. Assuming it's already downloaded.")
190198
should_download = False
191199
if should_download:
192-
print("Different file from remote at local: " + path + ". Re-downloading..")
200+
print("Different file from remote at local: " +
201+
path + ". Re-downloading..")
193202
check_call("curl {} -o {} --connect-timeout 10".format(url, path))
194203
else:
195204
print("No file at: " + path + ". Downloading..")
@@ -213,7 +222,8 @@ def download_jar(
213222
"https://s01.oss.sonatype.org/service/local/repositories/public/content"
214223
)
215224
url_prefix = maven_url_prefix if maven_url_prefix else default_url_prefix
216-
base_url = "{}/ai/chronon/spark_{}_{}".format(url_prefix, jar_type, scala_version)
225+
base_url = "{}/ai/chronon/spark_{}_{}".format(
226+
url_prefix, jar_type, scala_version)
217227
print("Downloading jar from url: " + base_url)
218228
jar_path = os.environ.get("CHRONON_DRIVER_JAR", None)
219229
if jar_path is None:
@@ -241,11 +251,15 @@ def download_jar(
241251
scala_version=scala_version,
242252
jar_type=jar_type,
243253
)
244-
jar_path = os.path.join("/tmp", jar_url.split("/")[-1])
254+
jar_path = os.path.join("/tmp", extract_filename_from_path(jar_url))
245255
download_only_once(jar_url, jar_path, skip_download)
246256
return jar_path
247257

248258

259+
def get_teams_json_file_path(repo_path):
260+
return os.path.join(repo_path, "teams.json")
261+
262+
249263
def set_runtime_env(params):
250264
"""
251265
Setting the runtime environment variables.
@@ -276,10 +290,11 @@ def set_runtime_env(params):
276290
if effective_mode and "streaming" in effective_mode:
277291
effective_mode = "streaming"
278292
if params["repo"]:
279-
teams_file = os.path.join(params["repo"], "teams.json")
293+
teams_file = get_teams_json_file_path(params["repo"])
280294
if os.path.exists(teams_file):
281295
with open(teams_file, "r") as infile:
282296
teams_json = json.load(infile)
297+
# we should have a fallback if user wants to set to something else `default`
283298
environment["common_env"] = teams_json.get("default", {}).get(
284299
"common_env", {}
285300
)
@@ -320,7 +335,8 @@ def set_runtime_env(params):
320335
"backfill-final",
321336
]:
322337
environment["conf_env"]["CHRONON_CONFIG_ADDITIONAL_ARGS"] = (
323-
" ".join(custom_json(conf_json).get("additional_args", []))
338+
" ".join(custom_json(conf_json).get(
339+
"additional_args", []))
324340
)
325341
environment["cli_args"]["APP_NAME"] = APP_NAME_TEMPLATE.format(
326342
mode=effective_mode,
@@ -333,7 +349,8 @@ def set_runtime_env(params):
333349
)
334350
# fall-back to prod env even in dev mode when dev env is undefined.
335351
environment["production_team_env"] = (
336-
teams_json[team].get("production", {}).get(effective_mode, {})
352+
teams_json[team].get("production", {}).get(
353+
effective_mode, {})
337354
)
338355
# By default use production env.
339356
environment["default_env"] = (
@@ -354,7 +371,8 @@ def set_runtime_env(params):
354371
for k in [
355372
"chronon",
356373
conf_type,
357-
params["mode"].replace("-", "_") if params["mode"] else None,
374+
params["mode"].replace(
375+
"-", "_") if params["mode"] else None,
358376
]
359377
if k is not None
360378
]
@@ -402,15 +420,17 @@ def __init__(self, args, jar_path):
402420

403421
if self.conf:
404422
try:
405-
self.context, self.conf_type, self.team, _ = self.conf.split("/")[-4:]
423+
self.context, self.conf_type, self.team, _ = self.conf.split(
424+
"/")[-4:]
406425
except Exception as e:
407426
logging.error(
408427
"Invalid conf path: {}, please ensure to supply the relative path to zipline/ folder".format(
409428
self.conf
410429
)
411430
)
412431
raise e
413-
possible_modes = list(ROUTES[self.conf_type].keys()) + UNIVERSAL_ROUTES
432+
possible_modes = list(
433+
ROUTES[self.conf_type].keys()) + UNIVERSAL_ROUTES
414434
assert (
415435
args["mode"] in possible_modes
416436
), "Invalid mode:{} for conf:{} of type:{}, please choose from {}".format(
@@ -520,8 +540,6 @@ def run(self):
520540
)
521541
command_list.append(command)
522542
else:
523-
# offline mode
524-
525543
# we'll always download the jar for now so that we can pull
526544
# in any fixes or latest changes
527545
dataproc_jar = download_dataproc_jar(temp_dir,
@@ -544,7 +562,8 @@ def run(self):
544562
script=self.spark_submit,
545563
jar=self.jar_path,
546564
subcommand=ROUTES[self.conf_type][self.mode],
547-
args=self._gen_final_args(start_ds=start_ds, end_ds=end_ds),
565+
args=self._gen_final_args(
566+
start_ds=start_ds, end_ds=end_ds),
548567
additional_args=os.environ.get(
549568
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
550569
),
@@ -563,11 +582,19 @@ def run(self):
563582
# when we include the gcs file path as part of dataproc,
564583
# the file is copied to root and not the complete path
565584
# is copied.
566-
override_conf_path=self.conf.split("/")[-1]),
585+
override_conf_path=extract_filename_from_path(
586+
self.conf) if self.conf else None),
567587
additional_args=os.environ.get(
568588
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
569589
),
570590
)
591+
local_files_to_upload_to_gcs = []
592+
if self.conf:
593+
local_files_to_upload_to_gcs.append(
594+
self.conf)
595+
# upload teams.json to gcs
596+
local_files_to_upload_to_gcs.append(
597+
get_teams_json_file_path(self.repo))
571598

572599
dataproc_command = generate_dataproc_submitter_args(
573600
local_files_to_upload_to_gcs=[self.conf],
@@ -603,19 +630,28 @@ def run(self):
603630
# does get reflected on GCS. But when we include the gcs file
604631
# path as part of dataproc, the file is copied to root and
605632
# not the complete path is copied.
606-
override_conf_path=self.conf.split("/")[-1]),
633+
override_conf_path=extract_filename_from_path(
634+
self.conf) if self.conf else None),
607635
additional_args=os.environ.get(
608636
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
609637
),
610638
)
639+
local_files_to_upload_to_gcs = []
640+
if self.conf:
641+
local_files_to_upload_to_gcs.append(self.conf)
642+
643+
# upload teams.json to gcs
644+
local_files_to_upload_to_gcs.append(
645+
get_teams_json_file_path(self.repo))
611646

612647
dataproc_command = generate_dataproc_submitter_args(
613648
# for now, self.conf is the only local file that requires uploading to gcs
614-
local_files_to_upload_to_gcs=[self.conf],
649+
local_files_to_upload_to_gcs=local_files_to_upload_to_gcs,
615650
user_args=user_args
616651
)
617652
command = f"java -cp {dataproc_jar} {DATAPROC_ENTRY} {dataproc_command}"
618653
command_list.append(command)
654+
619655
if len(command_list) > 1:
620656
# parallel backfill mode
621657
with multiprocessing.Pool(processes=int(self.parallelism)) as pool:
@@ -632,16 +668,23 @@ def _gen_final_args(self, start_ds=None, end_ds=None, override_conf_path=None):
632668
base_args = MODE_ARGS[self.mode].format(
633669
conf_path=override_conf_path if override_conf_path else self.conf,
634670
ds=end_ds if end_ds else self.ds,
635-
online_jar=self.online_jar,
636-
online_class=self.online_class,
671+
online_jar=self.online_jar if not self.dataproc else ZIPLINE_ONLINE_JAR_DEFAULT,
672+
online_class=self.online_class if not self.dataproc else ZIPLINE_ONLINE_CLASS_DEFAULT,
637673
)
638674
override_start_partition_arg = (
639675
" --start-partition-override=" + start_ds if start_ds else ""
640676
)
641-
final_args = base_args + " " + str(self.args) + override_start_partition_arg
677+
678+
final_args = base_args + " " + \
679+
str(self.args) + override_start_partition_arg
680+
642681
return final_args
643682

644683

684+
def extract_filename_from_path(path):
685+
return path.split("/")[-1]
686+
687+
645688
def split_date_range(start_date, end_date, parallelism):
646689
start_date = datetime.strptime(start_date, "%Y-%m-%d")
647690
end_date = datetime.strptime(end_date, "%Y-%m-%d")
@@ -653,7 +696,8 @@ def split_date_range(start_date, end_date, parallelism):
653696

654697
# Check if parallelism is greater than total_days
655698
if parallelism > total_days:
656-
raise ValueError("Parallelism should be less than or equal to total days")
699+
raise ValueError(
700+
"Parallelism should be less than or equal to total days")
657701

658702
split_size = total_days // parallelism
659703
date_ranges = []
@@ -710,24 +754,27 @@ def get_customer_id() -> str:
710754
def get_gcp_project_id() -> str:
711755
gcp_project_id = os.environ.get('ZIPLINE_GCP_PROJECT_ID')
712756
if not gcp_project_id:
713-
raise ValueError('Please set ZIPLINE_GCP_PROJECT_ID environment variable')
757+
raise ValueError(
758+
'Please set ZIPLINE_GCP_PROJECT_ID environment variable')
714759
return gcp_project_id
715760

716761

717762
def generate_dataproc_submitter_args(local_files_to_upload_to_gcs: List[str], user_args: str):
718763
customer_warehouse_bucket_name = f"zipline-warehouse-{get_customer_id()}"
719764

720765
gcs_files = []
721-
for f in local_files_to_upload_to_gcs:
766+
for source_file in local_files_to_upload_to_gcs:
722767
# upload to `metadata` folder
723-
destination_file_path = f"metadata/{f}"
724-
gcs_files.append(upload_gcs_blob(customer_warehouse_bucket_name, f, destination_file_path))
768+
destination_file_path = f"metadata/{extract_filename_from_path(source_file)}"
769+
gcs_files.append(upload_gcs_blob(
770+
customer_warehouse_bucket_name, source_file, destination_file_path))
725771

726772
# we also want the additional-confs included here. it should already be in the bucket
727773

728774
zipline_artifacts_bucket_prefix = 'gs://zipline-artifacts'
729775

730-
gcs_files.append(f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}/confs/additional-confs.yaml")
776+
gcs_files.append(
777+
f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}/confs/additional-confs.yaml")
731778

732779
gcs_file_args = ",".join(gcs_files)
733780

@@ -750,7 +797,8 @@ def download_dataproc_jar(destination_dir: str, customer_id: str):
750797
source_blob_name = f"jars/{file_name}"
751798
dataproc_jar_destination_path = f"{destination_dir}/{file_name}"
752799

753-
download_gcs_blob(bucket_name, source_blob_name, dataproc_jar_destination_path)
800+
download_gcs_blob(bucket_name, source_blob_name,
801+
dataproc_jar_destination_path)
754802
return dataproc_jar_destination_path
755803

756804

0 commit comments

Comments
 (0)