Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 32f26cc

Browse files
author
Rohit Kumar Srivastava
committed
fixing base lamb optimizer
1 parent 5812d83 commit 32f26cc

File tree

5 files changed

+160
-59
lines changed

5 files changed

+160
-59
lines changed

python/mxnet/optimizer/optimizer.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626
import os
2727
import numpy
2828
from ..base import py_str
29-
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply,
29+
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, minimum, maximum, abs as NDabs, array, multiply,
3030
multi_sum_sq, multi_lars, norm as NDnorm)
3131
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
3232
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
3333
signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
3434
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
3535
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
3636
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
37-
preloaded_multi_mp_sgd_mom_update, lamb_update)
37+
preloaded_multi_mp_sgd_mom_update, lamb_update, lamb_weight_update)
3838
from ..ndarray import sparse
3939
from ..random import normal
4040
from ..util import is_np_array
@@ -1250,7 +1250,7 @@ class LAMB(Optimizer):
12501250
"""LAMB Optimizer.
12511251
"""
12521252
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
1253-
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
1253+
lower_bound=None, upper_bound=None, bias_correction=False, **kwargs):
12541254
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
12551255
self.beta1 = beta1
12561256
self.beta2 = beta2
@@ -1259,13 +1259,14 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
12591259
self.upper_bound = upper_bound
12601260
self.bias_correction = bias_correction
12611261

1262+
12621263
def create_state(self, index, weight):
12631264
stype = weight.stype
12641265
dtype = weight.dtype
12651266
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
12661267
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))
12671268

1268-
def update(self, index, weight,grad, state):
1269+
def update(self, index, weight, grad, state):
12691270
assert(isinstance(weight, NDArray))
12701271
assert(isinstance(grad, NDArray))
12711272
self._update_count(index)
@@ -1274,14 +1275,21 @@ def update(self, index, weight,grad, state):
12741275
t = self._index_update_count[index]
12751276

12761277
kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
1277-
'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
12781278
'bias_correction': self.bias_correction, 't': t,
12791279
'rescale_grad': self.rescale_grad}
1280+
mean, var = state
12801281
if self.clip_gradient:
12811282
kwargs['clip_gradient'] = self.clip_gradient
1282-
1283-
mean, var = state
1284-
lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)
1283+
g = lamb_update(weight, grad, mean, var, wd=wd, **kwargs)
1284+
1285+
kwargs = {}
1286+
if self.lower_bound:
1287+
kwargs['lower_bound'] = self.lower_bound
1288+
if self.upper_bound:
1289+
kwargs['upper_bound'] = self.upper_bound
1290+
r_1 = weight.norm()
1291+
r_2 = g.norm()
1292+
lamb_weight_update(weight, g, r_1, r_2, lr = lr, out=weight, **kwargs)
12851293

12861294

12871295
# pylint: enable=line-too-long

src/operator/optimizer_op-inl.h

Lines changed: 102 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,20 +1564,15 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
15641564
}
15651565

15661566
struct LAMBParam : public dmlc::Parameter<LAMBParam> {
1567-
float lr;
15681567
float beta1;
15691568
float beta2;
15701569
float epsilon;
1571-
float lower_bound;
1572-
float upper_bound;
15731570
float t;
15741571
bool bias_correction;
15751572
float wd;
15761573
float rescale_grad;
15771574
float clip_gradient;
15781575
DMLC_DECLARE_PARAMETER(LAMBParam) {
1579-
DMLC_DECLARE_FIELD(lr)
1580-
.describe("Learning rate");
15811576
DMLC_DECLARE_FIELD(beta1)
15821577
.set_default(0.9f)
15831578
.describe("The decay rate for the 1st moment estimates.");
@@ -1587,19 +1582,12 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
15871582
DMLC_DECLARE_FIELD(epsilon)
15881583
.set_default(1e-6f)
15891584
.describe("A small constant for numerical stability.");
1590-
DMLC_DECLARE_FIELD(lower_bound)
1591-
.set_default(1e-3f)
1592-
.describe("Lower limit of norm of weight.");
1593-
DMLC_DECLARE_FIELD(upper_bound)
1594-
.set_default(10.0f)
1595-
.describe("Upper limit of norm of weight.");
15961585
DMLC_DECLARE_FIELD(t)
15971586
.describe("Index update count.");
15981587
DMLC_DECLARE_FIELD(bias_correction)
15991588
.set_default(false)
16001589
.describe("Whether to use bias correction.");
16011590
DMLC_DECLARE_FIELD(wd)
1602-
.set_default(0.0f)
16031591
.describe("Weight decay augments the objective function with a "
16041592
"regularization term that penalizes large weights. "
16051593
"The penalty scales with the square of the magnitude of each weight.");
@@ -1614,44 +1602,48 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
16141602
}
16151603
};
16161604

