Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit bd7eedf

Browse files
zhresholdptrendx
authored andcommitted
Fix #17164 symbolblock with BatchNorm inside during cast to fp16 (#17212)
* fix symbolblock with bn+fp16 * add unittest * fix * remove unused * fix lint
1 parent 1612533 commit bd7eedf

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

python/mxnet/gluon/block.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
# coding: utf-8
19-
# pylint: disable= arguments-differ, too-many-lines
19+
# pylint: disable= arguments-differ, too-many-lines, reimported
2020
"""Base container class for all neural network models."""
2121
__all__ = ['Block', 'HybridBlock', 'SymbolBlock']
2222

@@ -25,6 +25,7 @@
2525
import warnings
2626
import re
2727
from collections import OrderedDict, defaultdict
28+
import numpy as np
2829

2930
from ..base import mx_real_t, MXNetError
3031
from .. import symbol, ndarray, initializer, np_symbol
@@ -1353,6 +1354,28 @@ def _clear_cached_op(self):
13531354
def cast(self, dtype):
13541355
self._clear_cached_op()
13551356
super(SymbolBlock, self).cast(dtype)
1357+
if np.dtype(dtype).name == 'float16':
1358+
# correct BatchNorm types back to float32 due to its special requirement
1359+
out = self._cached_graph[1]
1360+
params_list = out.get_internals().list_inputs()
1361+
for node in params_list:
1362+
if node.endswith('running_var'):
1363+
prefix = node[:-11]
1364+
sibs = [prefix + t for t in ('running_mean', 'gamma', 'beta')]
1365+
is_bn = all(p in params_list for p in sibs)
1366+
if is_bn:
1367+
self.params.get(node).cast('float32')
1368+
for sib in sibs:
1369+
self.params.get(sib).cast('float32')
1370+
if node.endswith('moving_var'):
1371+
# another convention used
1372+
prefix = node[:-10]
1373+
sibs = [prefix + t for t in ('moving_mean', 'gamma', 'beta')]
1374+
is_bn = all(p in params_list for p in sibs)
1375+
if is_bn:
1376+
self.params.get(node).cast('float32')
1377+
for sib in sibs:
1378+
self.params.get(sib).cast('float32')
13561379

13571380
def hybrid_forward(self, F, x, *args, **kwargs):
13581381
raise NotImplementedError

tests/python/gpu/test_gluon_gpu.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,24 @@ def hybrid_forward(self, F, a, b):
594594
assert_raises(ValueError, lambda: foo_hybrid(mx.nd.ones((10,), ctx=mx.gpu()),
595595
mx.nd.ones((10,), ctx=mx.cpu())))
596596

597+
@with_seed()
598+
def test_symbol_block_symbolic_bn_fp16_cast():
599+
with mx.gpu(0):
600+
net = mx.gluon.nn.HybridSequential()
601+
sym = mx.sym.var('data')
602+
conv = mx.sym.Convolution(sym, kernel=(3, 3), num_filter=16)
603+
bn = mx.sym.BatchNorm(conv, name='bn_test')
604+
internals = bn.get_internals()
605+
net.add(mx.gluon.nn.SymbolBlock([internals['bn_test_output']], [mx.sym.var('data')]))
606+
net.add(mx.gluon.nn.Conv2D(10, kernel_size=1))
607+
net.initialize()
608+
x = mx.nd.zeros((1, 3, 32, 32), dtype='float32')
609+
y = net(x)
610+
assert np.dtype(y.dtype).name == 'float32'
611+
net.cast('float16')
612+
x = x.astype('float16')
613+
y1 = net(x)
614+
assert np.dtype(y1.dtype).name == 'float16'
597615

598616
if __name__ == '__main__':
599617
import nose

0 commit comments

Comments
 (0)