Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Faster label smoothing #5294

Merged
merged 7 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ensured `ensure_model_can_train_save_and_load` is consistently random.
- Fixed weight tying logic in `T5` transformer module. Previously input/output embeddings were always tied. Now this is optional,
and the default behavior is taken from the `config.tie_word_embeddings` value when instantiating `from_pretrained_module()`.
- Implemented slightly faster label smoothing
- Fixed the docs for `PytorchTransformerWrapper`
- Fixed recovering training jobs with models that expect `get_metrics()` to not be called until they have seen at least one batch.

Expand Down
5 changes: 2 additions & 3 deletions allennlp/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,9 @@ def sequence_cross_entropy_with_logits(
num_classes = logits.size(-1)
smoothing_value = label_smoothing / num_classes
# Fill all the correct indices with 1 - smoothing value.
one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(
-1, targets_flat, 1.0 - label_smoothing
smoothed_targets = torch.full_like(log_probs_flat, smoothing_value).scatter_(
-1, targets_flat, 1.0 - label_smoothing + smoothing_value
)
smoothed_targets = one_hot_targets + smoothing_value
negative_log_likelihood_flat = -log_probs_flat * smoothed_targets
negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True)
else:
Expand Down