Skip to content

Commit 3519ecf

Browse files
authored
disable metric when model parallel (NVIDIA#701)
### Description Previously metric logging is blocking model parallel. Disable instead of raising error. ### Type of changes - [x] Bug fix (non-breaking change which fixes an issue) Signed-off-by: sichu <[email protected]>
1 parent fca2cda commit 3519ecf

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,10 @@ def train_model(
282282
)
283283
# Configure the model
284284
train_metric = None
285-
if task_type == "regression":
285+
is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
286+
if is_model_parallel:
287+
valid_metric = None # metric logging under model parallelism is not supported yet
288+
elif task_type == "regression":
286289
valid_metric = TorchmetricsConfig(class_path="MeanSquaredError", task="regression", metric_name="val_mse")
287290
else:
288291
valid_metric = TorchmetricsConfig(
@@ -296,11 +299,6 @@ def train_model(
296299
metric_name="val_acc",
297300
)
298301

299-
if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
300-
train_metric is not None or valid_metric is not None
301-
):
302-
raise NotImplementedError("Metric logging under model parallelism is not supported yet.")
303-
304302
config = config_class(
305303
task_type=task_type,
306304
encoder_frozen=encoder_frozen,

sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,16 @@ def main(
265265
)
266266
# Configure the model
267267
train_metric = None
268-
valid_metric = TorchmetricsConfig(
269-
class_path="text.Perplexity",
270-
task="pretraining",
271-
kwargs={"ignore_index": MLM_LOSS_IGNORE_INDEX},
272-
metric_name="val_ppl",
273-
)
274-
if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
275-
train_metric is not None or valid_metric is not None
276-
):
277-
raise NotImplementedError("Metric logging under model parallelism is not supported yet.")
268+
is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
269+
if is_model_parallel:
270+
valid_metric = None # metric logging under model parallelism is not supported yet
271+
else:
272+
valid_metric = TorchmetricsConfig(
273+
class_path="text.Perplexity",
274+
task="pretraining",
275+
kwargs={"ignore_index": MLM_LOSS_IGNORE_INDEX},
276+
metric_name="val_ppl",
277+
)
278278

279279
esm2_config = ESM2Config(
280280
seq_length=max_seq_length,

0 commit comments

Comments
 (0)