Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Lamb optimizer update #16715

Merged
merged 3 commits into from
Nov 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
preloaded_multi_mp_sgd_mom_update)
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2)
from ..ndarray import sparse
from ..random import normal
from ..util import is_np_array

__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]

Expand Down Expand Up @@ -1244,6 +1244,54 @@ def update(self, index, weight, grad, state):
kwargs = {}
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)


@register
class LAMB(Optimizer):
"""LAMB Optimizer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls add doc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working on it now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is clashing with the GLuon one, can we give it a different name?

"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=None, upper_bound=None, bias_correction=True, **kwargs):
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction


def create_state(self, index, weight):
stype = weight.stype
dtype = weight.dtype
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'bias_correction': self.bias_correction, 't': t,
'rescale_grad': self.rescale_grad}
mean, var = state
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)

kwargs = {}
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound
r_1 = weight.norm()
r_2 = g.norm()
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)


# pylint: enable=line-too-long
@register
class DCASGD(Optimizer):
Expand Down
186 changes: 186 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,192 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
}
}

struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> {
float beta1;
float beta2;
float epsilon;
float t;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-haibin-lin @access2rohit I find this issue when reading the code. Here, the t should be the number of updates and should not be stored as float, which will lose the precision. I think we need to store it as index_t.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are using float here for integer data type. @sxjscience can you explain how we will loses precision for the operation beta^t ?

bool bias_correction;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
DMLC_DECLARE_FIELD(beta1)
.set_default(0.9f)
.describe("The decay rate for the 1st moment estimates.");
DMLC_DECLARE_FIELD(beta2)
.set_default(0.999f)
.describe("The decay rate for the 2nd moment estimates.");
DMLC_DECLARE_FIELD(epsilon)
.set_default(1e-6f)
.describe("A small constant for numerical stability.");
DMLC_DECLARE_FIELD(t)
.describe("Index update count.");
DMLC_DECLARE_FIELD(bias_correction)
.set_default(true)
.describe("Whether to use bias correction.");
DMLC_DECLARE_FIELD(wd)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> {
float lr;
float lower_bound;
float upper_bound;
DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(lower_bound)
.set_default(-1.0f)
.describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set");
DMLC_DECLARE_FIELD(upper_bound)
.set_default(-1.0f)
.describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set");
}
};

struct LambUpdatePhaseOneKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2, const DType wd,
const DType epsilon, const DType t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = grad_data[i] * rescale_grad;
if (clip_gradient >= 0.f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;

DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, in apex, it uses a float32 to calculate the power and then switch to float16:
https://github.com/NVIDIA/apex/blob/325f5a0bec542701edba1628ad34f3b2ea47c556/csrc/multi_tensor_lamb.cu#L231-L249

g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}
KERNEL_ASSIGN(out_data[i], req, g);
}
};

template<typename xpu>
inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
}

inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 4U);
CHECK_EQ(out_attrs->size(), 1U);

mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);

mxnet::TShape& weight_shape = in_attrs->at(0);
mxnet::TShape& g_shape = in_attrs->at(1);
CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
<< "total no. of dimensions for weights and g must match";
for (int i=0; i < weight_shape.ndim(); ++i) {
CHECK_EQ(weight_shape[i], g_shape[i])
<< "weight and g dimension size mismatch at " << i << "-th index";
}
mxnet::TShape& r1_shape = in_attrs->at(2);
mxnet::TShape& r2_shape = in_attrs->at(3);
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
for (int i=0; i < expected_out.ndim(); ++i) {
expected_out[i] = weight_shape[i];
}

SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
return shape_is_known(expected_out);
}

struct LambUpdatePhaseTwoKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const DType* weight_data, const DType* g,
const DType* r1, const DType* r2,
DType lr, const DType lower_bound,
const DType upper_bound, const OpReqType req) {
using namespace mshadow_op;

DType new_r1 = r1[0];
if (lower_bound >= 0) {
new_r1 = maximum::Map(new_r1, lower_bound);
}
if (upper_bound >= 0) {
new_r1 = minimum::Map(new_r1, upper_bound);
}
if (new_r1 == 0.0f || r2[0] == 0.0f) {
lr = lr * 1.0f;
} else {
lr = lr * new_r1 / r2[0];
}

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
}
};

template<typename xpu>
inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
static_cast<DType>(param.upper_bound), req[0]);
});
}

// This RMSProp code follows the version in
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
// by Alex Graves, 2013.
Expand Down
81 changes: 81 additions & 0 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ DMLC_REGISTER_PARAMETER(FtrlParam);
DMLC_REGISTER_PARAMETER(SignSGDParam);
DMLC_REGISTER_PARAMETER(SignumParam);
DMLC_REGISTER_PARAMETER(AdagradParam);
DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);

NNVM_REGISTER_OP(signsgd_update)
.describe(R"code(Update function for SignSGD optimizer.
Expand Down Expand Up @@ -921,5 +923,84 @@ Note that non-zero values for the weight decay option are not supported.
.add_argument("history", "NDArray-or-Symbol", "History")
.add_arguments(AdagradParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_update_phase1)
.describe(R"code(Phase I of lamb update it performs the following operations and returns g:.

Link to paper: https://arxiv.org/pdf/1904.00962.pdf

.. math::
\begin{gather*}
grad = grad * rescale_grad
if (grad < -clip_gradient)
then
grad = -clip_gradient
if (grad > clip_gradient)
then
grad = clip_gradient

mean = beta1 * mean + (1 - beta1) * grad;
variance = beta2 * variance + (1. - beta2) * grad ^ 2;

if (bias_correction)
then
mean_hat = mean / (1. - beta1^t);
var_hat = var / (1 - beta2^t);
g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
else
g = mean / (var_data^(1/2) + epsilon) + wd * weight_data[i];
\end{gather*}

)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(LambUpdatePhaseOneParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_update_phase2)
.describe(R"code(Phase II of lamb update it performs the following operations and updates grad.

Link to paper: https://arxiv.org/pdf/1904.00962.pdf

.. math::
\begin{gather*}
if (lower_bound >= 0)
then
r1 = max(r1, lower_bound)
if (upper_bound >= 0)
then
r1 = max(r1, upper_bound)

if (r1 == 0 or r2 == 0)
then
lr = lr
else
lr = lr * (r1/r2)
weight = weight - lr * g
\end{gather*}

)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
.add_argument("r1", "NDArray-or-Symbol", "r1")
.add_argument("r2", "NDArray-or-Symbol", "r2")
.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());

} // namespace op
} // namespace mxnet
7 changes: 7 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,12 @@ NNVM_REGISTER_OP(ftrl_update)
NNVM_REGISTER_OP(_sparse_adagrad_update)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);

NNVM_REGISTER_OP(lamb_update_phase1)
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseOne<gpu>);

NNVM_REGISTER_OP(lamb_update_phase2)
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseTwo<gpu>);


} // namespace op
} // namespace mxnet
Loading