Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 82ed82f

Browse files
drivanovapeforest
authored andcommitted
Aggregated zero grad (#16446)
* Trigger CI * Aggregated zeroing of the gradients/arrays * New files for aggregated zeroing of the gradients/arrays * Adding possibility to reset the arrays of different types. * Minor cleanup
1 parent 8c22fac commit 82ed82f

File tree

5 files changed

+281
-12
lines changed

5 files changed

+281
-12
lines changed

python/mxnet/gluon/parameter.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
'ParameterDict', 'tensor_types']
2525

2626

27-
from collections import OrderedDict
27+
from collections import OrderedDict, defaultdict
2828
import warnings
2929
import numpy as np
30+
import mxnet as mx
3031

3132
from ..base import mx_real_t, MXNetError
3233
from .. import symbol, ndarray, initializer, context
@@ -887,8 +888,22 @@ def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False,
887888

888889
def zero_grad(self):
889890
"""Sets all Parameters' gradient buffer to 0."""
890-
for i in self.values():
891-
i.zero_grad()
891+
# collect gradient arrays for each ctx
892+
arrays = defaultdict(list)
893+
for p in self.values():
894+
if p.grad_req == 'null' or p._grad is None:
895+
continue
896+
for g in p.list_grad():
897+
if g.stype == 'row_sparse':
898+
mx.ndarray.zeros_like(g, out=g)
899+
else:
900+
arrays[g.context].append(g)
901+
902+
if len(arrays) == 0:
903+
return
904+
905+
for arr in arrays.values():
906+
mx.nd.reset_arrays(*arr, num_arrays=len(arr))
892907

