|
4 | 4 | """
|
5 | 5 |
|
6 | 6 | import codecs
|
| 7 | +import copy |
7 | 8 | import logging
|
8 | 9 | import os
|
9 | 10 | from collections import defaultdict
|
@@ -232,6 +233,38 @@ def __init__(self,
|
232 | 233 | tokens_to_add,
|
233 | 234 | min_pretrained_embeddings)
|
234 | 235 |
|
| 236 | + |
| 237 | + def __getstate__(self): |
| 238 | + """ |
| 239 | + Need to sanitize defaultdict and defaultdict-like objects |
| 240 | + by converting them to vanilla dicts when we pickle the vocabulary. |
| 241 | + """ |
| 242 | + state = copy.copy(self.__dict__) |
| 243 | + state["_token_to_index"] = dict(state["_token_to_index"]) |
| 244 | + state["_index_to_token"] = dict(state["_index_to_token"]) |
| 245 | + |
| 246 | + if "_retained_counter" in state: |
| 247 | + state["_retained_counter"] = {key: dict(value) |
| 248 | + for key, value in state["_retained_counter"].items()} |
| 249 | + |
| 250 | + return state |
| 251 | + |
| 252 | + def __setstate__(self, state): |
| 253 | + """ |
| 254 | + Conversely, when we unpickle, we need to reload the plain dicts |
| 255 | + into our special DefaultDict subclasses. |
| 256 | + """ |
| 257 | + # pylint: disable=attribute-defined-outside-init |
| 258 | + self.__dict__ = copy.copy(state) |
| 259 | + self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces, |
| 260 | + self._padding_token, |
| 261 | + self._oov_token) |
| 262 | + self._token_to_index.update(state["_token_to_index"]) |
| 263 | + self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces, |
| 264 | + self._padding_token, |
| 265 | + self._oov_token) |
| 266 | + self._index_to_token.update(state["_index_to_token"]) |
| 267 | + |
235 | 268 | def save_to_files(self, directory: str) -> None:
|
236 | 269 | """
|
237 | 270 | Persist this Vocabulary to files so it can be reloaded later.
|
|
0 commit comments