Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit c629093

Browse files
epwalshDeNeutoy
authored andcommitted
CopyNet: replace in-place tensor operation with out-of-place equivalent (#2925)
* remove in-place operation * oops, fixed
1 parent 89700de commit c629093

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

allennlp/models/encoder_decoders/copynet_seq2seq.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def _gather_final_log_probs(self,
627627
source_token_ids = state["source_token_ids"]
628628

629629
# shape: [(batch_size, *)]
630-
modified_log_probs_list: List[torch.Tensor] = [generation_log_probs]
630+
modified_log_probs_list: List[torch.Tensor] = []
631631
for i in range(trimmed_source_length):
632632
# shape: (group_size,)
633633
copy_log_probs_slice = copy_log_probs[:, i]
@@ -648,7 +648,9 @@ def _gather_final_log_probs(self,
648648
selected_generation_log_probs = generation_log_probs.gather(1, source_to_target_slice.unsqueeze(-1))
649649
combined_scores = util.logsumexp(
650650
torch.cat((selected_generation_log_probs, copy_log_probs_to_add), dim=1))
651-
generation_log_probs.scatter_(-1, source_to_target_slice.unsqueeze(-1), combined_scores.unsqueeze(-1))
651+
generation_log_probs = generation_log_probs.scatter(-1,
652+
source_to_target_slice.unsqueeze(-1),
653+
combined_scores.unsqueeze(-1))
652654
# We have to combine copy scores for duplicate source tokens so that
653655
# we can find the overall most likely source token. So, if this is the first
654656
# occurence of this particular source token, we add the log_probs from all other
@@ -676,6 +678,7 @@ def _gather_final_log_probs(self,
676678
# shape: (group_size,)
677679
left_over_copy_log_probs = copy_log_probs_slice + (1.0 - copy_log_probs_to_add_mask + 1e-45).log()
678680
modified_log_probs_list.append(left_over_copy_log_probs.unsqueeze(-1))
681+
modified_log_probs_list.insert(0, generation_log_probs)
679682

680683
# shape: (group_size, target_vocab_size + trimmed_source_length)
681684
modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)

0 commit comments

Comments
 (0)