18
18
19
19
import torch
20
20
import typer
21
- from nemo .lightning import io , teardown
21
+ from megatron .core .dist_checkpointing .validation import StrictHandling
22
+ from nemo .lightning import MegatronStrategy , Trainer , io , teardown
22
23
from nemo .lightning .pytorch .utils import dtype_from_hf
23
24
from transformers import AutoConfig as HFAutoConfig
24
25
from transformers import AutoModelForMaskedLM
@@ -123,9 +124,6 @@ def init(self, dtype: torch.dtype = torch.bfloat16) -> EsmForMaskedLM:
123
124
124
125
def apply (self , output_path : Path ) -> Path :
125
126
"""Applies the transformation."""
126
- from megatron .core .dist_checkpointing .validation import StrictHandling
127
- from nemo .lightning import MegatronStrategy , Trainer
128
-
129
127
cpu = not torch .distributed .is_initialized ()
130
128
trainer = Trainer (
131
129
devices = 1 ,
@@ -136,7 +134,12 @@ def apply(self, output_path: Path) -> Path:
136
134
)
137
135
source , _ = self .nemo_load (self , trainer = trainer , cpu = cpu )
138
136
139
- target = self .init (source .dtype )
137
+ dtype = torch .bfloat16 if source .config .bf16 else torch .float32
138
+
139
+ # Not sure why we need to do this, for some reason lm_head stays as fp32
140
+ source .module .lm_head .to (dtype )
141
+
142
+ target = self .init (dtype )
140
143
target = self .convert_state (source , target )
141
144
142
145
target = target .cpu ()
@@ -342,30 +345,38 @@ def _import_qkv_bias(ctx: io.TransformCTX, query, key, value):
342
345
return concat_biases
343
346
344
347
345
- app = typer .Typer ()
348
+ app = typer .Typer (pretty_exceptions_enable = False )
346
349
347
350
348
351
@app .command ()
349
- def convert_nemo_to_hf (nemo_path : str , output_path : str ):
352
+ def convert_nemo_to_hf (nemo_path : str , output_path : str , overwrite : bool = True ):
350
353
"""Convert a NeMo ESM-2 checkpoint to a HuggingFace checkpoint.
351
354
352
355
Args:
353
356
nemo_path: Path to the NeMo checkpoint.
354
357
output_path: Path to the output HuggingFace checkpoint.
358
+ overwrite: Whether to overwrite the output path if it already exists.
355
359
"""
356
- io .export_ckpt (Path (nemo_path ), "hf" , Path (output_path ))
360
+ io .export_ckpt (
361
+ Path (nemo_path ),
362
+ "hf" ,
363
+ Path (output_path ),
364
+ overwrite = overwrite ,
365
+ load_connector = lambda path , ext : BionemoLightningModule .exporter (ext , path ),
366
+ )
357
367
358
368
359
369
@app .command ()
360
- def convert_hf_to_nemo (hf_tag_or_path : str , output_path : str ):
370
+ def convert_hf_to_nemo (hf_tag_or_path : str , output_path : str , overwrite : bool = True ):
361
371
"""Convert a HuggingFace ESM-2 checkpoint to a NeMo ESM-2 checkpoint.
362
372
363
373
Args:
364
374
hf_tag_or_path: Tag or path to the HuggingFace checkpoint.
365
375
output_path: Path to the output NeMo checkpoint.
376
+ overwrite: Whether to overwrite the output path if it already exists.
366
377
"""
367
378
module = biobert_lightning_module (config = ESM2Config (), post_process = True )
368
- io .import_ckpt (module , f"hf://{ hf_tag_or_path } " , Path (output_path ))
379
+ io .import_ckpt (module , f"hf://{ hf_tag_or_path } " , Path (output_path ), overwrite = overwrite )
369
380
370
381
371
382
if __name__ == "__main__" :
0 commit comments