1605+
struct LAMBWeightParam : public dmlc::Parameter<LAMBWeightParam> {
1606+
float lr;
1607+
float lower_bound;
1608+
float upper_bound;
1609+
DMLC_DECLARE_PARAMETER(LAMBWeightParam) {
1610+
DMLC_DECLARE_FIELD(lr)
1611+
.describe("Learning rate");
1612+
DMLC_DECLARE_FIELD(lower_bound)
1613+
.set_default(1e-3f)
1614+
.describe("Lower limit of norm of weight.");
1615+
DMLC_DECLARE_FIELD(upper_bound)
1616+
.set_default(10.0f)
1617+
.describe("Upper limit of norm of weight.");
1618+
}
1619+
};
1620+
16171621
struct LAMBUpdateKernel {
16181622
template<typename DType>
16191623
MSHADOW_XINLINE static void Map(int i, DType* out_data,
16201624
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
16211625
const DType clip_gradient, const DType rescale_grad,
1622-
const DType beta1, const DType beta2,
1623-
DType lr, const DType wd,
1624-
const DType epsilon, const DType lower_bound,
1625-
const DType upper_bound, const DType t,
1626+
const DType beta1, const DType beta2,const DType wd,
1627+
const DType epsilon, const DType t,
16261628
bool bias_correction, const OpReqType req) {
16271629
using namespace mshadow_op;
16281630

1629-
DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
1631+
DType grad_rescaled = grad_data[i] * rescale_grad;
16301632
if (clip_gradient >= 0.f) {
16311633
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
16321634
}
16331635

16341636
mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
16351637
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
16361638

1637-
DType r1 = square_root::Map(square::Map(weight_data[i]));
1638-
1639-
r1 = minimum::Map(maximum::Map(r1, lower_bound), upper_bound);
1640-
DType g = mean_data[i] / square_root::Map(var_data[i] + epsilon) + wd * weight_data[i];
1639+
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];
16411640

16421641
if (bias_correction) {
16431642
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
16441643
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
1645-
g = mean_hat / square_root::Map(var_hat + epsilon) + wd * weight_data[i];
1646-
}
1647-
DType r2 = square_root::Map(square::Map(g));
1648-
if (r1 == 0.0f || r2 == 0.0f) {
1649-
lr = lr * 1.0f;
1650-
} else {
1651-
lr = lr * r1 / r2;
1644+
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
16521645
}
1653-
1654-
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g);
1646+
KERNEL_ASSIGN(out_data[i], req, g);
16551647
}
16561648
};
16571649

