Skip to content

Commit 23679a4

Browse files
committed
skip first and last in metrics calculation
1 parent 3f0c752 commit 23679a4

File tree

3 files changed

+54
-31
lines changed

3 files changed

+54
-31
lines changed

dags/map_reproducibility/utils/common_utils.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@
3939
MAX_TFLOP = {"a3ultra": 989, "a3mega": 989, "a4": 2237}
4040

4141

42+
class Config:
43+
"""
44+
A simple configuration class that allows dot notation access
45+
to dictionary keys.
46+
"""
47+
48+
def __init__(self, **kwargs):
49+
self.__dict__.update(kwargs)
50+
51+
def __repr__(self):
52+
return repr(self.__dict__)
53+
54+
def __str__(self):
55+
return str(self.__dict__)
56+
57+
4258
# This is required to get auth to access
4359
def git_cookie_authdaemon():
4460
auth_cmds = (
@@ -401,13 +417,27 @@ def copy_bucket_cmds_maxtext(tmpdir, bucket_name=BUCKET_NAME):
401417
return cmds
402418

403419

404-
def calculate_maxtext_metrics(log_location: str, hardware: str = "a3ultra"):
420+
def get_profiler_skip_steps(config: Config):
421+
"""Extract the number of steps to skip for the profiler from config."""
422+
base_skip_steps = getattr(config, "dump_hlo", 1)
423+
additional_skip_steps = getattr(config, "profiler_steps", 5)
424+
return base_skip_steps + additional_skip_steps
425+
426+
427+
def calculate_maxtext_metrics(
428+
log_location: str, hardware: str = "a3ultra", skip_first=2, skip_last=2
429+
):
405430
metrics, _ = metric.read_from_tb(log_location, None, None)
406431

407432
print(f"metrics - {metrics}")
408433
step_time_metrics = metrics["perf/step_time_seconds"]
434+
435+
# Apply skip_first and skip_last when aggregating
409436
avg_step_time = metric.aggregate_metrics(
410-
step_time_metrics, metric_config.AggregationStrategy.AVERAGE
437+
step_time_metrics[skip_first:-skip_last]
438+
if skip_last > 0
439+
else step_time_metrics[skip_first:],
440+
metric_config.AggregationStrategy.AVERAGE,
411441
)
412442

413443
tflop_per_device_per_sec_metrics = metrics["perf/per_device_tflops_per_sec"]
@@ -707,22 +737,6 @@ def get_two_node_cmds(hypercomputer: str = "a3ultra"):
707737
return cmd
708738

709739

710-
class Config:
711-
"""
712-
A simple configuration class that allows dot notation access
713-
to dictionary keys.
714-
"""
715-
716-
def __init__(self, **kwargs):
717-
self.__dict__.update(kwargs)
718-
719-
def __repr__(self):
720-
return repr(self.__dict__)
721-
722-
def __str__(self):
723-
return str(self.__dict__)
724-
725-
726740
def parse_internal_config_filename(filename, config=None):
727741
"""
728742
Parse configuration values embedded in the filename.

dags/map_reproducibility/utils/internal_aotc_workload.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from dags.map_reproducibility.utils.common_utils import get_bq_writer_path
3535
from dags.map_reproducibility.utils.common_utils import get_recipe_repo_path, get_internal_recipe_repo_path
3636
from dags.map_reproducibility.utils.common_utils import get_cluster
37-
from dags.map_reproducibility.utils.common_utils import calculate_maxtext_metrics
37+
from dags.map_reproducibility.utils.common_utils import calculate_maxtext_metrics, get_profiler_skip_steps
3838
from dags.map_reproducibility.utils.common_utils import copy_bucket_cmds_maxtext, get_job_gcs_bucket_folder
3939
from dags.map_reproducibility.utils.common_utils import parse_internal_config_filename
4040
from dags.map_reproducibility.utils.common_utils import parse_internal_config_content
@@ -158,12 +158,6 @@ def run_internal_aotc_workload(
158158

159159
log_location = os.path.join(tmpdir, "tflog/metrics")
160160

161-
mfu, step_time = calculate_maxtext_metrics(
162-
log_location, config.HYPERCOMPUTER
163-
)
164-
165-
print(f"mfu: {mfu}")
166-
print(f"step_time: {step_time}")
167161
comment = (
168162
"internal recipes regression tests"
169163
if not backfill
@@ -173,6 +167,16 @@ def run_internal_aotc_workload(
173167
gcs_bucket = get_job_gcs_bucket_folder(job_name)
174168
print(f"GCS bucket is {gcs_bucket}")
175169

170+
# calculate mfu based on the config
171+
skip_first_n_steps_for_profiler = get_profiler_skip_steps(config)
172+
mfu, step_time = calculate_maxtext_metrics(
173+
log_location,
174+
config.HYPERCOMPUTER,
175+
skip_first=skip_first_n_steps_for_profiler,
176+
)
177+
print(f"mfu: {mfu}")
178+
print(f"step_time: {step_time}")
179+
176180
write_run(
177181
model_id=config.HELM_NAME_MODEL_ID,
178182
hardware_id=config.HYPERCOMPUTER,

dags/map_reproducibility/utils/sample_workload_utils.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
parse_internal_config_content,
3939
get_patheon_job_link,
4040
find_xprof_gcs_path,
41+
get_profiler_skip_steps,
4142
)
4243

4344
from dags.map_reproducibility.utils.benchmarkdb_utils import write_run
@@ -301,12 +302,6 @@ def run_internal_sample_aotc_workload(
301302
bq_writer_repo_root = get_bq_writer_path(tmpdir)
302303
log_location = os.path.join(tmpdir, "tflog/metrics")
303304

304-
mfu, step_time = calculate_maxtext_metrics(
305-
log_location, config.HYPERCOMPUTER
306-
)
307-
308-
print(f"mfu: {mfu}")
309-
print(f"step_time: {step_time}")
310305
comment = "sample benchmarking run"
311306
gcs_bucket = get_job_gcs_bucket_folder(
312307
job_name, bucket_name=sample_run_bucket_name
@@ -329,6 +324,16 @@ def run_internal_sample_aotc_workload(
329324
f"Profile command failed with error: {profiler_error_message}"
330325
)
331326

327+
# calculate mfu based on the config
328+
skip_first_n_steps_for_profiler = get_profiler_skip_steps(config)
329+
mfu, step_time = calculate_maxtext_metrics(
330+
log_location,
331+
config.HYPERCOMPUTER,
332+
skip_first=skip_first_n_steps_for_profiler,
333+
)
334+
print(f"mfu: {mfu}")
335+
print(f"step_time: {step_time}")
336+
332337
write_run(
333338
model_id=config.HELM_NAME_MODEL_ID,
334339
hardware_id=config.HYPERCOMPUTER,

0 commit comments

Comments
 (0)