Skip to content

Commit 75f6641

Browse files
authored
[Wav2Vec2FeatureExtractor] Fix extractor.pad() dtype backwards compatibility (#13693)
* Force dtype, add tests * Local torch imports * Remove unused logic (always ndarray)
1 parent 8e908c8 commit 75f6641

File tree

3 files changed

+29
-17
lines changed

3 files changed

+29
-17
lines changed

src/transformers/feature_extraction_sequence_utils.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -187,23 +187,6 @@ def pad(
187187
padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
188188

189189
required_input = processed_features[self.model_input_names[0]]
190-
if required_input and not isinstance(required_input[0], np.ndarray):
191-
# truncation
192-
processed_features = self._truncate(
193-
processed_features,
194-
max_length=max_length,
195-
pad_to_multiple_of=pad_to_multiple_of,
196-
truncation=truncation,
197-
)
198-
# padding
199-
processed_features = self._pad(
200-
processed_features,
201-
max_length=max_length,
202-
padding_strategy=padding_strategy,
203-
pad_to_multiple_of=pad_to_multiple_of,
204-
return_attention_mask=return_attention_mask,
205-
)
206-
return BatchFeature(processed_features, tensor_type=return_tensors)
207190

208191
batch_size = len(required_input)
209192
if not all(len(v) == batch_size for v in processed_features.values()):
@@ -240,6 +223,8 @@ def pad(
240223
for key, value in outputs.items():
241224
if key not in batch_outputs:
242225
batch_outputs[key] = []
226+
if value.dtype is np.dtype(np.float64):
227+
value = value.astype(np.float32)
243228
batch_outputs[key].append(value)
244229

245230
return BatchFeature(batch_outputs, tensor_type=return_tensors)

tests/test_feature_extraction_speech_to_text.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,16 @@ def test_cepstral_mean_and_variance_normalization_trunc_longest(self):
235235

236236
# make sure that if max_length < longest -> then pad to max_length
237237
self.assertEqual(input_features.shape, (3, 6, 24))
238+
239+
def test_double_precision_pad(self):
240+
import torch
241+
242+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
243+
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
244+
py_speech_inputs = np_speech_inputs.tolist()
245+
246+
for inputs in [py_speech_inputs, np_speech_inputs]:
247+
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
248+
self.assertTrue(np_processed.input_features.dtype == np.float32)
249+
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
250+
self.assertTrue(pt_processed.input_features.dtype == torch.float32)

tests/test_feature_extraction_wav2vec2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
196196
# make sure that if max_length > longest -> then pad to longest
197197
self.assertTrue(input_values.shape == (3, 1200))
198198

199+
@require_torch
200+
def test_double_precision_pad(self):
201+
import torch
202+
203+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
204+
np_speech_inputs = np.random.rand(100).astype(np.float64)
205+
py_speech_inputs = np_speech_inputs.tolist()
206+
207+
for inputs in [py_speech_inputs, np_speech_inputs]:
208+
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
209+
self.assertTrue(np_processed.input_values.dtype == np.float32)
210+
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
211+
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
212+
199213
@slow
200214
@require_torch
201215
def test_pretrained_checkpoints_are_set_correctly(self):

0 commit comments

Comments
 (0)