-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Lamb optimizer update #16715
Lamb optimizer update #16715
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1563,6 +1563,192 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs, | |
} | ||
} | ||
|
||
struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> { | ||
float beta1; | ||
float beta2; | ||
float epsilon; | ||
float t; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eric-haibin-lin @access2rohit I find this issue when reading the code. Here, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are using float here for integer data type. @sxjscience can you explain how we will loses precision for the operation |
||
bool bias_correction; | ||
float wd; | ||
float rescale_grad; | ||
float clip_gradient; | ||
DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) { | ||
DMLC_DECLARE_FIELD(beta1) | ||
.set_default(0.9f) | ||
.describe("The decay rate for the 1st moment estimates."); | ||
DMLC_DECLARE_FIELD(beta2) | ||
.set_default(0.999f) | ||
.describe("The decay rate for the 2nd moment estimates."); | ||
DMLC_DECLARE_FIELD(epsilon) | ||
.set_default(1e-6f) | ||
.describe("A small constant for numerical stability."); | ||
DMLC_DECLARE_FIELD(t) | ||
.describe("Index update count."); | ||
DMLC_DECLARE_FIELD(bias_correction) | ||
.set_default(true) | ||
.describe("Whether to use bias correction."); | ||
DMLC_DECLARE_FIELD(wd) | ||
.describe("Weight decay augments the objective function with a " | ||
"regularization term that penalizes large weights. " | ||
"The penalty scales with the square of the magnitude of each weight."); | ||
DMLC_DECLARE_FIELD(rescale_grad) | ||
.set_default(1.0f) | ||
.describe("Rescale gradient to grad = rescale_grad*grad."); | ||
DMLC_DECLARE_FIELD(clip_gradient) | ||
.set_default(-1.0f) | ||
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " | ||
"If clip_gradient <= 0, gradient clipping is turned off. " | ||
"grad = max(min(grad, clip_gradient), -clip_gradient)."); | ||
} | ||
}; | ||
|
||
struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> { | ||
float lr; | ||
float lower_bound; | ||
float upper_bound; | ||
DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) { | ||
DMLC_DECLARE_FIELD(lr) | ||
.describe("Learning rate"); | ||
DMLC_DECLARE_FIELD(lower_bound) | ||
access2rohit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.set_default(-1.0f) | ||
.describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set"); | ||
DMLC_DECLARE_FIELD(upper_bound) | ||
.set_default(-1.0f) | ||
.describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set"); | ||
} | ||
}; | ||
|
||
struct LambUpdatePhaseOneKernel { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out_data, | ||
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data, | ||
const DType clip_gradient, const DType rescale_grad, | ||
const DType beta1, const DType beta2, const DType wd, | ||
const DType epsilon, const DType t, | ||
bool bias_correction, const OpReqType req) { | ||
using namespace mshadow_op; | ||
|
||
DType grad_rescaled = grad_data[i] * rescale_grad; | ||
if (clip_gradient >= 0.f) { | ||
grad_rescaled = clip::Map(grad_rescaled, clip_gradient); | ||
} | ||
|
||
mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled; | ||
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled; | ||
|
||
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i]; | ||
|
||
if (bias_correction) { | ||
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t)); | ||
DType var_hat = var_data[i] / (1 - power::Map(beta2, t)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, in apex, it uses a float32 to calculate the power and then switch to float16: |
||
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i]; | ||
} | ||
KERNEL_ASSIGN(out_data[i], req, g); | ||
} | ||
}; | ||
|
||
template<typename xpu> | ||
inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
using namespace mxnet_op; | ||
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed); | ||
Stream<xpu>* s = ctx.get_stream<xpu>(); | ||
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); | ||
access2rohit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(), | ||
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, | ||
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad), | ||
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2), | ||
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon), | ||
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]); | ||
}); | ||
} | ||
|
||
inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs, | ||
mxnet::ShapeVector* in_attrs, | ||
mxnet::ShapeVector* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 4U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
|
||
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1); | ||
|
||
mxnet::TShape& weight_shape = in_attrs->at(0); | ||
mxnet::TShape& g_shape = in_attrs->at(1); | ||
CHECK_EQ(weight_shape.ndim(), g_shape.ndim()) | ||
<< "total no. of dimensions for weights and g must match"; | ||
for (int i=0; i < weight_shape.ndim(); ++i) { | ||
CHECK_EQ(weight_shape[i], g_shape[i]) | ||
<< "weight and g dimension size mismatch at " << i << "-th index"; | ||
} | ||
mxnet::TShape& r1_shape = in_attrs->at(2); | ||
mxnet::TShape& r2_shape = in_attrs->at(3); | ||
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect"; | ||
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect"; | ||
for (int i=0; i < expected_out.ndim(); ++i) { | ||
expected_out[i] = weight_shape[i]; | ||
} | ||
|
||
SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out); | ||
return shape_is_known(expected_out); | ||
} | ||
|
||
struct LambUpdatePhaseTwoKernel { | ||
template<typename DType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out_data, | ||
const DType* weight_data, const DType* g, | ||
const DType* r1, const DType* r2, | ||
DType lr, const DType lower_bound, | ||
const DType upper_bound, const OpReqType req) { | ||
using namespace mshadow_op; | ||
|
||
DType new_r1 = r1[0]; | ||
if (lower_bound >= 0) { | ||
new_r1 = maximum::Map(new_r1, lower_bound); | ||
} | ||
if (upper_bound >= 0) { | ||
new_r1 = minimum::Map(new_r1, upper_bound); | ||
} | ||
if (new_r1 == 0.0f || r2[0] == 0.0f) { | ||
lr = lr * 1.0f; | ||
} else { | ||
lr = lr * new_r1 / r2[0]; | ||
} | ||
|
||
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]); | ||
} | ||
}; | ||
|
||
template<typename xpu> | ||
inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
using namespace mxnet_op; | ||
const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed); | ||
Stream<xpu>* s = ctx.get_stream<xpu>(); | ||
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { | ||
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s); | ||
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s); | ||
access2rohit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(), | ||
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_, | ||
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound), | ||
static_cast<DType>(param.upper_bound), req[0]); | ||
}); | ||
} | ||
|
||
// This RMSProp code follows the version in | ||
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) | ||
// by Alex Graves, 2013. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls add doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
working on it now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is clashing with the GLuon one, can we give it a different name?