Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 31d7955

Browse files
author
Rohit Kumar Srivastava
committed
fixing base lamb optimizer
1 parent 4f5ebff commit 31d7955

File tree

5 files changed

+173
-70
lines changed

5 files changed

+173
-70
lines changed

python/mxnet/optimizer/optimizer.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
3535
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
3636
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
37-
preloaded_multi_mp_sgd_mom_update, lamb_update)
37+
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2)
3838
from ..ndarray import sparse
3939
from ..random import normal
4040
from ..util import is_np_array
@@ -1250,7 +1250,7 @@ class LAMB(Optimizer):
12501250
"""LAMB Optimizer.
12511251
"""
12521252
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
1253-
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
1253+
lower_bound=None, upper_bound=None, bias_correction=False, **kwargs):
12541254
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
12551255
self.beta1 = beta1
12561256
self.beta2 = beta2
@@ -1259,13 +1259,14 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
12591259
self.upper_bound = upper_bound
12601260
self.bias_correction = bias_correction
12611261

1262+
12621263
def create_state(self, index, weight):
12631264
stype = weight.stype
12641265
dtype = weight.dtype
12651266
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
12661267
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))
12671268

1268-
def update(self, index, weight,grad, state):
1269+
def update(self, index, weight, grad, state):
12691270
assert(isinstance(weight, NDArray))
12701271
assert(isinstance(grad, NDArray))
12711272
self._update_count(index)
@@ -1274,14 +1275,21 @@ def update(self, index, weight,grad, state):
12741275
t = self._index_update_count[index]
12751276

12761277
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
1277-
'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
12781278
'bias_correction': self.bias_correction, 't': t,
12791279
'rescale_grad': self.rescale_grad}
1280+
mean, var = state
12801281
if self.clip_gradient:
12811282
kwargs['clip_gradient'] = self.clip_gradient
1282-
1283-
mean, var = state
1284-
lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)
1283+
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)
1284+
1285+
kwargs = {}
1286+
if self.lower_bound:
1287+
kwargs['lower_bound'] = self.lower_bound
1288+
if self.upper_bound:
1289+
kwargs['upper_bound'] = self.upper_bound
1290+
r_1 = weight.norm()
1291+
r_2 = g.norm()
1292+
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)
12851293

12861294

12871295
# pylint: enable=line-too-long

src/operator/optimizer_op-inl.h

Lines changed: 109 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,21 +1563,16 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
15631563
}
15641564
}
15651565

