Skip to content

Commit 6a1e440

Browse files
access2rohitptrendx
authored andcommitted
Multi Precision Lamb Update operator (apache#16885)
* multi-precision lamb update operator * removing multi-tensor code from lamb * doing operation beta^t outside of kernel call * removing unecessary functions from PyLAMB
1 parent ff27b4b commit 6a1e440

File tree

6 files changed

+304
-24
lines changed

6 files changed

+304
-24
lines changed

python/mxnet/optimizer/optimizer.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
3535
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
3636
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
37-
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2)
37+
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2,
38+
mp_lamb_update_phase1, mp_lamb_update_phase2)
3839
from ..ndarray import sparse
3940
from ..random import normal
4041
from ..util import is_np_array
@@ -1262,11 +1263,10 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
12621263

12631264
def create_state(self, index, weight):
12641265
stype = weight.stype
1265-
dtype = weight.dtype
1266-
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
1267-
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))
1266+
return (zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype),
1267+
zeros(weight.shape, weight.context, dtype=numpy.float32, stype=stype))
12681268

1269-
def update(self, index, weight, grad, state):
1269+
def _update_impl(self, index, weight, grad, state, multi_precision=False):
12701270
assert(isinstance(weight, NDArray))
12711271
assert(isinstance(grad, NDArray))
12721272
self._update_count(index)
@@ -1277,19 +1277,46 @@ def update(self, index, weight, grad, state):
12771277
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
12781278
'bias_correction': self.bias_correction, 't': t,
12791279
'rescale_grad': self.rescale_grad}
1280-
mean, var = state
1280+
12811281
if self.clip_gradient:
12821282
kwargs['clip_gradient'] = self.clip_gradient
1283-
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)
1284-
1285-
kwargs = {}
1286-
if self.lower_bound:
1287-
kwargs['lower_bound'] = self.lower_bound
1288-
if self.upper_bound:
1289-
kwargs['upper_bound'] = self.upper_bound
1290-
r_1 = weight.norm()
1291-
r_2 = g.norm()
1292-
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)
1283+
1284+
if multi_precision:
1285+
mean, var = state[1]
1286+
weight32 = state[0]
1287+
g = mp_lamb_update_phase1(weight, grad, mean, var, weight32, wd=wd, **kwargs)
1288+
1289+
kwargs = {}
1290+
if self.lower_bound:
1291+
kwargs['lower_bound'] = self.lower_bound
1292+
if self.upper_bound:
1293+
kwargs['upper_bound'] = self.upper_bound
1294+
r_1 = weight32.norm()
1295+
r_2 = g.norm()
1296+
mp_lamb_update_phase2(weight, g, r_1, r_2, weight32, lr=lr, out=weight, **kwargs)
1297+
else:
1298+
mean, var = state
1299+
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)
1300+
1301+
kwargs = {}
1302+
if self.lower_bound:
1303+
kwargs['lower_bound'] = self.lower_bound
1304+
if self.upper_bound:
1305+
kwargs['upper_bound'] = self.upper_bound
1306+
r_1 = weight.norm()
1307+
r_2 = g.norm()
1308+
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)
1309+
1310+
def update(self, index, weight, grad, state):
1311+
self._update_impl(index, weight, grad, state, multi_precision=False)
1312+
1313+
def update_multi_precision(self, index, weight, grad, state):
1314+
if not isinstance(index, (tuple, list)):
1315+
use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
1316+
else:
1317+
use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
1318+
self._update_impl(index, weight, grad, state,
1319+
multi_precision=use_multi_precision)
12931320

12941321

12951322
# pylint: enable=line-too-long

src/operator/optimizer_op-inl.h

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,164 @@ inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
17511751
});
17521752
}
17531753

