Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 5a3acba

Browse files
eraldoluisdirkgr
andauthored
Implementation of Weighted CRF Tagger (handling unbalanced datasets) (#5676)
* Weighted CRF: scaled emission scores * Fixed bug in ConditionalRandomField self.label_weights is now created as a parameter so that it will be moved to GPU whenvever the model moves. * CRF weighting strategies * Weighted CRF: refactoring of three methods * Weighted CRF: refactoring of three methods * Weighted CRF: black formatting * Weighted CRF: moved classes to new module Simplified Lannoy implementation. * formatting and type checking * Moved ConditionalRandomField to new module Renamed module allennlp.modules.conditional_random_field_weight to ...conditional_random_files * Updated changelog Co-authored-by: Dirk Groeneveld <[email protected]>
1 parent 20df7cd commit 5a3acba

8 files changed

+617
-17
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- Added metric `FBetaVerboseMeasure` which extends `FBetaMeasure` to ensure compatibility with logging plugins and add some options.
13+
- Added three sample weighting techniques to `ConditionalRandomField` by supplying three new subclasses: `ConditionalRandomFieldWeightEmission`, `ConditionalRandomFieldWeightTrans`, and `ConditionalRandomFieldWeightLannoy`.
1314

1415
### Fixed
1516

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from allennlp.modules.conditional_random_field.conditional_random_field import (
2+
ConditionalRandomField,
3+
)
4+
from allennlp.modules.conditional_random_field.conditional_random_field_wemission import (
5+
ConditionalRandomFieldWeightEmission,
6+
)
7+
from allennlp.modules.conditional_random_field.conditional_random_field_wtrans import (
8+
ConditionalRandomFieldWeightTrans,
9+
)
10+
from allennlp.modules.conditional_random_field.conditional_random_field_wlannoy import (
11+
ConditionalRandomFieldWeightLannoy,
12+
)

allennlp/modules/conditional_random_field.py allennlp/modules/conditional_random_field/conditional_random_field.py

+45-14
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,21 @@ def reset_parameters(self):
214214
torch.nn.init.normal_(self.start_transitions)
215215
torch.nn.init.normal_(self.end_transitions)
216216

217-
def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
218-
"""
219-
Computes the (batch_size,) denominator term for the log-likelihood, which is the
220-
sum of the likelihoods across all possible state sequences.
217+
def _input_likelihood(
218+
self, logits: torch.Tensor, transitions: torch.Tensor, mask: torch.BoolTensor
219+
) -> torch.Tensor:
220+
"""Computes the (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood
221+
222+
This is the sum of the likelihoods across all possible state sequences.
223+
224+
Args:
225+
logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of
226+
unnormalized log-probabilities
227+
transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores
228+
mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags
229+
230+
Returns:
231+
torch.Tensor: (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood
221232
"""
222233
batch_size, sequence_length, num_tags = logits.size()
223234

@@ -239,7 +250,7 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor
239250
# The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
240251
emit_scores = logits[i].view(batch_size, 1, num_tags)
241252
# Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
242-
transition_scores = self.transitions.view(1, num_tags, num_tags)
253+
transition_scores = transitions.view(1, num_tags, num_tags)
243254
# Alpha is for the current_tag, so we broadcast along the next_tag axis.
244255
broadcast_alpha = alpha.view(batch_size, num_tags, 1)
245256

@@ -262,10 +273,23 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor
262273
return util.logsumexp(stops)
263274

264275
def _joint_likelihood(
265-
self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor
276+
self,
277+
logits: torch.Tensor,
278+
transitions: torch.Tensor,
279+
tags: torch.Tensor,
280+
mask: torch.BoolTensor,
266281
) -> torch.Tensor:
267-
"""
268-
Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
282+
"""Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
283+
284+
Args:
285+
logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of unnormalized
286+
log-probabilities
287+
transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores
288+
tags (torch.Tensor): output tag sequences (batch_size, sequence_length) $y$ for each input sequence
289+
mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags
290+
291+
Returns:
292+
torch.Tensor: numerator term for the log-likelihood, which is just score(inputs, tags)
269293
"""
270294
batch_size, sequence_length, _ = logits.data.shape
271295

@@ -286,7 +310,7 @@ def _joint_likelihood(
286310
current_tag, next_tag = tags[i], tags[i + 1]
287311

288312
# The scores for transitioning from current_tag to next_tag
289-
transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)]
313+
transition_score = transitions[current_tag.view(-1), next_tag.view(-1)]
290314

291315
# The score for using current_tag
292316
emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)
@@ -318,18 +342,25 @@ def _joint_likelihood(
318342
def forward(
319343
self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None
320344
) -> torch.Tensor:
321-
"""
322-
Computes the log likelihood.
323-
"""
345+
"""Computes the log likelihood for the given batch of input sequences $(x,y)$
346+
347+
Args:
348+
inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$
349+
tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$
350+
mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags.
351+
Defaults to None.
324352
353+
Returns:
354+
torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input
355+
"""
325356
if mask is None:
326357
mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device)
327358
else:
328359
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
329360
mask = mask.to(torch.bool)
330361

331-
log_denominator = self._input_likelihood(inputs, mask)
332-
log_numerator = self._joint_likelihood(inputs, tags, mask)
362+
log_denominator = self._input_likelihood(inputs, self.transitions, mask)
363+
log_numerator = self._joint_likelihood(inputs, self.transitions, tags, mask)
333364

334365
return torch.sum(log_numerator - log_denominator)
335366

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
Conditional random field with emission-based weighting
3+
"""
4+
from typing import List, Tuple
5+
6+
import torch
7+
8+
from allennlp.common.checks import ConfigurationError
9+
from allennlp.modules.conditional_random_field.conditional_random_field import (
10+
ConditionalRandomField,
11+
)
12+
13+
14+
class ConditionalRandomFieldWeightEmission(ConditionalRandomField):
15+
"""
16+
This module uses the "forward-backward" algorithm to compute
17+
the log-likelihood of its inputs assuming a conditional random field model.
18+
19+
See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf
20+
21+
This is a weighted version of `ConditionalRandomField` which accepts a `label_weights`
22+
parameter to be used in the loss function in order to give different weights for each
23+
token depending on its label. The method implemented here is based on the simple idea
24+
of weighting emission scores using the weight given for the corresponding tag.
25+
26+
There are two other sample weighting methods implemented. You can find more details
27+
about them in: https://eraldoluis.github.io/2022/05/10/weighted-crf.html
28+
29+
# Parameters
30+
31+
num_tags : `int`, required
32+
The number of tags.
33+
label_weights : `List[float]`, required
34+
A list of weights to be used in the loss function in order to
35+
give different weights for each token depending on its label.
36+
`len(label_weights)` must be equal to `num_tags`. This is useful to
37+
deal with highly unbalanced datasets. The method implemented here is
38+
based on the simple idea of weighting emission scores using the weight
39+
given for the corresponding tag.
40+
constraints : `List[Tuple[int, int]]`, optional (default = `None`)
41+
An optional list of allowed transitions (from_tag_id, to_tag_id).
42+
These are applied to `viterbi_tags()` but do not affect `forward()`.
43+
These should be derived from `allowed_transitions` so that the
44+
start and end transitions are handled correctly for your tag type.
45+
include_start_end_transitions : `bool`, optional (default = `True`)
46+
Whether to include the start and end transition parameters.
47+
"""
48+
49+
def __init__(
50+
self,
51+
num_tags: int,
52+
label_weights: List[float],
53+
constraints: List[Tuple[int, int]] = None,
54+
include_start_end_transitions: bool = True,
55+
) -> None:
56+
super().__init__(num_tags, constraints, include_start_end_transitions)
57+
58+
if label_weights is None:
59+
raise ConfigurationError("label_weights must be given")
60+
61+
self.register_buffer("label_weights", torch.Tensor(label_weights))
62+
63+
def forward(
64+
self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None
65+
) -> torch.Tensor:
66+
"""Computes the log likelihood for the given batch of input sequences $(x,y)$
67+
68+
Args:
69+
inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$
70+
tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$
71+
mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags.
72+
Defaults to None.
73+
74+
Returns:
75+
torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input
76+
"""
77+
if mask is None:
78+
mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device)
79+
else:
80+
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
81+
mask = mask.to(torch.bool)
82+
83+
label_weights = self.label_weights
84+
85+
# scale the logits for all examples and all time steps
86+
inputs = inputs * label_weights.view(1, 1, -1)
87+
88+
log_denominator = self._input_likelihood(inputs, self.transitions, mask)
89+
log_numerator = self._joint_likelihood(inputs, self.transitions, tags, mask)
90+
91+
return torch.sum(log_numerator - log_denominator)

0 commit comments

Comments
 (0)