Skip to content

Commit ed9a811

Browse files
committed
working with bcr checkpoint
Signed-off-by: Peter St. John <[email protected]>
1 parent f18adf4 commit ed9a811

File tree

1 file changed

+21
-10
lines changed
  • sub-packages/bionemo-esm2/src/bionemo/esm2/model

1 file changed

+21
-10
lines changed

sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
import torch
2020
import typer
21-
from nemo.lightning import io, teardown
21+
from megatron.core.dist_checkpointing.validation import StrictHandling
22+
from nemo.lightning import MegatronStrategy, Trainer, io, teardown
2223
from nemo.lightning.pytorch.utils import dtype_from_hf
2324
from transformers import AutoConfig as HFAutoConfig
2425
from transformers import AutoModelForMaskedLM
@@ -123,9 +124,6 @@ def init(self, dtype: torch.dtype = torch.bfloat16) -> EsmForMaskedLM:
123124

124125
def apply(self, output_path: Path) -> Path:
125126
"""Applies the transformation."""
126-
from megatron.core.dist_checkpointing.validation import StrictHandling
127-
from nemo.lightning import MegatronStrategy, Trainer
128-
129127
cpu = not torch.distributed.is_initialized()
130128
trainer = Trainer(
131129
devices=1,
@@ -136,7 +134,12 @@ def apply(self, output_path: Path) -> Path:
136134
)
137135
source, _ = self.nemo_load(self, trainer=trainer, cpu=cpu)
138136

139-
target = self.init(source.dtype)
137+
dtype = torch.bfloat16 if source.config.bf16 else torch.float32
138+
139+
# Not sure why we need to do this, for some reason lm_head stays as fp32
140+
source.module.lm_head.to(dtype)
141+
142+
target = self.init(dtype)
140143
target = self.convert_state(source, target)
141144

142145
target = target.cpu()
@@ -342,30 +345,38 @@ def _import_qkv_bias(ctx: io.TransformCTX, query, key, value):
342345
return concat_biases
343346

344347

345-
app = typer.Typer()
348+
app = typer.Typer(pretty_exceptions_enable=False)
346349

347350

348351
@app.command()
349-
def convert_nemo_to_hf(nemo_path: str, output_path: str):
352+
def convert_nemo_to_hf(nemo_path: str, output_path: str, overwrite: bool = True):
350353
"""Convert a NeMo ESM-2 checkpoint to a HuggingFace checkpoint.
351354
352355
Args:
353356
nemo_path: Path to the NeMo checkpoint.
354357
output_path: Path to the output HuggingFace checkpoint.
358+
overwrite: Whether to overwrite the output path if it already exists.
355359
"""
356-
io.export_ckpt(Path(nemo_path), "hf", Path(output_path))
360+
io.export_ckpt(
361+
Path(nemo_path),
362+
"hf",
363+
Path(output_path),
364+
overwrite=overwrite,
365+
load_connector=lambda path, ext: BionemoLightningModule.exporter(ext, path),
366+
)
357367

358368

359369
@app.command()
360-
def convert_hf_to_nemo(hf_tag_or_path: str, output_path: str):
370+
def convert_hf_to_nemo(hf_tag_or_path: str, output_path: str, overwrite: bool = True):
361371
"""Convert a HuggingFace ESM-2 checkpoint to a NeMo ESM-2 checkpoint.
362372
363373
Args:
364374
hf_tag_or_path: Tag or path to the HuggingFace checkpoint.
365375
output_path: Path to the output NeMo checkpoint.
376+
overwrite: Whether to overwrite the output path if it already exists.
366377
"""
367378
module = biobert_lightning_module(config=ESM2Config(), post_process=True)
368-
io.import_ckpt(module, f"hf://{hf_tag_or_path}", Path(output_path))
379+
io.import_ckpt(module, f"hf://{hf_tag_or_path}", Path(output_path), overwrite=overwrite)
369380

370381

371382
if __name__ == "__main__":

0 commit comments

Comments
 (0)