Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 03eaedc

Browse files
committed
np.broadcast_to extension
1 parent 22c7ef7 commit 03eaedc

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

src/operator/numpy/np_broadcast_reduce_op_value.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,17 +467,21 @@ bool NumpyBroadcastToShape(const nnvm::NodeAttrs& attrs,
467467
mxnet::TShape& ishape = (*in_attrs)[0];
468468
if (!mxnet::shape_is_known(ishape)) return false;
469469
const BroadcastToParam& param = nnvm::get<BroadcastToParam>(attrs.parsed);
470-
CHECK(mxnet::shape_is_known(param.shape))
471-
<< "the objective shape for broadcasting array must be known";
472470
CHECK_LE(ishape.ndim(), param.shape.ndim())
473471
<< "shape " << ishape << " is not broadcastable to " << param.shape;
472+
TShape pshape = param.shape;
474473
for (int i = param.shape.ndim() - 1; i >= 0; --i) {
475474
int j = i - param.shape.ndim() + ishape.ndim();
476475
if (j < 0) break;
477-
CHECK(ishape[j] == param.shape[i] || ishape[j] == 1)
478-
<< "shape " << ishape << " is not broadcastable to " << param.shape;
476+
if (pshape[i] == -2) {
477+
pshape[i] = ishape[j];
478+
}
479+
CHECK(ishape[j] == pshape[i] || ishape[j] == 1)
480+
<< "shape " << ishape << " is not broadcastable to " << pshape;
479481
}
480-
SHAPE_ASSIGN_CHECK(*out_attrs, 0, param.shape);
482+
CHECK(mxnet::shape_is_known(pshape))
483+
<< "the objective shape for broadcasting array must be known";
484+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, pshape);
481485
return true;
482486
}
483487

tests/python/unittest/test_numpy_op.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,7 @@ def hybrid_forward(self, F, x):
15501550
((4, 1), (1, 2, 3, 4, 5)),
15511551
((4, 1), (1, 0, 3, 4, 5))
15521552
]
1553+
15531554
for src_shape, dst_shape in shapes:
15541555
for hybridize in [True, False]:
15551556
test_broadcast_to = TestBroadcastTo(dst_shape)
@@ -1578,6 +1579,32 @@ def hybrid_forward(self, F, x):
15781579
ret = test_scalar_broadcast_to(np.empty(()))
15791580
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False)
15801581

1582+
# Test npx functionality
1583+
shapes = [
1584+
((5,), (3, 4, -2), (3, 4, 5)),
1585+
((5,), (0, -2), (0, 5)),
1586+
((1, 0), (2, -2, -2), (2, 1, 0)),
1587+
((3, 4), (1, 2, 3, -2), (1, 2, 3, 4)),
1588+
((3, 4), (1, 0, -2, 4), (1, 0, 3, 4))
1589+
]
1590+
1591+
for src_shape, npx_dst_shape, np_dst_shape in shapes:
1592+
for hybridize in [True, False]:
1593+
test_broadcast_to = TestBroadcastTo(npx_dst_shape)
1594+
if hybridize:
1595+
test_broadcast_to.hybridize()
1596+
1597+
a = _np.random.uniform(size=src_shape).astype(np.float32)
1598+
expected_ret = _np.broadcast_to(a, np_dst_shape)
1599+
a_mx = np.array(a, dtype=a.dtype)
1600+
a_mx.attach_grad()
1601+
with mx.autograd.record():
1602+
ret = test_broadcast_to(a_mx)
1603+
assert_almost_equal(ret.asnumpy(), expected_ret, rtol=1e-5, atol=1e-6, use_broadcast=False)
1604+
ret.backward()
1605+
expected_grad = collapse_sum_like(_np.ones_like(expected_ret), src_shape)
1606+
assert_almost_equal(a_mx.grad.asnumpy(), expected_grad, rtol=1e-5, atol=1e-6, use_broadcast=False)
1607+
15811608

15821609
@with_seed()
15831610
@use_np

0 commit comments

Comments
 (0)