1754+
template<int n_in, int n_out, int total_in>
1755+
inline bool MPLambPhaseOneType(const nnvm::NodeAttrs& attrs,
1756+
std::vector<int> *in_attrs,
1757+
std::vector<int> *out_attrs) {
1758+
CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
1759+
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
1760+
for (int i = 0; i < n_in; ++i) {
1761+
TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat16);
1762+
}
1763+
for (int i = n_in; i < total_in; ++i) {
1764+
TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
1765+
}
1766+
for (int i = 0; i < n_out; ++i) {
1767+
TYPE_ASSIGN_CHECK(*out_attrs, i, mshadow::kFloat32);
1768+
}
1769+
return true;
1770+
}
1771+
1772+
struct MPLambUpdatePhaseOneKernel {
1773+
template<typename DType>
1774+
MSHADOW_XINLINE static void Map(int i, float* out_data,
1775+
float* mean_data, float* var_data, const DType* weight_data,
1776+
const DType* grad_data, const float* weight32_data,
1777+
const float clip_gradient, const float rescale_grad,
1778+
const float beta1_t, const float beta1,
1779+
const float beta2_t, const float beta2,
1780+
const float wd, const float epsilon, const int t,
1781+
bool bias_correction, const OpReqType req) {
1782+
using namespace mshadow_op;
1783+
1784+
float grad_rescaled = grad_data[i] * rescale_grad;
1785+
if (clip_gradient >= 0.f) {
1786+
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1787+
}
1788+
1789+
mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1790+
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
1791+
1792+
float g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight32_data[i];
1793+
1794+
if (bias_correction) {
1795+
float mean_hat = mean_data[i] / (1. - beta1_t);
1796+
float var_hat = var_data[i] / (1 - beta2_t);
1797+
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight32_data[i];
1798+
}
1799+
KERNEL_ASSIGN(out_data[i], req, g);
1800+
}
1801+
};
1802+
1803+
template<typename xpu>
1804+
inline void MPLambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
1805+
const OpContext &ctx,
1806+
const std::vector<TBlob> &inputs,
1807+
const std::vector<OpReqType> &req,
1808+
const std::vector<TBlob> &outputs) {
1809+
using namespace mxnet_op;
1810+
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
1811+
Stream<xpu>* s = ctx.get_stream<xpu>();
1812+
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1813+
float beta1_t = std::pow(param.beta1, param.t);
1814+
float beta2_t = std::pow(param.beta2, param.t);
1815+
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1816+
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1817+
Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
1818+
Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
1819+
Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
1820+
Tensor<xpu, 2, float> out = outputs[0].FlatTo2D<xpu, float>(s);
1821+
1822+
Kernel<MPLambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
1823+
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
1824+
param.clip_gradient, param.rescale_grad, beta1_t, param.beta1, beta2_t, param.beta2,
1825+
param.wd, param.epsilon, param.t, param.bias_correction, req[0]);
1826+
});
1827+
}
1828+
1829+
inline bool MPLambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
1830+
mxnet::ShapeVector* in_attrs,
1831+
mxnet::ShapeVector* out_attrs) {
1832+
CHECK_EQ(in_attrs->size(), 5U);
1833+
CHECK_EQ(out_attrs->size(), 1U);
1834+
1835+
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1836+
1837+
mxnet::TShape& weight_shape = in_attrs->at(0);
1838+
mxnet::TShape& g_shape = in_attrs->at(1);
1839+
mxnet::TShape& weight32_shape = in_attrs->at(4);
1840+
CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
1841+
<< "total no. of dimensions for weights and g must match";
1842+
CHECK_EQ(weight_shape.ndim(), weight32_shape.ndim())
1843+
<< "total no. of dimensions for weights and g must match";
1844+
for (int i=0; i < weight_shape.ndim(); ++i) {
1845+
CHECK_EQ(weight_shape[i], g_shape[i])
1846+
<< "weight and g dimension size mismatch at " << i << "-th index";
1847+
CHECK_EQ(weight_shape[i], weight32_shape[i])
1848+
<< "weight and g dimension size mismatch at " << i << "-th index";
1849+
}
1850+
mxnet::TShape& r1_shape = in_attrs->at(2);
1851+
mxnet::TShape& r2_shape = in_attrs->at(3);
1852+
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1853+
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1854+
for (int i=0; i < expected_out.ndim(); ++i) {
1855+
expected_out[i] = weight_shape[i];
1856+
}
1857+
1858+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1859+
return shape_is_known(expected_out);
1860+
}
1861+
1862+
struct MPLambUpdatePhaseTwoKernel {
1863+
template<typename DType>
1864+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1865+
const DType* weight_data, const float* g,
1866+
const float* r1, const float* r2, const float* weight32_data,
1867+
float lr, const float lower_bound,
1868+
const float upper_bound, const OpReqType req) {
1869+
using namespace mshadow_op;
1870+
1871+
float new_r1 = r1[0];
1872+
if (lower_bound >= 0) {
1873+
new_r1 = maximum::Map(new_r1, lower_bound);
1874+
}
1875+
if (upper_bound >= 0) {
1876+
new_r1 = minimum::Map(new_r1, upper_bound);
1877+
}
1878+
if (new_r1 == 0.0f || r2[0] == 0.0f) {
1879+
lr = lr * 1.0f;
1880+
} else {
1881+
lr = lr * new_r1 / r2[0];
1882+
}
1883+
1884+
KERNEL_ASSIGN(out_data[i], req, weight32_data[i] - lr * g[i]);
1885+
}
1886+
};
1887+
1888+
template<typename xpu>
1889+
inline void MPLambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1890+
const OpContext &ctx,
1891+
const std::vector<TBlob> &inputs,
1892+
const std::vector<OpReqType> &req,
1893+
const std::vector<TBlob> &outputs) {
1894+
using namespace mxnet_op;
1895+
const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
1896+
Stream<xpu>* s = ctx.get_stream<xpu>();
1897+
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1898+
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1899+
Tensor<xpu, 2, float> g = inputs[1].FlatTo2D<xpu, float>(s);
1900+
Tensor<xpu, 2, float> r1 = inputs[2].FlatTo2D<xpu, float>(s);
1901+
Tensor<xpu, 2, float> r2 = inputs[3].FlatTo2D<xpu, float>(s);
1902+
Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
1903+
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1904+
1905+
Kernel<MPLambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
1906+
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_, weight32.dptr_,
1907+
param.lr, param.lower_bound,
1908+
param.upper_bound, req[0]);
1909+
});
1910+
}
1911+
17541912
// This RMSProp code follows the version in
17551913
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
17561914
// by Alex Graves, 2013.
@@ -2493,5 +2651,4 @@ inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
24932651
} // namespace op
24942652
} // namespace mxnet
24952653

