Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 174f539

Browse files
authored
enable pickling for vocabulary (#2391)
* enable pickling for vocabulary * pylint
1 parent 2d29736 commit 174f539

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

allennlp/data/vocabulary.py

+33
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import codecs
7+
import copy
78
import logging
89
import os
910
from collections import defaultdict
@@ -232,6 +233,38 @@ def __init__(self,
232233
tokens_to_add,
233234
min_pretrained_embeddings)
234235

236+
237+
def __getstate__(self):
238+
"""
239+
Need to sanitize defaultdict and defaultdict-like objects
240+
by converting them to vanilla dicts when we pickle the vocabulary.
241+
"""
242+
state = copy.copy(self.__dict__)
243+
state["_token_to_index"] = dict(state["_token_to_index"])
244+
state["_index_to_token"] = dict(state["_index_to_token"])
245+
246+
if "_retained_counter" in state:
247+
state["_retained_counter"] = {key: dict(value)
248+
for key, value in state["_retained_counter"].items()}
249+
250+
return state
251+
252+
def __setstate__(self, state):
253+
"""
254+
Conversely, when we unpickle, we need to reload the plain dicts
255+
into our special DefaultDict subclasses.
256+
"""
257+
# pylint: disable=attribute-defined-outside-init
258+
self.__dict__ = copy.copy(state)
259+
self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces,
260+
self._padding_token,
261+
self._oov_token)
262+
self._token_to_index.update(state["_token_to_index"])
263+
self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces,
264+
self._padding_token,
265+
self._oov_token)
266+
self._index_to_token.update(state["_index_to_token"])
267+
235268
def save_to_files(self, directory: str) -> None:
236269
"""
237270
Persist this Vocabulary to files so it can be reloaded later.

allennlp/tests/data/vocabulary_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import codecs
2+
import pickle
23
import gzip
34
import zipfile
45
from copy import deepcopy
@@ -29,6 +30,19 @@ def setUp(self):
2930
self.dataset = Batch([self.instance])
3031
super(TestVocabulary, self).setUp()
3132

33+
def test_pickling(self):
34+
vocab = Vocabulary.from_instances(self.dataset)
35+
36+
pickled = pickle.dumps(vocab)
37+
unpickled = pickle.loads(pickled)
38+
39+
assert dict(unpickled._index_to_token) == dict(vocab._index_to_token)
40+
assert dict(unpickled._token_to_index) == dict(vocab._token_to_index)
41+
assert unpickled._non_padded_namespaces == vocab._non_padded_namespaces
42+
assert unpickled._oov_token == vocab._oov_token
43+
assert unpickled._padding_token == vocab._padding_token
44+
assert unpickled._retained_counter == vocab._retained_counter
45+
3246
def test_from_dataset_respects_max_vocab_size_single_int(self):
3347
max_vocab_size = 1
3448
vocab = Vocabulary.from_instances(self.dataset, max_vocab_size=max_vocab_size)

0 commit comments

Comments
 (0)