|
11 | 11 |
|
12 | 12 | from __future__ import annotations
|
13 | 13 |
|
| 14 | +import logging |
| 15 | +import re |
14 | 16 | from collections.abc import Callable
|
15 | 17 | from functools import partial
|
| 18 | +from pathlib import Path |
16 | 19 | from typing import Any
|
17 | 20 |
|
18 | 21 | import torch
|
|
21 | 24 | from monai.networks.layers.factories import Conv, Norm, Pool
|
22 | 25 | from monai.networks.layers.utils import get_pool_layer
|
23 | 26 | from monai.utils import ensure_tuple_rep
|
24 |
| -from monai.utils.module import look_up_option |
| 27 | +from monai.utils.module import look_up_option, optional_import |
| 28 | + |
| 29 | +hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download") |
| 30 | +EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError") |
| 31 | + |
| 32 | +MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet" |
| 33 | +MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_" |
25 | 34 |
|
26 | 35 | __all__ = [
|
27 | 36 | "ResNet",
|
|
36 | 45 | "resnet200",
|
37 | 46 | ]
|
38 | 47 |
|
| 48 | +logger = logging.getLogger(__name__) |
| 49 | + |
39 | 50 |
|
40 | 51 | def get_inplanes():
|
41 | 52 | return [64, 128, 256, 512]
|
@@ -329,21 +340,54 @@ def _resnet(
|
329 | 340 | block: type[ResNetBlock | ResNetBottleneck],
|
330 | 341 | layers: list[int],
|
331 | 342 | block_inplanes: list[int],
|
332 |
| - pretrained: bool, |
| 343 | + pretrained: bool | str, |
333 | 344 | progress: bool,
|
334 | 345 | **kwargs: Any,
|
335 | 346 | ) -> ResNet:
|
336 | 347 | model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)
|
337 | 348 | if pretrained:
|
338 |
| - # Author of paper zipped the state_dict on googledrive, |
339 |
| - # so would need to download, unzip and read (2.8gb file for a ~150mb state dict). |
340 |
| - # Would like to load dict from url but need somewhere to save the state dicts. |
341 |
| - raise NotImplementedError( |
342 |
| - "Currently not implemented. You need to manually download weights provided by the paper's author" |
343 |
| - " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet" |
344 |
| - "Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified" |
345 |
| - "here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730" |
346 |
| - ) |
| 349 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 350 | + if isinstance(pretrained, str): |
| 351 | + if Path(pretrained).exists(): |
| 352 | + logger.info(f"Loading weights from {pretrained}...") |
| 353 | + model_state_dict = torch.load(pretrained, map_location=device) |
| 354 | + else: |
| 355 | + # Throw error |
| 356 | + raise FileNotFoundError("The pretrained checkpoint file is not found") |
| 357 | + else: |
| 358 | + # Also check bias downsample and shortcut. |
| 359 | + if kwargs.get("spatial_dims", 3) == 3: |
| 360 | + if kwargs.get("n_input_channels", 3) == 1 and kwargs.get("feed_forward", True) is False: |
| 361 | + search_res = re.search(r"resnet(\d+)", arch) |
| 362 | + if search_res: |
| 363 | + resnet_depth = int(search_res.group(1)) |
| 364 | + else: |
| 365 | + raise ValueError("arch argument should be as 'resnet_{resnet_depth}") |
| 366 | + |
| 367 | + # Check model bias_downsample and shortcut_type |
| 368 | + bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth) |
| 369 | + if shortcut_type == kwargs.get("shortcut_type", "B") and ( |
| 370 | + bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True |
| 371 | + ): |
| 372 | + # Download the MedicalNet pretrained model |
| 373 | + model_state_dict = get_pretrained_resnet_medicalnet( |
| 374 | + resnet_depth, device=device, datasets23=True |
| 375 | + ) |
| 376 | + else: |
| 377 | + raise NotImplementedError( |
| 378 | + f"Please set shortcut_type to {shortcut_type} and bias_downsample to" |
| 379 | + f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}" |
| 380 | + f"when using pretrained MedicalNet resnet{resnet_depth}" |
| 381 | + ) |
| 382 | + else: |
| 383 | + raise NotImplementedError( |
| 384 | + "Please set n_input_channels to 1" |
| 385 | + "and feed_forward to False in order to use MedicalNet pretrained weights" |
| 386 | + ) |
| 387 | + else: |
| 388 | + raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models") |
| 389 | + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} |
| 390 | + model.load_state_dict(model_state_dict, strict=True) |
347 | 391 | return model
|
348 | 392 |
|
349 | 393 |
|
@@ -429,3 +473,71 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
|
429 | 473 | progress (bool): If True, displays a progress bar of the download to stderr
|
430 | 474 | """
|
431 | 475 | return _resnet("resnet200", ResNetBottleneck, [3, 24, 36, 3], get_inplanes(), pretrained, progress, **kwargs)
|
| 476 | + |
| 477 | + |
| 478 | +def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): |
| 479 | + """ |
| 480 | + Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet |
| 481 | +
|
| 482 | + Args: |
| 483 | + resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 |
| 484 | + device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example. |
| 485 | + datasets23: if True, get the weights trained on more datasets (23). |
| 486 | + Not all depths are available. If not, standard weights are returned. |
| 487 | +
|
| 488 | + Returns: |
| 489 | + Pretrained state dict |
| 490 | +
|
| 491 | + Raises: |
| 492 | + huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub |
| 493 | + NotImplementedError: if `resnet_depth` is not supported |
| 494 | + """ |
| 495 | + |
| 496 | + medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet" |
| 497 | + medicalnet_huggingface_files_basename = "resnet_" |
| 498 | + supported_depth = [10, 18, 34, 50, 101, 152, 200] |
| 499 | + |
| 500 | + logger.info( |
| 501 | + f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}" |
| 502 | + ) |
| 503 | + |
| 504 | + if resnet_depth in supported_depth: |
| 505 | + filename = ( |
| 506 | + f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" |
| 507 | + if not datasets23 |
| 508 | + else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth" |
| 509 | + ) |
| 510 | + try: |
| 511 | + pretrained_path = hf_hub_download( |
| 512 | + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename |
| 513 | + ) |
| 514 | + except Exception: |
| 515 | + if datasets23: |
| 516 | + logger.info(f"{filename} not available for resnet{resnet_depth}") |
| 517 | + filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth" |
| 518 | + logger.info(f"Trying with {filename}") |
| 519 | + pretrained_path = hf_hub_download( |
| 520 | + repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename |
| 521 | + ) |
| 522 | + else: |
| 523 | + raise EntryNotFoundError( |
| 524 | + f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" |
| 525 | + ) from None |
| 526 | + checkpoint = torch.load(pretrained_path, map_location=torch.device(device)) |
| 527 | + else: |
| 528 | + raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") |
| 529 | + logger.info(f"{filename} downloaded") |
| 530 | + return checkpoint.get("state_dict") |
| 531 | + |
| 532 | + |
| 533 | +def get_medicalnet_pretrained_resnet_args(resnet_depth: int): |
| 534 | + """ |
| 535 | + Return correct shortcut_type and bias_downsample |
| 536 | + for pretrained MedicalNet weights according to resnet depth |
| 537 | + """ |
| 538 | + # After testing |
| 539 | + # False: 10, 50, 101, 152, 200 |
| 540 | + # Any: 18, 34 |
| 541 | + bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 |
| 542 | + shortcut_type = "A" if resnet_depth in [18, 34] else "B" |
| 543 | + return bias_downsample, shortcut_type |
0 commit comments