Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Refactor span extractors and unify forward. #5160

Merged
merged 7 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
You can do this by setting the parameter `load_weights` to `False`.
See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details.

- Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths.

## Unreleased

### Fixed

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Optional

import torch
from overrides import overrides
from torch.nn.parameter import Parameter

from allennlp.common.checks import ConfigurationError
from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.nn import util


@SpanExtractor.register("bidirectional_endpoint")
class BidirectionalEndpointSpanExtractor(SpanExtractor):
class BidirectionalEndpointSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""
Represents spans from a bidirectional encoder as a concatenation of two different
representations of the span endpoints, one for the forward direction of the encoder
Expand Down Expand Up @@ -79,12 +78,14 @@ def __init__(
bucket_widths: bool = False,
use_sentinels: bool = True,
) -> None:
super().__init__()
self._input_dim = input_dim
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._forward_combination = forward_combination
self._backward_combination = backward_combination
self._num_width_embeddings = num_width_embeddings
self._bucket_widths = bucket_widths

if self._input_dim % 2 != 0:
raise ConfigurationError(
Expand All @@ -93,25 +94,11 @@ def __init__(
"is bidirectional (and hence divisible by 2)."
)

self._span_width_embedding: Optional[Embedding] = None
if num_width_embeddings is not None and span_width_embedding_dim is not None:
self._span_width_embedding = Embedding(
num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim
)
elif num_width_embeddings is not None or span_width_embedding_dim is not None:
raise ConfigurationError(
"To use a span width embedding representation, you must"
"specify both num_width_buckets and span_width_embedding_dim."
)

self._use_sentinels = use_sentinels
if use_sentinels:
self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim / 2)]))
self._end_sentinel = Parameter(torch.randn([1, 1, int(input_dim / 2)]))

def get_input_dim(self) -> int:
return self._input_dim

