Skip to content

Commit 3b09464

Browse files
eustlbArthurZucker
authored andcommitted
Patch moonshine (#35731)
* udpate expected logits for T4 runners * update doc * correct order of the args for better readability * remove generate wrap * convert modular
1 parent b00807f commit 3b09464

File tree

4 files changed

+40
-74
lines changed

4 files changed

+40
-74
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,6 @@
506506
title: MobileBERT
507507
- local: model_doc/modernbert
508508
title: ModernBert
509-
- local: model_doc/moonshine
510-
title: moonshine
511509
- local: model_doc/mpnet
512510
title: MPNet
513511
- local: model_doc/mpt
@@ -770,6 +768,8 @@
770768
title: Mimi
771769
- local: model_doc/mms
772770
title: MMS
771+
- local: model_doc/moonshine
772+
title: Moonshine
773773
- local: model_doc/moshi
774774
title: Moshi
775775
- local: model_doc/musicgen

src/transformers/models/moonshine/modeling_moonshine.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,10 +1166,9 @@ def compute_num_masked_span(input_length):
11661166
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
11671167
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
11681168
and conversion into a tensor of type `torch.FloatTensor`.
1169-
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1170-
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1171-
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1172-
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1169+
attention_mask (`torch.Tensor`)`, *optional*):
1170+
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
1171+
but it is not used.
11731172
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
11741173
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
11751174
it.
@@ -1178,9 +1177,6 @@ def compute_num_masked_span(input_length):
11781177
[`PreTrainedTokenizer.__call__`] for details.
11791178
11801179
[What are input IDs?](../glossary#input-ids)
1181-
attention_mask (`torch.Tensor`)`, *optional*):
1182-
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
1183-
but it is not used.
11841180
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
11851181
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
11861182
@@ -1201,11 +1197,10 @@ def compute_num_masked_span(input_length):
12011197
12021198
- 1 indicates the head is **not masked**,
12031199
- 0 indicates the head is **masked**.
1204-
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1205-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1206-
config.n_positions - 1]`.
1207-
1208-
[What are position IDs?](../glossary#position-ids)
1200+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1201+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
1202+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
1203+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
12091204
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
12101205
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
12111206
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
@@ -1228,6 +1223,11 @@ def compute_num_masked_span(input_length):
12281223
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
12291224
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
12301225
model's internal embedding lookup matrix.
1226+
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1227+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1228+
config.n_positions - 1]`.
1229+
1230+
[What are position IDs?](../glossary#position-ids)
12311231
use_cache (`bool`, *optional*):
12321232
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
12331233
`past_key_values`).
@@ -1549,22 +1549,5 @@ def forward(
15491549
encoder_attentions=outputs.encoder_attentions,
15501550
)
15511551

1552-
def generate(self, *args, **kwargs):
1553-
# TODO: @eustlb do it rather with a custom logits processor
1554-
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second
1555-
if kwargs.get("max_new_tokens") is None and kwargs.get("max_length") is None:
1556-
if kwargs.get("attention_mask") is not None:
1557-
seq_lens = kwargs["attention_mask"].sum(dim=-1)
1558-
else:
1559-
seq_lens = kwargs["input_values"].shape[-1]
1560-
max_length = int(seq_lens.max().item() * token_limit_factor)
1561-
logger.warning_once(
1562-
f"Based on the input length, Moonshine will generate up to {max_length} tokens (ratio of 6.5 tokens/second). "
1563-
"To specify a different length, set either `max_new_tokens` or `max_length`."
1564-
)
1565-
kwargs["max_length"] = max_length
1566-
1567-
return super().generate(*args, **kwargs)
1568-
15691552

15701553
__all__ = ["MoonshineModel", "MoonshinePreTrainedModel", "MoonshineForConditionalGeneration"]

src/transformers/models/moonshine/modular_moonshine.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -816,10 +816,9 @@ def forward(
816816
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
817817
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
818818
and conversion into a tensor of type `torch.FloatTensor`.
819-
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
820-
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
821-
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
822-
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
819+
attention_mask (`torch.Tensor`)`, *optional*):
820+
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
821+
but it is not used.
823822
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
824823
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
825824
it.
@@ -828,9 +827,6 @@ def forward(
828827
[`PreTrainedTokenizer.__call__`] for details.
829828
830829
[What are input IDs?](../glossary#input-ids)
831-
attention_mask (`torch.Tensor`)`, *optional*):
832-
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
833-
but it is not used.
834830
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
835831
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
836832
@@ -851,11 +847,10 @@ def forward(
851847
852848
- 1 indicates the head is **not masked**,
853849
- 0 indicates the head is **masked**.
854-
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
855-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
856-
config.n_positions - 1]`.
857-
858-
[What are position IDs?](../glossary#position-ids)
850+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
851+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
852+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
853+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
859854
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
860855
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
861856
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
@@ -878,6 +873,11 @@ def forward(
878873
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
879874
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
880875
model's internal embedding lookup matrix.
876+
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
877+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
878+
config.n_positions - 1]`.
879+
880+
[What are position IDs?](../glossary#position-ids)
881881
use_cache (`bool`, *optional*):
882882
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
883883
`past_key_values`).
@@ -1109,23 +1109,6 @@ def forward(
11091109
encoder_attentions=outputs.encoder_attentions,
11101110
)
11111111