1566-
struct LAMBParam : public dmlc::Parameter<LAMBParam> {
1567-
float lr;
1566+
struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> {
15681567
float beta1;
15691568
float beta2;
15701569
float epsilon;
1571-
float lower_bound;
1572-
float upper_bound;
15731570
float t;
15741571
bool bias_correction;
15751572
float wd;
15761573
float rescale_grad;
15771574
float clip_gradient;
1578-
DMLC_DECLARE_PARAMETER(LAMBParam) {
1579-
DMLC_DECLARE_FIELD(lr)
1580-
.describe("Learning rate");
1575+
DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
15811576
DMLC_DECLARE_FIELD(beta1)
15821577
.set_default(0.9f)
15831578
.describe("The decay rate for the 1st moment estimates.");
@@ -1587,19 +1582,12 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
15871582
DMLC_DECLARE_FIELD(epsilon)
15881583
.set_default(1e-6f)
15891584
.describe("A small constant for numerical stability.");
1590-
DMLC_DECLARE_FIELD(lower_bound)
1591-
.set_default(1e-3f)
1592-
.describe("Lower limit of norm of weight.");
1593-
DMLC_DECLARE_FIELD(upper_bound)
1594-
.set_default(10.0f)
1595-
.describe("Upper limit of norm of weight.");
15961585
DMLC_DECLARE_FIELD(t)
15971586
.describe("Index update count.");
15981587
DMLC_DECLARE_FIELD(bias_correction)
15991588
.set_default(false)
16001589
.describe("Whether to use bias correction.");
16011590
DMLC_DECLARE_FIELD(wd)
1602-
.set_default(0.0f)
16031591
.describe("Weight decay augments the objective function with a "
16041592
"regularization term that penalizes large weights. "
16051593
"The penalty scales with the square of the magnitude of each weight.");
@@ -1614,74 +1602,152 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
16141602
}
16151603
};
16161604

1617-
struct LAMBUpdateKernel {
1605+
struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> {
1606+
float lr;
1607+
float lower_bound;
1608+
float upper_bound;
1609+
DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
1610+
DMLC_DECLARE_FIELD(lr)
1611+
.describe("Learning rate");
1612+
DMLC_DECLARE_FIELD(lower_bound)
1613+
.set_default(-1.0f)
1614+
.describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set");
1615+
DMLC_DECLARE_FIELD(upper_bound)
1616+
.set_default(-1.0f)
1617+
.describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set");
1618+
}
1619+
};
1620+
1621+
struct LambUpdatePhaseOneKernel {
16181622
template<typename DType>
16191623
MSHADOW_XINLINE static void Map(int i, DType* out_data,
16201624
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
16211625
const DType clip_gradient, const DType rescale_grad,
1622-
const DType beta1, const DType beta2,
1623-
DType lr, const DType wd,
1624-
const DType epsilon, const DType lower_bound,
1625-
const DType upper_bound, const DType t,
1626+
const DType beta1, const DType beta2, const DType wd,
1627+
const DType epsilon, const DType t,
16261628
bool bias_correction, const OpReqType req) {
16271629
using namespace mshadow_op;
16281630

1629-
DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
1631+
DType grad_rescaled = grad_data[i] * rescale_grad;
16301632
if (clip_gradient >= 0.f) {
16311633
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
16321634
}
16331635

16341636
mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
16351637
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
16361638

1637-
DType r1 = square_root::Map(square::Map(weight_data[i]));
1638-
1639-
r1 = minimum::Map(maximum::Map(r1, lower_bound), upper_bound);
1640-
DType g = mean_data[i] / square_root::Map(var_data[i] + epsilon) + wd * weight_data[i];
1639+
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];
16411640

16421641
if (bias_correction) {
16431642
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
16441643
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
1645-
g = mean_hat / square_root::Map(var_hat + epsilon) + wd * weight_data[i];
1646-
}
1647-
DType r2 = square_root::Map(square::Map(g));
1648-
if (r1 == 0.0f || r2 == 0.0f) {
1649-
lr = lr * 1.0f;
1650-
} else {
1651-
lr = lr * r1 / r2;
1644+
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
16521645
}
1653-
1654-
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g);
1646+
KERNEL_ASSIGN(out_data[i], req, g);
16551647
}
16561648
};
16571649

16581650
template<typename xpu>
1659-
inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
1651+
inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
16601652
const OpContext &ctx,
16611653
const std::vector<TBlob> &inputs,
16621654
const std::vector<OpReqType> &req,
16631655
const std::vector<TBlob> &outputs) {
1664-
using namespace mxnet_op;
1665-
const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed);
1666-
Stream<xpu>* s = ctx.get_stream<xpu>();
1656+
using namespace mxnet_op;
1657+
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
1658+
Stream<xpu>* s = ctx.get_stream<xpu>();
16671659
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
16681660
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
16691661
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
16701662
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
16711663
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
16721664
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
16731665

1674-
Kernel<LAMBUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
1666+
Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
16751667
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
16761668
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
16771669
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1678-
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
1679-
static_cast<DType>(param.epsilon), static_cast<DType>(param.lower_bound),
1680-
static_cast<DType>(param.upper_bound), static_cast<DType>(param.t),
1681-
static_cast<bool>(param.bias_correction), req[0]);
1682-
});
1670+
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
1671+
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
1672+
});
1673+
}
1674+
1675+
inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
1676+
mxnet::ShapeVector* in_attrs,
1677+
mxnet::ShapeVector* out_attrs) {
1678+
CHECK_EQ(in_attrs->size(), 4U);
1679+
CHECK_EQ(out_attrs->size(), 1U);
1680+
1681+
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1682+
1683+
mxnet::TShape& weight_shape = in_attrs->at(0);
1684+
mxnet::TShape& g_shape = in_attrs->at(1);
1685+
CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
1686+
<< "total no. of dimensions for weights and g must match";
1687+
for (int i=0; i < weight_shape.ndim(); ++i) {
1688+
CHECK_EQ(weight_shape[i], g_shape[i])
1689+
<< "weight and g dimension size mismatch at " << i << "-th index";
1690+
}
1691+
mxnet::TShape& r1_shape = in_attrs->at(2);
1692+
mxnet::TShape& r2_shape = in_attrs->at(3);
1693+
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1694+
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1695+
for (int i=0; i < expected_out.ndim(); ++i) {
1696+
expected_out[i] = weight_shape[i];
1697+
}
1698+
1699+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1700+
return shape_is_known(expected_out);
16831701
}
16841702

