|
18 | 18 | # limitations under the License.
|
19 | 19 |
|
20 | 20 | from google.cloud import storage
|
| 21 | +import base64 |
21 | 22 | import click
|
| 23 | +import google_crc32c |
22 | 24 | import json
|
23 | 25 | import logging
|
24 | 26 | import multiprocessing
|
25 | 27 | import os
|
26 | 28 | import re
|
27 | 29 | import subprocess
|
28 |
| -import tempfile |
29 | 30 | import time
|
30 | 31 | from typing import List
|
31 | 32 | import xml.etree.ElementTree as ET
|
|
134 | 135 | ZIPLINE_ONLINE_CLASS_DEFAULT = "ai.chronon.integrations.cloud_gcp.GcpApiImpl"
|
135 | 136 | ZIPLINE_FLINK_JAR_DEFAULT = "flink_assembly_deploy.jar"
|
136 | 137 | ZIPLINE_DATAPROC_SUBMITTER_JAR = "cloud_gcp_submitter_deploy.jar"
|
| 138 | +ZIPLINE_SERVICE_JAR = "service-0.1.0-SNAPSHOT.jar" |
| 139 | + |
| 140 | +ZIPLINE_DIRECTORY = "/tmp/zipline" |
137 | 141 |
|
138 | 142 |
|
139 | 143 | class DataprocJobType(Enum):
|
@@ -861,46 +865,27 @@ def generate_dataproc_submitter_args(user_args: str, job_type: DataprocJobType =
|
861 | 865 | raise ValueError(f"Invalid job type: {job_type}")
|
862 | 866 |
|
863 | 867 |
|
864 |
| -def download_dataproc_submitter_jar(destination_dir: str, customer_id: str): |
865 |
| - print("Downloading dataproc submitter jar from GCS...") |
| 868 | +def download_zipline_jar(destination_dir: str, customer_id: str, jar_name: str): |
866 | 869 | bucket_name = f"zipline-artifacts-{customer_id}"
|
867 | 870 |
|
868 |
| - file_name = ZIPLINE_DATAPROC_SUBMITTER_JAR |
869 |
| - |
870 |
| - source_blob_name = f"jars/{file_name}" |
871 |
| - dataproc_jar_destination_path = f"{destination_dir}/{file_name}" |
872 |
| - |
873 |
| - download_gcs_blob(bucket_name, source_blob_name, |
874 |
| - dataproc_jar_destination_path) |
875 |
| - return dataproc_jar_destination_path |
876 |
| - |
877 |
| - |
878 |
| -def download_chronon_gcp_jar(destination_dir: str, customer_id: str): |
879 |
| - print("Downloading chronon gcp jar from GCS...") |
880 |
| - bucket_name = f"zipline-artifacts-{customer_id}" |
881 |
| - |
882 |
| - file_name = ZIPLINE_ONLINE_JAR_DEFAULT |
883 |
| - |
884 |
| - source_blob_name = f"jars/{file_name}" |
885 |
| - chronon_gcp_jar_destination_path = f"{destination_dir}/{file_name}" |
886 |
| - |
887 |
| - download_gcs_blob(bucket_name, source_blob_name, |
888 |
| - chronon_gcp_jar_destination_path) |
889 |
| - return chronon_gcp_jar_destination_path |
890 |
| - |
891 |
| - |
892 |
| -def download_service_jar(destination_dir: str, customer_id: str): |
893 |
| - print("Downloading service jar from GCS...") |
894 |
| - bucket_name = f"zipline-artifacts-{customer_id}" |
| 871 | + source_blob_name = f"jars/{jar_name}" |
| 872 | + destination_path = f"{destination_dir}/{jar_name}" |
895 | 873 |
|
896 |
| - file_name = "service-0.1.0-SNAPSHOT.jar" |
| 874 | + are_identical = compare_gcs_and_local_file_hashes(bucket_name, source_blob_name, |
| 875 | + destination_path) if os.path.exists( |
| 876 | + destination_path) else False |
897 | 877 |
|
898 |
| - source_blob_name = f"jars/{file_name}" |
899 |
| - service_jar_destination_path = f"{destination_dir}/{file_name}" |
| 878 | + if are_identical: |
| 879 | + print( |
| 880 | + f"{destination_path} matches GCS {bucket_name}/{source_blob_name}") |
| 881 | + else: |
| 882 | + print( |
| 883 | + f"{destination_path} does NOT match GCS {bucket_name}/{source_blob_name}") |
| 884 | + print(f"Downloading {jar_name} from GCS...") |
900 | 885 |
|
901 |
| - download_gcs_blob(bucket_name, source_blob_name, |
902 |
| - service_jar_destination_path) |
903 |
| - return service_jar_destination_path |
| 886 | + download_gcs_blob(bucket_name, source_blob_name, |
| 887 | + destination_path) |
| 888 | + return destination_path |
904 | 889 |
|
905 | 890 |
|
906 | 891 | @retry_decorator(retries=2, backoff=5)
|
@@ -938,6 +923,66 @@ def upload_gcs_blob(bucket_name, source_file_name, destination_blob_name):
|
938 | 923 | raise RuntimeError(f"Failed to upload {source_file_name}: {str(e)}")
|
939 | 924 |
|
940 | 925 |
|
| 926 | +def get_gcs_file_hash(bucket_name: str, blob_name: str) -> str: |
| 927 | + """ |
| 928 | + Get the hash of a file stored in Google Cloud Storage. |
| 929 | + """ |
| 930 | + storage_client = storage.Client(project=get_gcp_project_id()) |
| 931 | + bucket = storage_client.bucket(bucket_name) |
| 932 | + blob = bucket.get_blob(blob_name) |
| 933 | + |
| 934 | + if not blob: |
| 935 | + raise FileNotFoundError(f"File {blob_name} not found in bucket {bucket_name}") |
| 936 | + |
| 937 | + return blob.crc32c |
| 938 | + |
| 939 | + |
| 940 | +def get_local_file_hash(file_path: str) -> str: |
| 941 | + """ |
| 942 | + Calculate CRC32C hash of a local file. |
| 943 | +
|
| 944 | + Args: |
| 945 | + file_path: Path to the local file |
| 946 | +
|
| 947 | + Returns: |
| 948 | + Base64-encoded string of the file's CRC32C hash |
| 949 | + """ |
| 950 | + crc32c = google_crc32c.Checksum() |
| 951 | + |
| 952 | + with open(file_path, "rb") as f: |
| 953 | + # Read the file in chunks to handle large files efficiently |
| 954 | + for chunk in iter(lambda: f.read(4096), b""): |
| 955 | + crc32c.update(chunk) |
| 956 | + |
| 957 | + # Convert to base64 to match GCS format |
| 958 | + return base64.b64encode(crc32c.digest()).decode('utf-8') |
| 959 | + |
| 960 | + |
| 961 | +def compare_gcs_and_local_file_hashes(bucket_name: str, blob_name: str, local_file_path: str) -> bool: |
| 962 | + """ |
| 963 | + Compare hashes of a GCS file and a local file to check if they're identical. |
| 964 | +
|
| 965 | + Args: |
| 966 | + bucket_name: Name of the GCS bucket |
| 967 | + blob_name: Name/path of the blob in the bucket |
| 968 | + local_file_path: Path to the local file |
| 969 | +
|
| 970 | + Returns: |
| 971 | + True if files are identical, False otherwise |
| 972 | + """ |
| 973 | + try: |
| 974 | + gcs_hash = get_gcs_file_hash(bucket_name, blob_name) |
| 975 | + local_hash = get_local_file_hash(local_file_path) |
| 976 | + |
| 977 | + print(f"Local hash of {local_file_path}: {local_hash}. GCS file {blob_name} hash: {gcs_hash}") |
| 978 | + |
| 979 | + return gcs_hash == local_hash |
| 980 | + |
| 981 | + except Exception as e: |
| 982 | + print(f"Error comparing files: {str(e)}") |
| 983 | + return False |
| 984 | + |
| 985 | + |
941 | 986 | @click.command(name="run", context_settings=dict(allow_extra_args=True, ignore_unknown_options=True))
|
942 | 987 | @click.option("--conf", required=False, help="Conf param - required for every mode except fetch")
|
943 | 988 | @click.option("--env", required=False, default="dev", help="Running environment - default to be dev")
|
@@ -987,17 +1032,18 @@ def main(ctx, conf, env, mode, dataproc, ds, app_name, start_ds, end_ds, paralle
|
987 | 1032 | set_defaults(ctx)
|
988 | 1033 | extra_args = (" " + online_args) if mode in ONLINE_MODES and online_args else ""
|
989 | 1034 | ctx.params["args"] = " ".join(unknown_args) + extra_args
|
990 |
| - with tempfile.TemporaryDirectory() as temp_dir: |
991 |
| - if dataproc: |
992 |
| - jar_path = download_dataproc_submitter_jar(temp_dir, get_customer_id()) |
993 |
| - elif chronon_jar: |
994 |
| - jar_path = chronon_jar |
995 |
| - else: |
996 |
| - service_jar_path = download_service_jar(temp_dir, get_customer_id()) |
997 |
| - chronon_gcp_jar_path = download_chronon_gcp_jar(temp_dir, get_customer_id()) |
998 |
| - jar_path = f"{service_jar_path}:{chronon_gcp_jar_path}" |
| 1035 | + os.makedirs(ZIPLINE_DIRECTORY, exist_ok=True) |
| 1036 | + |
| 1037 | + if dataproc: |
| 1038 | + jar_path = download_zipline_jar(ZIPLINE_DIRECTORY, get_customer_id(), ZIPLINE_DATAPROC_SUBMITTER_JAR) |
| 1039 | + elif chronon_jar: |
| 1040 | + jar_path = chronon_jar |
| 1041 | + else: |
| 1042 | + service_jar_path = download_zipline_jar(ZIPLINE_DIRECTORY, get_customer_id(), ZIPLINE_SERVICE_JAR) |
| 1043 | + chronon_gcp_jar_path = download_zipline_jar(ZIPLINE_DIRECTORY, get_customer_id(), ZIPLINE_ONLINE_JAR_DEFAULT) |
| 1044 | + jar_path = f"{service_jar_path}:{chronon_gcp_jar_path}" |
999 | 1045 |
|
1000 |
| - Runner(ctx.params, os.path.expanduser(jar_path)).run() |
| 1046 | + Runner(ctx.params, os.path.expanduser(jar_path)).run() |
1001 | 1047 |
|
1002 | 1048 |
|
1003 | 1049 | if __name__ == "__main__":
|
|
0 commit comments