7
7
import logging
8
8
import copy
9
9
import math
10
- import warnings
11
10
12
11
import torch
13
12
@@ -517,7 +516,6 @@ def weighted_sum(matrix: torch.Tensor, attention: torch.Tensor) -> torch.Tensor:
517
516
def sequence_cross_entropy_with_logits (logits : torch .FloatTensor ,
518
517
targets : torch .LongTensor ,
519
518
weights : torch .FloatTensor ,
520
- batch_average : bool = None ,
521
519
average : str = "batch" ,
522
520
label_smoothing : float = None ) -> torch .FloatTensor :
523
521
"""
@@ -537,15 +535,6 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
537
535
index of the true class for each corresponding step.
538
536
weights : ``torch.FloatTensor``, required.
539
537
A ``torch.FloatTensor`` of size (batch, sequence_length)
540
- batch_average : bool, optional, (default = None).
541
- A bool indicating whether the loss should be averaged across the batch,
542
- or returned as a vector of losses per batch element.
543
-
544
- .. deprecated:: 0.6.2
545
- ``batch_average`` was deprecated and replaced with
546
- the more general ``average`` in version 0.6.2. It will be removed
547
- in version 0.8.
548
-
549
538
average: str, optional (default = "batch")
550
539
If "batch", average the loss across the batches. If "token", average
551
540
the loss across each item in the input. If ``None``, return a vector
@@ -563,18 +552,6 @@ def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,
563
552
If ``average is None``, the returned loss is a vector of shape (batch_size,).
564
553
565
554
"""
566
- if batch_average is not None :
567
- # Maintain old behavior
568
- if batch_average :
569
- warnings .warn ("batch_average=True was deprecated and replaced "
570
- "with average='batch' in version 0.6.2. It will be "
571
- "removed in version 0.8." , DeprecationWarning )
572
- average = "batch"
573
- else :
574
- warnings .warn ("batch_average=False was deprecated and replaced "
575
- "with average=None in version 0.6.2. It will be "
576
- "removed in version 0.8." , DeprecationWarning )
577
- average = None
578
555
if average not in {None , "token" , "batch" }:
579
556
raise ValueError ("Got average f{average}, expected one of "
580
557
"None, 'token', or 'batch'" )
0 commit comments