@@ -627,7 +627,7 @@ def _gather_final_log_probs(self,
627
627
source_token_ids = state ["source_token_ids" ]
628
628
629
629
# shape: [(batch_size, *)]
630
- modified_log_probs_list : List [torch .Tensor ] = [generation_log_probs ]
630
+ modified_log_probs_list : List [torch .Tensor ] = []
631
631
for i in range (trimmed_source_length ):
632
632
# shape: (group_size,)
633
633
copy_log_probs_slice = copy_log_probs [:, i ]
@@ -648,7 +648,9 @@ def _gather_final_log_probs(self,
648
648
selected_generation_log_probs = generation_log_probs .gather (1 , source_to_target_slice .unsqueeze (- 1 ))
649
649
combined_scores = util .logsumexp (
650
650
torch .cat ((selected_generation_log_probs , copy_log_probs_to_add ), dim = 1 ))
651
- generation_log_probs .scatter_ (- 1 , source_to_target_slice .unsqueeze (- 1 ), combined_scores .unsqueeze (- 1 ))
651
+ generation_log_probs = generation_log_probs .scatter (- 1 ,
652
+ source_to_target_slice .unsqueeze (- 1 ),
653
+ combined_scores .unsqueeze (- 1 ))
652
654
# We have to combine copy scores for duplicate source tokens so that
653
655
# we can find the overall most likely source token. So, if this is the first
654
656
# occurence of this particular source token, we add the log_probs from all other
@@ -676,6 +678,7 @@ def _gather_final_log_probs(self,
676
678
# shape: (group_size,)
677
679
left_over_copy_log_probs = copy_log_probs_slice + (1.0 - copy_log_probs_to_add_mask + 1e-45 ).log ()
678
680
modified_log_probs_list .append (left_over_copy_log_probs .unsqueeze (- 1 ))
681
+ modified_log_probs_list .insert (0 , generation_log_probs )
679
682
680
683
# shape: (group_size, target_vocab_size + trimmed_source_length)
681
684
modified_log_probs = torch .cat (modified_log_probs_list , dim = - 1 )
0 commit comments