Skip to content

Commit 2e752ea

Browse files
committed
revert my changes
1 parent 785b5cf commit 2e752ea

File tree

1 file changed

+3
-22
lines changed

1 file changed

+3
-22
lines changed

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,14 +1482,9 @@ def get_rope_index(
14821482
if attention_mask is None:
14831483
attention_mask = torch.ones_like(total_input_ids)
14841484
position_ids = torch.ones(
1485-
3,
1486-
input_ids.shape[0],
1487-
input_ids.shape[1],
1488-
dtype=input_ids.dtype,
1489-
device=input_ids.device,
1485+
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
14901486
)
14911487
image_index, video_index = 0, 0
1492-
attention_mask = attention_mask.to(total_input_ids.device)
14931488
for i, input_ids in enumerate(total_input_ids):
14941489
input_ids = input_ids[attention_mask[i] == 1]
14951490
image_nums, video_nums = 0, 0
@@ -1516,21 +1511,15 @@ def get_rope_index(
15161511
image_grid_thw[image_index][1],
15171512
image_grid_thw[image_index][2],
15181513
)
1519-
second_per_grid_t = 0
15201514
image_index += 1
15211515
remain_images -= 1
15221516
ed = ed_image
1523-
15241517
else:
15251518
t, h, w = (
15261519
video_grid_thw[video_index][0],
15271520
video_grid_thw[video_index][1],
15281521
video_grid_thw[video_index][2],
15291522
)
1530-
if second_per_grid_ts is not None:
1531-
second_per_grid_t = second_per_grid_ts[video_index]
1532-
else:
1533-
second_per_grid_t = 1.0
15341523
video_index += 1
15351524
remain_videos -= 1
15361525
ed = ed_video
@@ -1544,15 +1533,7 @@ def get_rope_index(
15441533
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
15451534
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
15461535

1547-
t_index = (
1548-
(
1549-
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
1550-
* second_per_grid_t
1551-
* self.config.vision_config.tokens_per_second
1552-
)
1553-
.long()
1554-
.flatten()
1555-
)
1536+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
15561537
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
15571538
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
15581539
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
@@ -1572,7 +1553,7 @@ def get_rope_index(
15721553
if attention_mask is not None:
15731554
position_ids = attention_mask.long().cumsum(-1) - 1
15741555
position_ids.masked_fill_(attention_mask == 0, 1)
1575-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
1556+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
15761557
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
15771558
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
15781559
else:

0 commit comments

Comments
 (0)