18
18
import torch
19
19
from nemo .lightning import io
20
20
from transformers import AutoModelForMaskedLM
21
+ from typer .testing import CliRunner
21
22
22
23
from bionemo .core .data .load import load
23
- from bionemo .esm2 .model .convert import HFESM2Exporter , HFESM2Importer # noqa: F401
24
+ from bionemo .esm2 .model .convert import (
25
+ HFESM2Importer , # noqa: F401
26
+ app ,
27
+ )
24
28
from bionemo .esm2 .model .model import ESM2Config
25
29
from bionemo .esm2 .testing .compare import assert_esm2_equivalence
26
30
from bionemo .llm .model .biobert .lightning import biobert_lightning_module
27
31
from bionemo .testing import megatron_parallel_state_utils
28
32
29
33
30
- # pytestmark = pytest.mark.xfail(
31
- # reason="These tests are failing due to a bug in nemo global state when run in the same process as previous "
32
- # "checkpoint save/load scripts."
33
- # )
34
+ def test_nemo2_conversion_equivalent_8m (tmp_path ):
35
+ model_tag = "facebook/esm2_t6_8M_UR50D"
36
+ module = biobert_lightning_module (config = ESM2Config ())
37
+ io .import_ckpt (module , f"hf://{ model_tag } " , tmp_path / "nemo_checkpoint" )
38
+ with megatron_parallel_state_utils .distributed_model_parallel_state ():
39
+ assert_esm2_equivalence (tmp_path / "nemo_checkpoint" , model_tag )
34
40
35
41
36
- def test_nemo2_conversion_equivalent_8m (tmp_path ):
42
+ def test_nemo2_conversion_equivalent_8m_with_local_path (tmp_path ):
37
43
model_tag = "facebook/esm2_t6_8M_UR50D"
44
+ hf_model = AutoModelForMaskedLM .from_pretrained (model_tag )
45
+ hf_model .save_pretrained (tmp_path / "hf_checkpoint" )
46
+
38
47
module = biobert_lightning_module (config = ESM2Config (), post_process = True )
39
- io .import_ckpt (module , f"hf://{ model_tag } " , tmp_path / "nemo_checkpoint" )
48
+ io .import_ckpt (module , f"hf://{ tmp_path / 'hf_checkpoint' } " , tmp_path / "nemo_checkpoint" )
40
49
with megatron_parallel_state_utils .distributed_model_parallel_state ():
41
50
assert_esm2_equivalence (tmp_path / "nemo_checkpoint" , model_tag )
42
51
43
52
44
53
def test_nemo2_export_8m_weights_equivalent (tmp_path ):
45
54
ckpt_path = load ("esm2/8m:2.0" )
46
- with megatron_parallel_state_utils .distributed_model_parallel_state ():
47
- output_path = io .export_ckpt (ckpt_path , "hf" , tmp_path / "hf_checkpoint" )
55
+ output_path = io .export_ckpt (ckpt_path , "hf" , tmp_path / "hf_checkpoint" )
48
56
49
57
hf_model_from_nemo = AutoModelForMaskedLM .from_pretrained (output_path )
50
58
hf_model_from_hf = AutoModelForMaskedLM .from_pretrained ("facebook/esm2_t6_8M_UR50D" )
@@ -56,19 +64,25 @@ def test_nemo2_export_8m_weights_equivalent(tmp_path):
56
64
torch .testing .assert_close (
57
65
hf_model_from_nemo .state_dict ()[key ],
58
66
hf_model_from_hf .state_dict ()[key ],
59
- atol = 1e-2 ,
60
- rtol = 1e-2 ,
67
+ atol = 1e-4 ,
68
+ rtol = 1e-4 ,
61
69
msg = lambda msg : f"{ key } : { msg } " ,
62
70
)
63
71
64
72
65
73
def test_nemo2_export_golden_values (tmp_path ):
66
74
ckpt_path = load ("esm2/8m:2.0" )
75
+ output_path = io .export_ckpt (ckpt_path , "hf" , tmp_path / "hf_checkpoint" )
67
76
with megatron_parallel_state_utils .distributed_model_parallel_state ():
68
- output_path = io .export_ckpt (ckpt_path , "hf" , tmp_path / "hf_checkpoint" )
69
77
assert_esm2_equivalence (ckpt_path , output_path , precision = "bf16" )
70
78
71
79
80
+ def test_nemo2_export_on_gpu (tmp_path ):
81
+ ckpt_path = load ("esm2/8m:2.0" )
82
+ with megatron_parallel_state_utils .distributed_model_parallel_state ():
83
+ io .export_ckpt (ckpt_path , "hf" , tmp_path / "hf_checkpoint" )
84
+
85
+
72
86
def test_nemo2_conversion_equivalent_8m_bf16 (tmp_path ):
73
87
model_tag = "facebook/esm2_t6_8M_UR50D"
74
88
module = biobert_lightning_module (config = ESM2Config ())
@@ -84,3 +98,35 @@ def test_nemo2_conversion_equivalent_650m(tmp_path):
84
98
io .import_ckpt (module , f"hf://{ model_tag } " , tmp_path / "nemo_checkpoint" )
85
99
with megatron_parallel_state_utils .distributed_model_parallel_state ():
86
100
assert_esm2_equivalence (tmp_path / "nemo_checkpoint" , model_tag , atol = 1e-4 , rtol = 1e-4 )
101
+
102
+
103
+ def test_cli_nemo2_conversion_equivalent_8m (tmp_path ):
104
+ """Test that the CLI conversion functions maintain model equivalence."""
105
+ model_tag = "facebook/esm2_t6_8M_UR50D"
106
+ runner = CliRunner ()
107
+
108
+ # First convert HF to NeMo
109
+ nemo_path = tmp_path / "nemo_checkpoint"
110
+ result = runner .invoke (app , ["convert-hf-to-nemo" , model_tag , str (nemo_path )])
111
+ assert result .exit_code == 0 , f"CLI command failed: { result .output } "
112
+
113
+ # Then convert back to HF
114
+ hf_path = tmp_path / "hf_checkpoint"
115
+ result = runner .invoke (app , ["convert-nemo-to-hf" , str (nemo_path ), str (hf_path )])
116
+ assert result .exit_code == 0 , f"CLI command failed: { result .output } "
117
+
118
+ hf_model_from_nemo = AutoModelForMaskedLM .from_pretrained (model_tag )
119
+ hf_model_from_hf = AutoModelForMaskedLM .from_pretrained (hf_path )
120
+
121
+ # These aren't initialized, so they're going to be different.
122
+ del hf_model_from_nemo .esm .contact_head
123
+ del hf_model_from_hf .esm .contact_head
124
+
125
+ for key in hf_model_from_nemo .state_dict ().keys ():
126
+ torch .testing .assert_close (
127
+ hf_model_from_nemo .state_dict ()[key ],
128
+ hf_model_from_hf .state_dict ()[key ],
129
+ atol = 1e-4 ,
130
+ rtol = 1e-4 ,
131
+ msg = lambda msg : f"{ key } : { msg } " ,
132
+ )
0 commit comments