Skip to content

Conditionally download jars only if hashes do not match #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 8, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 92 additions & 46 deletions api/py/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
# limitations under the License.

from google.cloud import storage
import base64
import click
import google_crc32c
import json
import logging
import multiprocessing
import os
import re
import subprocess
import tempfile
import time
from typing import List
import xml.etree.ElementTree as ET
Expand Down Expand Up @@ -134,6 +135,9 @@
ZIPLINE_ONLINE_CLASS_DEFAULT = "ai.chronon.integrations.cloud_gcp.GcpApiImpl"
ZIPLINE_FLINK_JAR_DEFAULT = "flink-assembly-0.1.0-SNAPSHOT.jar"
ZIPLINE_DATAPROC_SUBMITTER_JAR = "cloud_gcp_submitter_deploy.jar"
ZIPLINE_SERVICE_JAR = "service-0.1.0-SNAPSHOT.jar"

ZIPLINE_DIRECTORY = "/tmp/zipline"


class DataprocJobType(Enum):
Expand Down Expand Up @@ -861,46 +865,27 @@ def generate_dataproc_submitter_args(user_args: str, job_type: DataprocJobType =
raise ValueError(f"Invalid job type: {job_type}")


def download_dataproc_submitter_jar(destination_dir: str, customer_id: str):
print("Downloading dataproc submitter jar from GCS...")
def download_zipline_jar(destination_dir: str, customer_id: str, jar_name: str):
bucket_name = f"zipline-artifacts-{customer_id}"

file_name = ZIPLINE_DATAPROC_SUBMITTER_JAR

source_blob_name = f"jars/{file_name}"
dataproc_jar_destination_path = f"{destination_dir}/{file_name}"

download_gcs_blob(bucket_name, source_blob_name,
dataproc_jar_destination_path)
return dataproc_jar_destination_path


def download_chronon_gcp_jar(destination_dir: str, customer_id: str):
print("Downloading chronon gcp jar from GCS...")
bucket_name = f"zipline-artifacts-{customer_id}"

file_name = ZIPLINE_ONLINE_JAR_DEFAULT

source_blob_name = f"jars/{file_name}"
chronon_gcp_jar_destination_path = f"{destination_dir}/{file_name}"

download_gcs_blob(bucket_name, source_blob_name,
chronon_gcp_jar_destination_path)
return chronon_gcp_jar_destination_path


def download_service_jar(destination_dir: str, customer_id: str):
print("Downloading service jar from GCS...")
bucket_name = f"zipline-artifacts-{customer_id}"
source_blob_name = f"jars/{jar_name}"
destination_path = f"{destination_dir}/{jar_name}"

file_name = "service-0.1.0-SNAPSHOT.jar"
are_identical = compare_gcs_and_local_file_hashes(bucket_name, source_blob_name,
destination_path) if os.path.exists(
destination_path) else False

source_blob_name = f"jars/{file_name}"
service_jar_destination_path = f"{destination_dir}/{file_name}"
if are_identical:
print(
f"{destination_path} matches GCS {bucket_name}/{source_blob_name}")
else:
print(
f"{destination_path} does NOT match GCS {bucket_name}/{source_blob_name}")
print(f"Downloading {jar_name} from GCS...")

download_gcs_blob(bucket_name, source_blob_name,
service_jar_destination_path)
return service_jar_destination_path
download_gcs_blob(bucket_name, source_blob_name,
destination_path)
return destination_path


@retry_decorator(retries=2, backoff=5)
Expand Down Expand Up @@ -938,6 +923,66 @@ def upload_gcs_blob(bucket_name, source_file_name, destination_blob_name):
raise RuntimeError(f"Failed to upload {source_file_name}: {str(e)}")


def get_gcs_file_hash(bucket_name: str, blob_name: str) -> str:
"""
Get the hash of a file stored in Google Cloud Storage.
"""
storage_client = storage.Client(project=get_gcp_project_id())
bucket = storage_client.bucket(bucket_name)
blob = bucket.get_blob(blob_name)

if not blob:
raise FileNotFoundError(f"File {blob_name} not found in bucket {bucket_name}")

return blob.crc32c


def get_local_file_hash(file_path: str) -> str:
"""
Calculate CRC32C hash of a local file.

Args:
file_path: Path to the local file

Returns:
Base64-encoded string of the file's CRC32C hash
"""
crc32c = google_crc32c.Checksum()

with open(file_path, "rb") as f:
# Read the file in chunks to handle large files efficiently
for chunk in iter(lambda: f.read(4096), b""):
crc32c.update(chunk)

# Convert to base64 to match GCS format
return base64.b64encode(crc32c.digest()).decode('utf-8')


def compare_gcs_and_local_file_hashes(bucket_name: str, blob_name: str, local_file_path: str) -> bool:
"""
Compare hashes of a GCS file and a local file to check if they're identical.

Args:
bucket_name: Name of the GCS bucket
blob_name: Name/path of the blob in the bucket
local_file_path: Path to the local file

Returns:
True if files are identical, False otherwise
"""
try:
gcs_hash = get_gcs_file_hash(bucket_name, blob_name)
local_hash = get_local_file_hash(local_file_path)

print(f"Local hash of {local_file_path}: {local_hash}. GCS file {blob_name} hash: {gcs_hash}")

return gcs_hash == local_hash

except Exception as e:
print(f"Error comparing files: {str(e)}")
return False

Comment on lines +961 to +984
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling in hash comparison.

The function silently returns False on errors, which could mask issues like permission problems or network errors.

-    except Exception as e:
-        print(f"Error comparing files: {str(e)}")
-        return False
+    except FileNotFoundError as e:
+        print(f"File not found error: {str(e)}")
+        return False
+    except (storage.exceptions.NotFound, storage.exceptions.Forbidden) as e:
+        print(f"GCS error: {str(e)}")
+        return False
+    except Exception as e:
+        print(f"Unexpected error comparing files: {str(e)}")
+        raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def compare_gcs_and_local_file_hashes(bucket_name: str, blob_name: str, local_file_path: str) -> bool:
"""
Compare hashes of a GCS file and a local file to check if they're identical.
Args:
bucket_name: Name of the GCS bucket
blob_name: Name/path of the blob in the bucket
local_file_path: Path to the local file
Returns:
True if files are identical, False otherwise
"""
try:
gcs_hash = get_gcs_file_hash(bucket_name, blob_name)
local_hash = get_local_file_hash(local_file_path)
print(f"Local hash of {local_file_path}: {local_hash}. GCS file {blob_name} hash: {gcs_hash}")
return gcs_hash == local_hash
except Exception as e:
print(f"Error comparing files: {str(e)}")
return False
def compare_gcs_and_local_file_hashes(bucket_name: str, blob_name: str, local_file_path: str) -> bool:
"""
Compare hashes of a GCS file and a local file to check if they're identical.
Args:
bucket_name: Name of the GCS bucket
blob_name: Name/path of the blob in the bucket
local_file_path: Path to the local file
Returns:
True if files are identical, False otherwise
"""
try:
gcs_hash = get_gcs_file_hash(bucket_name, blob_name)
local_hash = get_local_file_hash(local_file_path)
print(f"Local hash of {local_file_path}: {local_hash}. GCS file {blob_name} hash: {gcs_hash}")
return gcs_hash == local_hash
except FileNotFoundError as e:
print(f"File not found error: {str(e)}")
return False
except (storage.exceptions.NotFound, storage.exceptions.Forbidden) as e:
print(f"GCS error: {str(e)}")
return False
except Exception as e:
print(f"Unexpected error comparing files: {str(e)}")
raise


@click.command(name="run", context_settings=dict(allow_extra_args=True, ignore_unknown_options=True))
@click.option("--conf", required=False, help="Conf param - required for every mode except fetch")
@click.option("--env", required=False, default="dev", help="Running environment - default to be dev")
Expand Down Expand Up @@ -987,17 +1032,18 @@ def main(ctx, conf, env, mode, dataproc, ds, app_name, start_ds, end_ds, paralle
set_defaults(ctx)
extra_args = (" " + online_args) if mode in ONLINE_MODES and online_args else ""
ctx.params["args"] = " ".join(unknown_args) + extra_args
with tempfile.TemporaryDirectory() as temp_dir:
if dataproc:
jar_path = download_dataproc_submitter_jar(temp_dir, get_customer_id())
elif chronon_jar:
jar_path = chronon_jar
else:
service_jar_path = download_service_jar(temp_dir, get_customer_id())
chronon_gcp_jar_path = download_chronon_gcp_jar(temp_dir, get_customer_id())
jar_path = f"{service_jar_path}:{chronon_gcp_jar_path}"
os.makedirs(ZIPLINE_DIRECTORY, exist_ok=True)

if dataproc:
jar_path = download_zipline_jar(ZIPLINE_DIRECTORY, get_customer_id(), ZIPLINE_DATAPROC_SUBMITTER_JAR)
elif chronon_jar:
jar_path = chronon_jar
else:
service_jar_path = download_zipline_jar(ZIPLINE_DIRECTORY, get_customer_id(), ZIPLINE_SERVICE_JAR)
chronon_gcp_jar_path = download_zipline_jar(ZIPLINE_DIRECTORY, get_customer_id(), ZIPLINE_ONLINE_JAR_DEFAULT)
jar_path = f"{service_jar_path}:{chronon_gcp_jar_path}"

Runner(ctx.params, os.path.expanduser(jar_path)).run()
Runner(ctx.params, os.path.expanduser(jar_path)).run()


if __name__ == "__main__":
Expand Down
Loading