Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Refactor span extractors and unify forward. #5160

Merged
merged 7 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.sanity_checks.task_checklists` module.

### Changed

- Refactored all span extractors, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractor.forward` to unify the arguments. And the `SelfAttentiveSpanExtractor` is able to embed span widths.

## Unreleased

### Fixed

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from typing import Optional

import torch
from overrides import overrides
from torch.nn.parameter import Parameter

from allennlp.common.checks import ConfigurationError
from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.nn import util


Expand Down Expand Up @@ -79,12 +75,14 @@ def __init__(
bucket_widths: bool = False,
use_sentinels: bool = True,
) -> None:
super().__init__()
self._input_dim = input_dim
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._forward_combination = forward_combination
self._backward_combination = backward_combination
self._num_width_embeddings = num_width_embeddings
self._bucket_widths = bucket_widths

if self._input_dim % 2 != 0:
raise ConfigurationError(
Expand All @@ -93,17 +91,6 @@ def __init__(
"is bidirectional (and hence divisible by 2)."
)

self._span_width_embedding: Optional[Embedding] = None
if num_width_embeddings is not None and span_width_embedding_dim is not None:
self._span_width_embedding = Embedding(
num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim
)
elif num_width_embeddings is not None or span_width_embedding_dim is not None:
raise ConfigurationError(
"To use a span width embedding representation, you must"
"specify both num_width_buckets and span_width_embedding_dim."
)

self._use_sentinels = use_sentinels
if use_sentinels:
self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim / 2)]))
Expand All @@ -128,8 +115,7 @@ def get_output_dim(self) -> int:
)
return forward_combined_dim + backward_combined_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
Expand Down Expand Up @@ -238,18 +224,4 @@ def forward(
# Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
span_embeddings = torch.cat([forward_spans, backward_spans], -1)

if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(
span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore
)
else:
span_widths = span_ends - span_starts

span_width_embeddings = self._span_width_embedding(span_widths)
return torch.cat([span_embeddings, span_width_embeddings], -1)

if span_indices_mask is not None:
return span_embeddings * span_indices_mask.unsqueeze(-1)
return span_embeddings
43 changes: 7 additions & 36 deletions allennlp/modules/span_extractors/endpoint_span_extractor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from typing import Optional

import torch
from torch.nn.parameter import Parameter
from overrides import overrides

from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.nn import util
from allennlp.common.checks import ConfigurationError


@SpanExtractor.register("endpoint")
Expand Down Expand Up @@ -61,27 +56,18 @@ def __init__(
bucket_widths: bool = False,
use_exclusive_start_indices: bool = False,
) -> None:
super().__init__()
self._input_dim = input_dim
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._combination = combination
self._num_width_embeddings = num_width_embeddings
self._bucket_widths = bucket_widths

self._use_exclusive_start_indices = use_exclusive_start_indices
if use_exclusive_start_indices:
self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim)]))

self._span_width_embedding: Optional[Embedding] = None
if num_width_embeddings is not None and span_width_embedding_dim is not None:
self._span_width_embedding = Embedding(
num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim
)
elif num_width_embeddings is not None or span_width_embedding_dim is not None:
raise ConfigurationError(
"To use a span width embedding representation, you must"
"specify both num_width_buckets and span_width_embedding_dim."
)

def get_input_dim(self) -> int:
return self._input_dim

