Skip to content

Commit d69971a

Browse files
access2rohitptrendx
authored andcommitted
Lamb optimizer update (apache#16715)
* initial commit lamb optimizer * fixing base lamb optimizer * adding API doc for Lamb Phase 1 and 2
1 parent c973f01 commit d69971a

File tree

5 files changed

+397
-2
lines changed

5 files changed

+397
-2
lines changed

python/mxnet/optimizer/optimizer.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
3535
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
3636
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
37-
preloaded_multi_mp_sgd_mom_update)
37+
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2)
3838
from ..ndarray import sparse
3939
from ..random import normal
4040
from ..util import is_np_array
4141

4242
__all__ = [
4343
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
44-
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
44+
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
4545
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
4646
]
4747

@@ -1244,6 +1244,54 @@ def update(self, index, weight, grad, state):
12441244
kwargs = {}
12451245
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
12461246

1247+
1248+
@register
1249+
class LAMB(Optimizer):
1250+
"""LAMB Optimizer.
1251+
"""
1252+
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
1253+
lower_bound=None, upper_bound=None, bias_correction=True, **kwargs):
1254+
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
1255+
self.beta1 = beta1
1256+
self.beta2 = beta2
1257+
self.epsilon = epsilon
1258+
self.lower_bound = lower_bound
1259+
self.upper_bound = upper_bound
1260+
self.bias_correction = bias_correction
1261+
1262+
1263+
def create_state(self, index, weight):
1264+
stype = weight.stype
1265+
dtype = weight.dtype
1266+
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
1267+
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))
1268+
1269+
def update(self, index, weight, grad, state):
1270+
assert(isinstance(weight, NDArray))
1271+
assert(isinstance(grad, NDArray))
1272+
self._update_count(index)
1273+
lr = self._get_lr(index)
1274+
wd = self._get_wd(index)
1275+
t = self._index_update_count[index]
1276+
1277+
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
1278+
'bias_correction': self.bias_correction, 't': t,
1279+
'rescale_grad': self.rescale_grad}
1280+
mean, var = state
1281+
if self.clip_gradient:
1282+
kwargs['clip_gradient'] = self.clip_gradient
1283+
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)
1284+
1285+
kwargs = {}
1286+
if self.lower_bound:
1287+
kwargs['lower_bound'] = self.lower_bound
1288+
if self.upper_bound:
1289+
kwargs['upper_bound'] = self.upper_bound
1290+
r_1 = weight.norm()
1291+
r_2 = g.norm()
1292+
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)
1293+
1294+
12471295
# pylint: enable=line-too-long
12481296
@register
12491297
class DCASGD(Optimizer):

src/operator/optimizer_op-inl.h

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,192 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
15631563
}
15641564
}
15651565

