Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2f4411b

Browse files
committed
Add test
1 parent 04b2bb1 commit 2f4411b

File tree

4 files changed

+123
-4
lines changed

4 files changed

+123
-4
lines changed

python/mxnet/gluon/block.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,8 @@ def _get_graph(self, *args):
775775
grouped_inputs = _regroup(inputs, self._in_format)[0]
776776

777777
params = {i: j.var() for i, j in self._reg_params.items()}
778-
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
778+
with self.name_scope():
779+
out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter
779780
out, self._out_format = _flatten(out, "output")
780781

781782
self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out))
@@ -959,7 +960,8 @@ def forward(self, x, *args):
959960
"HybridBlock requires the first argument to forward be either " \
960961
"Symbol or NDArray, but got %s"%type(x)
961962
params = {i: j.var() for i, j in self._reg_params.items()}
962-
return self.hybrid_forward(symbol, x, *args, **params)
963+
with self.name_scope():
964+
return self.hybrid_forward(symbol, x, *args, **params)
963965

964966
def hybrid_forward(self, F, x, *args, **kwargs):
965967
"""Overrides to construct symbolic graph for this `Block`.

python/mxnet/numpy/multiarray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,9 @@ def copy(self, order='C'): # pylint: disable=arguments-differ
782782
return super(ndarray, self).copy().as_np_ndarray()
783783

784784
def dot(self, b, out=None):
785-
raise NotImplementedError
785+
"""Dot product of two arrays.
786+
Refer to ``numpy.dot`` for full documentation."""
787+
return _mx_np_op.dot(self, b, out=out)
786788

787789
def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ
788790
"""Returns an array containing the same data with a new shape.

python/mxnet/symbol/numpy/_symbol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ
218218
raise NotImplementedError
219219

220220
def dot(self, b, out=None):
221-
raise NotImplementedError
221+
"""Dot product of two arrays.
222+
Refer to ``numpy.dot`` for full documentation."""
223+
return _mx_np_op.dot(self, b, out=out)
222224

223225
def reshape(self, *args, **kwargs): # pylint: disable=arguments-differ
224226
"""Returns an array containing the same data with a new shape.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# pylint: skip-file
19+
from __future__ import absolute_import
20+
from __future__ import division
21+
22+
import mxnet as mx
23+
from mxnet import gluon, autograd, np
24+
from mxnet.test_utils import use_np
25+
26+
27+
def test_create_np_param():
28+
M, K, N = 10, 9, 20
29+
30+
def check_block_params(x, TestBlock, hybridize, expected_type):
31+
net = TestBlock()
32+
net.initialize()
33+
if hybridize:
34+
net.hybridize()
35+
net(x)
36+
params = net.collect_params()
37+
for k, v in params.items():
38+
assert type(v.data()) is expected_type
39+
40+
class TestBlock1(gluon.HybridBlock):
41+
def __init__(self):
42+
super(TestBlock1, self).__init__()
43+
with self.name_scope():
44+
self.w = self.params.get('w', shape=(K, N), allow_deferred_init=True)
45+
46+
def hybrid_forward(self, F, x, w):
47+
return F.dot(x, w)
48+
49+
@use_np
50+
class TestBlock2(gluon.HybridBlock):
51+
def __init__(self):
52+
super(TestBlock2, self).__init__()
53+
with self.name_scope():
54+
self.w = self.params.get('w', shape=(K, N), allow_deferred_init=True)
55+
56+
def hybrid_forward(self, F, x, w):
57+
return F.np.dot(x, w)
58+
59+
x = mx.nd.random.uniform(shape=(M, K))
60+
check_block_params(x, TestBlock1, False, mx.nd.NDArray)
61+
check_block_params(x, TestBlock1, True, mx.nd.NDArray)
62+
check_block_params(x.as_np_ndarray(), TestBlock2, False, np.ndarray)
63+
check_block_params(x.as_np_ndarray(), TestBlock2, True, np.ndarray)
64+
65+
66+
@use_np
67+
def test_optimizer_with_np_ndarrays():
68+
class LinearRegression(gluon.HybridBlock):
69+
def __init__(self, num_input_dim=0, num_hidden_dim=100, num_output_dim=10):
70+
super(LinearRegression, self).__init__()
71+
with self.name_scope():
72+
self.w1 = self.params.get('w1', shape=(num_input_dim, num_hidden_dim),
73+
allow_deferred_init=True)
74+
self.w2 = self.params.get('w2', shape=(num_hidden_dim, num_output_dim),
75+
allow_deferred_init=True)
76+
77+
def hybrid_forward(self, F, x, w1, w2):
78+
h = x.dot(w1) # equivalent to F.np.dot(x, w1)
79+
h_relu = F.npx.relu(h) # equivalent to F.relu(h) but generating np.ndarray
80+
y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)
81+
return y_pred
82+
83+
class TotalLoss(gluon.HybridBlock):
84+
def hybrid_forward(self, F, pred, label):
85+
return ((pred - label) ** 2).sum() # equivalent to F.np.sum(F.np.square(pred - label))
86+
87+
regressor = LinearRegression()
88+
regressor.initialize(mx.init.Uniform())
89+
regressor.hybridize()
90+
91+
# Create random input and output data
92+
x = np.random.uniform(size=(64, 1000)) # x is of type mxnet.numpy.ndarray
93+
regressor(x)
94+
y = np.random.uniform(size=(64, 10)) # y is of type mxnet.numpy.ndarray
95+
96+
total_loss = TotalLoss()
97+
total_loss.hybridize()
98+
99+
trainer = gluon.Trainer(regressor.collect_params(),
100+
'sgd',
101+
{'learning_rate': 1e-3, 'momentum': 0.9})
102+
103+
for t in range(2):
104+
with autograd.record():
105+
output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network
106+
loss = total_loss(output, y) # loss is a scalar np.ndarray
107+
loss.backward()
108+
trainer.step(1)
109+
110+
111+
if __name__ == '__main__':
112+
import nose
113+
nose.runmodule()

0 commit comments

Comments
 (0)