Expand All @@ -91,8 +77,7 @@ def get_output_dim(self) -> int:
return combined_dim + self._span_width_embedding.get_output_dim()
return combined_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
Expand Down Expand Up @@ -148,19 +133,5 @@ def forward(
combined_tensors = util.combine_tensors(
self._combination, [start_embeddings, end_embeddings]
)
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(
span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore
)
else:
span_widths = span_ends - span_starts

span_width_embeddings = self._span_width_embedding(span_widths)
combined_tensors = torch.cat([combined_tensors, span_width_embeddings], -1)

if span_indices_mask is not None:
return combined_tensors * span_indices_mask.unsqueeze(-1)

return combined_tensors
37 changes: 25 additions & 12 deletions allennlp/modules/span_extractors/self_attentive_span_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from overrides import overrides

from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.time_distributed import TimeDistributed
Expand All @@ -23,6 +22,14 @@ class SelfAttentiveSpanExtractor(SpanExtractor):

input_dim : `int`, required.
The final dimension of the `sequence_tensor`.
num_width_embeddings : `int`, optional (default = `None`).
Specifies the number of buckets to use when representing
span width features.
span_width_embedding_dim : `int`, optional (default = `None`).
The embedding size for the span_width features.
bucket_widths : `bool`, optional (default = `False`).
Whether to bucket the span widths into log-space buckets. If `False`,
the raw span widths are used.

# Returns

Expand All @@ -33,22 +40,34 @@ class SelfAttentiveSpanExtractor(SpanExtractor):
over which they are normalized.
"""

def __init__(self, input_dim: int) -> None:
super().__init__()
self._input_dim = input_dim
def __init__(
self,
input_dim: int,
num_width_embeddings: int = None,
span_width_embedding_dim: int = None,
bucket_widths: bool = False,
) -> None:
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._global_attention = TimeDistributed(torch.nn.Linear(input_dim, 1))

def get_input_dim(self) -> int:
return self._input_dim

def get_output_dim(self) -> int:
if self._span_width_embedding is not None:
return self._input_dim + self._span_width_embedding.get_output_dim()
return self._input_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
) -> torch.FloatTensor:
# shape (batch_size, sequence_length, 1)
Expand All @@ -72,10 +91,4 @@ def forward(
# Shape: (batch_size, num_spans, embedding_dim)
attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

if span_indices_mask is not None:
# Above we were masking the widths of spans with respect to the max
# span width in the batch. Here we are masking the spans which were
# originally passed in as padding.
return attended_text_embeddings * span_indices_mask.unsqueeze(-1)

return attended_text_embeddings
85 changes: 84 additions & 1 deletion allennlp/modules/span_extractors/span_extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import torch
from typing import Optional
from overrides import overrides

import torch

from allennlp.common.checks import ConfigurationError
from allennlp.common.registrable import Registrable
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.nn import util


class SpanExtractor(torch.nn.Module, Registrable):
Expand All @@ -14,8 +19,50 @@ class SpanExtractor(torch.nn.Module, Registrable):
and indices of shape (batch_size, num_spans, 2) and return a tensor of
shape (batch_size, num_spans, ...), forming some representation of the
spans.

# Parameters

input_dim : `int`, required.
The final dimension of the `sequence_tensor`.
num_width_embeddings : `int`, optional (default = `None`).
Specifies the number of buckets to use when representing
span width features.
span_width_embedding_dim : `int`, optional (default = `None`).
The embedding size for the span_width features.
bucket_widths : `bool`, optional (default = `False`).
Whether to bucket the span widths into log-space buckets. If `False`,
the raw span widths are used.

# Returns

span_embeddings : `torch.FloatTensor`.
A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
where `embedded_span_size` depends on the way spans are represented.
"""

def __init__(
self,
input_dim: int,
num_width_embeddings: int = None,
span_width_embedding_dim: int = None,
bucket_widths: bool = False,
) -> None:
super().__init__()
self._input_dim = input_dim
self._num_width_embeddings = num_width_embeddings
self._bucket_widths = bucket_widths

self._span_width_embedding: Optional[Embedding] = None
if num_width_embeddings is not None and span_width_embedding_dim is not None:
self._span_width_embedding = Embedding(
num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim
)
elif num_width_embeddings is not None or span_width_embedding_dim is not None:
raise ConfigurationError(
"To use a span width embedding representation, you must"
"specify both num_width_embeddings and span_width_embedding_dim."
)

@overrides
def forward(
self,
Expand Down Expand Up @@ -53,6 +100,42 @@ def forward(
A tensor of shape `(batch_size, num_spans, embedded_span_size)`,
where `embedded_span_size` depends on the way spans are represented.
"""
# shape (batch_size, num_spans, embedding_dim)
span_embeddings = self._embed_spans(
sequence_tensor, span_indices, sequence_mask, span_indices_mask
)
if self._span_width_embedding is not None:
# width = end_index - start_index + 1 since `SpanField` use inclusive indices.
# But here we do not add 1 beacuse we offen initiate the span width
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo

Copy link
Contributor Author

@izhx izhx Apr 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed offen to often. 🤣

# embedding matrix with `num_width_embeddings = max_span_width`
# shape (batch_size, num_spans)
widths_minus_one = span_indices[..., 1] - span_indices[..., 0]

if self._bucket_widths:
widths_minus_one = util.bucket_values(
widths_minus_one, num_total_buckets=self._num_width_embeddings # type: ignore
)

# Embed the span widths and concatenate to the rest of the representations.
span_width_embeddings = self._span_width_embedding(widths_minus_one)
span_embeddings = torch.cat([span_embeddings, span_width_embeddings], -1)

if span_indices_mask is not None:
# Here we are masking the spans which were originally passed in as padding.
return span_embeddings * span_indices_mask.unsqueeze(-1)

return span_embeddings

def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
) -> torch.Tensor:
"""
Returns the span embeddings computed in many different ways.
"""
raise NotImplementedError

def get_input_dim(self) -> int:
Expand Down
Loading