@@ -1661,9 +1653,9 @@ inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
16611653
const std::vector<TBlob> &inputs,
16621654
const std::vector<OpReqType> &req,
16631655
const std::vector<TBlob> &outputs) {
1664-
using namespace mxnet_op;
1665-
const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed);
1666-
Stream<xpu>* s = ctx.get_stream<xpu>();
1656+
using namespace mxnet_op;
1657+
const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed);
1658+
Stream<xpu>* s = ctx.get_stream<xpu>();
16671659
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
16681660
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
16691661
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
@@ -1675,13 +1667,85 @@ inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
16751667
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
16761668
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
16771669
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1678-
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
1679-
static_cast<DType>(param.epsilon), static_cast<DType>(param.lower_bound),
1680-
static_cast<DType>(param.upper_bound), static_cast<DType>(param.t),
1681-
static_cast<bool>(param.bias_correction), req[0]);
1682-
});
1670+
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
1671+
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
1672+
});
1673+
}
1674+
1675+
inline bool LambWeightShape(const nnvm::NodeAttrs& attrs,
1676+
mxnet::ShapeVector* in_attrs,
1677+
mxnet::ShapeVector* out_attrs) {
1678+
CHECK_EQ(in_attrs->size(), 4U);
1679+
CHECK_EQ(out_attrs->size(), 1U);
1680+
1681+
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1682+
1683+
mxnet::TShape& weight_shape = in_attrs->at(0);
1684+
mxnet::TShape& g_shape = in_attrs->at(1);
1685+
CHECK_EQ(weight_shape.ndim(), g_shape.ndim()) << "total no. of dimensions for weights and g must match";
1686+
for (int i=0; i < weight_shape.ndim(); ++i){
1687+
CHECK_EQ(weight_shape[i], g_shape[i]) << "weight and g dimension size mismatch at " << i << "-th index";
1688+
}
1689+
mxnet::TShape& r1_shape = in_attrs->at(2);
1690+
mxnet::TShape& r2_shape = in_attrs->at(3);
1691+
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1692+
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1693+
for (int i=0; i < expected_out.ndim(); ++i) {
1694+
expected_out[i] = weight_shape[i];
1695+
}
1696+
1697+
SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1698+
return shape_is_known(expected_out);
16831699
}
16841700

1701+
struct LAMBWeightUpdateKernel {
1702+
template<typename DType>
1703+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1704+
const DType* weight_data, const DType* g,
1705+
const DType* r1, const DType* r2,
1706+
DType lr, const DType lower_bound,
1707+
const DType upper_bound, const OpReqType req) {
1708+
using namespace mshadow_op;
1709+
1710+
DType new_r1 = r1[0];
1711+
if (lower_bound >= 0) {
1712+
new_r1 = maximum::Map(new_r1, lower_bound);
1713+
}
1714+
if (upper_bound >= 0) {
1715+
new_r1 = minimum::Map(new_r1, upper_bound);
1716+
}
1717+
if (new_r1 == 0.0f || r2[0] == 0.0f) {
1718+
lr = lr * 1.0f;
1719+
} else {
1720+
lr = lr * new_r1 / r2[0];
1721+
}
1722+
1723+
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
1724+
}
1725+
};
1726+
1727+
template<typename xpu>
1728+
inline void LAMBWeightUpdate(const nnvm::NodeAttrs& attrs,
1729+
const OpContext &ctx,
1730+
const std::vector<TBlob> &inputs,
1731+
const std::vector<OpReqType> &req,
1732+
const std::vector<TBlob> &outputs) {
1733+
using namespace mxnet_op;
1734+
const LAMBWeightParam& param = nnvm::get<LAMBWeightParam>(attrs.parsed);
1735+
Stream<xpu>* s = ctx.get_stream<xpu>();
1736+
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1737+
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1738+
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
1739+
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
1740+
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
1741+
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1742+
1743+
Kernel<LAMBWeightUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
1744+
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
1745+
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
1746+
static_cast<DType>(param.upper_bound), req[0]);
1747+
});
1748+
}
16851749

16861750
// This RMSProp code follows the version in
16871751
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)

src/operator/optimizer_op.cc

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ DMLC_REGISTER_PARAMETER(SignSGDParam);
4444
DMLC_REGISTER_PARAMETER(SignumParam);
4545
DMLC_REGISTER_PARAMETER(AdagradParam);
4646
DMLC_REGISTER_PARAMETER(LAMBParam);
47+
DMLC_REGISTER_PARAMETER(LAMBWeightParam);
4748

4849
NNVM_REGISTER_OP(signsgd_update)
4950
.describe(R"code(Update function for SignSGD optimizer.
@@ -928,14 +929,33 @@ NNVM_REGISTER_OP(lamb_update)
928929
.set_num_inputs(4)
929930
.set_num_outputs(1)
930931
.set_attr_parser(ParamParser<LAMBParam>)
931-
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4,1>)
932-
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4,1>)
932+
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
933+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
933934
.set_attr<FCompute>("FCompute<cpu>", LAMBUpdate<cpu>)
935+
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
936+
[](const nnvm::NodeAttrs& attrs) {
937+
return std::vector<uint32_t>{2, 3};
938+
})
934939
.add_argument("weight", "NDArray-or-Symbol", "Weight")
935940
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
936941
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
937942
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
938943
.add_arguments(LAMBParam::__FIELDS__());
939944

