@@ -122,7 +122,7 @@ def helm_apply_cmds(
122
122
gcs_cmd = ""
123
123
if hypercomputer == "a3ultra" :
124
124
gcs_cmd = f" --set clusterName={ cluster_name } "
125
- # gcs_cmd += f" --set queue={kueue_name}"
125
+ gcs_cmd += f" --set queue={ kueue_name } "
126
126
gcs_cmd += f" --set volumes.gcsMounts[0].bucketName={ BUCKET_NAME } "
127
127
else :
128
128
gcs_cmd = f" --set workload.gcsBucketForDataCataPath={ BUCKET_NAME } "
@@ -325,3 +325,33 @@ def get_scheduled_time(hardware: str, model: str, framework: str):
325
325
return schedule_map [hardware ][model ][framework ]
326
326
327
327
return None # Return None if no schedule is found for the given combination
328
+
329
+
330
+ def get_docker_image (hardware : str , framework : str ):
331
+ """
332
+ Returns the appropriate Docker image based on the given hardware, model, and framework.
333
+
334
+ Args:
335
+ hardware: The hardware type (e.g., "a3ultra", "a3mega").
336
+ framework: The framework (e.g., "nemo", "maxtext").
337
+
338
+ Returns:
339
+ A Docker image string or None if no image is defined for the given combination.
340
+ """
341
+
342
+ image_map = {
343
+ "a3ultra" : {
344
+ "nemo" : "us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-gpu-nemo-nccl:nemo24.07-gib1.0.3-A3U" ,
345
+ "maxtext" : "us-central1-docker.pkg.dev/supercomputer-testing/gunjanjalori/maxtext-benchmark" ,
346
+ },
347
+ "a3mega" : {
348
+ "nemo" : "us-central1-docker.pkg.dev/supercomputer-testing/gunjanjalori/nemo_test/nemo_workload:24.07" ,
349
+ "maxtext" : "us-central1-docker.pkg.dev/supercomputer-testing/gunjanjalori/maxtext-benchmark" ,
350
+ },
351
+ }
352
+
353
+ if hardware in image_map :
354
+ if framework in image_map [hardware ]:
355
+ return image_map [hardware ][framework ]
356
+
357
+ return None # Return None if no image is found for the given combination
0 commit comments