1566+
struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> {
1567+
float beta1;
1568+
float beta2;
1569+
float epsilon;
1570+
float t;
1571+
bool bias_correction;
1572+
float wd;
1573+
float rescale_grad;
1574+
float clip_gradient;
1575+
DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
1576+
DMLC_DECLARE_FIELD(beta1)
1577+
.set_default(0.9f)
1578+
.describe("The decay rate for the 1st moment estimates.");
1579+
DMLC_DECLARE_FIELD(beta2)
1580+
.set_default(0.999f)
1581+
.describe("The decay rate for the 2nd moment estimates.");
1582+
DMLC_DECLARE_FIELD(epsilon)
1583+
.set_default(1e-6f)
1584+
.describe("A small constant for numerical stability.");
1585+
DMLC_DECLARE_FIELD(t)
1586+
.describe("Index update count.");
1587+
DMLC_DECLARE_FIELD(bias_correction)
1588+
.set_default(true)
1589+
.describe("Whether to use bias correction.");
1590+
DMLC_DECLARE_FIELD(wd)
1591+
.describe("Weight decay augments the objective function with a "
1592+
"regularization term that penalizes large weights. "
1593+
"The penalty scales with the square of the magnitude of each weight.");
1594+
DMLC_DECLARE_FIELD(rescale_grad)
1595+
.set_default(1.0f)
1596+
.describe("Rescale gradient to grad = rescale_grad*grad.");
1597+
DMLC_DECLARE_FIELD(clip_gradient)
1598+
.set_default(-1.0f)
1599+
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1600+
"If clip_gradient <= 0, gradient clipping is turned off. "
1601+
"grad = max(min(grad, clip_gradient), -clip_gradient).");
1602+
}
1603+
};
1604+
1605+
struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> {
1606+
float lr;
1607+
float lower_bound;
1608+
float upper_bound;
1609+
DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
1610+
DMLC_DECLARE_FIELD(lr)
1611+
.describe("Learning rate");
1612+
DMLC_DECLARE_FIELD(lower_bound)
1613+
.set_default(-1.0f)
1614+
.describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set");
1615+
DMLC_DECLARE_FIELD(upper_bound)
1616+
.set_default(-1.0f)
1617+
.describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set");
1618+
}
1619+
};
1620+
1621+
struct LambUpdatePhaseOneKernel {
1622+
template<typename DType>
1623+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1624+
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1625+
const DType clip_gradient, const DType rescale_grad,
1626+
const DType beta1, const DType beta2, const DType wd,
1627+
const DType epsilon, const DType t,
1628+
bool bias_correction, const OpReqType req) {
1629+
using namespace mshadow_op;
1630+
1631+
DType grad_rescaled = grad_data[i] * rescale_grad;
1632+
if (clip_gradient >= 0.f) {
1633+
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1634+
}
1635+
1636+
mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1637+
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
1638+
1639+
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];
1640+
1641+
if (bias_correction) {
1642+
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
1643+
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
1644+
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
1645+
}
1646+
KERNEL_ASSIGN(out_data[i], req, g);
1647+
}
1648+
};
1649+
1650+
template<typename xpu>
1651+
inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
1652+
const OpContext &ctx,
1653+
const std::vector<TBlob> &inputs,
1654+
const std::vector<OpReqType> &req,
1655+
const std::vector<TBlob> &outputs) {
1656+
using namespace mxnet_op;
1657+
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
1658+
Stream<xpu>* s = ctx.get_stream<xpu>();
1659+
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1660+
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1661+
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1662+
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
1663+
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
1664+
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1665+
1666+
Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
1667+
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
1668+
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1669+
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1670+
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
1671+
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
1672+
});
1673+
}
1674+
1675+
inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
1676+
mxnet::ShapeVector* in_attrs,
1677+
mxnet::ShapeVector* out_attrs) {
1678+
CHECK_EQ(in_attrs->size(), 4U);
1679+
CHECK_EQ(out_attrs->size(), 1U);
1680+
1681+
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1682+
1683+
mxnet::TShape& weight_shape = in_attrs->at(0);
1684+
mxnet::TShape& g_shape = in_attrs->at(1);
1685+
CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
1686+
<< "total no. of dimensions for weights and g must match";
1687+
for (int i=0; i < weight_shape.ndim(); ++i) {
1688+
CHECK_EQ(weight_shape[i], g_shape[i])
1689+
<< "weight and g dimension size mismatch at " << i << "-th index";
1690+
}
1691+
mxnet::TShape& r1_shape = in_attrs->at(2);
1692+
mxnet::TShape& r2_shape = in_attrs->at(3);
1693+
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1694+
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1695+
for (int i=0; i < expected_out.ndim(); ++i) {
1696+
expected_out[i] = weight_shape[i];
1697+
}
1698+
1699+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1700+
return shape_is_known(expected_out);
1701+
}
1702+
1703+
struct LambUpdatePhaseTwoKernel {
1704+
template<typename DType>
1705+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1706+
const DType* weight_data, const DType* g,
1707+
const DType* r1, const DType* r2,
1708+
DType lr, const DType lower_bound,
1709+
const DType upper_bound, const OpReqType req) {
1710+
using namespace mshadow_op;
1711+
1712+
DType new_r1 = r1[0];
1713+
if (lower_bound >= 0) {
1714+
new_r1 = maximum::Map(new_r1, lower_bound);
1715+
}
1716+
if (upper_bound >= 0) {
1717+
new_r1 = minimum::Map(new_r1, upper_bound);
1718+
}
1719+
if (new_r1 == 0.0f || r2[0] == 0.0f) {
1720+
lr = lr * 1.0f;
1721+
} else {
1722+
lr = lr * new_r1 / r2[0];
1723+
}
1724+
1725+
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
1726+
}
1727+
};
1728+
1729+
template<typename xpu>
1730+
inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1731+
const OpContext &ctx,
1732+
const std::vector<TBlob> &inputs,
1733+
const std::vector<OpReqType> &req,
1734+
const std::vector<TBlob> &outputs) {
1735+
using namespace mxnet_op;
1736+
const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
1737+
Stream<xpu>* s = ctx.get_stream<xpu>();
1738+
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1739+
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1740+
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
1741+
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
1742+
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
1743+
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1744+
1745+
Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
1746+
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
1747+
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
1748+
static_cast<DType>(param.upper_bound), req[0]);
1749+
});
1750+
}
1751+
15661752
// This RMSProp code follows the version in
15671753
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
15681754
// by Alex Graves, 2013.

