Skip to content

Commit f18adf4

Browse files
committed
Add cli interface for esm2 checkpoint conversion
Signed-off-by: Peter St. John <[email protected]>
1 parent 2efeba6 commit f18adf4

File tree

3 files changed

+106
-31
lines changed

3 files changed

+106
-31
lines changed

sub-packages/bionemo-esm2/pyproject.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@ dependencies = [
1818
]
1919

2020
[project.optional-dependencies]
21-
test = [
22-
'bionemo-testing'
23-
]
21+
test = ['bionemo-testing']
2422
te = [
2523
# TE & Apex need to be installed after PyTorch, NVCC, and CUDA.
2624
# TODO(@pstjohn, @cspades): Figure out how to do this without post-installation.
27-
'transformer_engine[pytorch]'
25+
'transformer_engine[pytorch]',
2826
]
2927

3028
[project.scripts]
31-
bionemo-esm2-train= "bionemo.esm2.run.main:main"
32-
bionemo-esm2-recipe= "bionemo.esm2.run.recipes:main"
29+
bionemo-esm2-train = "bionemo.esm2.run.main:main"
30+
bionemo-esm2-recipe = "bionemo.esm2.run.recipes:main"
3331
infer_esm2 = "bionemo.esm2.scripts.infer_esm2:infer_esm2_entrypoint"
3432
train_esm2 = "bionemo.esm2.scripts.train_esm2:train_esm2_entrypoint"
3533
finetune_esm2 = "bionemo.esm2.scripts.finetune_esm2:finetune_esm2_entrypoint"
34+
convert_esm2 = "bionemo.esm2.model.convert:app"
3635

3736
# Make sure that the tokenizer files are included along with the python files during installation.
3837
[tool.setuptools.package-data]

sub-packages/bionemo-esm2/src/bionemo/esm2/model/convert.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pathlib import Path
1818

1919
import torch
20+
import typer
2021
from nemo.lightning import io, teardown
2122
from nemo.lightning.pytorch.utils import dtype_from_hf
2223
from transformers import AutoConfig as HFAutoConfig
@@ -122,19 +123,20 @@ def init(self, dtype: torch.dtype = torch.bfloat16) -> EsmForMaskedLM:
122123

123124
def apply(self, output_path: Path) -> Path:
124125
"""Applies the transformation."""
125-
nemo_config = ESM2Config(
126-
initial_ckpt_path=str(self),
127-
include_embeddings=True,
128-
include_hiddens=True,
129-
params_dtype=torch.bfloat16,
130-
autocast_dtype=torch.bfloat16,
131-
bf16=True,
132-
fp16=False,
126+
from megatron.core.dist_checkpointing.validation import StrictHandling
127+
from nemo.lightning import MegatronStrategy, Trainer
128+
129+
cpu = not torch.distributed.is_initialized()
130+
trainer = Trainer(
131+
devices=1,
132+
accelerator="cpu" if cpu else "gpu",
133+
strategy=MegatronStrategy(
134+
ddp="pytorch", setup_optimizers=False, ckpt_load_strictness=StrictHandling.LOG_UNEXPECTED
135+
),
133136
)
137+
source, _ = self.nemo_load(self, trainer=trainer, cpu=cpu)
134138

135-
source = nemo_config.configure_model(self.tokenizer)
136-
137-
target = self.init(torch.bfloat16)
139+
target = self.init(source.dtype)
138140
target = self.convert_state(source, target)
139141

140142
target = target.cpu()
@@ -169,8 +171,6 @@ def convert_state(self, nemo_module, target):
169171
"lm_head.layer_norm.bias": "lm_head.layer_norm.bias",
170172
}
171173

172-
nemo_module.lm_head.to(torch.bfloat16)
173-
174174
return io.apply_transforms(
175175
nemo_module,
176176
target,
@@ -340,3 +340,33 @@ def _import_qkv_bias(ctx: io.TransformCTX, query, key, value):
340340
concat_biases = concat_biases.transpose(0, 1).contiguous()
341341
concat_biases = concat_biases.view(*input_shape)
342342
return concat_biases
343+
344+
345+
app = typer.Typer()
346+
347+
348+
@app.command()
349+
def convert_nemo_to_hf(nemo_path: str, output_path: str):
350+
"""Convert a NeMo ESM-2 checkpoint to a HuggingFace checkpoint.
351+
352+
Args:
353+
nemo_path: Path to the NeMo checkpoint.
354+
output_path: Path to the output HuggingFace checkpoint.
355+
"""
356+
io.export_ckpt(Path(nemo_path), "hf", Path(output_path))
357+
358+
359+
@app.command()
360+
def convert_hf_to_nemo(hf_tag_or_path: str, output_path: str):
361+
"""Convert a HuggingFace ESM-2 checkpoint to a NeMo ESM-2 checkpoint.
362+
363+
Args:
364+
hf_tag_or_path: Tag or path to the HuggingFace checkpoint.
365+
output_path: Path to the output NeMo checkpoint.
366+
"""
367+
module = biobert_lightning_module(config=ESM2Config(), post_process=True)
368+
io.import_ckpt(module, f"hf://{hf_tag_or_path}", Path(output_path))
369+
370+
371+
if __name__ == "__main__":
372+
app()

sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_convert.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,41 @@
1818
import torch
1919
from nemo.lightning import io
2020
from transformers import AutoModelForMaskedLM
21+
from typer.testing import CliRunner
2122

2223
from bionemo.core.data.load import load
23-
from bionemo.esm2.model.convert import HFESM2Exporter, HFESM2Importer # noqa: F401
24+
from bionemo.esm2.model.convert import (
25+
HFESM2Importer, # noqa: F401
26+
app,
27+
)
2428
from bionemo.esm2.model.model import ESM2Config
2529
from bionemo.esm2.testing.compare import assert_esm2_equivalence
2630
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
2731
from bionemo.testing import megatron_parallel_state_utils
2832

2933

30-
# pytestmark = pytest.mark.xfail(
31-
# reason="These tests are failing due to a bug in nemo global state when run in the same process as previous "
32-
# "checkpoint save/load scripts."
33-
# )
34+
def test_nemo2_conversion_equivalent_8m(tmp_path):
35+
model_tag = "facebook/esm2_t6_8M_UR50D"
36+
module = biobert_lightning_module(config=ESM2Config())
37+
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")
38+
with megatron_parallel_state_utils.distributed_model_parallel_state():
39+
assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag)
3440