893908
def reset_ctx(self, ctx):
894909
"""Re-assign all Parameters to other contexts.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file reset_arrays-inl.h
23+
* \brief setting all array element values to zeros
24+
* \author Moises Hernandez-Fernandez, Andrei Ivanov
25+
*/
26+
27+
#ifndef MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_
28+
#define MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_
29+
30+
#include <vector>
31+
#include "../tensor/init_op.h"
32+
33+
namespace mxnet {
34+
namespace op {
35+
36+
struct ResetArraysParam : public dmlc::Parameter<ResetArraysParam> {
37+
int num_arrays;
38+
DMLC_DECLARE_PARAMETER(ResetArraysParam) {
39+
DMLC_DECLARE_FIELD(num_arrays)
40+
.describe("number of input arrays.");
41+
}
42+
};
43+
44+
inline bool ResetArraysShape(const NodeAttrs& attrs,
45+
std::vector<mxnet::TShape>* in_shape,
46+
std::vector<mxnet::TShape>* out_shape) {
47+
const auto& param = dmlc::get<ResetArraysParam>(attrs.parsed);
48+
CHECK_EQ(in_shape->size(), param.num_arrays);
49+
for (auto s : *in_shape) {
50+
if (s.ndim() == 0)
51+
return false;
52+
}
53+
54+
return true;
55+
}
56+
57+
inline bool ResetArraysType(const NodeAttrs& attrs,
58+
std::vector<int>* in_type,
59+
std::vector<int>* out_type) {
60+
const auto& param = dmlc::get<ResetArraysParam>(attrs.parsed);
61+
CHECK_EQ(in_type->size(), param.num_arrays);
62+
for (size_t i = 0; i < in_type->size(); ++i) {
63+
if ((*in_type)[i] == -1)
64+
return false;
65+
}
66+
67+
return true;
68+
}
69+
70+
template<typename xpu>
71+
void ResetMemory(void *pntr, size_t len, mshadow::Stream<xpu> *s);
72+
73+
template<typename xpu>
74+
void ResetArrays(const nnvm::NodeAttrs& attrs,
75+
const OpContext &ctx,
76+
const std::vector<TBlob> &inputs,
77+
const std::vector<OpReqType> &req,
78+
const std::vector<TBlob> &outputs) {
79+
auto s = ctx.get_stream<xpu>();
80+
const auto& param = nnvm::get<ResetArraysParam>(attrs.parsed);
81+
for (int i = 0; i < param.num_arrays; i++) { // array index in inputs
82+
const size_t size = inputs[i].shape_.Size();
83+
MSHADOW_REAL_TYPE_SWITCH(inputs[i].type_flag_, DType,
84+
ResetMemory(inputs[i].FlatTo2D<xpu, DType>(s).dptr_, size * sizeof(DType), s);
85+
)
86+
}
87+
}
88+
89+
} // namespace op
90+
} // namespace mxnet
91+
92+
#endif // MXNET_OPERATOR_CONTRIB_RESET_ARRAYS_INL_H_

src/operator/contrib/reset_arrays.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file reset_arrays.cc
23+
* \brief setting all array element values to zeros
24+
* \author Moises Hernandez-Fernandez, Andrei Ivanov
25+
*/
26+
27+
#include "./reset_arrays-inl.h"
28+
29+
namespace mxnet {
30+
namespace op {
31+
32+
DMLC_REGISTER_PARAMETER(ResetArraysParam);
33+
34+
NNVM_REGISTER_OP(reset_arrays)
35+
.describe(R"code(Set to zero multiple arrays
36+
)code" ADD_FILELINE)
37+
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
38+
return static_cast<uint32_t>(dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays);
39+
})
40+
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
41+
[](const nnvm::NodeAttrs& attrs) {
42+
const uint32_t num_args = dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays;
43+
std::vector<uint32_t> ret;
44+
for (uint32_t i = 0; i < num_args; ++i) {
45+
ret.push_back(i);
46+
}
47+
return ret;
48+
})
49+
.set_num_outputs(0)
50+
.set_attr_parser(ParamParser<ResetArraysParam>)
51+
.set_attr<mxnet::FInferShape>("FInferShape", ResetArraysShape)
52+
.set_attr<nnvm::FInferType>("FInferType", ResetArraysType)
53+
.set_attr<nnvm::FListInputNames>("FListInputNames",
54+
[](const NodeAttrs& attrs) {
55+
const uint32_t num_args = dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays;
56+
std::vector<std::string> ret;
57+
for (uint32_t i = 0; i < num_args; ++i) {
58+
ret.push_back(std::string("array_") + std::to_string(i));
59+
}
60+
return ret;
61+
})
62+
.add_argument("data", "NDArray-or-Symbol[]", "Arrays")
63+
.add_arguments(ResetArraysParam::__FIELDS__());
64+
65+
NNVM_REGISTER_OP(reset_arrays)
66+
.set_attr<FCompute>("FCompute<cpu>", ResetArrays<cpu>);
67+
68+
template<>
69+
void ResetMemory<cpu>(void *pntr, size_t len, mshadow::Stream<cpu> *s) {
70+
memset(pntr, 0, len);
71+
}
72+
73+
} // namespace op
74+
} // namespace mxnet

src/operator/contrib/reset_arrays.cu

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2019 by Contributors
22+
* \file reset_arrays.cu
23+
* \brief setting all array element values to zeros
24+
* \author Moises Hernandez-Fernandez, Andrei Ivanov
25+
*/
26+
#include "./reset_arrays-inl.h"
27+
28+
namespace mxnet {
29+
namespace op {
30+
31+
template<>
32+
void ResetMemory<gpu>(void *pntr, size_t len, mshadow::Stream<gpu> *s) {
33+
CUDA_CALL(cudaMemsetAsync(pntr, 0, len, mshadow::Stream<gpu>::GetStream(s)));
34+
}
35+
36+
NNVM_REGISTER_OP(reset_arrays)
37+
.set_attr<FCompute>("FCompute<gpu>", ResetArrays<gpu>);
38+
39+
} // namespace op
40+
} // namespace mxnet

tests/python/unittest/test_gluon.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import warnings
3434
import json
3535
import unittest
36+
import random
3637

3738
@with_seed()
3839
def test_parameter():
@@ -1504,15 +1505,62 @@ def test_hybrid_multi_context():
15041505

15051506
@with_seed()
15061507
def test_zero_grad():
1507-
data = mx.nd.random.uniform(shape=(3,3))
1508-
net = nn.Embedding(3, 4, sparse_grad=True, prefix='test_zero_grad_')
1509-
net.initialize()
1510-
with mx.autograd.record():
1511-
l = net(data)
1512-
l.backward()
1513-
net.collect_params().zero_grad()
1514-
grad = net.collect_params()['test_zero_grad_weight'].grad()
1515-
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
1508+
def _test_grad_reset(ctx, dtype='float32', sparse=False, embeddingType=None):
1509+
data = mx.nd.random.uniform(shape=(3,3), dtype=dtype, ctx=ctx)
1510+
if embeddingType is None:
1511+
embeddingType = dtype
1512+
net = nn.Embedding(3, 4, sparse_grad=sparse, prefix='test_zero_grad_', dtype=embeddingType)
1513+
net.initialize(ctx=ctx)
1514+
with mx.autograd.record():
1515+
l = net(data)
1516+
l.backward()
1517+
net.collect_params().zero_grad()
1518+
grad = net.collect_params()['test_zero_grad_weight'].grad()
1519+
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
1520+
1521+
def _test_multi_reset(nArrays, dtype, ctx):
1522+
# Construct the list of non-zeros arrays with random shapes
1523+
arr = []
1524+
for _ in range(nArrays):
1525+
arrType = random.choice(dtype) if isinstance(dtype, list) else dtype
1526+
shape = ()
1527+
for _ in range(np.random.randint(1, 5)):
1528+
shape = shape + (np.random.randint(1, 10),)
1529+
arr.append(mx.nd.random.uniform(shape=shape, dtype=arrType, ctx=ctx))
1530+
1531+
# Reset all arrays
1532+
mx.nd.reset_arrays(*arr, num_arrays=len(arr))
1533+
1534+
# Check results
1535+
for i in range(nArrays):
1536+
grad = arr[i].asnumpy()
1537+
assert_almost_equal(grad, grad * 0)
1538+
1539+
1540+
# Setting context for current test
1541+
ctx = mx.context.current_context()
1542+
1543+
# Launching _test_multi_reset 10 times with different types & randomly chosen nArrays
1544+
testedTypes = ['float16', 'float32', 'float64']
1545+
for _ in range(10):
1546+
for type in [testedTypes] + testedTypes:
1547+
_test_multi_reset(np.random.randint(1, 50), type, ctx)
1548+
1549+
# Saving value of environment variable, if it was defined
1550+
envVarKey = 'MXNET_STORAGE_FALLBACK_LOG_VERBOSE'
1551+
envVarValue = os.environ[envVarKey] if envVarKey in os.environ else None
1552+
# Changing value of environment variable
1553+
os.environ[envVarKey] = '0'
1554+
for type in ['float16', 'float32', 'float64']:
1555+
for embType in ['float32', 'float64']:
1556+
for sparse in [True, False]:
1557+
_test_grad_reset(ctx, dtype=type, sparse=sparse, embeddingType=embType)
1558+
1559+
# Remove or restore the value of environment variable
1560+
if envVarValue is None:
1561+
del os.environ[envVarKey]
1562+
else:
1563+
os.environ[envVarKey] = envVarValue
15161564

15171565
def check_hybrid_static_memory(**kwargs):
15181566
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))

0 commit comments

Comments
 (0)