Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5a928f9

Browse files
haojin2reminisce
authored andcommitted
numpy infra residual part
1 parent 196d1f4 commit 5a928f9

File tree

15 files changed

+318
-107
lines changed

15 files changed

+318
-107
lines changed

python/mxnet/contrib/text/embedding.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
from ... import ndarray as nd
3636
from ... import registry
3737
from ... import base
38+
from ...util import is_np_array
39+
from ... import numpy as _mx_np
40+
from ... import numpy_extension as _mx_npx
3841

3942

4043
def register(embedding_cls):
@@ -295,12 +298,15 @@ def _load_embedding(self, pretrained_file_path, elem_delim, init_unknown_vec, en
295298
tokens.add(token)
296299

297300
self._vec_len = vec_len
298-
self._idx_to_vec = nd.array(all_elems).reshape((-1, self.vec_len))
301+
array_fn = _mx_np.array if is_np_array() else nd.array
302+
self._idx_to_vec = array_fn(all_elems).reshape((-1, self.vec_len))
299303

300304
if loaded_unknown_vec is None:
301-
self._idx_to_vec[C.UNKNOWN_IDX] = init_unknown_vec(shape=self.vec_len)
305+
init_val = init_unknown_vec(shape=self.vec_len)
306+
self._idx_to_vec[C.UNKNOWN_IDX] =\
307+
init_val.as_np_ndarray() if is_np_array() else init_val
302308
else:
303-
self._idx_to_vec[C.UNKNOWN_IDX] = nd.array(loaded_unknown_vec)
309+
self._idx_to_vec[C.UNKNOWN_IDX] = array_fn(loaded_unknown_vec)
304310

305311
def _index_tokens_from_vocabulary(self, vocabulary):
306312
self._token_to_idx = vocabulary.token_to_idx.copy() \
@@ -328,7 +334,8 @@ def _set_idx_to_vec_by_embeddings(self, token_embeddings, vocab_len, vocab_idx_t
328334
"""
329335

330336
new_vec_len = sum(embed.vec_len for embed in token_embeddings)
331-
new_idx_to_vec = nd.zeros(shape=(vocab_len, new_vec_len))
337+
zeros_fn = _mx_np.zeros if is_np_array() else nd.zeros
338+
new_idx_to_vec = zeros_fn(shape=(vocab_len, new_vec_len))
332339

333340
col_start = 0
334341
# Concatenate all the embedding vectors in token_embeddings.
@@ -397,7 +404,13 @@ def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
397404
else self.token_to_idx.get(token.lower(), C.UNKNOWN_IDX)
398405
for token in tokens]
399406

400-
vecs = nd.Embedding(nd.array(indices), self.idx_to_vec, self.idx_to_vec.shape[0],
407+
if is_np_array():
408+
embedding_fn = _mx_npx.embedding
409+
array_fn = _mx_np.array
410+
else:
411+
embedding_fn = nd.Embedding
412+
array_fn = nd.array
413+
vecs = embedding_fn(array_fn(indices), self.idx_to_vec, self.idx_to_vec.shape[0],
401414
self.idx_to_vec.shape[1])
402415

403416
return vecs[0] if to_reduce else vecs
@@ -425,7 +438,8 @@ def update_token_vectors(self, tokens, new_vectors):
425438
if not isinstance(tokens, list):
426439
tokens = [tokens]
427440
if len(new_vectors.shape) == 1:
428-
new_vectors = new_vectors.expand_dims(0)
441+
expand_dims_fn = _mx_np.expand_dims if is_np_array() else nd.expand_dims
442+
new_vectors = expand_dims_fn(new_vectors, axis=0)
429443

430444
else:
431445
assert isinstance(new_vectors, nd.NDArray) and len(new_vectors.shape) == 2, \
@@ -444,7 +458,8 @@ def update_token_vectors(self, tokens, new_vectors):
444458
'`unknown_token` %s in `tokens`. This is to avoid unintended '
445459
'updates.' % (token, self.idx_to_token[C.UNKNOWN_IDX]))
446460

447-
self._idx_to_vec[nd.array(indices)] = new_vectors
461+
array_fn = _mx_np.array if is_np_array() else nd.array
462+
self._idx_to_vec[array_fn(indices)] = new_vectors
448463

449464
@classmethod
450465
def _check_pretrained_file_names(cls, pretrained_file_name):

python/mxnet/gluon/block.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -775,8 +775,7 @@ def _get_graph(self, *args):
775775
grouped_inputs = _regroup(inputs, self._in_format)[0]
776776

777777
params = {i: j.var() for i, j in self._reg_params.items()}
778-
with self.name_scope():
779-
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
778+
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
780779
out, self._out_format = _flatten(out, "output")
781780

782781
self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out))
@@ -960,8 +959,7 @@ def forward(self, x, *args):
960959
"HybridBlock requires the first argument to forward be either " \
961960
"Symbol or NDArray, but got %s"%type(x)
962961
params = {i: j.var() for i, j in self._reg_params.items()}
963-
with self.name_scope():
964-
return self.hybrid_forward(symbol, x, *args, **params)
962+
return self.hybrid_forward(symbol, x, *args, **params)
965963

966964
def hybrid_forward(self, F, x, *args, **kwargs):
967965
"""Overrides to construct symbolic graph for this `Block`.

python/mxnet/gluon/data/dataloader.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# coding: utf-8
1919
# pylint: disable=ungrouped-imports
2020
"""Dataset generator."""
21+
from __future__ import absolute_import
2122
__all__ = ['DataLoader']
2223

2324
import pickle
@@ -37,6 +38,8 @@
3738

3839
from . import sampler as _sampler
3940
from ... import nd, context
41+
from ...util import is_np_shape, is_np_array, set_np
42+
from ... import numpy as _mx_np
4043

4144
if sys.platform == 'darwin' or sys.platform == 'win32':
4245
def rebuild_ndarray(*args):
@@ -128,27 +131,33 @@ def __init__(self, *args, **kwargs):
128131
def default_batchify_fn(data):
129132
"""Collate data into batch."""
130133
if isinstance(data[0], nd.NDArray):
131-
return nd.stack(*data)
134+
return _mx_np.stack(data) if is_np_array() else nd.stack(*data)
132135
elif isinstance(data[0], tuple):
133136
data = zip(*data)
134137
return [default_batchify_fn(i) for i in data]
135138
else:
136139
data = np.asarray(data)
137-
return nd.array(data, dtype=data.dtype)
140+
array_fn = _mx_np.array if is_np_array() else nd.array
141+
return array_fn(data, dtype=data.dtype)
138142

139143

140144
def default_mp_batchify_fn(data):
141145
"""Collate data into batch. Use shared memory for stacking."""
142146
if isinstance(data[0], nd.NDArray):
143-
out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype,
147+
empty_fn = _mx_np.empty if is_np_array() else nd.empty
148+
out = empty_fn((len(data),) + data[0].shape, dtype=data[0].dtype,
144149
ctx=context.Context('cpu_shared', 0))
145-
return nd.stack(*data, out=out)
150+
if is_np_array():
151+
return _mx_np.stack(data, out=out)
152+
else:
153+
return nd.stack(*data, out=out)
146154
elif isinstance(data[0], tuple):
147155
data = zip(*data)
148156
return [default_mp_batchify_fn(i) for i in data]
149157
else:
150158
data = np.asarray(data)
151-
return nd.array(data, dtype=data.dtype,
159+
array_fn = _mx_np.array if is_np_array() else nd.array
160+
return array_fn(data, dtype=data.dtype,
152161
ctx=context.Context('cpu_shared', 0))
153162

154163

@@ -384,14 +393,20 @@ def __len__(self):
384393
return len(self._batch_sampler)
385394

386395

396+
def _thread_worker_initializer(active_shape, active_array):
397+
"""Initializer for ThreadPool."""
398+
set_np(shape=active_shape, array=active_array)
399+
400+
387401
_worker_dataset = None
388-
def _worker_initializer(dataset):
402+
def _worker_initializer(dataset, active_shape, active_array):
389403
"""Initialier for processing pool."""
390404
# global dataset is per-process based and only available in worker processes
391405
# this is only necessary to handle MXIndexedRecordIO because otherwise dataset
392406
# can be passed as argument
393407
global _worker_dataset
394408
_worker_dataset = dataset
409+
set_np(shape=active_shape, array=active_array)
395410

396411
def _worker_fn(samples, batchify_fn, dataset=None):
397412
"""Function for processing data in worker process."""
@@ -558,10 +573,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
558573
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
559574
if self._num_workers > 0:
560575
if self._thread_pool:
561-
self._worker_pool = ThreadPool(self._num_workers)
576+
self._worker_pool = ThreadPool(self._num_workers,
577+
initializer=_thread_worker_initializer,
578+
initargs=(is_np_shape(), is_np_array()))
562579
else:
563580
self._worker_pool = multiprocessing.Pool(
564-
self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
581+
self._num_workers, initializer=_worker_initializer,
582+
initargs=[self._dataset, is_np_shape(), is_np_array()])
565583
if batchify_fn is None:
566584
if num_workers > 0:
567585
self._batchify_fn = default_mp_batchify_fn

python/mxnet/gluon/data/vision/datasets.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from .. import dataset
3232
from ...utils import download, check_sha1, _get_repo_file_url
3333
from .... import nd, image, recordio, base
34+
from .... import numpy as _mx_np # pylint: disable=reimported
35+
from ....util import is_np_array
3436

3537

3638
class MNIST(dataset._DownloadedDataset):
@@ -81,13 +83,16 @@ def _get_data(self):
8183
with gzip.open(label_file, 'rb') as fin:
8284
struct.unpack(">II", fin.read(8))
8385
label = np.frombuffer(fin.read(), dtype=np.uint8).astype(np.int32)
86+
if is_np_array():
87+
label = _mx_np.array(label, dtype=label.dtype)
8488

8589
with gzip.open(data_file, 'rb') as fin:
8690
struct.unpack(">IIII", fin.read(16))
8791
data = np.frombuffer(fin.read(), dtype=np.uint8)
8892
data = data.reshape(len(label), 28, 28, 1)
8993

90-
self._data = nd.array(data, dtype=data.dtype)
94+
array_fn = _mx_np.array if is_np_array() else nd.array
95+
self._data = array_fn(data, dtype=data.dtype)
9196
self._label = label
9297

9398

@@ -183,8 +188,9 @@ def _get_data(self):
183188
data = np.concatenate(data)
184189
label = np.concatenate(label)
185190

186-
self._data = nd.array(data, dtype=data.dtype)
187-
self._label = label
191+
array_fn = _mx_np.array if is_np_array() else nd.array
192+
self._data = array_fn(data, dtype=data.dtype)
193+
self._label = array_fn(label, dtype=label.dtype) if is_np_array() else label
188194

189195

190196
class CIFAR100(CIFAR10):

python/mxnet/gluon/data/vision/transforms.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...nn import Sequential, HybridSequential
2424
from .... import image
2525
from ....base import numeric_types
26+
from ....util import is_np_array
2627

2728

2829
class Compose(Sequential):
@@ -92,6 +93,8 @@ def __init__(self, dtype='float32'):
9293
self._dtype = dtype
9394

9495
def hybrid_forward(self, F, x):
96+
if is_np_array():
97+
F = F.npx
9598
return F.cast(x, self._dtype)
9699

97100

@@ -134,6 +137,8 @@ def __init__(self):
134137
super(ToTensor, self).__init__()
135138

136139
def hybrid_forward(self, F, x):
140+
if is_np_array():
141+
F = F.npx
137142
return F.image.to_tensor(x)
138143

139144

@@ -187,6 +192,8 @@ def __init__(self, mean=0.0, std=1.0):
187192
self._std = std
188193

189194
def hybrid_forward(self, F, x):
195+
if is_np_array():
196+
F = F.npx
190197
return F.image.normalize(x, self._mean, self._std)
191198

192199

@@ -369,6 +376,8 @@ def __init__(self, size, keep_ratio=False, interpolation=1):
369376
self._interpolation = interpolation
370377

371378
def hybrid_forward(self, F, x):
379+
if is_np_array():
380+
F = F.npx
372381
return F.image.resize(x, self._size, self._keep, self._interpolation)
373382

374383
class RandomFlipLeftRight(HybridBlock):
@@ -385,6 +394,8 @@ def __init__(self):
385394
super(RandomFlipLeftRight, self).__init__()
386395

387396
def hybrid_forward(self, F, x):
397+
if is_np_array():
398+
F = F.npx
388399
return F.image.random_flip_left_right(x)
389400

390401

@@ -402,6 +413,8 @@ def __init__(self):
402413
super(RandomFlipTopBottom, self).__init__()
403414

404415
def hybrid_forward(self, F, x):
416+
if is_np_array():
417+
F = F.npx
405418
return F.image.random_flip_top_bottom(x)
406419

407420

@@ -427,6 +440,8 @@ def __init__(self, brightness):
427440
self._args = (max(0, 1-brightness), 1+brightness)
428441

429442
def hybrid_forward(self, F, x):
443+
if is_np_array():
444+
F = F.npx
430445
return F.image.random_brightness(x, *self._args)
431446

432447

@@ -452,6 +467,8 @@ def __init__(self, contrast):
452467
self._args = (max(0, 1-contrast), 1+contrast)
453468

454469
def hybrid_forward(self, F, x):
470+
if is_np_array():
471+
F = F.npx
455472
return F.image.random_contrast(x, *self._args)
456473

457474

@@ -477,6 +494,8 @@ def __init__(self, saturation):
477494
self._args = (max(0, 1-saturation), 1+saturation)
478495

479496
def hybrid_forward(self, F, x):
497+
if is_np_array():
498+
F = F.npx
480499
return F.image.random_saturation(x, *self._args)
481500

482501

@@ -502,6 +521,8 @@ def __init__(self, hue):
502521
self._args = (max(0, 1-hue), 1+hue)
503522

504523
def hybrid_forward(self, F, x):
524+
if is_np_array():
525+
F = F.npx
505526
return F.image.random_hue(x, *self._args)
506527

507528

@@ -536,6 +557,8 @@ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
536557
self._args = (brightness, contrast, saturation, hue)
537558

538559
def hybrid_forward(self, F, x):
560+
if is_np_array():
561+
F = F.npx
539562
return F.image.random_color_jitter(x, *self._args)
540563

541564

@@ -559,4 +582,6 @@ def __init__(self, alpha):
559582
self._alpha = alpha
560583

561584
def hybrid_forward(self, F, x):
585+
if is_np_array():
586+
F = F.npx
562587
return F.image.random_lighting(x, self._alpha)

0 commit comments

Comments
 (0)