def get_output_dim(self) -> int:
unidirectional_dim = int(self._input_dim / 2)
forward_combined_dim = util.get_combined_dim(
Expand All @@ -128,8 +115,7 @@ def get_output_dim(self) -> int:
)
return forward_combined_dim + backward_combined_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
Expand Down Expand Up @@ -238,18 +224,4 @@ def forward(
# Shape (batch_size, num_spans, forward_combination_dim + backward_combination_dim)
span_embeddings = torch.cat([forward_spans, backward_spans], -1)

if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(
span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore
)
else:
span_widths = span_ends - span_starts

span_width_embeddings = self._span_width_embedding(span_widths)
return torch.cat([span_embeddings, span_width_embeddings], -1)

if span_indices_mask is not None:
return span_embeddings * span_indices_mask.unsqueeze(-1)
return span_embeddings
51 changes: 11 additions & 40 deletions allennlp/modules/span_extractors/endpoint_span_extractor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from typing import Optional

import torch
from torch.nn.parameter import Parameter
from overrides import overrides

from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.nn import util
from allennlp.common.checks import ConfigurationError


@SpanExtractor.register("endpoint")
class EndpointSpanExtractor(SpanExtractor):
class EndpointSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""
Represents spans as a combination of the embeddings of their endpoints. Additionally,
the width of the spans can be embedded and concatenated on to the final combination.
Expand Down Expand Up @@ -61,38 +59,25 @@ def __init__(
bucket_widths: bool = False,
use_exclusive_start_indices: bool = False,
) -> None:
super().__init__()
self._input_dim = input_dim
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._combination = combination
self._num_width_embeddings = num_width_embeddings
self._bucket_widths = bucket_widths

self._use_exclusive_start_indices = use_exclusive_start_indices
if use_exclusive_start_indices:
self._start_sentinel = Parameter(torch.randn([1, 1, int(input_dim)]))

self._span_width_embedding: Optional[Embedding] = None
if num_width_embeddings is not None and span_width_embedding_dim is not None:
self._span_width_embedding = Embedding(
num_embeddings=num_width_embeddings, embedding_dim=span_width_embedding_dim
)
elif num_width_embeddings is not None or span_width_embedding_dim is not None:
raise ConfigurationError(
"To use a span width embedding representation, you must"
"specify both num_width_buckets and span_width_embedding_dim."
)

def get_input_dim(self) -> int:
return self._input_dim

def get_output_dim(self) -> int:
combined_dim = util.get_combined_dim(self._combination, [self._input_dim, self._input_dim])
if self._span_width_embedding is not None:
return combined_dim + self._span_width_embedding.get_output_dim()
return combined_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
Expand Down Expand Up @@ -148,19 +133,5 @@ def forward(
combined_tensors = util.combine_tensors(
self._combination, [start_embeddings, end_embeddings]
)
if self._span_width_embedding is not None:
# Embed the span widths and concatenate to the rest of the representations.
if self._bucket_widths:
span_widths = util.bucket_values(
span_ends - span_starts, num_total_buckets=self._num_width_embeddings # type: ignore
)
else:
span_widths = span_ends - span_starts

span_width_embeddings = self._span_width_embedding(span_widths)
combined_tensors = torch.cat([combined_tensors, span_width_embeddings], -1)

if span_indices_mask is not None:
return combined_tensors * span_indices_mask.unsqueeze(-1)

return combined_tensors
45 changes: 29 additions & 16 deletions allennlp/modules/span_extractors/self_attentive_span_extractor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import torch
from overrides import overrides

from allennlp.modules.span_extractors.span_extractor import SpanExtractor
from allennlp.modules.span_extractors.span_extractor_with_span_width_embedding import (
SpanExtractorWithSpanWidthEmbedding,
)
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.nn import util


@SpanExtractor.register("self_attentive")
class SelfAttentiveSpanExtractor(SpanExtractor):
class SelfAttentiveSpanExtractor(SpanExtractorWithSpanWidthEmbedding):
"""
Computes span representations by generating an unnormalized attention score for each
word in the document. Spans representations are computed with respect to these
Expand All @@ -23,6 +25,14 @@ class SelfAttentiveSpanExtractor(SpanExtractor):

input_dim : `int`, required.
The final dimension of the `sequence_tensor`.
num_width_embeddings : `int`, optional (default = `None`).
Specifies the number of buckets to use when representing
span width features.
span_width_embedding_dim : `int`, optional (default = `None`).
The embedding size for the span_width features.
bucket_widths : `bool`, optional (default = `False`).
Whether to bucket the span widths into log-space buckets. If `False`,
the raw span widths are used.

# Returns

Expand All @@ -33,22 +43,31 @@ class SelfAttentiveSpanExtractor(SpanExtractor):
over which they are normalized.
"""

def __init__(self, input_dim: int) -> None:
super().__init__()
self._input_dim = input_dim
def __init__(
self,
input_dim: int,
num_width_embeddings: int = None,
span_width_embedding_dim: int = None,
bucket_widths: bool = False,
) -> None:
super().__init__(
input_dim=input_dim,
num_width_embeddings=num_width_embeddings,
span_width_embedding_dim=span_width_embedding_dim,
bucket_widths=bucket_widths,
)
self._global_attention = TimeDistributed(torch.nn.Linear(input_dim, 1))

def get_input_dim(self) -> int:
return self._input_dim

def get_output_dim(self) -> int:
if self._span_width_embedding is not None:
return self._input_dim + self._span_width_embedding.get_output_dim()
return self._input_dim

@overrides
def forward(
def _embed_spans(
self,
sequence_tensor: torch.FloatTensor,
span_indices: torch.LongTensor,
sequence_mask: torch.BoolTensor = None,
span_indices_mask: torch.BoolTensor = None,
) -> torch.FloatTensor:
# shape (batch_size, sequence_length, 1)
Expand All @@ -72,10 +91,4 @@ def forward(
# Shape: (batch_size, num_spans, embedding_dim)
attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

if span_indices_mask is not None:
# Above we were masking the widths of spans with respect to the max
# span width in the batch. Here we are masking the spans which were
# originally passed in as padding.
return attended_text_embeddings * span_indices_mask.unsqueeze(-1)

return attended_text_embeddings
Loading