|
1 | 1 | from typing import Dict, List
|
| 2 | +import warnings |
2 | 3 |
|
3 | 4 | import torch
|
4 | 5 | from overrides import overrides
|
@@ -91,16 +92,42 @@ def forward(self, text_field_input: Dict[str, torch.Tensor], num_wrapping_dims:
|
91 | 92 | # This is some unusual logic, it needs a custom from_params.
|
92 | 93 | @classmethod
|
93 | 94 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder': # type: ignore
|
94 |
| - # pylint: disable=arguments-differ |
| 95 | + # pylint: disable=arguments-differ,bad-super-call |
| 96 | + |
| 97 | + # The original `from_params` for this class was designed in a way that didn't agree |
| 98 | + # with the constructor. The constructor wants a 'token_embedders' parameter that is a |
| 99 | + # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those |
| 100 | + # key-value pairs to be top-level in the params object. |
| 101 | + # |
| 102 | + # This breaks our 'configuration wizard' and configuration checks. Hence, going forward, |
| 103 | + # the params need a 'token_embedders' key so that they line up with what the constructor wants. |
| 104 | + # For now, the old behavior is still supported, but produces a DeprecationWarning. |
| 105 | + |
95 | 106 | embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None)
|
96 | 107 | if embedder_to_indexer_map is not None:
|
97 | 108 | embedder_to_indexer_map = embedder_to_indexer_map.as_dict(quiet=True)
|
98 | 109 | allow_unmatched_keys = params.pop_bool("allow_unmatched_keys", False)
|
99 | 110 |
|
100 |
| - token_embedders = {} |
101 |
| - keys = list(params.keys()) |
102 |
| - for key in keys: |
103 |
| - embedder_params = params.pop(key) |
104 |
| - token_embedders[key] = TokenEmbedder.from_params(vocab=vocab, params=embedder_params) |
| 111 | + token_embedder_params = params.pop('token_embedders', None) |
| 112 | + |
| 113 | + if token_embedder_params is not None: |
| 114 | + # New way: explicitly specified, so use it. |
| 115 | + token_embedders = { |
| 116 | + name: TokenEmbedder.from_params(subparams, vocab=vocab) |
| 117 | + for name, subparams in token_embedder_params.items() |
| 118 | + } |
| 119 | + |
| 120 | + else: |
| 121 | + # Warn that the original behavior is deprecated |
| 122 | + warnings.warn(DeprecationWarning("the token embedders for BasicTextFieldEmbedder should now " |
| 123 | + "be specified as a dict under the 'token_embedders' key, " |
| 124 | + "not as top-level key-value pairs")) |
| 125 | + |
| 126 | + token_embedders = {} |
| 127 | + keys = list(params.keys()) |
| 128 | + for key in keys: |
| 129 | + embedder_params = params.pop(key) |
| 130 | + token_embedders[key] = TokenEmbedder.from_params(vocab=vocab, params=embedder_params) |
| 131 | + |
105 | 132 | params.assert_empty(cls.__name__)
|
106 | 133 | return cls(token_embedders, embedder_to_indexer_map, allow_unmatched_keys)
|
0 commit comments