Skip to content

Commit 1c38f1a

Browse files
michaelbenayounamyeroberts
authored andcommitted
[FX] _generate_dummy_input supports audio-classification models for labels (huggingface#18580)
* Support audio classification architectures for labels generation, as well as provides a flag to print warnings or not * Use ENV_VARS_TRUE_VALUES
1 parent 9d87c2d commit 1c38f1a

File tree

1 file changed

+14
-9
lines changed
  • src/transformers/utils

1 file changed

+14
-9
lines changed

src/transformers/utils/fx.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import inspect
2020
import math
2121
import operator
22+
import os
2223
import random
2324
import warnings
2425
from typing import Any, Callable, Dict, List, Optional, Type, Union
@@ -48,11 +49,12 @@
4849
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
4950
MODEL_MAPPING_NAMES,
5051
)
51-
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
52+
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
5253
from ..utils.versions import importlib_metadata
5354

5455

5556
logger = logging.get_logger(__name__)
57+
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
5658

5759

5860
def _generate_supported_model_class_names(
@@ -678,7 +680,12 @@ def _generate_dummy_input(
678680
if input_name in ["labels", "start_positions", "end_positions"]:
679681

680682
batch_size = shape[0]
681-
if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
683+
if model_class_name in [
684+
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
685+
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
686+
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
687+
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
688+
]:
682689
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
683690
elif model_class_name in [
684691
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
@@ -710,11 +717,6 @@ def _generate_dummy_input(
710717
)
711718
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
712719

713-
elif model_class_name in [
714-
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
715-
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
716-
]:
717-
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
718720
elif model_class_name in [
719721
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
720722
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
@@ -725,7 +727,9 @@ def _generate_dummy_input(
725727
]:
726728
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
727729
else:
728-
raise NotImplementedError(f"{model_class_name} not supported yet.")
730+
raise NotImplementedError(
731+
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
732+
)
729733
elif "pixel_values" in input_name:
730734
batch_size = shape[0]
731735
image_size = getattr(model.config, "image_size", None)
@@ -846,7 +850,8 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr
846850
raise ValueError("Don't support composite output yet")
847851
rv.install_metadata(meta_out)
848852
except Exception as e:
849-
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
853+
if _IS_IN_DEBUG_MODE:
854+
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
850855

851856
return rv
852857

0 commit comments

Comments
 (0)