945+
NNVM_REGISTER_OP(lamb_weight_update)
946+
.describe(R"code(Update function for lamb optimizer.
947+
)code" ADD_FILELINE)
948+
.set_num_inputs(4)
949+
.set_num_outputs(1)
950+
.set_attr_parser(ParamParser<LAMBWeightParam>)
951+
.set_attr<mxnet::FInferShape>("FInferShape", LambWeightShape)
952+
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
953+
.set_attr<FCompute>("FCompute<cpu>", LAMBWeightUpdate<cpu>)
954+
.add_argument("weight", "NDArray-or-Symbol", "Weight")
955+
.add_argument("g", "NDArray-or-Symbol", "g")
956+
.add_argument("r1", "NDArray-or-Symbol", "r1")
957+
.add_argument("r2", "NDArray-or-Symbol", "r2")
958+
.add_arguments(LAMBWeightParam::__FIELDS__());
959+
940960
} // namespace op
941961
} // namespace mxnet

src/operator/optimizer_op.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,11 @@ NNVM_REGISTER_OP(_sparse_adagrad_update)
278278
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);
279279

280280
NNVM_REGISTER_OP(lamb_update)
281-
.set_attr<FCompute>("FCompute<gpu>", LambUpdate<gpu>);
281+
.set_attr<FCompute>("FCompute<gpu>", LAMBUpdate<gpu>);
282+
283+
NNVM_REGISTER_OP(lamb_weight_update)
284+
.set_attr<FCompute>("FCompute<gpu>", LAMBWeightUpdate<gpu>);
285+
282286

283287
} // namespace op
284288
} // namespace mxnet

tests/python/unittest/test_optimizer.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -454,31 +454,34 @@ def update(self, index, weight, grad, state):
454454

455455
grad *= self.rescale_grad
456456
if self.clip_gradient is not None:
457-
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
457+
grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient)
458458

459459
mean, var = state
460460
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
461461
var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad)
462462

463+
mean_hat = mean
464+
var_hat = var
463465
r1 = weight.norm()
464-
if not self.bias_correction:
465-
r1 = mx.nd.minimum(mx.nd.maximum(r1, self.lower_bound), self.upper_bound)
466-
g = mean / (mx.nd.sqrt(var) + self.epsilon) + wd * weight
467-
468-
else:
466+
if self.lower_bound:
467+
r1 = mx.nd.maximum(r1, self.lower_bound)
468+
if self.upper_bound:
469+
r1 = mx.nd.minimum(r1, self.upper_bound)
470+
if self.bias_correction:
469471
mean_hat = mean / (1. - mx.nd.power(self.beta1, t))
470472
var_hat = var / (1. - mx.nd.power(self.beta2, t))
471-
g = mean_hat / mx.nd.sqrt(var_hat + self.epsilon) + wd * weight
472473

474+
g = mean_hat / (mx.nd.sqrt(var_hat) + self.epsilon) + wd * weight
473475
r2 = g.norm()
474-
475476
# calculate lamb_trust_ratio
476477
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
477478
lr *= r
478-
479479
# update weight
480480
weight[:] -= lr * g
481481

482+
def update_multi_precision(self, index, weight, grad, state):
483+
self.update(index, weight, grad, state)
484+
482485
@with_seed()
483486
def test_lamb():
484487
opt1 = PyLAMB
@@ -488,7 +491,9 @@ def test_lamb():
488491
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
489492
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
490493
bc_options = [{}, {'bias_correction': False}, {'bias_correction': True}]
491-
for params in itertools.product(cg_options, rg_options, wd_options, bc_options):
494+
lb_options = [{}, {'lower_bound': None}]
495+
ub_options = [{}, {'upper_bound': None}]
496+
for params in itertools.product(cg_options, rg_options, wd_options, bc_options, lb_options, ub_options):
492497
kwarg = {k: v for param in params for k, v in param.items()}
493498
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)
494499

0 commit comments

Comments
 (0)