src/operator/optimizer_op.cc

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ DMLC_REGISTER_PARAMETER(FtrlParam);
4343
DMLC_REGISTER_PARAMETER(SignSGDParam);
4444
DMLC_REGISTER_PARAMETER(SignumParam);
4545
DMLC_REGISTER_PARAMETER(AdagradParam);
46+
DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
47+
DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);
4648

4749
NNVM_REGISTER_OP(signsgd_update)
4850
.describe(R"code(Update function for SignSGD optimizer.
@@ -921,5 +923,84 @@ Note that non-zero values for the weight decay option are not supported.
921923
.add_argument("history", "NDArray-or-Symbol", "History")
922924
.add_arguments(AdagradParam::__FIELDS__());
923925

926+
NNVM_REGISTER_OP(lamb_update_phase1)
927+
.describe(R"code(Phase I of lamb update it performs the following operations and returns g:.
928+
929+
Link to paper: https://arxiv.org/pdf/1904.00962.pdf
930+
931+
.. math::
932+
\begin{gather*}
933+
grad = grad * rescale_grad
934+
if (grad < -clip_gradient)
935+
then
936+
grad = -clip_gradient
937+
if (grad > clip_gradient)
938+
then
939+
grad = clip_gradient
940+
941+
mean = beta1 * mean + (1 - beta1) * grad;
942+
variance = beta2 * variance + (1. - beta2) * grad ^ 2;
943+
944+
if (bias_correction)
945+
then
946+
mean_hat = mean / (1. - beta1^t);
947+
var_hat = var / (1 - beta2^t);
948+
g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
949+
else
950+
g = mean / (var_data^(1/2) + epsilon) + wd * weight_data[i];
951+
\end{gather*}
952+
953+
)code" ADD_FILELINE)
954+
.set_num_inputs(4)
955+
.set_num_outputs(1)
956+
.set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
957+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
958+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
959+
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
960+
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
961+
[](const nnvm::NodeAttrs& attrs) {
962+
return std::vector<uint32_t>{2, 3};
963+
})
964+
.add_argument("weight", "NDArray-or-Symbol", "Weight")
965+
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
966+
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
967+
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
968+
.add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
969+
970+
NNVM_REGISTER_OP(lamb_update_phase2)
971+
.describe(R"code(Phase II of lamb update it performs the following operations and updates grad.
972+
973+
Link to paper: https://arxiv.org/pdf/1904.00962.pdf
974+
975+
.. math::
976+
\begin{gather*}
977+
if (lower_bound >= 0)
978+
then
979+
r1 = max(r1, lower_bound)
980+
if (upper_bound >= 0)
981+
then
982+
r1 = max(r1, upper_bound)
983+
984+
if (r1 == 0 or r2 == 0)
985+
then
986+
lr = lr
987+
else
988+
lr = lr * (r1/r2)
989+
weight = weight - lr * g
990+
\end{gather*}
991+
992+
)code" ADD_FILELINE)
993+
.set_num_inputs(4)
994+
.set_num_outputs(1)
995+
.set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
996+
.set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
997+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
998+
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
999+
.add_argument("weight", "NDArray-or-Symbol", "Weight")
1000+
.add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
1001+
.add_argument("r1", "NDArray-or-Symbol", "r1")
1002+
.add_argument("r2", "NDArray-or-Symbol", "r2")
1003+
.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
1004+
9241005
} // namespace op
9251006
} // namespace mxnet

src/operator/optimizer_op.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,5 +277,12 @@ NNVM_REGISTER_OP(ftrl_update)
277277
NNVM_REGISTER_OP(_sparse_adagrad_update)
278278
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);
279279

280+
NNVM_REGISTER_OP(lamb_update_phase1)
281+
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseOne<gpu>);
282+
283+
NNVM_REGISTER_OP(lamb_update_phase2)
284+
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseTwo<gpu>);
285+
286+
280287
} // namespace op
281288
} // namespace mxnet

0 commit comments

Comments
 (0)