Skip to content

Commit d0df18f

Browse files
Rocketknight1novice03
authored andcommitted
Make the TF dummies even smaller (huggingface#24071)
* Let's see if we can use the smallest possible dummies * Make GPT-2's dummies a little longer * Just use (1,2) as the default shape * Update other dummies in sync * Correct imports for Keras 2.13 * Shrink the Wav2Vec2 dummies
1 parent 4dd6601 commit d0df18f

File tree

5 files changed

+12
-23
lines changed

5 files changed

+12
-23
lines changed

src/transformers/modeling_tf_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
if parse(tf.__version__).minor >= 13:
7575
from keras import backend as K
7676
from keras.__internal__ import KerasTensor
77-
from keras.engine.base_layer_utils import call_context
77+
from keras.src.engine.base_layer_utils import call_context
7878
elif parse(tf.__version__).minor >= 11:
7979
from keras import backend as K
8080
from keras.engine.base_layer_utils import call_context
@@ -1125,15 +1125,19 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
11251125
sig = self._prune_signature(self.input_signature)
11261126
for key, spec in sig.items():
11271127
# 2 is the most correct arbitrary size. I will not be taking questions
1128-
dummies[key] = tf.ones(shape=[dim if dim is not None else 2 for dim in spec.shape], dtype=spec.dtype)
1128+
dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
1129+
if spec.shape[0] is None:
1130+
# But let's make the batch size 1 to save memory anyway
1131+
dummy_shape[0] = 1
1132+
dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
11291133
if key == "token_type_ids":
11301134
# Some models have token_type_ids but with a vocab_size of 1
11311135
dummies[key] = tf.zeros_like(dummies[key])
11321136
if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
11331137
if "encoder_hidden_states" not in dummies:
11341138
if self.main_input_name == "input_ids":
11351139
dummies["encoder_hidden_states"] = tf.ones(
1136-
shape=(2, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
1140+
shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
11371141
)
11381142
else:
11391143
raise NotImplementedError(

src/transformers/models/funnel/modeling_tf_funnel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,7 @@ class TFFunnelPreTrainedModel(TFPreTrainedModel):
978978
@property
979979
def dummy_inputs(self):
980980
# Funnel misbehaves with very small inputs, so we override and make them a bit bigger
981-
return {"input_ids": tf.ones((3, 3), dtype=tf.int32)}
981+
return {"input_ids": tf.ones((1, 3), dtype=tf.int32)}
982982

983983

984984
@dataclass

src/transformers/models/sam/modeling_tf_sam.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,21 +1147,6 @@ class TFSamPreTrainedModel(TFPreTrainedModel):
11471147
base_model_prefix = "sam"
11481148
main_input_name = "pixel_values"
11491149

1150-
@property
1151-
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
1152-
# We override the default dummy inputs here because SAM has some really explosive memory usage in the
1153-
# attention layers, so we want to pass the smallest possible batches
1154-
VISION_DUMMY_INPUTS = tf.random.uniform(
1155-
shape=(
1156-
1,
1157-
self.config.vision_config.num_channels,
1158-
self.config.vision_config.image_size,
1159-
self.config.vision_config.image_size,
1160-
),
1161-
dtype=tf.float32,
1162-
)
1163-
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
1164-
11651150

11661151
SAM_START_DOCSTRING = r"""
11671152
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the

src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,8 +1194,8 @@ def input_signature(self):
11941194
@property
11951195
def dummy_inputs(self):
11961196
return {
1197-
"input_values": tf.random.uniform(shape=(1, 16000), dtype=tf.float32),
1198-
"attention_mask": tf.ones(shape=(1, 16000), dtype=tf.float32),
1197+
"input_values": tf.random.uniform(shape=(1, 500), dtype=tf.float32),
1198+
"attention_mask": tf.ones(shape=(1, 500), dtype=tf.float32),
11991199
}
12001200

12011201
def __init__(self, config, *inputs, **kwargs):

src/transformers/models/whisper/modeling_tf_whisper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,9 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
481481
"""
482482
return {
483483
self.main_input_name: tf.random.uniform(
484-
[2, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
484+
[1, self.config.num_mel_bins, self.config.max_source_positions * 2 - 1], dtype=tf.float32
485485
),
486-
"decoder_input_ids": tf.constant([[2, 3]], dtype=tf.int32),
486+
"decoder_input_ids": tf.constant([[1, 3]], dtype=tf.int32),
487487
}
488488

489489
@property

0 commit comments

Comments
 (0)