3541

36-
def test_nemo2_conversion_equivalent_8m(tmp_path):
42+
def test_nemo2_conversion_equivalent_8m_with_local_path(tmp_path):
3743
model_tag = "facebook/esm2_t6_8M_UR50D"
44+
hf_model = AutoModelForMaskedLM.from_pretrained(model_tag)
45+
hf_model.save_pretrained(tmp_path / "hf_checkpoint")
46+
3847
module = biobert_lightning_module(config=ESM2Config(), post_process=True)
39-
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")
48+
io.import_ckpt(module, f"hf://{tmp_path / 'hf_checkpoint'}", tmp_path / "nemo_checkpoint")
4049
with megatron_parallel_state_utils.distributed_model_parallel_state():
4150
assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag)
4251

4352

4453
def test_nemo2_export_8m_weights_equivalent(tmp_path):
4554
ckpt_path = load("esm2/8m:2.0")
46-
with megatron_parallel_state_utils.distributed_model_parallel_state():
47-
output_path = io.export_ckpt(ckpt_path, "hf", tmp_path / "hf_checkpoint")
55+
output_path = io.export_ckpt(ckpt_path, "hf", tmp_path / "hf_checkpoint")
4856

4957
hf_model_from_nemo = AutoModelForMaskedLM.from_pretrained(output_path)
5058
hf_model_from_hf = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
@@ -56,19 +64,25 @@ def test_nemo2_export_8m_weights_equivalent(tmp_path):
5664
torch.testing.assert_close(
5765
hf_model_from_nemo.state_dict()[key],
5866
hf_model_from_hf.state_dict()[key],
59-
atol=1e-2,
60-
rtol=1e-2,
67+
atol=1e-4,
68+
rtol=1e-4,
6169
msg=lambda msg: f"{key}: {msg}",
6270
)
6371

6472

6573
def test_nemo2_export_golden_values(tmp_path):
6674
ckpt_path = load("esm2/8m:2.0")
75+
output_path = io.export_ckpt(ckpt_path, "hf", tmp_path / "hf_checkpoint")
6776
with megatron_parallel_state_utils.distributed_model_parallel_state():
68-
output_path = io.export_ckpt(ckpt_path, "hf", tmp_path / "hf_checkpoint")
6977
assert_esm2_equivalence(ckpt_path, output_path, precision="bf16")
7078

7179

80+
def test_nemo2_export_on_gpu(tmp_path):
81+
ckpt_path = load("esm2/8m:2.0")
82+
with megatron_parallel_state_utils.distributed_model_parallel_state():
83+
io.export_ckpt(ckpt_path, "hf", tmp_path / "hf_checkpoint")
84+
85+
7286
def test_nemo2_conversion_equivalent_8m_bf16(tmp_path):
7387
model_tag = "facebook/esm2_t6_8M_UR50D"
7488
module = biobert_lightning_module(config=ESM2Config())
@@ -84,3 +98,35 @@ def test_nemo2_conversion_equivalent_650m(tmp_path):
8498
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")
8599
with megatron_parallel_state_utils.distributed_model_parallel_state():
86100
assert_esm2_equivalence(tmp_path / "nemo_checkpoint", model_tag, atol=1e-4, rtol=1e-4)
101+
102+
103+
def test_cli_nemo2_conversion_equivalent_8m(tmp_path):
104+
"""Test that the CLI conversion functions maintain model equivalence."""
105+
model_tag = "facebook/esm2_t6_8M_UR50D"
106+
runner = CliRunner()
107+
108+
# First convert HF to NeMo
109+
nemo_path = tmp_path / "nemo_checkpoint"
110+
result = runner.invoke(app, ["convert-hf-to-nemo", model_tag, str(nemo_path)])
111+
assert result.exit_code == 0, f"CLI command failed: {result.output}"
112+
113+
# Then convert back to HF
114+
hf_path = tmp_path / "hf_checkpoint"
115+
result = runner.invoke(app, ["convert-nemo-to-hf", str(nemo_path), str(hf_path)])
116+
assert result.exit_code == 0, f"CLI command failed: {result.output}"
117+
118+
hf_model_from_nemo = AutoModelForMaskedLM.from_pretrained(model_tag)
119+
hf_model_from_hf = AutoModelForMaskedLM.from_pretrained(hf_path)
120+
121+
# These aren't initialized, so they're going to be different.
122+
del hf_model_from_nemo.esm.contact_head
123+
del hf_model_from_hf.esm.contact_head
124+
125+
for key in hf_model_from_nemo.state_dict().keys():
126+
torch.testing.assert_close(
127+
hf_model_from_nemo.state_dict()[key],
128+
hf_model_from_hf.state_dict()[key],
129+
atol=1e-4,
130+
rtol=1e-4,
131+
msg=lambda msg: f"{key}: {msg}",
132+
)

0 commit comments

Comments
 (0)