Skip to content

Commit 9f1168f

Browse files
authored
7047 simplify resnet pretrained (#7095)
Fixes #7047 ### Description Resnet did not support `True` value (not implemented ) for its pretrained flag. 2 implemented behavior: - When pretrained is True, download weights from https://huggingface.co/TencentMedicalNet - When pretrained is a string, loads weights from the path defined by the string ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vgrau98 <[email protected]>
1 parent 2c9f44c commit 9f1168f

File tree

8 files changed

+229
-13
lines changed

8 files changed

+229
-13
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ opencv-python-headless
3939
onnx>=1.13.0
4040
onnxruntime; python_version <= '3.10'
4141
zarr
42+
huggingface_hub

monai/networks/nets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
ResNet,
6060
ResNetBlock,
6161
ResNetBottleneck,
62+
get_medicalnet_pretrained_resnet_args,
63+
get_pretrained_resnet_medicalnet,
6264
resnet10,
6365
resnet18,
6466
resnet34,

monai/networks/nets/resnet.py

Lines changed: 123 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111

1212
from __future__ import annotations
1313

14+
import logging
15+
import re
1416
from collections.abc import Callable
1517
from functools import partial
18+
from pathlib import Path
1619
from typing import Any
1720

1821
import torch
@@ -21,7 +24,13 @@
2124
from monai.networks.layers.factories import Conv, Norm, Pool
2225
from monai.networks.layers.utils import get_pool_layer
2326
from monai.utils import ensure_tuple_rep
24-
from monai.utils.module import look_up_option
27+
from monai.utils.module import look_up_option, optional_import
28+
29+
hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download")
30+
EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError")
31+
32+
MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet"
33+
MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_"
2534

2635
__all__ = [
2736
"ResNet",
@@ -36,6 +45,8 @@
3645
"resnet200",
3746
]
3847

48+
logger = logging.getLogger(__name__)
49+
3950

4051
def get_inplanes():
4152
return [64, 128, 256, 512]
@@ -329,21 +340,54 @@ def _resnet(
329340
block: type[ResNetBlock | ResNetBottleneck],
330341
layers: list[int],
331342
block_inplanes: list[int],
332-
pretrained: bool,
343+
pretrained: bool | str,
333344
progress: bool,
334345
**kwargs: Any,
335346
) -> ResNet:
336347
model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)
337348
if pretrained:
338-
# Author of paper zipped the state_dict on googledrive,
339-
# so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
340-
# Would like to load dict from url but need somewhere to save the state dicts.
341-
raise NotImplementedError(
342-
"Currently not implemented. You need to manually download weights provided by the paper's author"
343-
" and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
344-
"Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified"
345-
"here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730"
346-
)
349+
device = "cuda" if torch.cuda.is_available() else "cpu"
350+
if isinstance(pretrained, str):
351+
if Path(pretrained).exists():
352+
logger.info(f"Loading weights from {pretrained}...")
353+
model_state_dict = torch.load(pretrained, map_location=device)
354+
else:
355+
# Throw error
356+
raise FileNotFoundError("The pretrained checkpoint file is not found")
357+
else:
358+
# Also check bias downsample and shortcut.
359+
if kwargs.get("spatial_dims", 3) == 3:
360+
if kwargs.get("n_input_channels", 3) == 1 and kwargs.get("feed_forward", True) is False:
361+
search_res = re.search(r"resnet(\d+)", arch)
362+
if search_res:
363+
resnet_depth = int(search_res.group(1))
364+
else:
365+
raise ValueError("arch argument should be as 'resnet_{resnet_depth}")
366+
367+
# Check model bias_downsample and shortcut_type
368+
bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)
369+
if shortcut_type == kwargs.get("shortcut_type", "B") and (
370+
bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True
371+
):
372+
# Download the MedicalNet pretrained model
373+
model_state_dict = get_pretrained_resnet_medicalnet(
374+
resnet_depth, device=device, datasets23=True
375+
)
376+
else:
377+
raise NotImplementedError(
378+
f"Please set shortcut_type to {shortcut_type} and bias_downsample to"
379+
f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}"
380+
f"when using pretrained MedicalNet resnet{resnet_depth}"
381+
)
382+
else:
383+
raise NotImplementedError(
384+
"Please set n_input_channels to 1"
385+
"and feed_forward to False in order to use MedicalNet pretrained weights"
386+
)
387+
else:
388+
raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models")
389+
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
390+
model.load_state_dict(model_state_dict, strict=True)
347391
return model
348392

