File tree Expand file tree Collapse file tree 3 files changed +29
-17
lines changed Expand file tree Collapse file tree 3 files changed +29
-17
lines changed Original file line number Diff line number Diff line change @@ -187,23 +187,6 @@ def pad(
187
187
padding_strategy = self ._get_padding_strategies (padding = padding , max_length = max_length )
188
188
189
189
required_input = processed_features [self .model_input_names [0 ]]
190
- if required_input and not isinstance (required_input [0 ], np .ndarray ):
191
- # truncation
192
- processed_features = self ._truncate (
193
- processed_features ,
194
- max_length = max_length ,
195
- pad_to_multiple_of = pad_to_multiple_of ,
196
- truncation = truncation ,
197
- )
198
- # padding
199
- processed_features = self ._pad (
200
- processed_features ,
201
- max_length = max_length ,
202
- padding_strategy = padding_strategy ,
203
- pad_to_multiple_of = pad_to_multiple_of ,
204
- return_attention_mask = return_attention_mask ,
205
- )
206
- return BatchFeature (processed_features , tensor_type = return_tensors )
207
190
208
191
batch_size = len (required_input )
209
192
if not all (len (v ) == batch_size for v in processed_features .values ()):
@@ -240,6 +223,8 @@ def pad(
240
223
for key , value in outputs .items ():
241
224
if key not in batch_outputs :
242
225
batch_outputs [key ] = []
226
+ if value .dtype is np .dtype (np .float64 ):
227
+ value = value .astype (np .float32 )
243
228
batch_outputs [key ].append (value )
244
229
245
230
return BatchFeature (batch_outputs , tensor_type = return_tensors )
Original file line number Diff line number Diff line change @@ -235,3 +235,16 @@ def test_cepstral_mean_and_variance_normalization_trunc_longest(self):
235
235
236
236
# make sure that if max_length < longest -> then pad to max_length
237
237
self .assertEqual (input_features .shape , (3 , 6 , 24 ))
238
+
239
+ def test_double_precision_pad (self ):
240
+ import torch
241
+
242
+ feature_extractor = self .feature_extraction_class (** self .feat_extract_tester .prepare_feat_extract_dict ())
243
+ np_speech_inputs = np .random .rand (100 , 32 ).astype (np .float64 )
244
+ py_speech_inputs = np_speech_inputs .tolist ()
245
+
246
+ for inputs in [py_speech_inputs , np_speech_inputs ]:
247
+ np_processed = feature_extractor .pad ([{"input_features" : inputs }], return_tensors = "np" )
248
+ self .assertTrue (np_processed .input_features .dtype == np .float32 )
249
+ pt_processed = feature_extractor .pad ([{"input_features" : inputs }], return_tensors = "pt" )
250
+ self .assertTrue (pt_processed .input_features .dtype == torch .float32 )
Original file line number Diff line number Diff line change @@ -196,6 +196,20 @@ def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):
196
196
# make sure that if max_length > longest -> then pad to longest
197
197
self .assertTrue (input_values .shape == (3 , 1200 ))
198
198
199
+ @require_torch
200
+ def test_double_precision_pad (self ):
201
+ import torch
202
+
203
+ feature_extractor = self .feature_extraction_class (** self .feat_extract_tester .prepare_feat_extract_dict ())
204
+ np_speech_inputs = np .random .rand (100 ).astype (np .float64 )
205
+ py_speech_inputs = np_speech_inputs .tolist ()
206
+
207
+ for inputs in [py_speech_inputs , np_speech_inputs ]:
208
+ np_processed = feature_extractor .pad ([{"input_values" : inputs }], return_tensors = "np" )
209
+ self .assertTrue (np_processed .input_values .dtype == np .float32 )
210
+ pt_processed = feature_extractor .pad ([{"input_values" : inputs }], return_tensors = "pt" )
211
+ self .assertTrue (pt_processed .input_values .dtype == torch .float32 )
212
+
199
213
@slow
200
214
@require_torch
201
215
def test_pretrained_checkpoints_are_set_correctly (self ):
You can’t perform that action at this time.
0 commit comments