1112-
def generate(self, *args, **kwargs):
1113-
# TODO: @eustlb do it rather with a custom logits processor
1114-
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second
1115-
if kwargs.get("max_new_tokens") is None and kwargs.get("max_length") is None:
1116-
if kwargs.get("attention_mask") is not None:
1117-
seq_lens = kwargs["attention_mask"].sum(dim=-1)
1118-
else:
1119-
seq_lens = kwargs["input_values"].shape[-1]
1120-
max_length = int(seq_lens.max().item() * token_limit_factor)
1121-
logger.warning_once(
1122-
f"Based on the input length, Moonshine will generate up to {max_length} tokens (ratio of 6.5 tokens/second). "
1123-
"To specify a different length, set either `max_new_tokens` or `max_length`."
1124-
)
1125-
kwargs["max_length"] = max_length
1126-
1127-
return super().generate(*args, **kwargs)
1128-
11291112

11301113
__all__ = [
11311114
"MoonshineConfig",

tests/models/moonshine/test_modeling_moonshine.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,9 @@ def test_tiny_logits_single(self):
484484

485485
# fmt: off
486486
EXPECTED_LOGITS = torch.tensor([
487-
-9.1107, 4.5538, 6.3902, -6.8141, -7.2459, -7.9077, -7.2842, -7.6045, -8.0387, -7.8354,
488-
-7.3870, -7.2453, -7.7423, -7.3914, -7.3869, -7.6982, -7.6422, -7.0507, -7.3982, -7.2486,
489-
-8.0799, -7.3303, -7.3675, -6.8769, -7.6879, -7.2684, -6.9868, -6.7459, -7.6858, -7.3052,
487+
-9.1106, 4.5542, 6.3892, -6.8139, -7.2456, -7.9074, -7.2839, -7.6043, -8.0384, -7.8351,
488+
-7.3867, -7.2450, -7.7420, -7.3912, -7.3866, -7.6979, -7.6420, -7.0504, -7.3979, -7.2483,
489+
-8.0796, -7.3300, -7.3672, -6.8765, -7.6876, -7.2682, -6.9866, -6.7457, -7.6855, -7.3050,
490490
])
491491
# fmt: on
492492
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
@@ -502,9 +502,9 @@ def test_base_logits_single(self):
502502

503503
# fmt: off
504504
EXPECTED_LOGITS = torch.tensor([
505-
-6.7340, 1.9483, 5.2449, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616,
506-
-8.1070, -7.7696, -7.8809, -7.9451, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657,
507-
-7.9310, -8.1024, -7.8698, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9289,
505+
-6.7336, 1.9482, 5.2448, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616,
506+
-8.1070, -7.7696, -7.8809, -7.9450, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657,
507+
-7.9310, -8.1024, -7.8699, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9290,
508508
])
509509
# fmt: on
510510
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
@@ -519,10 +519,10 @@ def test_tiny_logits_batch(self):
519519
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
520520
# fmt: off
521521
EXPECTED_LOGITS = torch.tensor([
522-
[-8.0098, 5.0239, 4.5986, -6.8125, -7.1676, -7.8782, -7.2152, -7.5188, -7.9078, -7.7394],
523-
[-4.4394, -1.4429, 6.6715, -6.8927, -7.3748, -7.0967, -6.5255, -7.0255, -7.2583, -7.0007],
524-
[-10.0088, 3.2862, 0.7342, -6.5558, -6.8514, -6.5309, -6.4173, -6.9485, -6.6215, -6.6230],
525-
[-10.8083, 4.0034, -0.0635, -5.0501, -5.3903, -5.4587, -5.2416, -5.4742, -5.2662, -5.3154]
522+
[-8.0109, 5.0241, 4.5979, -6.8125, -7.1675, -7.8783, -7.2152, -7.5188, -7.9077, -7.7394],
523+
[-4.4399, -1.4422, 6.6710, -6.8929, -7.3751, -7.0969, -6.5257, -7.0257, -7.2585, -7.0008],
524+
[-10.0086, 3.2859, 0.7345, -6.5557, -6.8514, -6.5308, -6.4172, -6.9484, -6.6214, -6.6229],
525+
[-10.8078, 4.0030, -0.0633, -5.0505, -5.3906, -5.4590, -5.2420, -5.4746, -5.2665, -5.3158]
526526
])
527527
# fmt: on
528528
self.assertTrue(torch.allclose(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, atol=1e-4))
@@ -538,10 +538,10 @@ def test_base_logits_batch(self):
538538

539539
# fmt: off
540540
EXPECTED_LOGITS = torch.tensor([
541-
[-7.7288, 1.4636, 5.2273, -7.7310, -7.6249, -7.6009, -7.6786, -7.6438, -7.8450, -7.7546],
542-
[-6.2161, -0.5891, 7.9489, -7.0693, -6.9996, -6.9980, -7.0952, -7.0830, -7.1685, -7.0136],
543-
[-7.3186, 3.1192, 3.8938, -5.7208, -5.8429, -5.7610, -5.9997, -5.8213, -5.8616, -5.8720],
544-
[-9.5488, 1.0147, 4.1174, -5.9972, -6.0616, -6.0331, -6.2105, -6.0320, -6.0791, -6.0875]
541+
[-7.7272, 1.4630, 5.2294, -7.7313, -7.6252, -7.6011, -7.6788, -7.6441, -7.8452, -7.7549],
542+
[-6.2173, -0.5891, 7.9493, -7.0694, -6.9997, -6.9982, -7.0953, -7.0831, -7.1686, -7.0137],
543+
[-7.3184, 3.1192, 3.8937, -5.7206, -5.8428, -5.7609, -5.9996, -5.8212, -5.8615, -5.8719],
544+
[-9.5475, 1.0146, 4.1179, -5.9971, -6.0614, -6.0329, -6.2103, -6.0318, -6.0789, -6.0873]
545545
])
546546

547547
# fmt: on

0 commit comments

Comments
 (0)