Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 947bd16

Browse files
authored
make api more pythonic (#1926)
* first stab at contextual encoder wrappers * contextual encoders * remove sru encoder * pr comments * replace _ElmoCharacterEncoder with CharacterEncoder * docs * sphinx stuff * address pr comments * address more PR comments * make sphinx happy * iterate * make parameters required * this is still wip * wip * bidirectional-lm proof of concept * progress * revert elmo * revert elmo test * revert elmo token embedder * cnn_highway_encoder -> seq2vec * remove contextual encoders * fix docs * remove print * address more feedback * replace none with identity function * fix docs + checks * fix tests * add comments * add top level imports * fix imports * unused import * progress * use brendan's dataset reader * make data interface more pythonic * remove unused import * fix pytest + pylint
1 parent 0e82106 commit 947bd16

15 files changed

+177
-9
lines changed

allennlp/data/fields/index_field.py

+8
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,11 @@ def empty_field(self):
5151

5252
def __str__(self) -> str:
5353
return f"IndexField with index: {self.sequence_index}."
54+
55+
def __eq__(self, other) -> bool:
56+
# Allow equality checks to ints that are the sequence index
57+
if isinstance(other, int):
58+
return self.sequence_index == other
59+
# Otherwise it has to be the same object
60+
else:
61+
return id(other) == id(self)

allennlp/data/fields/list_field.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pylint: disable=no-self-use
2-
from typing import Dict, List
2+
from typing import Dict, List, Iterator
33

44
from overrides import overrides
55

@@ -31,6 +31,16 @@ def __init__(self, field_list: List[Field]) -> None:
3131
# Not sure why mypy has a hard time with this type...
3232
self.field_list: List[Field] = field_list
3333

34+
# Sequence[Field] methods
35+
def __iter__(self) -> Iterator[Field]:
36+
return iter(self.field_list)
37+
38+
def __getitem__(self, idx: int) -> Field:
39+
return self.field_list[idx]
40+
41+
def __len__(self) -> int:
42+
return len(self.field_list)
43+
3444
@overrides
3545
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
3646
for field in self.field_list:

allennlp/data/fields/metadata_field.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# pylint: disable=no-self-use
2-
from typing import Any, Dict, List
2+
from typing import Any, Dict, List, Mapping
33

44
from overrides import overrides
55

66
from allennlp.data.fields.field import DataArray, Field
77

88

9-
class MetadataField(Field[DataArray]):
9+
class MetadataField(Field[DataArray], Mapping[str, Any]):
1010
"""
1111
A ``MetadataField`` is a ``Field`` that does not get converted into tensors. It just carries
1212
side information that might be needed later on, for computing some third-party metric, or
@@ -27,6 +27,24 @@ class MetadataField(Field[DataArray]):
2727
def __init__(self, metadata: Any) -> None:
2828
self.metadata = metadata
2929

30+
def __getitem__(self, key: str) -> Any:
31+
try:
32+
return self.metadata[key] # type: ignore
33+
except TypeError:
34+
raise TypeError("your metadata is not a dict")
35+
36+
def __iter__(self):
37+
try:
38+
return iter(self.metadata)
39+
except TypeError:
40+
raise TypeError("your metadata is not iterable")
41+
42+
def __len__(self):
43+
try:
44+
return len(self.metadata)
45+
except TypeError:
46+
raise TypeError("your metadata has no length")
47+
3048
@overrides
3149
def get_padding_lengths(self) -> Dict[str, int]:
3250
return {}

allennlp/data/fields/sequence_label_field.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Union, Set
1+
from typing import Dict, List, Union, Set, Iterator
22
import logging
33
import textwrap
44

@@ -75,6 +75,16 @@ def _maybe_warn_for_namespace(self, label_namespace: str) -> None:
7575
self._label_namespace)
7676
self._already_warned_namespaces.add(label_namespace)
7777

78+
# Sequence methods
79+
def __iter__(self) -> Iterator[Union[str, int]]:
80+
return iter(self.labels)
81+
82+
def __getitem__(self, idx: int) -> Union[str, int]:
83+
return self.labels[idx]
84+
85+
def __len__(self) -> int:
86+
return len(self.labels)
87+
7888
@overrides
7989
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
8090
if self._indexed_labels is None:

allennlp/data/fields/span_field.py

+6
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,9 @@ def empty_field(self):
5858

5959
def __str__(self) -> str:
6060
return f"SpanField with spans: ({self.span_start}, {self.span_end})."
61+
62+
def __eq__(self, other) -> bool:
63+
if isinstance(other, tuple) and len(other) == 2:
64+
return other == (self.span_start, self.span_end)
65+
else:
66+
return id(self) == id(other)

allennlp/data/fields/text_field.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
A ``TextField`` represents a string of text, the kind that you might want to represent with
33
standard word vectors, or pass through an LSTM.
44
"""
5-
from typing import Dict, List, Optional
5+
from typing import Dict, List, Optional, Iterator
66
import textwrap
77

88
from overrides import overrides
@@ -44,6 +44,16 @@ def __init__(self, tokens: List[Token], token_indexers: Dict[str, TokenIndexer])
4444
raise ConfigurationError("TextFields must be passed Tokens. "
4545
"Found: {} with types {}.".format(tokens, [type(x) for x in tokens]))
4646

47+
# Sequence[Token] methods
48+
def __iter__(self) -> Iterator[Token]:
49+
return iter(self.tokens)
50+
51+
def __getitem__(self, idx: int) -> Token:
52+
return self.tokens[idx]
53+
54+
def __len__(self) -> int:
55+
return len(self.tokens)
56+
4757
@overrides
4858
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
4959
for indexer in self._token_indexers.values():

allennlp/data/instance.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import Dict, MutableMapping
1+
from typing import Dict, MutableMapping, Mapping
22

33
from allennlp.data.fields.field import DataArray, Field
44
from allennlp.data.vocabulary import Vocabulary
55

66

7-
class Instance:
7+
class Instance(Mapping[str, Field]):
88
"""
99
An ``Instance`` is a collection of :class:`~allennlp.data.fields.field.Field` objects,
1010
specifying the inputs and outputs to
@@ -26,6 +26,18 @@ def __init__(self, fields: MutableMapping[str, Field]) -> None:
2626
self.fields = fields
2727
self.indexed = False
2828

29+
# Add methods for ``Mapping``. Note, even though the fields are
30+
# mutable, we don't implement ``MutableMapping`` because we want
31+
# you to use ``add_field`` and supply a vocabulary.
32+
def __getitem__(self, key: str) -> Field:
33+
return self.fields[key]
34+
35+
def __iter__(self):
36+
return iter(self.fields)
37+
38+
def __len__(self) -> int:
39+
return len(self.fields)
40+
2941
def add_field(self, field_name: str, field: Field, vocab: Vocabulary = None) -> None:
3042
"""
3143
Add the field to the existing fields mapping.

allennlp/tests/data/fields/index_field_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,11 @@ def test_index_field_empty_field_works(self):
3131

3232
def test_printing_doesnt_crash(self):
3333
print(self.text)
34+
35+
def test_equality(self):
36+
index_field1 = IndexField(4, self.text)
37+
index_field2 = IndexField(4, self.text)
38+
39+
assert index_field1 == 4
40+
assert index_field1 == index_field1
41+
assert index_field1 != index_field2

allennlp/tests/data/fields/list_field_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,10 @@ def test_as_tensor_can_handle_multiple_token_indexers_and_empty_fields(self):
181181
def test_printing_doesnt_crash(self):
182182
list_field = ListField([self.field1, self.field2])
183183
print(list_field)
184+
185+
def test_sequence_methods(self):
186+
list_field = ListField([self.field1, self.field2, self.field3])
187+
188+
assert len(list_field) == 3
189+
assert list_field[1] == self.field2
190+
assert [f for f in list_field] == [self.field1, self.field2, self.field3]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# pylint: disable=no-self-use,invalid-name
2+
import pytest
3+
4+
from allennlp.common.testing.test_case import AllenNlpTestCase
5+
from allennlp.data.fields import MetadataField
6+
7+
8+
class TestMetadataField(AllenNlpTestCase):
9+
def test_mapping_works_with_dict(self):
10+
field = MetadataField({"a": 1, "b": [0]})
11+
12+
assert "a" in field
13+
assert field["a"] == 1
14+
assert len(field) == 2
15+
16+
keys = {k for k in field}
17+
assert keys == {"a", "b"}
18+
19+
values = [v for v in field.values()]
20+
assert len(values) == 2
21+
assert 1 in values
22+
assert [0] in values
23+
24+
def test_mapping_raises_with_non_dict(self):
25+
field = MetadataField(0)
26+
27+
with pytest.raises(TypeError):
28+
_ = field[0]
29+
30+
with pytest.raises(TypeError):
31+
_ = len(field)
32+
33+
with pytest.raises(TypeError):
34+
_ = [x for x in field]

allennlp/tests/data/fields/sequence_label_field_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,11 @@ def test_printing_doesnt_crash(self):
8888
tags = ["B", "I", "O", "O", "O"]
8989
sequence_label_field = SequenceLabelField(tags, self.text, label_namespace="labels")
9090
print(sequence_label_field)
91+
92+
def test_sequence_methods(self):
93+
tags = ["B", "I", "O", "O", "O"]
94+
sequence_label_field = SequenceLabelField(tags, self.text, label_namespace="labels")
95+
96+
assert len(sequence_label_field) == 5
97+
assert sequence_label_field[1] == "I"
98+
assert [label for label in sequence_label_field] == tags

allennlp/tests/data/fields/span_field_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,11 @@ def test_empty_span_field_works(self):
4040
def test_printing_doesnt_crash(self):
4141
span_field = SpanField(2, 3, self.text)
4242
print(span_field)
43+
44+
def test_equality(self):
45+
span_field1 = SpanField(2, 3, self.text)
46+
span_field2 = SpanField(2, 3, self.text)
47+
48+
assert span_field1 == (2, 3)
49+
assert span_field1 == span_field1
50+
assert span_field1 != span_field2

allennlp/tests/data/fields/text_field_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,10 @@ def test_token_embedder_returns_dict(self):
253253
assert list(tensors['additional_key'].shape) == [3]
254254
assert list(tensors['words'].shape) == [4]
255255
assert list(tensors['characters'].shape) == [4, 8]
256+
257+
def test_sequence_methods(self):
258+
field = TextField([Token(t) for t in ["This", "is", "a", "sentence", "."]], {})
259+
260+
assert len(field) == 5
261+
assert field[1].text == "is"
262+
assert [token.text for token in field] == ["This", "is", "a", "sentence", "."]

allennlp/tests/data/instance_test.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# pylint: disable=no-self-use,invalid-name
2+
from allennlp.common.testing import AllenNlpTestCase
3+
from allennlp.data import Instance
4+
from allennlp.data.fields import TextField, LabelField
5+
from allennlp.data.tokenizers import Token
6+
7+
class TestInstance(AllenNlpTestCase):
8+
def test_instance_implements_mutable_mapping(self):
9+
words_field = TextField([Token("hello")], {})
10+
label_field = LabelField(1, skip_indexing=True)
11+
instance = Instance({"words": words_field, "labels": label_field})
12+
13+
assert instance["words"] == words_field
14+
assert instance["labels"] == label_field
15+
assert len(instance) == 2
16+
17+
keys = {k for k, v in instance.items()}
18+
assert keys == {"words", "labels"}
19+
20+
values = [v for k, v in instance.items()]
21+
assert words_field in values
22+
assert label_field in values

allennlp/tests/data/iterators/basic_iterator_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def test_shuffle(self):
155155
assert in_order_batches != shuffled_batches
156156

157157
# But not the counts of the instances.
158-
in_order_counts = Counter(instance for batch in in_order_batches for instance in batch)
159-
shuffled_counts = Counter(instance for batch in shuffled_batches for instance in batch)
158+
in_order_counts = Counter(id(instance) for batch in in_order_batches for instance in batch)
159+
shuffled_counts = Counter(id(instance) for batch in shuffled_batches for instance in batch)
160160
assert in_order_counts == shuffled_counts
161161

162162

0 commit comments

Comments
 (0)