1703+
struct LambUpdatePhaseTwoKernel {
1704+
template<typename DType>
1705+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1706+
const DType* weight_data, const DType* g,
1707+
const DType* r1, const DType* r2,
1708+
DType lr, const DType lower_bound,
1709+
const DType upper_bound, const OpReqType req) {
1710+
using namespace mshadow_op;
1711+
1712+
DType new_r1 = r1[0];
1713+
if (lower_bound >= 0) {
1714+
new_r1 = maximum::Map(new_r1, lower_bound);
1715+
}
1716+
if (upper_bound >= 0) {
1717+
new_r1 = minimum::Map(new_r1, upper_bound);
1718+
}
1719+
if (new_r1 == 0.0f || r2[0] == 0.0f) {
1720+
lr = lr * 1.0f;
1721+
} else {
1722+
lr = lr * new_r1 / r2[0];
1723+
}
1724+
1725+
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
1726+
}
1727+
};
1728+
1729+
template<typename xpu>
1730+
inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1731+
const OpContext &ctx,
1732+
const std::vector<TBlob> &inputs,
1733+
const std::vector<OpReqType> &req,
1734+
const std::vector<TBlob> &outputs) {
1735+
using namespace mxnet_op;
1736+
const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
1737+
Stream<xpu>* s = ctx.get_stream<xpu>();
1738+
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1739+
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1740+
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
1741+
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
1742+
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
1743+
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1744+
1745+
Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
1746+
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
1747+
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
1748+
static_cast<DType>(param.upper_bound), req[0]);
1749+
});
1750+
}
16851751

16861752
// This RMSProp code follows the version in
16871753
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)

src/operator/optimizer_op.cc

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ DMLC_REGISTER_PARAMETER(FtrlParam);
4343
DMLC_REGISTER_PARAMETER(SignSGDParam);
4444
DMLC_REGISTER_PARAMETER(SignumParam);
4545
DMLC_REGISTER_PARAMETER(AdagradParam);
46-
DMLC_REGISTER_PARAMETER(LAMBParam);
46+
DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
47+
DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);
4748

4849
NNVM_REGISTER_OP(signsgd_update)
4950
.describe(R"code(Update function for SignSGD optimizer.
@@ -922,20 +923,39 @@ Note that non-zero values for the weight decay option are not supported.
922923
.add_argument("history", "NDArray-or-Symbol", "History")
923924
.add_arguments(AdagradParam::__FIELDS__());
924925

925-
NNVM_REGISTER_OP(lamb_update)
926+
NNVM_REGISTER_OP(lamb_update_phase1)
926927
.describe(R"code(Update function for lamb optimizer.
927928
)code" ADD_FILELINE)
928929
.set_num_inputs(4)
929930
.set_num_outputs(1)
930-
.set_attr_parser(ParamParser<LAMBParam>)
931-
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4,1>)
932-
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4,1>)
933-
.set_attr<FCompute>("FCompute<cpu>", LAMBUpdate<cpu>)
931+
.set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
932+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
933+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
934+
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
935+
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
936+
[](const nnvm::NodeAttrs& attrs) {
937+
return std::vector<uint32_t>{2, 3};
938+
})
934939
.add_argument("weight", "NDArray-or-Symbol", "Weight")
935940
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
936941
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
937942
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
938-
.add_arguments(LAMBParam::__FIELDS__());
943+
.add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
944+
945+
NNVM_REGISTER_OP(lamb_update_phase2)
946+
.describe(R"code(Update function for lamb optimizer.
947+
)code" ADD_FILELINE)
948+
.set_num_inputs(4)
949+
.set_num_outputs(1)
950+
.set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
951+
.set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
952+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
953+
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
954+
.add_argument("weight", "NDArray-or-Symbol", "Weight")
955+
.add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
956+
.add_argument("r1", "NDArray-or-Symbol", "r1")
957+
.add_argument("r2", "NDArray-or-Symbol", "r2")
958+
.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
939959

940960
} // namespace op
941961
} // namespace mxnet

src/operator/optimizer_op.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,12 @@ NNVM_REGISTER_OP(ftrl_update)
277277
NNVM_REGISTER_OP(_sparse_adagrad_update)
278278
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);
279279

280-
NNVM_REGISTER_OP(lamb_update)
281-
.set_attr<FCompute>("FCompute<gpu>", LambUpdate<gpu>);
280+
NNVM_REGISTER_OP(lamb_update_phase1)
281+
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseOne<gpu>);
282+
283+
NNVM_REGISTER_OP(lamb_update_phase2)
284+
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseTwo<gpu>);
285+
282286

283287
} // namespace op
284288
} // namespace mxnet

0 commit comments

Comments
 (0)