349393

@@ -429,3 +473,71 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
429473
progress (bool): If True, displays a progress bar of the download to stderr
430474
"""
431475
return _resnet("resnet200", ResNetBottleneck, [3, 24, 36, 3], get_inplanes(), pretrained, progress, **kwargs)
476+
477+
478+
def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True):
479+
"""
480+
Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet
481+
482+
Args:
483+
resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
484+
device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example.
485+
datasets23: if True, get the weights trained on more datasets (23).
486+
Not all depths are available. If not, standard weights are returned.
487+
488+
Returns:
489+
Pretrained state dict
490+
491+
Raises:
492+
huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub
493+
NotImplementedError: if `resnet_depth` is not supported
494+
"""
495+
496+
medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet"
497+
medicalnet_huggingface_files_basename = "resnet_"
498+
supported_depth = [10, 18, 34, 50, 101, 152, 200]
499+
500+
logger.info(
501+
f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}"
502+
)
503+
504+
if resnet_depth in supported_depth:
505+
filename = (
506+
f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth"
507+
if not datasets23
508+
else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth"
509+
)
510+
try:
511+
pretrained_path = hf_hub_download(
512+
repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename
513+
)
514+
except Exception:
515+
if datasets23:
516+
logger.info(f"{filename} not available for resnet{resnet_depth}")
517+
filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth"
518+
logger.info(f"Trying with {filename}")
519+
pretrained_path = hf_hub_download(
520+
repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename
521+
)
522+
else:
523+
raise EntryNotFoundError(
524+
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
525+
) from None
526+
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
527+
else:
528+
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
529+
logger.info(f"{filename} downloaded")
530+
return checkpoint.get("state_dict")
531+
532+
533+
def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
534+
"""
535+
Return correct shortcut_type and bias_downsample
536+
for pretrained MedicalNet weights according to resnet depth
537+
"""
538+
# After testing
539+
# False: 10, 50, 101, 152, 200
540+
# Any: 18, 34
541+
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
542+
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
543+
return bias_downsample, shortcut_type

monai/networks/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
onnxreference, _ = optional_import("onnx.reference")
3838
onnxruntime, _ = optional_import("onnxruntime")
3939

40+
4041
__all__ = [
4142
"one_hot",
4243
"predict_segmentation",

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
5656
zarr
5757
lpips==0.1.4
5858
nvidia-ml-py
59+
huggingface_hub

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ all =
8383
zarr
8484
lpips==0.1.4
8585
nvidia-ml-py
86+
huggingface_hub
8687
nibabel =
8788
nibabel
8889
ninja =

tests/test_resnet.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,32 @@
1111

1212
from __future__ import annotations
1313

14+
import copy
15+
import os
16+
import re
17+
import sys
1418
import unittest
1519
from typing import TYPE_CHECKING
1620

1721
import torch
1822
from parameterized import parameterized
1923

2024
from monai.networks import eval_mode
21-
from monai.networks.nets import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
25+
from monai.networks.nets import (
26+
ResNet,
27+
get_medicalnet_pretrained_resnet_args,
28+
get_pretrained_resnet_medicalnet,
29+
resnet10,
30+
resnet18,
31+
resnet34,
32+
resnet50,
33+
resnet101,
34+
resnet152,
35+
resnet200,
36+
)
2237
from monai.networks.nets.resnet import ResNetBlock
2338
from monai.utils import optional_import
24-
from tests.utils import test_script_save
39+
from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save
2540

2641
if TYPE_CHECKING:
2742
import torchvision
@@ -30,6 +45,10 @@
3045
else:
3146
torchvision, has_torchvision = optional_import("torchvision")
3247

48+
has_hf_modules = "huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules
49+
50+
# from torchvision.models import ResNet50_Weights, resnet50
51+
3352
device = "cuda" if torch.cuda.is_available() else "cpu"
3453

3554
TEST_CASE_1 = [ # 3D, batch 3, 2 input channel
@@ -159,9 +178,11 @@
159178
]
160179

161180
TEST_CASES = []
181+
PRETRAINED_TEST_CASES = []
162182
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
163183
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
164184
TEST_CASES.append([model, *case])
185+
PRETRAINED_TEST_CASES.append([model, *case])
165186
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]:
166187
TEST_CASES.append([ResNet, *case])
167188

@@ -171,6 +192,16 @@
171192

172193

173194
class TestResNet(unittest.TestCase):
195+
def setUp(self):
196+
self.tmp_ckpt_filename = os.path.join("tests", "monai_unittest_tmp_ckpt.pth")
197+
198+
def tearDown(self):
199+
if os.path.exists(self.tmp_ckpt_filename):
200+
try:
201+
os.remove(self.tmp_ckpt_filename)
202+
except BaseException:
203+
pass
204+
174205
@parameterized.expand(TEST_CASES)
175206
def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
176207
net = model(**input_param).to(device)
@@ -181,6 +212,56 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
181212
else:
182213
self.assertTrue(result.shape in expected_shape)
183214

215+
@parameterized.expand(PRETRAINED_TEST_CASES)
216+
@skip_if_quick
217+
@skip_if_no_cuda
218+
def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape):
219+
net = model(**input_param).to(device)
220+
# Save ckpt
221+
torch.save(net.state_dict(), self.tmp_ckpt_filename)
222+
223+
cp_input_param = copy.copy(input_param)
224+
# Custom pretrained weights
225+
cp_input_param["pretrained"] = self.tmp_ckpt_filename
226+
pretrained_net = model(**cp_input_param)
227+
self.assertTrue(equal_state_dict(net.state_dict(), pretrained_net.state_dict()))
228+
229+
if has_hf_modules:
230+
# True flag
231+
cp_input_param["pretrained"] = True
232+
resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1))
233+
234+
bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)
235+
236+
# With orig. test cases
237+
if (
238+
input_param.get("spatial_dims", 3) == 3
239+
and input_param.get("n_input_channels", 3) == 1
240+
and input_param.get("feed_forward", True) is False
241+
and input_param.get("shortcut_type", "B") == shortcut_type
242+
and (
243+
input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True
244+
)
245+
):
246+
model(**cp_input_param)
247+
else:
248+
with self.assertRaises(NotImplementedError):
249+
model(**cp_input_param)
250+
251+
# forcing MedicalNet pretrained download for 3D tests cases
252+
cp_input_param["n_input_channels"] = 1
253+
cp_input_param["feed_forward"] = False
254+
cp_input_param["shortcut_type"] = shortcut_type
255+
cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True
256+
if cp_input_param.get("spatial_dims", 3) == 3:
257+
with skip_if_downloading_fails():
258+
pretrained_net = model(**cp_input_param).to(device)
259+
medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device)
260+
medicalnet_state_dict = {
261+
key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()
262+
}
263+
self.assertTrue(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict))
264+
184265
@parameterized.expand(TEST_SCRIPT_CASES)
185266
def test_script(self, model, input_param, input_shape, expected_shape):
186267
net = model(**input_param)

tests/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,23 @@ def command_line_tests(cmd, copy_env=True):
825825
raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e
826826

827827

828+
def equal_state_dict(st_1, st_2):
829+
"""
830+
Compare 2 torch state dicts.
831+
"""
832+
r = True
833+
for key_st_1, val_st_1 in st_1.items():
834+
if key_st_1 in st_2:
835+
val_st_2 = st_2.get(key_st_1)
836+
if not torch.equal(val_st_1, val_st_2):
837+
r = False
838+
break
839+
else:
840+
r = False
841+
break
842+
return r
843+
844+
828845
TEST_TORCH_TENSORS: tuple = (torch.as_tensor,)
829846
if torch.cuda.is_available():
830847
gpu_tensor: Callable = partial(torch.as_tensor, device="cuda")

0 commit comments

Comments
 (0)