|
| 1 | +import argparse |
| 2 | +import csv |
| 3 | +import os |
| 4 | +import re |
| 5 | +import tempfile |
| 6 | +import time |
| 7 | +from pathlib import Path |
| 8 | +from typing import Callable |
| 9 | + |
| 10 | +import boto3 |
| 11 | + |
| 12 | + |
| 13 | +def sync_buckets(include_checkpoints: bool, dry_run: bool) -> None: |
| 14 | + minio_resource = boto3.resource( |
| 15 | + service_name="s3", |
| 16 | + endpoint_url=os.getenv("MINIO_ENDPOINT_URL"), |
| 17 | + aws_access_key_id=os.getenv("MINIO_ACCESS_KEY"), |
| 18 | + aws_secret_access_key=os.getenv("MINIO_SECRET_KEY"), |
| 19 | + # Verify is false if endpoint_url is an IP address. Aqua/Cheetah connecting to MinIO need this disabled for now. |
| 20 | + verify=False if re.match(r"https://\d+\.\d+\.\d+\.\d+", os.getenv("MINIO_ENDPOINT_URL")) else True, |
| 21 | + ) |
| 22 | + minio_bucket = minio_resource.Bucket("nlp-research") |
| 23 | + |
| 24 | + b2_resource = boto3.resource( |
| 25 | + service_name="s3", |
| 26 | + endpoint_url=os.getenv("B2_ENDPOINT_URL"), |
| 27 | + aws_access_key_id=os.getenv("B2_KEY_ID"), |
| 28 | + aws_secret_access_key=os.getenv("B2_APPLICATION_KEY"), |
| 29 | + verify=True, |
| 30 | + ) |
| 31 | + b2_bucket = b2_resource.Bucket("silnlp") |
| 32 | + |
| 33 | + b2_objects = {} |
| 34 | + minio_objects = {} |
| 35 | + |
| 36 | + # Get all objects in the MinIO bucket |
| 37 | + print("Getting objects from MinIO") |
| 38 | + for obj in minio_bucket.objects.all(): |
| 39 | + minio_objects[obj.key] = obj.last_modified |
| 40 | + |
| 41 | + # Get all objects in the B2 bucket |
| 42 | + print("Getting objects from B2") |
| 43 | + for obj in b2_bucket.objects.all(): |
| 44 | + b2_objects[obj.key] = obj.last_modified |
| 45 | + |
| 46 | + if not include_checkpoints: |
| 47 | + print("Excluding model checkpoints from the sync") |
| 48 | + keys_to_remove = set() |
| 49 | + for key in minio_objects.keys(): |
| 50 | + # Check if key matches regex |
| 51 | + if re.match( |
| 52 | + r"^MT/experiments/.+/run/(checkpoint.*(pytorch_model\.bin|\.safetensors)$|ckpt.+\.(data-00000-of-00001|index)$)", |
| 53 | + key, |
| 54 | + ): |
| 55 | + keys_to_remove.add(key) |
| 56 | + |
| 57 | + for key in keys_to_remove: |
| 58 | + b2_objects.pop(key, None) |
| 59 | + minio_objects.pop(key, None) |
| 60 | + |
| 61 | + output_csv = f"sync_output_{time.strftime('%Y%m%d-%H%M%S')}" + ("_dryrun" if dry_run else "") + ".csv" |
| 62 | + with open(output_csv, mode="w", newline="", encoding="utf-8") as csv_file: |
| 63 | + csv_writer = csv.writer(csv_file) |
| 64 | + csv_writer.writerow(["Filename", "Action"]) |
| 65 | + # Get the objects that are in the MinIO bucket but not in the B2 bucket, or have been modified |
| 66 | + objects_to_sync = [] |
| 67 | + for key, value in minio_objects.items(): |
| 68 | + if key not in b2_objects.keys(): |
| 69 | + objects_to_sync.append(key) |
| 70 | + elif value > b2_objects[key]: |
| 71 | + objects_to_sync.append(key) |
| 72 | + |
| 73 | + objects_to_delete = [] |
| 74 | + for key in b2_objects.keys(): |
| 75 | + if key not in minio_objects.keys(): |
| 76 | + objects_to_delete.append(key) |
| 77 | + if not dry_run: |
| 78 | + csv_writer.writerow([key, "Deleted from B2"]) |
| 79 | + else: |
| 80 | + csv_writer.writerow([key, "Would be deleted from B2"]) |
| 81 | + if not dry_run: |
| 82 | + for i in range(0, len(objects_to_delete), 1000): |
| 83 | + batch = objects_to_delete[i : i + 1000] |
| 84 | + delete_params = {"Objects": [{"Key": key} for key in batch]} |
| 85 | + b2_bucket.delete_objects(Delete=delete_params) |
| 86 | + |
| 87 | + # Sync the objects to the B2 bucket |
| 88 | + length = len(objects_to_sync) |
| 89 | + if not dry_run: |
| 90 | + print(f"Total objects to sync: {len(objects_to_sync)}") |
| 91 | + else: |
| 92 | + print(f"Total objects that would be synced: {len(objects_to_sync)}") |
| 93 | + x = 0 |
| 94 | + for key in objects_to_sync: |
| 95 | + x += 1 |
| 96 | + if not dry_run: |
| 97 | + print(f"Syncing, {x}/{length}: {key}") |
| 98 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 99 | + obj_path = Path(temp_dir) / key |
| 100 | + obj_path.parent.mkdir(parents=True, exist_ok=True) |
| 101 | + try_n_times(lambda: minio_bucket.download_file(key, str(obj_path))) |
| 102 | + try_n_times(lambda: b2_bucket.upload_file(str(obj_path), key)) |
| 103 | + csv_writer.writerow([key, "Synced to B2"]) |
| 104 | + else: |
| 105 | + print(f"Would be syncing, {x}/{length}: {key}") |
| 106 | + csv_writer.writerow([key, "Would be synced to B2"]) |
| 107 | + |
| 108 | + |
| 109 | +def try_n_times(func: Callable, n=10): |
| 110 | + for i in range(n): |
| 111 | + try: |
| 112 | + func() |
| 113 | + break |
| 114 | + except Exception as e: |
| 115 | + if i < n - 1: |
| 116 | + print(f"Failed {i+1} of {n} times. Retrying.") |
| 117 | + time.sleep(2**i) |
| 118 | + else: |
| 119 | + raise e |
| 120 | + |
| 121 | + |
| 122 | +def main() -> None: |
| 123 | + parser = argparse.ArgumentParser(description="Sync MinIO and B2 buckets") |
| 124 | + parser.add_argument( |
| 125 | + "--include-checkpoints", default=False, action="store_true", help="Include model checkpoints in the sync" |
| 126 | + ) |
| 127 | + parser.add_argument( |
| 128 | + "--dry-run", |
| 129 | + default=False, |
| 130 | + action="store_true", |
| 131 | + help="Don't sync any files, just report what would be synced", |
| 132 | + ) |
| 133 | + args = parser.parse_args() |
| 134 | + |
| 135 | + sync_buckets(args.include_checkpoints, args.dry_run) |
| 136 | + |
| 137 | + |
| 138 | +if __name__ == "__main__": |
| 139 | + main() |
0 commit comments