@@ -1482,14 +1482,9 @@ def get_rope_index(
1482
1482
if attention_mask is None :
1483
1483
attention_mask = torch .ones_like (total_input_ids )
1484
1484
position_ids = torch .ones (
1485
- 3 ,
1486
- input_ids .shape [0 ],
1487
- input_ids .shape [1 ],
1488
- dtype = input_ids .dtype ,
1489
- device = input_ids .device ,
1485
+ 3 , input_ids .shape [0 ], input_ids .shape [1 ], dtype = input_ids .dtype , device = input_ids .device
1490
1486
)
1491
1487
image_index , video_index = 0 , 0
1492
- attention_mask = attention_mask .to (total_input_ids .device )
1493
1488
for i , input_ids in enumerate (total_input_ids ):
1494
1489
input_ids = input_ids [attention_mask [i ] == 1 ]
1495
1490
image_nums , video_nums = 0 , 0
@@ -1516,21 +1511,15 @@ def get_rope_index(
1516
1511
image_grid_thw [image_index ][1 ],
1517
1512
image_grid_thw [image_index ][2 ],
1518
1513
)
1519
- second_per_grid_t = 0
1520
1514
image_index += 1
1521
1515
remain_images -= 1
1522
1516
ed = ed_image
1523
-
1524
1517
else :
1525
1518
t , h , w = (
1526
1519
video_grid_thw [video_index ][0 ],
1527
1520
video_grid_thw [video_index ][1 ],
1528
1521
video_grid_thw [video_index ][2 ],
1529
1522
)
1530
- if second_per_grid_ts is not None :
1531
- second_per_grid_t = second_per_grid_ts [video_index ]
1532
- else :
1533
- second_per_grid_t = 1.0
1534
1523
video_index += 1
1535
1524
remain_videos -= 1
1536
1525
ed = ed_video
@@ -1544,15 +1533,7 @@ def get_rope_index(
1544
1533
st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
1545
1534
llm_pos_ids_list .append (torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
1546
1535
1547
- t_index = (
1548
- (
1549
- torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (- 1 , llm_grid_h * llm_grid_w )
1550
- * second_per_grid_t
1551
- * self .config .vision_config .tokens_per_second
1552
- )
1553
- .long ()
1554
- .flatten ()
1555
- )
1536
+ t_index = torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (- 1 , llm_grid_h * llm_grid_w ).flatten ()
1556
1537
h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (llm_grid_t , - 1 , llm_grid_w ).flatten ()
1557
1538
w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (llm_grid_t , llm_grid_h , - 1 ).flatten ()
1558
1539
llm_pos_ids_list .append (torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx )
@@ -1572,7 +1553,7 @@ def get_rope_index(
1572
1553
if attention_mask is not None :
1573
1554
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
1574
1555
position_ids .masked_fill_ (attention_mask == 0 , 1 )
1575
- position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 ).to (input_ids .device )
1556
+ position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 ).to (attention_mask .device )
1576
1557
max_position_ids = position_ids .max (0 , keepdim = False )[0 ].max (- 1 , keepdim = True )[0 ]
1577
1558
mrope_position_deltas = max_position_ids + 1 - attention_mask .shape [- 1 ]
1578
1559
else :
0 commit comments