Skip to content

Commit 02215e7

Browse files
Sreyan88SreyanG-NVIDIA
authored andcommitted
fixed calculation of ctc loss in TFWav2Vec2ForCTC (huggingface#18014)
Co-authored-by: Sreyan-G@NVIDIA <[email protected]>
1 parent 99ea21d commit 02215e7

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525

2626
from ...activations_tf import get_tf_activation
2727
from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
28-
from ...modeling_tf_utils import TFPreTrainedModel, booleans_processing, get_initializer, keras_serializable
28+
from ...modeling_tf_utils import (
29+
TFPreTrainedModel,
30+
booleans_processing,
31+
get_initializer,
32+
keras_serializable,
33+
unpack_inputs,
34+
)
2935
from ...tf_utils import shape_list, stable_softmax
3036
from ...utils import (
3137
ModelOutput,
@@ -1580,6 +1586,7 @@ def freeze_feature_encoder(self):
15801586
"""
15811587
self.wav2vec2.feature_extractor.trainable = False
15821588

1589+
@unpack_inputs
15831590
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
15841591
@replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
15851592
def call(
@@ -1702,6 +1709,8 @@ def call(
17021709
loss = tf.reduce_sum(loss)
17031710
if self.config.ctc_loss_reduction == "mean":
17041711
loss = tf.reduce_mean(loss)
1712+
1713+
loss = tf.reshape(loss, (1,))
17051714
else:
17061715
loss = None
17071716

0 commit comments

Comments
 (0)