File tree Expand file tree Collapse file tree 2 files changed +14
-16
lines changed
sub-packages/bionemo-esm2/src/bionemo/esm2/scripts Expand file tree Collapse file tree 2 files changed +14
-16
lines changed Original file line number Diff line number Diff line change @@ -282,7 +282,10 @@ def train_model(
282
282
)
283
283
# Configure the model
284
284
train_metric = None
285
- if task_type == "regression" :
285
+ is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
286
+ if is_model_parallel :
287
+ valid_metric = None # metric logging under model parallelism is not supported yet
288
+ elif task_type == "regression" :
286
289
valid_metric = TorchmetricsConfig (class_path = "MeanSquaredError" , task = "regression" , metric_name = "val_mse" )
287
290
else :
288
291
valid_metric = TorchmetricsConfig (
@@ -296,11 +299,6 @@ def train_model(
296
299
metric_name = "val_acc" ,
297
300
)
298
301
299
- if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
300
- train_metric is not None or valid_metric is not None
301
- ):
302
- raise NotImplementedError ("Metric logging under model parallelism is not supported yet." )
303
-
304
302
config = config_class (
305
303
task_type = task_type ,
306
304
encoder_frozen = encoder_frozen ,
Original file line number Diff line number Diff line change @@ -265,16 +265,16 @@ def main(
265
265
)
266
266
# Configure the model
267
267
train_metric = None
268
- valid_metric = TorchmetricsConfig (
269
- class_path = "text.Perplexity" ,
270
- task = "pretraining" ,
271
- kwargs = { "ignore_index" : MLM_LOSS_IGNORE_INDEX },
272
- metric_name = "val_ppl" ,
273
- )
274
- if tensor_model_parallel_size * pipeline_model_parallel_size > 1 and (
275
- train_metric is not None or valid_metric is not None
276
- ):
277
- raise NotImplementedError ( "Metric logging under model parallelism is not supported yet." )
268
+ is_model_parallel = tensor_model_parallel_size * pipeline_model_parallel_size > 1
269
+ if is_model_parallel :
270
+ valid_metric = None # metric logging under model parallelism is not supported yet
271
+ else :
272
+ valid_metric = TorchmetricsConfig (
273
+ class_path = "text.Perplexity" ,
274
+ task = "pretraining" ,
275
+ kwargs = { "ignore_index" : MLM_LOSS_IGNORE_INDEX },
276
+ metric_name = "val_ppl" ,
277
+ )
278
278
279
279
esm2_config = ESM2Config (
280
280
seq_length = max_seq_length ,
You can’t perform that action at this time.
0 commit comments