Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6c7e1e7

Browse files
committed
pickler override for np ndarrays
1 parent 746cbc5 commit 6c7e1e7

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

python/mxnet/gluon/data/dataloader.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,38 @@ def reduce_ndarray(data):
7474

7575
ForkingPickler.register(nd.NDArray, reduce_ndarray)
7676

77+
if sys.platform == 'darwin' or sys.platform == 'win32':
78+
def rebuild_ndarray(*args):
79+
"""Rebuild ndarray from pickled shared memory"""
80+
# pylint: disable=no-value-for-parameter
81+
return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(*args))
82+
83+
def reduce_ndarray(data):
84+
"""Reduce ndarray to shared memory handle"""
85+
return rebuild_ndarray, data._to_shared_mem()
86+
else:
87+
def rebuild_ndarray(pid, fd, shape, dtype):
88+
"""Rebuild ndarray from pickled shared memory"""
89+
# pylint: disable=no-value-for-parameter
90+
if sys.version_info[0] == 2:
91+
fd = multiprocessing.reduction.rebuild_handle(fd)
92+
else:
93+
fd = fd.detach()
94+
return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype))
95+
96+
def reduce_ndarray(data):
97+
"""Reduce ndarray to shared memory handle"""
98+
# keep a local ref before duplicating fd
99+
data = data.as_in_context(context.Context('cpu_shared', 0))
100+
pid, fd, shape, dtype = data._to_shared_mem()
101+
if sys.version_info[0] == 2:
102+
fd = multiprocessing.reduction.reduce_handle(fd)
103+
else:
104+
fd = multiprocessing.reduction.DupFd(fd)
105+
return rebuild_ndarray, (pid, fd, shape, dtype)
106+
107+
ForkingPickler.register(_mx_np.ndarray, reduce_ndarray)
108+
77109

78110
class ConnectionWrapper(object):
79111
"""Connection wrapper for multiprocessing that supports sending

tests/python/unittest/test_numpy_ndarray.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,18 @@ def test_np_get_dtype():
10981098
assert type(mx_arr.dtype) == type(np_arr.dtype)
10991099

11001100

1101+
@use_np
1102+
def test_np_ndarray_pickle():
1103+
a = np.random.uniform(size=(4, 5))
1104+
a_copy = a.copy()
1105+
import pickle
1106+
with open("np_ndarray_pickle_test_file", 'wb') as f:
1107+
pickle.dump(a_copy, f)
1108+
with open("np_ndarray_pickle_test_file", 'rb') as f:
1109+
a_load = pickle.load(f)
1110+
same(a.asnumpy(), a_load.asnumpy())
1111+
1112+
11011113
if __name__ == '__main__':
11021114
import nose
11031115
nose.runmodule()

0 commit comments

Comments
 (0)