Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Add missing default axis value to symbol.squeeze op #15707

Merged
merged 3 commits into from
Jul 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,7 +2539,7 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument
def squeeze(self, axis=None, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`squeeze`.

The arguments are the same as for :py:func:`squeeze`, with
Expand Down
24 changes: 20 additions & 4 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_parameter_dict():
params1.get('w1', shape=(10, 10), stype='row_sparse')
params1.load('test_parameter_dict.params', ctx)
trainer1 = mx.gluon.Trainer(params1, 'sgd')

# compare the values before and after save/load
cur_w0 = params1.get('w0').data(ctx)
cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
Expand All @@ -134,7 +134,7 @@ def test_parameter_dict():
cur_w1 = params2.get('w1').data(ctx)
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())

# test the dtype casting functionality
params0 = gluon.ParameterDict('')
params0.get('w0', shape=(10, 10), dtype='float32')
Expand Down Expand Up @@ -386,7 +386,7 @@ def hybrid_forward(self, F, x):
if 'conv' in param_name and 'weight' in param_name:
break
assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)

# 3.b Verify same functionnality with the imports API
net_fp_64 = mx.gluon.SymbolBlock.imports(sym_file, 'data', params_file, ctx=ctx)

Expand Down Expand Up @@ -2788,7 +2788,7 @@ def test_gluon_param_load():
net.cast('float16')
net.load_parameters('test_gluon_param_load.params', cast_dtype=True)
mx.nd.waitall()

@with_seed()
def test_gluon_param_load_dtype_source():
net = mx.gluon.nn.Dense(10, in_units=10)
Expand All @@ -2800,6 +2800,22 @@ def test_gluon_param_load_dtype_source():
assert net.weight.dtype == np.float16
mx.nd.waitall()

@with_seed()
def test_squeeze_consistency():
class Foo(gluon.HybridBlock):
def __init__(self, inplace, **kwargs):
super(Foo, self).__init__(**kwargs)
self.inplace = inplace

def forward(self, x):
return x.squeeze(inplace=self.inplace)

for inplace in (True, False):
block = Foo(inplace)
block.hybridize()
shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1)
block(mx.nd.ones(shape))

if __name__ == '__main__':
import nose
nose.runmodule()
1 change: 1 addition & 0 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
check_fluent_regular('reshape', {'shape': (17, 1, 5)})
check_fluent_regular('broadcast_to', {'shape': (5, 17, 47)})
check_fluent_regular('squeeze', {'axis': (1, 3)}, shape=(2, 1, 3, 1, 4))
check_fluent_regular('squeeze', {}, shape=(2, 1, 3, 1, 4))

def check_symbol_consistency(sym1, sym2, ctx, skip_grad=False, equal_nan=False):
assert sym1.list_arguments() == sym2.list_arguments()
Expand Down