2496-
24972654
#endif // MXNET_OPERATOR_OPTIMIZER_OP_INL_H_

src/operator/optimizer_op.cc

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ Link to paper: https://arxiv.org/pdf/1904.00962.pdf
947947
var_hat = var / (1 - beta2^t);
948948
g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
949949
else
950-
g = mean / (var_data^(1/2) + epsilon) + wd * weight_data[i];
950+
g = mean / (var_data^(1/2) + epsilon) + wd * weight;
951951
\end{gather*}
952952
953953
)code" ADD_FILELINE)
@@ -1002,5 +1002,93 @@ Link to paper: https://arxiv.org/pdf/1904.00962.pdf
10021002
.add_argument("r2", "NDArray-or-Symbol", "r2")
10031003
.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
10041004

1005+
NNVM_REGISTER_OP(mp_lamb_update_phase1)
1006+
.describe(R"code(Mixed Precision version of Phase I of lamb update
1007+
it performs the following operations and returns g:.
1008+
1009+
Link to paper: https://arxiv.org/pdf/1904.00962.pdf
1010+
1011+
.. math::
1012+
\begin{gather*}
1013+
grad32 = grad(float16) * rescale_grad
1014+
if (grad < -clip_gradient)
1015+
then
1016+
grad = -clip_gradient
1017+
if (grad > clip_gradient)
1018+
then
1019+
grad = clip_gradient
1020+
1021+
mean = beta1 * mean + (1 - beta1) * grad;
1022+
variance = beta2 * variance + (1. - beta2) * grad ^ 2;
1023+
1024+
if (bias_correction)
1025+
then
1026+
mean_hat = mean / (1. - beta1^t);
1027+
var_hat = var / (1 - beta2^t);
1028+
g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight32;
1029+
else
1030+
g = mean / (var_data^(1/2) + epsilon) + wd * weight32;
1031+
\end{gather*}
1032+
1033+
)code" ADD_FILELINE)
1034+
.set_num_inputs(5)
1035+
.set_num_outputs(1)
1036+
.set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
1037+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
1038+
.set_attr<nnvm::FInferType>("FInferType", MPLambPhaseOneType<2, 1, 5>)
1039+
.set_attr<FCompute>("FCompute<cpu>", MPLambUpdatePhaseOne<cpu>)
1040+
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
1041+
[](const nnvm::NodeAttrs& attrs) {
1042+
return std::vector<uint32_t>{2, 3};
1043+
})
1044+
.add_argument("weight", "NDArray-or-Symbol", "Weight")
1045+
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
1046+
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
1047+
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
1048+
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
1049+
.add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
1050+
1051+
NNVM_REGISTER_OP(mp_lamb_update_phase2)
1052+
.describe(R"code(Mixed Precision version Phase II of lamb update
1053+
it performs the following operations and updates grad.
1054+
1055+
Link to paper: https://arxiv.org/pdf/1904.00962.pdf
1056+
1057+
.. math::
1058+
\begin{gather*}
1059+
if (lower_bound >= 0)
1060+
then
1061+
r1 = max(r1, lower_bound)
1062+
if (upper_bound >= 0)
1063+
then
1064+
r1 = max(r1, upper_bound)
1065+
1066+
if (r1 == 0 or r2 == 0)
1067+
then
1068+
lr = lr
1069+
else
1070+
lr = lr * (r1/r2)
1071+
weight32 = weight32 - lr * g
1072+
weight(float16) = weight32
1073+
\end{gather*}
1074+
1075+
)code" ADD_FILELINE)
1076+
.set_num_inputs(5)
1077+
.set_num_outputs(1)
1078+
.set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
1079+
.set_attr<mxnet::FInferShape>("FInferShape", MPLambUpdatePhaseTwoShape)
1080+
.set_attr<nnvm::FInferType>("FInferType", MP_InferType<1, 1, 5>)
1081+
.set_attr<FCompute>("FCompute<cpu>", MPLambUpdatePhaseTwo<cpu>)
1082+
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
1083+
[](const nnvm::NodeAttrs& attrs) {
1084+
return std::vector<uint32_t>{4};
1085+
})
1086+
.add_argument("weight", "NDArray-or-Symbol", "Weight")
1087+
.add_argument("g", "NDArray-or-Symbol", "Output of mp_lamb_update_phase 1")
1088+
.add_argument("r1", "NDArray-or-Symbol", "r1")
1089+
.add_argument("r2", "NDArray-or-Symbol", "r2")
1090+
.add_argument("weight32", "NDArray-or-Symbol", "Weight32")
1091+
.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
1092+
10051093
} // namespace op
10061094
} // namespace mxnet

src/operator/optimizer_op.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ NNVM_REGISTER_OP(lamb_update_phase1)
283283
NNVM_REGISTER_OP(lamb_update_phase2)
284284
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseTwo<gpu>);
285285

286+
NNVM_REGISTER_OP(mp_lamb_update_phase1)
287+
.set_attr<FCompute>("FCompute<gpu>", MPLambUpdatePhaseOne<gpu>);
288+
289+
NNVM_REGISTER_OP(mp_lamb_update_phase2)
290+
.set_attr<FCompute>("FCompute<gpu>", MPLambUpdatePhaseTwo<gpu>);
286291

287292
} // namespace op
288293
} // namespace mxnet

tests/python/gpu/test_operator_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def test_preloaded_multi_sgd():
422422
shapes = [np.random.randint(1, maxdim + 1, size=maxndim) for i in range(nparam)]
423423
check_preloaded_multi_sgd(dtype, shapes, momentum, use_master_weights)
424424

425+
425426
@with_seed()
426427
def test_batchnorm_with_type():
427428
ctx_list_v1_2D = [

0 commit comments

Comments
 (0)