Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit a98481c

Browse files
authored
Remove deprecated batch_average argument to sequence ce with logits (#2247)
* Remove deprecated batch_average argument to sequence ce with logits * Fix test * Fix lint
1 parent cf67128 commit a98481c

File tree

2 files changed

+2
-24
lines changed

2 files changed

+2
-24
lines changed

allennlp/nn/util.py

-23
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import logging
88
import copy
99
import math
10-
import warnings
1110

1211
import torch
1312

@@ -517,7 +516,6 @@ def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
517516
def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
518517
targets: torch.LongTensor,
519518
weights: torch.FloatTensor,
520-
batch_average: bool = None,
521519
average: str = "batch",
522520
label_smoothing: float = None) -> torch.FloatTensor:
523521
"""
@@ -537,15 +535,6 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
537535
index of the true class for each corresponding step.
538536
weights : ``torch.FloatTensor``, required.
539537
A ``torch.FloatTensor`` of size (batch, sequence_length)
540-
batch_average : bool, optional, (default = None).
541-
A bool indicating whether the loss should be averaged across the batch,
542-
or returned as a vector of losses per batch element.
543-
544-
.. deprecated:: 0.6.2
545-
``batch_average`` was deprecated and replaced with
546-
the more general ``average`` in version 0.6.2. It will be removed
547-
in version 0.8.
548-
549538
average: str, optional (default = "batch")
550539
If "batch", average the loss across the batches. If "token", average
551540
the loss across each item in the input. If ``None``, return a vector
@@ -563,18 +552,6 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
563552
If ``average is None``, the returned loss is a vector of shape (batch_size,).
564553
565554
"""
566-
if batch_average is not None:
567-
# Maintain old behavior
568-
if batch_average:
569-
warnings.warn("batch_average=True was deprecated and replaced "
570-
"with average='batch' in version 0.6.2. It will be "
571-
"removed in version 0.8.", DeprecationWarning)
572-
average = "batch"
573-
else:
574-
warnings.warn("batch_average=False was deprecated and replaced "
575-
"with average=None in version 0.6.2. It will be "
576-
"removed in version 0.8.", DeprecationWarning)
577-
average = None
578555
if average not in {None, "token", "batch"}:
579556
raise ValueError("Got average f{average}, expected one of "
580557
"None, 'token', or 'batch'")

allennlp/tests/nn/util_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,8 @@ def test_sequence_cross_entropy_with_logits_averages_token_correctly(self):
533533

534534
loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, average="token")
535535

536-
vector_loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights, batch_average=False)
536+
vector_loss = util.sequence_cross_entropy_with_logits(tensor, targets, weights,
537+
average=None)
537538
total_token_loss = (vector_loss * weights.float().sum(dim=-1)).sum()
538539
average_token_loss = (total_token_loss / weights.float().sum()).detach()
539540
assert_almost_equal(loss.detach().item(), average_token_loss.item())

0 commit comments

Comments
 (0)