19
19
import inspect
20
20
import math
21
21
import operator
22
+ import os
22
23
import random
23
24
import warnings
24
25
from typing import Any , Callable , Dict , List , Optional , Type , Union
48
49
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES ,
49
50
MODEL_MAPPING_NAMES ,
50
51
)
51
- from ..utils import TORCH_FX_REQUIRED_VERSION , is_torch_fx_available
52
+ from ..utils import ENV_VARS_TRUE_VALUES , TORCH_FX_REQUIRED_VERSION , is_torch_fx_available
52
53
from ..utils .versions import importlib_metadata
53
54
54
55
55
56
logger = logging .get_logger (__name__ )
57
+ _IS_IN_DEBUG_MODE = os .environ .get ("FX_DEBUG_MODE" , "" ).upper () in ENV_VARS_TRUE_VALUES
56
58
57
59
58
60
def _generate_supported_model_class_names (
@@ -678,7 +680,12 @@ def _generate_dummy_input(
678
680
if input_name in ["labels" , "start_positions" , "end_positions" ]:
679
681
680
682
batch_size = shape [0 ]
681
- if model_class_name in get_values (MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES ):
683
+ if model_class_name in [
684
+ * get_values (MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES ),
685
+ * get_values (MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES ),
686
+ * get_values (MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ),
687
+ * get_values (MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES ),
688
+ ]:
682
689
inputs_dict ["labels" ] = torch .zeros (batch_size , dtype = torch .long , device = device )
683
690
elif model_class_name in [
684
691
* get_values (MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES ),
@@ -710,11 +717,6 @@ def _generate_dummy_input(
710
717
)
711
718
inputs_dict ["labels" ] = torch .zeros (* labels_shape , dtype = labels_dtype , device = device )
712
719
713
- elif model_class_name in [
714
- * get_values (MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES ),
715
- * get_values (MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES ),
716
- ]:
717
- inputs_dict ["labels" ] = torch .zeros (batch_size , dtype = torch .long , device = device )
718
720
elif model_class_name in [
719
721
* get_values (MODEL_FOR_PRETRAINING_MAPPING_NAMES ),
720
722
* get_values (MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES ),
@@ -725,7 +727,9 @@ def _generate_dummy_input(
725
727
]:
726
728
inputs_dict ["labels" ] = torch .zeros (shape , dtype = torch .long , device = device )
727
729
else :
728
- raise NotImplementedError (f"{ model_class_name } not supported yet." )
730
+ raise NotImplementedError (
731
+ f"Generating the dummy input named { input_name } for { model_class_name } is not supported yet."
732
+ )
729
733
elif "pixel_values" in input_name :
730
734
batch_size = shape [0 ]
731
735
image_size = getattr (model .config , "image_size" , None )
@@ -846,7 +850,8 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr
846
850
raise ValueError ("Don't support composite output yet" )
847
851
rv .install_metadata (meta_out )
848
852
except Exception as e :
849
- warnings .warn (f"Could not compute metadata for { kind } target { target } : { e } " )
853
+ if _IS_IN_DEBUG_MODE :
854
+ warnings .warn (f"Could not compute metadata for { kind } target { target } : { e } " )
850
855
851
856
return rv
852
857
0 commit comments